All checks were successful
continuous-integration/drone/push Build is passing
- 新增告警模块 (alerts): 告警规则配置与触发 - 新增成本管理模块 (cost): 成本统计与分析 - 新增配额模块 (quota): 配额管理与限制 - 新增微信模块 (wechat): 微信相关功能接口 - 新增缓存服务 (cache): Redis 缓存封装 - 新增请求日志中间件 (request_logger) - 新增异常处理和链路追踪中间件 - 更新 dashboard 前端展示 - 更新 SDK stats_client 功能
191 lines
6.1 KiB
Python
191 lines
6.1 KiB
Python
"""
|
||
请求日志中间件
|
||
|
||
自动将所有请求记录到数据库 platform_logs 表
|
||
"""
|
||
import time
|
||
import logging
|
||
from typing import Optional, Set
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from starlette.requests import Request
|
||
from starlette.responses import Response
|
||
|
||
from .trace import get_trace_id
|
||
from ..database import SessionLocal
|
||
from ..models.logs import PlatformLog
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RequestLoggerMiddleware(BaseHTTPMiddleware):
|
||
"""请求日志中间件
|
||
|
||
自动记录所有请求到数据库,便于后续查询和分析
|
||
|
||
使用示例:
|
||
app.add_middleware(RequestLoggerMiddleware, app_code="000-platform")
|
||
"""
|
||
|
||
# 默认排除的路径(不记录这些请求)
|
||
DEFAULT_EXCLUDE_PATHS: Set[str] = {
|
||
"/",
|
||
"/docs",
|
||
"/redoc",
|
||
"/openapi.json",
|
||
"/api/health",
|
||
"/api/health/",
|
||
"/favicon.ico",
|
||
}
|
||
|
||
def __init__(
|
||
self,
|
||
app,
|
||
app_code: str = "platform",
|
||
exclude_paths: Optional[Set[str]] = None,
|
||
log_request_body: bool = False,
|
||
log_response_body: bool = False,
|
||
max_body_length: int = 1000
|
||
):
|
||
"""初始化中间件
|
||
|
||
Args:
|
||
app: FastAPI应用
|
||
app_code: 应用代码,记录到日志中
|
||
exclude_paths: 排除的路径集合,这些路径不记录日志
|
||
log_request_body: 是否记录请求体
|
||
log_response_body: 是否记录响应体
|
||
max_body_length: 记录体的最大长度
|
||
"""
|
||
super().__init__(app)
|
||
self.app_code = app_code
|
||
self.exclude_paths = exclude_paths or self.DEFAULT_EXCLUDE_PATHS
|
||
self.log_request_body = log_request_body
|
||
self.log_response_body = log_response_body
|
||
self.max_body_length = max_body_length
|
||
|
||
async def dispatch(self, request: Request, call_next) -> Response:
|
||
path = request.url.path
|
||
|
||
# 检查是否排除
|
||
if self._should_exclude(path):
|
||
return await call_next(request)
|
||
|
||
trace_id = get_trace_id()
|
||
method = request.method
|
||
start_time = time.time()
|
||
|
||
# 获取客户端IP
|
||
client_ip = self._get_client_ip(request)
|
||
|
||
# 获取租户ID(从查询参数)
|
||
tenant_id = request.query_params.get("tid") or request.query_params.get("tenant_id")
|
||
|
||
# 执行请求
|
||
response = None
|
||
error_message = None
|
||
status_code = 500
|
||
|
||
try:
|
||
response = await call_next(request)
|
||
status_code = response.status_code
|
||
except Exception as e:
|
||
error_message = str(e)
|
||
raise
|
||
finally:
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
# 异步写入数据库(不阻塞响应)
|
||
try:
|
||
self._save_log(
|
||
trace_id=trace_id,
|
||
method=method,
|
||
path=path,
|
||
status_code=status_code,
|
||
duration_ms=duration_ms,
|
||
ip_address=client_ip,
|
||
tenant_id=tenant_id,
|
||
error_message=error_message
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Failed to save request log: {e}")
|
||
|
||
return response
|
||
|
||
def _should_exclude(self, path: str) -> bool:
|
||
"""检查路径是否应排除"""
|
||
# 精确匹配
|
||
if path in self.exclude_paths:
|
||
return True
|
||
|
||
# 前缀匹配(静态文件等)
|
||
exclude_prefixes = ["/static/", "/assets/", "/_next/"]
|
||
for prefix in exclude_prefixes:
|
||
if path.startswith(prefix):
|
||
return True
|
||
|
||
return False
|
||
|
||
def _get_client_ip(self, request: Request) -> str:
|
||
"""获取客户端真实IP"""
|
||
# 优先从代理头获取
|
||
forwarded = request.headers.get("X-Forwarded-For")
|
||
if forwarded:
|
||
return forwarded.split(",")[0].strip()
|
||
|
||
real_ip = request.headers.get("X-Real-IP")
|
||
if real_ip:
|
||
return real_ip
|
||
|
||
# 直连IP
|
||
if request.client:
|
||
return request.client.host
|
||
|
||
return "unknown"
|
||
|
||
def _save_log(
|
||
self,
|
||
trace_id: str,
|
||
method: str,
|
||
path: str,
|
||
status_code: int,
|
||
duration_ms: int,
|
||
ip_address: str,
|
||
tenant_id: Optional[str] = None,
|
||
error_message: Optional[str] = None
|
||
):
|
||
"""保存日志到数据库"""
|
||
from datetime import datetime
|
||
|
||
# 使用独立的数据库会话
|
||
db = SessionLocal()
|
||
try:
|
||
# 转换 tenant_id 为整数(如果是数字字符串)
|
||
tenant_id_int = None
|
||
if tenant_id:
|
||
try:
|
||
tenant_id_int = int(tenant_id)
|
||
except (ValueError, TypeError):
|
||
tenant_id_int = None
|
||
|
||
log_entry = PlatformLog(
|
||
log_type="request",
|
||
level="error" if status_code >= 500 else ("warn" if status_code >= 400 else "info"),
|
||
app_code=self.app_code,
|
||
tenant_id=tenant_id_int,
|
||
trace_id=trace_id,
|
||
message=f"{method} {path}" + (f" - {error_message}" if error_message else ""),
|
||
path=path,
|
||
method=method,
|
||
status_code=status_code,
|
||
duration_ms=duration_ms,
|
||
log_time=datetime.now(), # 必须设置 log_time
|
||
context={"ip": ip_address} # ip_address 放到 context 中
|
||
)
|
||
db.add(log_entry)
|
||
db.commit()
|
||
except Exception as e:
|
||
logger.error(f"Database error saving log: {e}")
|
||
db.rollback()
|
||
finally:
|
||
db.close()
|