Files
000-platform/backend/app/middleware/request_logger.py
111 6c6c48cf71
All checks were successful
continuous-integration/drone/push Build is passing
feat: 新增告警、成本、配额、微信模块及缓存服务
- 新增告警模块 (alerts): 告警规则配置与触发
- 新增成本管理模块 (cost): 成本统计与分析
- 新增配额模块 (quota): 配额管理与限制
- 新增微信模块 (wechat): 微信相关功能接口
- 新增缓存服务 (cache): Redis 缓存封装
- 新增请求日志中间件 (request_logger)
- 新增异常处理和链路追踪中间件
- 更新 dashboard 前端展示
- 更新 SDK stats_client 功能
2026-01-24 16:53:47 +08:00

191 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
请求日志中间件
自动将所有请求记录到数据库 platform_logs 表
"""
import time
import logging
from typing import Optional, Set
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from .trace import get_trace_id
from ..database import SessionLocal
from ..models.logs import PlatformLog
logger = logging.getLogger(__name__)
class RequestLoggerMiddleware(BaseHTTPMiddleware):
"""请求日志中间件
自动记录所有请求到数据库,便于后续查询和分析
使用示例:
app.add_middleware(RequestLoggerMiddleware, app_code="000-platform")
"""
# 默认排除的路径(不记录这些请求)
DEFAULT_EXCLUDE_PATHS: Set[str] = {
"/",
"/docs",
"/redoc",
"/openapi.json",
"/api/health",
"/api/health/",
"/favicon.ico",
}
def __init__(
self,
app,
app_code: str = "platform",
exclude_paths: Optional[Set[str]] = None,
log_request_body: bool = False,
log_response_body: bool = False,
max_body_length: int = 1000
):
"""初始化中间件
Args:
app: FastAPI应用
app_code: 应用代码,记录到日志中
exclude_paths: 排除的路径集合,这些路径不记录日志
log_request_body: 是否记录请求体
log_response_body: 是否记录响应体
max_body_length: 记录体的最大长度
"""
super().__init__(app)
self.app_code = app_code
self.exclude_paths = exclude_paths or self.DEFAULT_EXCLUDE_PATHS
self.log_request_body = log_request_body
self.log_response_body = log_response_body
self.max_body_length = max_body_length
async def dispatch(self, request: Request, call_next) -> Response:
path = request.url.path
# 检查是否排除
if self._should_exclude(path):
return await call_next(request)
trace_id = get_trace_id()
method = request.method
start_time = time.time()
# 获取客户端IP
client_ip = self._get_client_ip(request)
# 获取租户ID从查询参数
tenant_id = request.query_params.get("tid") or request.query_params.get("tenant_id")
# 执行请求
response = None
error_message = None
status_code = 500
try:
response = await call_next(request)
status_code = response.status_code
except Exception as e:
error_message = str(e)
raise
finally:
duration_ms = int((time.time() - start_time) * 1000)
# 异步写入数据库(不阻塞响应)
try:
self._save_log(
trace_id=trace_id,
method=method,
path=path,
status_code=status_code,
duration_ms=duration_ms,
ip_address=client_ip,
tenant_id=tenant_id,
error_message=error_message
)
except Exception as e:
logger.error(f"Failed to save request log: {e}")
return response
def _should_exclude(self, path: str) -> bool:
"""检查路径是否应排除"""
# 精确匹配
if path in self.exclude_paths:
return True
# 前缀匹配(静态文件等)
exclude_prefixes = ["/static/", "/assets/", "/_next/"]
for prefix in exclude_prefixes:
if path.startswith(prefix):
return True
return False
def _get_client_ip(self, request: Request) -> str:
"""获取客户端真实IP"""
# 优先从代理头获取
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直连IP
if request.client:
return request.client.host
return "unknown"
def _save_log(
self,
trace_id: str,
method: str,
path: str,
status_code: int,
duration_ms: int,
ip_address: str,
tenant_id: Optional[str] = None,
error_message: Optional[str] = None
):
"""保存日志到数据库"""
from datetime import datetime
# 使用独立的数据库会话
db = SessionLocal()
try:
# 转换 tenant_id 为整数(如果是数字字符串)
tenant_id_int = None
if tenant_id:
try:
tenant_id_int = int(tenant_id)
except (ValueError, TypeError):
tenant_id_int = None
log_entry = PlatformLog(
log_type="request",
level="error" if status_code >= 500 else ("warn" if status_code >= 400 else "info"),
app_code=self.app_code,
tenant_id=tenant_id_int,
trace_id=trace_id,
message=f"{method} {path}" + (f" - {error_message}" if error_message else ""),
path=path,
method=method,
status_code=status_code,
duration_ms=duration_ms,
log_time=datetime.now(), # 必须设置 log_time
context={"ip": ip_address} # ip_address 放到 context 中
)
db.add(log_entry)
db.commit()
except Exception as e:
logger.error(f"Database error saving log: {e}")
db.rollback()
finally:
db.close()