""" SQL 执行器 API - 用于内部服务调用 支持执行查询和写入操作的 SQL 语句 """ import json from typing import Any, Dict, List, Optional, Union from datetime import datetime, date from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.engine.result import Result import structlog from app.core.deps import get_current_user, get_db try: from app.core.simple_auth import get_current_user_simple except ImportError: get_current_user_simple = None from app.core.config import settings from app.models.user import User from app.schemas.base import ResponseModel logger = structlog.get_logger(__name__) router = APIRouter(tags=["SQL Executor"]) class SQLExecutorRequest: """SQL执行请求模型""" def __init__(self, sql: str, params: Optional[Dict[str, Any]] = None): self.sql = sql self.params = params or {} class DateTimeEncoder(json.JSONEncoder): """处理日期时间对象的 JSON 编码器""" def default(self, obj): if isinstance(obj, (datetime, date)): return obj.isoformat() return super().default(obj) def serialize_row(row: Any) -> Union[Dict[str, Any], Any]: """序列化数据库行结果""" if hasattr(row, '_mapping'): # 处理 SQLAlchemy Row 对象 return dict(row._mapping) elif hasattr(row, '__dict__'): # 处理 ORM 对象 return {k: v for k, v in row.__dict__.items() if not k.startswith('_')} else: # 处理单值结果 return row @router.post("/execute", response_model=ResponseModel) async def execute_sql( request: Dict[str, Any], current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db) ) -> ResponseModel: """ 执行 SQL 语句 Args: request: 包含 sql 和可选的 params 字段 - sql: SQL 语句 - params: 参数字典(可选) Returns: 执行结果,包括: - 查询操作:返回数据行 - 写入操作:返回影响的行数 安全说明: - 需要用户身份验证 - 所有操作都会记录日志 - 建议在生产环境中限制可执行的 SQL 类型 """ try: # 提取参数 sql = request.get('sql', '').strip() params = request.get('params', {}) if not sql: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="SQL 语句不能为空" ) # 记录 SQL 执行日志 logger.info( "sql_execution_request", user_id=current_user.id, username=current_user.username, sql_type=sql.split()[0].upper() if sql else "UNKNOWN", sql_length=len(sql), has_params=bool(params) ) # 判断 SQL 类型 sql_upper = sql.upper().strip() is_select = sql_upper.startswith('SELECT') is_show = sql_upper.startswith('SHOW') is_describe = sql_upper.startswith(('DESCRIBE', 'DESC')) is_query = is_select or is_show or is_describe # 执行 SQL try: result = await db.execute(text(sql), params) if is_query: # 查询操作 rows = result.fetchall() columns = list(result.keys()) if result.keys() else [] # 序列化结果 data = [] for row in rows: serialized_row = serialize_row(row) if isinstance(serialized_row, dict): data.append(serialized_row) else: # 单列结果 data.append({columns[0] if columns else 'value': serialized_row}) # 使用自定义编码器处理日期时间 response_data = { "type": "query", "columns": columns, "rows": json.loads(json.dumps(data, cls=DateTimeEncoder)), "row_count": len(data) } logger.info( "sql_query_success", user_id=current_user.id, row_count=len(data), column_count=len(columns) ) else: # 写入操作 await db.commit() affected_rows = result.rowcount response_data = { "type": "execute", "affected_rows": affected_rows, "success": True } logger.info( "sql_execute_success", user_id=current_user.id, affected_rows=affected_rows ) return ResponseModel( code=200, message="SQL 执行成功", data=response_data ) except Exception as e: # 回滚事务 await db.rollback() logger.error( "sql_execution_error", user_id=current_user.id, sql_type=sql.split()[0].upper() if sql else "UNKNOWN", error=str(e), exc_info=True ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"SQL 执行失败: {str(e)}" ) except HTTPException: raise except Exception as e: logger.error( "sql_executor_error", user_id=current_user.id, error=str(e), exc_info=True ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"处理请求时发生错误: {str(e)}" ) @router.post("/validate", response_model=ResponseModel) async def validate_sql( request: Dict[str, Any], current_user: User = Depends(get_current_user) ) -> ResponseModel: """ 验证 SQL 语句的语法(不执行) Args: request: 包含 sql 字段的请求 Returns: 验证结果 """ try: sql = request.get('sql', '').strip() if not sql: return ResponseModel( code=400, message="SQL 语句不能为空", data={"valid": False, "error": "SQL 语句不能为空"} ) # 基本的 SQL 验证 sql_upper = sql.upper().strip() # 检查危险操作(可根据需要调整) dangerous_keywords = ['DROP', 'TRUNCATE', 'DELETE FROM', 'UPDATE'] warnings = [] for keyword in dangerous_keywords: if keyword in sql_upper: warnings.append(f"包含危险操作: {keyword}") return ResponseModel( code=200, message="SQL 验证完成", data={ "valid": True, "warnings": warnings, "sql_type": sql_upper.split()[0] if sql_upper else "UNKNOWN" } ) except Exception as e: logger.error( "sql_validation_error", user_id=current_user.id, error=str(e) ) return ResponseModel( code=500, message="SQL 验证失败", data={"valid": False, "error": str(e)} ) @router.get("/tables", response_model=ResponseModel) async def get_tables( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db) ) -> ResponseModel: """ 获取数据库中的所有表 Returns: 数据库表列表 """ try: result = await db.execute(text("SHOW TABLES")) tables = [row[0] for row in result.fetchall()] return ResponseModel( code=200, message="获取表列表成功", data={ "tables": tables, "count": len(tables) } ) except Exception as e: logger.error( "get_tables_error", user_id=current_user.id, error=str(e) ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取表列表失败: {str(e)}" ) @router.get("/table/{table_name}/schema", response_model=ResponseModel) async def get_table_schema( table_name: str, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db) ) -> ResponseModel: """ 获取指定表的结构信息 Args: table_name: 表名 Returns: 表结构信息 """ try: # MySQL 的 DESCRIBE 不支持参数化,需要直接拼接 # 但为了安全,先验证表名 if not table_name.replace('_', '').isalnum(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="无效的表名" ) result = await db.execute(text(f"DESCRIBE {table_name}")) columns = [] for row in result.fetchall(): columns.append({ "field": row[0], "type": row[1], "null": row[2], "key": row[3], "default": row[4], "extra": row[5] }) return ResponseModel( code=200, message="获取表结构成功", data={ "table_name": table_name, "columns": columns, "column_count": len(columns) } ) except Exception as e: logger.error( "get_table_schema_error", user_id=current_user.id, table_name=table_name, error=str(e) ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取表结构失败: {str(e)}" ) # 简化认证版本的端点(如果启用) if get_current_user_simple: @router.post("/execute-simple", response_model=ResponseModel) async def execute_sql_simple( request: Dict[str, Any], current_user: User = Depends(get_current_user_simple), db: AsyncSession = Depends(get_db) ) -> ResponseModel: """ 执行 SQL 语句(简化认证版本) 支持 API Key 和 Token 两种认证方式,专为内部服务设计。 """ return await execute_sql(request, current_user, db)