""" TraceID 追踪中间件 为每个请求生成唯一的 TraceID,用于日志追踪和问题排查。 功能: - 自动生成 TraceID(或从请求头获取) - 注入到响应头 X-Trace-ID - 提供上下文变量供日志使用 - 支持请求耗时统计 """ import time import uuid import logging from contextvars import ContextVar from typing import Optional from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response logger = logging.getLogger(__name__) # 上下文变量存储当前请求的 TraceID _trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None) # 请求头名称 TRACE_ID_HEADER = "X-Trace-ID" REQUEST_ID_HEADER = "X-Request-ID" def get_trace_id() -> str: """获取当前请求的 TraceID""" trace_id = _trace_id_var.get() return trace_id if trace_id else "no-trace" def set_trace_id(trace_id: str) -> None: """设置当前请求的 TraceID""" _trace_id_var.set(trace_id) def generate_trace_id() -> str: """生成新的 TraceID,格式: 时间戳-随机8位""" timestamp = int(time.time()) random_part = uuid.uuid4().hex[:8] return f"{timestamp}-{random_part}" class TraceMiddleware(BaseHTTPMiddleware): """TraceID 追踪中间件""" def __init__(self, app, log_requests: bool = True): super().__init__(app) self.log_requests = log_requests async def dispatch(self, request: Request, call_next) -> Response: # 从请求头获取 TraceID,或生成新的 trace_id = ( request.headers.get(TRACE_ID_HEADER) or request.headers.get(REQUEST_ID_HEADER) or generate_trace_id() ) set_trace_id(trace_id) start_time = time.time() method = request.method path = request.url.path if self.log_requests: logger.info(f"[{trace_id}] --> {method} {path}") try: response = await call_next(request) duration_ms = int((time.time() - start_time) * 1000) response.headers[TRACE_ID_HEADER] = trace_id response.headers["X-Response-Time"] = f"{duration_ms}ms" if self.log_requests: logger.info(f"[{trace_id}] <-- {response.status_code} ({duration_ms}ms)") return response except Exception as e: duration_ms = int((time.time() - start_time) * 1000) logger.error(f"[{trace_id}] !!! 请求异常: {e} ({duration_ms}ms)") raise class TraceLogFilter(logging.Filter): """日志过滤器:自动添加 TraceID""" def filter(self, record): record.trace_id = get_trace_id() return True def setup_logging(level: int = logging.INFO, include_trace: bool = True): """配置日志格式""" if include_trace: format_str = "%(asctime)s [%(trace_id)s] %(levelname)s %(name)s: %(message)s" else: format_str = "%(asctime)s %(levelname)s %(name)s: %(message)s" handler = logging.StreamHandler() handler.setFormatter(logging.Formatter(format_str, datefmt="%Y-%m-%d %H:%M:%S")) if include_trace: handler.addFilter(TraceLogFilter()) root_logger = logging.getLogger() root_logger.setLevel(level) root_logger.handlers = [handler]