All checks were successful
continuous-integration/drone/push Build is passing
- 所有SQL执行器端点改用 require_admin 权限校验
- /sql/execute - 执行SQL
- /sql/validate - 验证SQL
- /sql/tables - 获取表列表
- /sql/table/{name}/schema - 获取表结构
370 lines
11 KiB
Python
370 lines
11 KiB
Python
"""
|
|
SQL 执行器 API - 用于内部服务调用
|
|
支持执行查询和写入操作的 SQL 语句
|
|
"""
|
|
import json
|
|
from typing import Any, Dict, List, Optional, Union
|
|
from datetime import datetime, date
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.engine.result import Result
|
|
import structlog
|
|
|
|
from app.core.deps import get_current_user, get_db, require_admin
|
|
try:
|
|
from app.core.simple_auth import get_current_user_simple
|
|
except ImportError:
|
|
get_current_user_simple = None
|
|
from app.core.config import settings
|
|
from app.models.user import User
|
|
from app.schemas.base import ResponseModel
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
router = APIRouter(tags=["SQL Executor"])
|
|
|
|
|
|
class SQLExecutorRequest:
|
|
"""SQL执行请求模型"""
|
|
def __init__(self, sql: str, params: Optional[Dict[str, Any]] = None):
|
|
self.sql = sql
|
|
self.params = params or {}
|
|
|
|
|
|
class DateTimeEncoder(json.JSONEncoder):
|
|
"""处理日期时间对象的 JSON 编码器"""
|
|
def default(self, obj):
|
|
if isinstance(obj, (datetime, date)):
|
|
return obj.isoformat()
|
|
return super().default(obj)
|
|
|
|
|
|
def serialize_row(row: Any) -> Union[Dict[str, Any], Any]:
|
|
"""序列化数据库行结果"""
|
|
if hasattr(row, '_mapping'):
|
|
# 处理 SQLAlchemy Row 对象
|
|
return dict(row._mapping)
|
|
elif hasattr(row, '__dict__'):
|
|
# 处理 ORM 对象
|
|
return {k: v for k, v in row.__dict__.items() if not k.startswith('_')}
|
|
else:
|
|
# 处理单值结果
|
|
return row
|
|
|
|
|
|
@router.post("/execute", response_model=ResponseModel)
|
|
async def execute_sql(
|
|
request: Dict[str, Any],
|
|
current_user: User = Depends(require_admin),
|
|
db: AsyncSession = Depends(get_db)
|
|
) -> ResponseModel:
|
|
"""
|
|
执行 SQL 语句
|
|
|
|
Args:
|
|
request: 包含 sql 和可选的 params 字段
|
|
- sql: SQL 语句
|
|
- params: 参数字典(可选)
|
|
|
|
Returns:
|
|
执行结果,包括:
|
|
- 查询操作:返回数据行
|
|
- 写入操作:返回影响的行数
|
|
|
|
安全说明:
|
|
- 需要管理员权限
|
|
- 所有操作都会记录日志
|
|
- 建议在生产环境中限制可执行的 SQL 类型
|
|
"""
|
|
try:
|
|
# 提取参数
|
|
sql = request.get('sql', '').strip()
|
|
params = request.get('params', {})
|
|
|
|
if not sql:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="SQL 语句不能为空"
|
|
)
|
|
|
|
# 记录 SQL 执行日志
|
|
logger.info(
|
|
"sql_execution_request",
|
|
user_id=current_user.id,
|
|
username=current_user.username,
|
|
sql_type=sql.split()[0].upper() if sql else "UNKNOWN",
|
|
sql_length=len(sql),
|
|
has_params=bool(params)
|
|
)
|
|
|
|
# 判断 SQL 类型
|
|
sql_upper = sql.upper().strip()
|
|
is_select = sql_upper.startswith('SELECT')
|
|
is_show = sql_upper.startswith('SHOW')
|
|
is_describe = sql_upper.startswith(('DESCRIBE', 'DESC'))
|
|
is_query = is_select or is_show or is_describe
|
|
|
|
# 执行 SQL
|
|
try:
|
|
result = await db.execute(text(sql), params)
|
|
|
|
if is_query:
|
|
# 查询操作
|
|
rows = result.fetchall()
|
|
columns = list(result.keys()) if result.keys() else []
|
|
|
|
# 序列化结果
|
|
data = []
|
|
for row in rows:
|
|
serialized_row = serialize_row(row)
|
|
if isinstance(serialized_row, dict):
|
|
data.append(serialized_row)
|
|
else:
|
|
# 单列结果
|
|
data.append({columns[0] if columns else 'value': serialized_row})
|
|
|
|
# 使用自定义编码器处理日期时间
|
|
response_data = {
|
|
"type": "query",
|
|
"columns": columns,
|
|
"rows": json.loads(json.dumps(data, cls=DateTimeEncoder)),
|
|
"row_count": len(data)
|
|
}
|
|
|
|
logger.info(
|
|
"sql_query_success",
|
|
user_id=current_user.id,
|
|
row_count=len(data),
|
|
column_count=len(columns)
|
|
)
|
|
|
|
else:
|
|
# 写入操作
|
|
await db.commit()
|
|
affected_rows = result.rowcount
|
|
|
|
response_data = {
|
|
"type": "execute",
|
|
"affected_rows": affected_rows,
|
|
"success": True
|
|
}
|
|
|
|
logger.info(
|
|
"sql_execute_success",
|
|
user_id=current_user.id,
|
|
affected_rows=affected_rows
|
|
)
|
|
|
|
return ResponseModel(
|
|
code=200,
|
|
message="SQL 执行成功",
|
|
data=response_data
|
|
)
|
|
|
|
except Exception as e:
|
|
# 回滚事务
|
|
await db.rollback()
|
|
logger.error(
|
|
"sql_execution_error",
|
|
user_id=current_user.id,
|
|
sql_type=sql.split()[0].upper() if sql else "UNKNOWN",
|
|
error=str(e),
|
|
exc_info=True
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"SQL 执行失败: {str(e)}"
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
"sql_executor_error",
|
|
user_id=current_user.id,
|
|
error=str(e),
|
|
exc_info=True
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"处理请求时发生错误: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.post("/validate", response_model=ResponseModel)
|
|
async def validate_sql(
|
|
request: Dict[str, Any],
|
|
current_user: User = Depends(require_admin)
|
|
) -> ResponseModel:
|
|
"""
|
|
验证 SQL 语句的语法(不执行)
|
|
|
|
权限:需要管理员权限
|
|
|
|
Args:
|
|
request: 包含 sql 字段的请求
|
|
|
|
Returns:
|
|
验证结果
|
|
"""
|
|
try:
|
|
sql = request.get('sql', '').strip()
|
|
|
|
if not sql:
|
|
return ResponseModel(
|
|
code=400,
|
|
message="SQL 语句不能为空",
|
|
data={"valid": False, "error": "SQL 语句不能为空"}
|
|
)
|
|
|
|
# 基本的 SQL 验证
|
|
sql_upper = sql.upper().strip()
|
|
|
|
# 检查危险操作(可根据需要调整)
|
|
dangerous_keywords = ['DROP', 'TRUNCATE', 'DELETE FROM', 'UPDATE']
|
|
warnings = []
|
|
|
|
for keyword in dangerous_keywords:
|
|
if keyword in sql_upper:
|
|
warnings.append(f"包含危险操作: {keyword}")
|
|
|
|
return ResponseModel(
|
|
code=200,
|
|
message="SQL 验证完成",
|
|
data={
|
|
"valid": True,
|
|
"warnings": warnings,
|
|
"sql_type": sql_upper.split()[0] if sql_upper else "UNKNOWN"
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"sql_validation_error",
|
|
user_id=current_user.id,
|
|
error=str(e)
|
|
)
|
|
return ResponseModel(
|
|
code=500,
|
|
message="SQL 验证失败",
|
|
data={"valid": False, "error": str(e)}
|
|
)
|
|
|
|
|
|
@router.get("/tables", response_model=ResponseModel)
|
|
async def get_tables(
|
|
current_user: User = Depends(require_admin),
|
|
db: AsyncSession = Depends(get_db)
|
|
) -> ResponseModel:
|
|
"""
|
|
获取数据库中的所有表
|
|
|
|
权限:需要管理员权限
|
|
|
|
Returns:
|
|
数据库表列表
|
|
"""
|
|
try:
|
|
result = await db.execute(text("SHOW TABLES"))
|
|
tables = [row[0] for row in result.fetchall()]
|
|
|
|
return ResponseModel(
|
|
code=200,
|
|
message="获取表列表成功",
|
|
data={
|
|
"tables": tables,
|
|
"count": len(tables)
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"get_tables_error",
|
|
user_id=current_user.id,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取表列表失败: {str(e)}"
|
|
)
|
|
|
|
|
|
@router.get("/table/{table_name}/schema", response_model=ResponseModel)
|
|
async def get_table_schema(
|
|
table_name: str,
|
|
current_user: User = Depends(require_admin),
|
|
db: AsyncSession = Depends(get_db)
|
|
) -> ResponseModel:
|
|
"""
|
|
获取指定表的结构信息
|
|
|
|
权限:需要管理员权限
|
|
|
|
Args:
|
|
table_name: 表名
|
|
|
|
Returns:
|
|
表结构信息
|
|
"""
|
|
try:
|
|
# MySQL 的 DESCRIBE 不支持参数化,需要直接拼接
|
|
# 但为了安全,先验证表名
|
|
if not table_name.replace('_', '').isalnum():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="无效的表名"
|
|
)
|
|
|
|
result = await db.execute(text(f"DESCRIBE {table_name}"))
|
|
|
|
columns = []
|
|
for row in result.fetchall():
|
|
columns.append({
|
|
"field": row[0],
|
|
"type": row[1],
|
|
"null": row[2],
|
|
"key": row[3],
|
|
"default": row[4],
|
|
"extra": row[5]
|
|
})
|
|
|
|
return ResponseModel(
|
|
code=200,
|
|
message="获取表结构成功",
|
|
data={
|
|
"table_name": table_name,
|
|
"columns": columns,
|
|
"column_count": len(columns)
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"get_table_schema_error",
|
|
user_id=current_user.id,
|
|
table_name=table_name,
|
|
error=str(e)
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"获取表结构失败: {str(e)}"
|
|
)
|
|
|
|
|
|
# 简化认证版本的端点(如果启用)
|
|
if get_current_user_simple:
|
|
@router.post("/execute-simple", response_model=ResponseModel)
|
|
async def execute_sql_simple(
|
|
request: Dict[str, Any],
|
|
current_user: User = Depends(get_current_user_simple),
|
|
db: AsyncSession = Depends(get_db)
|
|
) -> ResponseModel:
|
|
"""
|
|
执行 SQL 语句(简化认证版本)
|
|
|
|
支持 API Key 和 Token 两种认证方式,专为内部服务设计。
|
|
"""
|
|
return await execute_sql(request, current_user, db)
|