Files
012-kaopeilian/backend/app/api/v1/sql_executor.py
111 998211c483 feat: 初始化考培练系统项目
- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
2026-01-24 19:33:28 +08:00

364 lines
11 KiB
Python

"""
SQL 执行器 API - 用于内部服务调用
支持执行查询和写入操作的 SQL 语句
"""
import json
from typing import Any, Dict, List, Optional, Union
from datetime import datetime, date
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.engine.result import Result
import structlog
from app.core.deps import get_current_user, get_db
try:
from app.core.simple_auth import get_current_user_simple
except ImportError:
get_current_user_simple = None
from app.core.config import settings
from app.models.user import User
from app.schemas.base import ResponseModel
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["SQL Executor"])
class SQLExecutorRequest:
"""SQL执行请求模型"""
def __init__(self, sql: str, params: Optional[Dict[str, Any]] = None):
self.sql = sql
self.params = params or {}
class DateTimeEncoder(json.JSONEncoder):
"""处理日期时间对象的 JSON 编码器"""
def default(self, obj):
if isinstance(obj, (datetime, date)):
return obj.isoformat()
return super().default(obj)
def serialize_row(row: Any) -> Union[Dict[str, Any], Any]:
"""序列化数据库行结果"""
if hasattr(row, '_mapping'):
# 处理 SQLAlchemy Row 对象
return dict(row._mapping)
elif hasattr(row, '__dict__'):
# 处理 ORM 对象
return {k: v for k, v in row.__dict__.items() if not k.startswith('_')}
else:
# 处理单值结果
return row
@router.post("/execute", response_model=ResponseModel)
async def execute_sql(
request: Dict[str, Any],
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
执行 SQL 语句
Args:
request: 包含 sql 和可选的 params 字段
- sql: SQL 语句
- params: 参数字典(可选)
Returns:
执行结果,包括:
- 查询操作:返回数据行
- 写入操作:返回影响的行数
安全说明:
- 需要用户身份验证
- 所有操作都会记录日志
- 建议在生产环境中限制可执行的 SQL 类型
"""
try:
# 提取参数
sql = request.get('sql', '').strip()
params = request.get('params', {})
if not sql:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="SQL 语句不能为空"
)
# 记录 SQL 执行日志
logger.info(
"sql_execution_request",
user_id=current_user.id,
username=current_user.username,
sql_type=sql.split()[0].upper() if sql else "UNKNOWN",
sql_length=len(sql),
has_params=bool(params)
)
# 判断 SQL 类型
sql_upper = sql.upper().strip()
is_select = sql_upper.startswith('SELECT')
is_show = sql_upper.startswith('SHOW')
is_describe = sql_upper.startswith(('DESCRIBE', 'DESC'))
is_query = is_select or is_show or is_describe
# 执行 SQL
try:
result = await db.execute(text(sql), params)
if is_query:
# 查询操作
rows = result.fetchall()
columns = list(result.keys()) if result.keys() else []
# 序列化结果
data = []
for row in rows:
serialized_row = serialize_row(row)
if isinstance(serialized_row, dict):
data.append(serialized_row)
else:
# 单列结果
data.append({columns[0] if columns else 'value': serialized_row})
# 使用自定义编码器处理日期时间
response_data = {
"type": "query",
"columns": columns,
"rows": json.loads(json.dumps(data, cls=DateTimeEncoder)),
"row_count": len(data)
}
logger.info(
"sql_query_success",
user_id=current_user.id,
row_count=len(data),
column_count=len(columns)
)
else:
# 写入操作
await db.commit()
affected_rows = result.rowcount
response_data = {
"type": "execute",
"affected_rows": affected_rows,
"success": True
}
logger.info(
"sql_execute_success",
user_id=current_user.id,
affected_rows=affected_rows
)
return ResponseModel(
code=200,
message="SQL 执行成功",
data=response_data
)
except Exception as e:
# 回滚事务
await db.rollback()
logger.error(
"sql_execution_error",
user_id=current_user.id,
sql_type=sql.split()[0].upper() if sql else "UNKNOWN",
error=str(e),
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"SQL 执行失败: {str(e)}"
)
except HTTPException:
raise
except Exception as e:
logger.error(
"sql_executor_error",
user_id=current_user.id,
error=str(e),
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"处理请求时发生错误: {str(e)}"
)
@router.post("/validate", response_model=ResponseModel)
async def validate_sql(
request: Dict[str, Any],
current_user: User = Depends(get_current_user)
) -> ResponseModel:
"""
验证 SQL 语句的语法(不执行)
Args:
request: 包含 sql 字段的请求
Returns:
验证结果
"""
try:
sql = request.get('sql', '').strip()
if not sql:
return ResponseModel(
code=400,
message="SQL 语句不能为空",
data={"valid": False, "error": "SQL 语句不能为空"}
)
# 基本的 SQL 验证
sql_upper = sql.upper().strip()
# 检查危险操作(可根据需要调整)
dangerous_keywords = ['DROP', 'TRUNCATE', 'DELETE FROM', 'UPDATE']
warnings = []
for keyword in dangerous_keywords:
if keyword in sql_upper:
warnings.append(f"包含危险操作: {keyword}")
return ResponseModel(
code=200,
message="SQL 验证完成",
data={
"valid": True,
"warnings": warnings,
"sql_type": sql_upper.split()[0] if sql_upper else "UNKNOWN"
}
)
except Exception as e:
logger.error(
"sql_validation_error",
user_id=current_user.id,
error=str(e)
)
return ResponseModel(
code=500,
message="SQL 验证失败",
data={"valid": False, "error": str(e)}
)
@router.get("/tables", response_model=ResponseModel)
async def get_tables(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
获取数据库中的所有表
Returns:
数据库表列表
"""
try:
result = await db.execute(text("SHOW TABLES"))
tables = [row[0] for row in result.fetchall()]
return ResponseModel(
code=200,
message="获取表列表成功",
data={
"tables": tables,
"count": len(tables)
}
)
except Exception as e:
logger.error(
"get_tables_error",
user_id=current_user.id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取表列表失败: {str(e)}"
)
@router.get("/table/{table_name}/schema", response_model=ResponseModel)
async def get_table_schema(
table_name: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
获取指定表的结构信息
Args:
table_name: 表名
Returns:
表结构信息
"""
try:
# MySQL 的 DESCRIBE 不支持参数化,需要直接拼接
# 但为了安全,先验证表名
if not table_name.replace('_', '').isalnum():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的表名"
)
result = await db.execute(text(f"DESCRIBE {table_name}"))
columns = []
for row in result.fetchall():
columns.append({
"field": row[0],
"type": row[1],
"null": row[2],
"key": row[3],
"default": row[4],
"extra": row[5]
})
return ResponseModel(
code=200,
message="获取表结构成功",
data={
"table_name": table_name,
"columns": columns,
"column_count": len(columns)
}
)
except Exception as e:
logger.error(
"get_table_schema_error",
user_id=current_user.id,
table_name=table_name,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取表结构失败: {str(e)}"
)
# 简化认证版本的端点(如果启用)
if get_current_user_simple:
@router.post("/execute-simple", response_model=ResponseModel)
async def execute_sql_simple(
request: Dict[str, Any],
current_user: User = Depends(get_current_user_simple),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
执行 SQL 语句(简化认证版本)
支持 API Key 和 Token 两种认证方式,专为内部服务设计。
"""
return await execute_sql(request, current_user, db)