""" 请求日志中间件 自动将所有请求记录到数据库 platform_logs 表 """ import time import logging from typing import Optional, Set from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response from .trace import get_trace_id from ..database import SessionLocal from ..models.logs import PlatformLog logger = logging.getLogger(__name__) class RequestLoggerMiddleware(BaseHTTPMiddleware): """请求日志中间件 自动记录所有请求到数据库,便于后续查询和分析 使用示例: app.add_middleware(RequestLoggerMiddleware, app_code="000-platform") """ # 默认排除的路径(不记录这些请求) DEFAULT_EXCLUDE_PATHS: Set[str] = { "/", "/docs", "/redoc", "/openapi.json", "/api/health", "/api/health/", "/favicon.ico", } def __init__( self, app, app_code: str = "platform", exclude_paths: Optional[Set[str]] = None, log_request_body: bool = False, log_response_body: bool = False, max_body_length: int = 1000 ): """初始化中间件 Args: app: FastAPI应用 app_code: 应用代码,记录到日志中 exclude_paths: 排除的路径集合,这些路径不记录日志 log_request_body: 是否记录请求体 log_response_body: 是否记录响应体 max_body_length: 记录体的最大长度 """ super().__init__(app) self.app_code = app_code self.exclude_paths = exclude_paths or self.DEFAULT_EXCLUDE_PATHS self.log_request_body = log_request_body self.log_response_body = log_response_body self.max_body_length = max_body_length async def dispatch(self, request: Request, call_next) -> Response: path = request.url.path # 检查是否排除 if self._should_exclude(path): return await call_next(request) trace_id = get_trace_id() method = request.method start_time = time.time() # 获取客户端IP client_ip = self._get_client_ip(request) # 获取租户ID(从查询参数) tenant_id = request.query_params.get("tid") or request.query_params.get("tenant_id") # 执行请求 response = None error_message = None status_code = 500 try: response = await call_next(request) status_code = response.status_code except Exception as e: error_message = str(e) raise finally: duration_ms = int((time.time() - start_time) * 1000) # 异步写入数据库(不阻塞响应) try: self._save_log( trace_id=trace_id, method=method, path=path, status_code=status_code, duration_ms=duration_ms, ip_address=client_ip, tenant_id=tenant_id, error_message=error_message ) except Exception as e: logger.error(f"Failed to save request log: {e}") return response def _should_exclude(self, path: str) -> bool: """检查路径是否应排除""" # 精确匹配 if path in self.exclude_paths: return True # 前缀匹配(静态文件等) exclude_prefixes = ["/static/", "/assets/", "/_next/"] for prefix in exclude_prefixes: if path.startswith(prefix): return True return False def _get_client_ip(self, request: Request) -> str: """获取客户端真实IP""" # 优先从代理头获取 forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # 直连IP if request.client: return request.client.host return "unknown" def _save_log( self, trace_id: str, method: str, path: str, status_code: int, duration_ms: int, ip_address: str, tenant_id: Optional[str] = None, error_message: Optional[str] = None ): """保存日志到数据库""" from datetime import datetime # 使用独立的数据库会话 db = SessionLocal() try: # 转换 tenant_id 为整数(如果是数字字符串) tenant_id_int = None if tenant_id: try: tenant_id_int = int(tenant_id) except (ValueError, TypeError): tenant_id_int = None log_entry = PlatformLog( log_type="request", level="error" if status_code >= 500 else ("warn" if status_code >= 400 else "info"), app_code=self.app_code, tenant_id=tenant_id_int, trace_id=trace_id, message=f"{method} {path}" + (f" - {error_message}" if error_message else ""), path=path, method=method, status_code=status_code, duration_ms=duration_ms, log_time=datetime.now(), # 必须设置 log_time context={"ip": ip_address} # ip_address 放到 context 中 ) db.add(log_entry) db.commit() except Exception as e: logger.error(f"Database error saving log: {e}") db.rollback() finally: db.close()