All checks were successful
continuous-integration/drone/push Build is passing
- 新增告警模块 (alerts): 告警规则配置与触发 - 新增成本管理模块 (cost): 成本统计与分析 - 新增配额模块 (quota): 配额管理与限制 - 新增微信模块 (wechat): 微信相关功能接口 - 新增缓存服务 (cache): Redis 缓存封装 - 新增请求日志中间件 (request_logger) - 新增异常处理和链路追踪中间件 - 更新 dashboard 前端展示 - 更新 SDK stats_client 功能
115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
"""
|
||
TraceID 追踪中间件
|
||
|
||
为每个请求生成唯一的 TraceID,用于日志追踪和问题排查。
|
||
|
||
功能:
|
||
- 自动生成 TraceID(或从请求头获取)
|
||
- 注入到响应头 X-Trace-ID
|
||
- 提供上下文变量供日志使用
|
||
- 支持请求耗时统计
|
||
"""
|
||
import time
|
||
import uuid
|
||
import logging
|
||
from contextvars import ContextVar
|
||
from typing import Optional
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from starlette.requests import Request
|
||
from starlette.responses import Response
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 上下文变量存储当前请求的 TraceID
|
||
_trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None)
|
||
|
||
# 请求头名称
|
||
TRACE_ID_HEADER = "X-Trace-ID"
|
||
REQUEST_ID_HEADER = "X-Request-ID"
|
||
|
||
|
||
def get_trace_id() -> str:
|
||
"""获取当前请求的 TraceID"""
|
||
trace_id = _trace_id_var.get()
|
||
return trace_id if trace_id else "no-trace"
|
||
|
||
|
||
def set_trace_id(trace_id: str) -> None:
|
||
"""设置当前请求的 TraceID"""
|
||
_trace_id_var.set(trace_id)
|
||
|
||
|
||
def generate_trace_id() -> str:
|
||
"""生成新的 TraceID,格式: 时间戳-随机8位"""
|
||
timestamp = int(time.time())
|
||
random_part = uuid.uuid4().hex[:8]
|
||
return f"{timestamp}-{random_part}"
|
||
|
||
|
||
class TraceMiddleware(BaseHTTPMiddleware):
|
||
"""TraceID 追踪中间件"""
|
||
|
||
def __init__(self, app, log_requests: bool = True):
|
||
super().__init__(app)
|
||
self.log_requests = log_requests
|
||
|
||
async def dispatch(self, request: Request, call_next) -> Response:
|
||
# 从请求头获取 TraceID,或生成新的
|
||
trace_id = (
|
||
request.headers.get(TRACE_ID_HEADER) or
|
||
request.headers.get(REQUEST_ID_HEADER) or
|
||
generate_trace_id()
|
||
)
|
||
|
||
set_trace_id(trace_id)
|
||
|
||
start_time = time.time()
|
||
method = request.method
|
||
path = request.url.path
|
||
|
||
if self.log_requests:
|
||
logger.info(f"[{trace_id}] --> {method} {path}")
|
||
|
||
try:
|
||
response = await call_next(request)
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
response.headers[TRACE_ID_HEADER] = trace_id
|
||
response.headers["X-Response-Time"] = f"{duration_ms}ms"
|
||
|
||
if self.log_requests:
|
||
logger.info(f"[{trace_id}] <-- {response.status_code} ({duration_ms}ms)")
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
logger.error(f"[{trace_id}] !!! 请求异常: {e} ({duration_ms}ms)")
|
||
raise
|
||
|
||
|
||
class TraceLogFilter(logging.Filter):
|
||
"""日志过滤器:自动添加 TraceID"""
|
||
|
||
def filter(self, record):
|
||
record.trace_id = get_trace_id()
|
||
return True
|
||
|
||
|
||
def setup_logging(level: int = logging.INFO, include_trace: bool = True):
|
||
"""配置日志格式"""
|
||
if include_trace:
|
||
format_str = "%(asctime)s [%(trace_id)s] %(levelname)s %(name)s: %(message)s"
|
||
else:
|
||
format_str = "%(asctime)s %(levelname)s %(name)s: %(message)s"
|
||
|
||
handler = logging.StreamHandler()
|
||
handler.setFormatter(logging.Formatter(format_str, datefmt="%Y-%m-%d %H:%M:%S"))
|
||
|
||
if include_trace:
|
||
handler.addFilter(TraceLogFilter())
|
||
|
||
root_logger = logging.getLogger()
|
||
root_logger.setLevel(level)
|
||
root_logger.handlers = [handler]
|