Files
000-platform/backend/app/middleware/trace.py
111 6c6c48cf71
All checks were successful
continuous-integration/drone/push Build is passing
feat: 新增告警、成本、配额、微信模块及缓存服务
- 新增告警模块 (alerts): 告警规则配置与触发
- 新增成本管理模块 (cost): 成本统计与分析
- 新增配额模块 (quota): 配额管理与限制
- 新增微信模块 (wechat): 微信相关功能接口
- 新增缓存服务 (cache): Redis 缓存封装
- 新增请求日志中间件 (request_logger)
- 新增异常处理和链路追踪中间件
- 更新 dashboard 前端展示
- 更新 SDK stats_client 功能
2026-01-24 16:53:47 +08:00

115 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
TraceID 追踪中间件
为每个请求生成唯一的 TraceID用于日志追踪和问题排查。
功能:
- 自动生成 TraceID或从请求头获取
- 注入到响应头 X-Trace-ID
- 提供上下文变量供日志使用
- 支持请求耗时统计
"""
import time
import uuid
import logging
from contextvars import ContextVar
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
logger = logging.getLogger(__name__)
# 上下文变量存储当前请求的 TraceID
_trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None)
# 请求头名称
TRACE_ID_HEADER = "X-Trace-ID"
REQUEST_ID_HEADER = "X-Request-ID"
def get_trace_id() -> str:
"""获取当前请求的 TraceID"""
trace_id = _trace_id_var.get()
return trace_id if trace_id else "no-trace"
def set_trace_id(trace_id: str) -> None:
"""设置当前请求的 TraceID"""
_trace_id_var.set(trace_id)
def generate_trace_id() -> str:
"""生成新的 TraceID格式: 时间戳-随机8位"""
timestamp = int(time.time())
random_part = uuid.uuid4().hex[:8]
return f"{timestamp}-{random_part}"
class TraceMiddleware(BaseHTTPMiddleware):
"""TraceID 追踪中间件"""
def __init__(self, app, log_requests: bool = True):
super().__init__(app)
self.log_requests = log_requests
async def dispatch(self, request: Request, call_next) -> Response:
# 从请求头获取 TraceID或生成新的
trace_id = (
request.headers.get(TRACE_ID_HEADER) or
request.headers.get(REQUEST_ID_HEADER) or
generate_trace_id()
)
set_trace_id(trace_id)
start_time = time.time()
method = request.method
path = request.url.path
if self.log_requests:
logger.info(f"[{trace_id}] --> {method} {path}")
try:
response = await call_next(request)
duration_ms = int((time.time() - start_time) * 1000)
response.headers[TRACE_ID_HEADER] = trace_id
response.headers["X-Response-Time"] = f"{duration_ms}ms"
if self.log_requests:
logger.info(f"[{trace_id}] <-- {response.status_code} ({duration_ms}ms)")
return response
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(f"[{trace_id}] !!! 请求异常: {e} ({duration_ms}ms)")
raise
class TraceLogFilter(logging.Filter):
"""日志过滤器:自动添加 TraceID"""
def filter(self, record):
record.trace_id = get_trace_id()
return True
def setup_logging(level: int = logging.INFO, include_trace: bool = True):
"""配置日志格式"""
if include_trace:
format_str = "%(asctime)s [%(trace_id)s] %(levelname)s %(name)s: %(message)s"
else:
format_str = "%(asctime)s %(levelname)s %(name)s: %(message)s"
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(format_str, datefmt="%Y-%m-%d %H:%M:%S"))
if include_trace:
handler.addFilter(TraceLogFilter())
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.handlers = [handler]