feat: 新增告警、成本、配额、微信模块及缓存服务
All checks were successful
continuous-integration/drone/push Build is passing

- 新增告警模块 (alerts): 告警规则配置与触发
- 新增成本管理模块 (cost): 成本统计与分析
- 新增配额模块 (quota): 配额管理与限制
- 新增微信模块 (wechat): 微信相关功能接口
- 新增缓存服务 (cache): Redis 缓存封装
- 新增请求日志中间件 (request_logger)
- 新增异常处理和链路追踪中间件
- 更新 dashboard 前端展示
- 更新 SDK stats_client 功能
This commit is contained in:
111
2026-01-24 16:53:47 +08:00
parent eab2533c36
commit 6c6c48cf71
29 changed files with 4607 additions and 41 deletions

View File

@@ -0,0 +1,19 @@
"""
中间件模块
提供:
- TraceID 追踪
- 统一异常处理
- 请求日志记录
"""
from .trace import TraceMiddleware, get_trace_id, set_trace_id
from .exception_handler import setup_exception_handlers
from .request_logger import RequestLoggerMiddleware
__all__ = [
"TraceMiddleware",
"get_trace_id",
"set_trace_id",
"setup_exception_handlers",
"RequestLoggerMiddleware"
]

View File

@@ -0,0 +1,128 @@
"""
统一异常处理
捕获所有异常,返回统一格式的错误响应,包含 TraceID。
"""
import logging
import traceback
from typing import Union
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from .trace import get_trace_id
logger = logging.getLogger(__name__)
class ErrorCode:
"""错误码常量"""
BAD_REQUEST = "BAD_REQUEST"
UNAUTHORIZED = "UNAUTHORIZED"
FORBIDDEN = "FORBIDDEN"
NOT_FOUND = "NOT_FOUND"
VALIDATION_ERROR = "VALIDATION_ERROR"
RATE_LIMITED = "RATE_LIMITED"
INTERNAL_ERROR = "INTERNAL_ERROR"
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
GATEWAY_ERROR = "GATEWAY_ERROR"
STATUS_TO_ERROR_CODE = {
400: ErrorCode.BAD_REQUEST,
401: ErrorCode.UNAUTHORIZED,
403: ErrorCode.FORBIDDEN,
404: ErrorCode.NOT_FOUND,
422: ErrorCode.VALIDATION_ERROR,
429: ErrorCode.RATE_LIMITED,
500: ErrorCode.INTERNAL_ERROR,
502: ErrorCode.GATEWAY_ERROR,
503: ErrorCode.SERVICE_UNAVAILABLE,
}
def create_error_response(
status_code: int,
code: str,
message: str,
trace_id: str = None,
details: dict = None
) -> JSONResponse:
"""创建统一格式的错误响应"""
if trace_id is None:
trace_id = get_trace_id()
error_body = {
"code": code,
"message": message,
"trace_id": trace_id
}
if details:
error_body["details"] = details
return JSONResponse(
status_code=status_code,
content={"success": False, "error": error_body},
headers={"X-Trace-ID": trace_id}
)
async def http_exception_handler(request: Request, exc: Union[HTTPException, StarletteHTTPException]):
"""处理 HTTP 异常"""
trace_id = get_trace_id()
status_code = exc.status_code
error_code = STATUS_TO_ERROR_CODE.get(status_code, ErrorCode.INTERNAL_ERROR)
message = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
logger.warning(f"[{trace_id}] HTTP {status_code}: {message}")
return create_error_response(
status_code=status_code,
code=error_code,
message=message,
trace_id=trace_id
)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""处理请求验证错误"""
trace_id = get_trace_id()
errors = exc.errors()
error_messages = [f"{'.'.join(str(l) for l in e['loc'])}: {e['msg']}" for e in errors]
logger.warning(f"[{trace_id}] 验证错误: {error_messages}")
return create_error_response(
status_code=422,
code=ErrorCode.VALIDATION_ERROR,
message="请求参数验证失败",
trace_id=trace_id,
details={"validation_errors": error_messages}
)
async def generic_exception_handler(request: Request, exc: Exception):
"""处理所有未捕获的异常"""
trace_id = get_trace_id()
logger.error(f"[{trace_id}] 未捕获异常: {type(exc).__name__}: {exc}")
logger.error(f"[{trace_id}] 堆栈:\n{traceback.format_exc()}")
return create_error_response(
status_code=500,
code=ErrorCode.INTERNAL_ERROR,
message="服务器内部错误,请稍后重试",
trace_id=trace_id
)
def setup_exception_handlers(app: FastAPI):
"""配置 FastAPI 应用的异常处理器"""
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(Exception, generic_exception_handler)
logger.info("异常处理器已配置")

View File

@@ -0,0 +1,190 @@
"""
请求日志中间件
自动将所有请求记录到数据库 platform_logs 表
"""
import time
import logging
from typing import Optional, Set
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from .trace import get_trace_id
from ..database import SessionLocal
from ..models.logs import PlatformLog
logger = logging.getLogger(__name__)
class RequestLoggerMiddleware(BaseHTTPMiddleware):
"""请求日志中间件
自动记录所有请求到数据库,便于后续查询和分析
使用示例:
app.add_middleware(RequestLoggerMiddleware, app_code="000-platform")
"""
# 默认排除的路径(不记录这些请求)
DEFAULT_EXCLUDE_PATHS: Set[str] = {
"/",
"/docs",
"/redoc",
"/openapi.json",
"/api/health",
"/api/health/",
"/favicon.ico",
}
def __init__(
self,
app,
app_code: str = "platform",
exclude_paths: Optional[Set[str]] = None,
log_request_body: bool = False,
log_response_body: bool = False,
max_body_length: int = 1000
):
"""初始化中间件
Args:
app: FastAPI应用
app_code: 应用代码,记录到日志中
exclude_paths: 排除的路径集合,这些路径不记录日志
log_request_body: 是否记录请求体
log_response_body: 是否记录响应体
max_body_length: 记录体的最大长度
"""
super().__init__(app)
self.app_code = app_code
self.exclude_paths = exclude_paths or self.DEFAULT_EXCLUDE_PATHS
self.log_request_body = log_request_body
self.log_response_body = log_response_body
self.max_body_length = max_body_length
async def dispatch(self, request: Request, call_next) -> Response:
path = request.url.path
# 检查是否排除
if self._should_exclude(path):
return await call_next(request)
trace_id = get_trace_id()
method = request.method
start_time = time.time()
# 获取客户端IP
client_ip = self._get_client_ip(request)
# 获取租户ID从查询参数
tenant_id = request.query_params.get("tid") or request.query_params.get("tenant_id")
# 执行请求
response = None
error_message = None
status_code = 500
try:
response = await call_next(request)
status_code = response.status_code
except Exception as e:
error_message = str(e)
raise
finally:
duration_ms = int((time.time() - start_time) * 1000)
# 异步写入数据库(不阻塞响应)
try:
self._save_log(
trace_id=trace_id,
method=method,
path=path,
status_code=status_code,
duration_ms=duration_ms,
ip_address=client_ip,
tenant_id=tenant_id,
error_message=error_message
)
except Exception as e:
logger.error(f"Failed to save request log: {e}")
return response
def _should_exclude(self, path: str) -> bool:
"""检查路径是否应排除"""
# 精确匹配
if path in self.exclude_paths:
return True
# 前缀匹配(静态文件等)
exclude_prefixes = ["/static/", "/assets/", "/_next/"]
for prefix in exclude_prefixes:
if path.startswith(prefix):
return True
return False
def _get_client_ip(self, request: Request) -> str:
"""获取客户端真实IP"""
# 优先从代理头获取
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直连IP
if request.client:
return request.client.host
return "unknown"
def _save_log(
self,
trace_id: str,
method: str,
path: str,
status_code: int,
duration_ms: int,
ip_address: str,
tenant_id: Optional[str] = None,
error_message: Optional[str] = None
):
"""保存日志到数据库"""
from datetime import datetime
# 使用独立的数据库会话
db = SessionLocal()
try:
# 转换 tenant_id 为整数(如果是数字字符串)
tenant_id_int = None
if tenant_id:
try:
tenant_id_int = int(tenant_id)
except (ValueError, TypeError):
tenant_id_int = None
log_entry = PlatformLog(
log_type="request",
level="error" if status_code >= 500 else ("warn" if status_code >= 400 else "info"),
app_code=self.app_code,
tenant_id=tenant_id_int,
trace_id=trace_id,
message=f"{method} {path}" + (f" - {error_message}" if error_message else ""),
path=path,
method=method,
status_code=status_code,
duration_ms=duration_ms,
log_time=datetime.now(), # 必须设置 log_time
context={"ip": ip_address} # ip_address 放到 context 中
)
db.add(log_entry)
db.commit()
except Exception as e:
logger.error(f"Database error saving log: {e}")
db.rollback()
finally:
db.close()

View File

@@ -0,0 +1,114 @@
"""
TraceID 追踪中间件
为每个请求生成唯一的 TraceID用于日志追踪和问题排查。
功能:
- 自动生成 TraceID或从请求头获取
- 注入到响应头 X-Trace-ID
- 提供上下文变量供日志使用
- 支持请求耗时统计
"""
import time
import uuid
import logging
from contextvars import ContextVar
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
logger = logging.getLogger(__name__)
# 上下文变量存储当前请求的 TraceID
_trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None)
# 请求头名称
TRACE_ID_HEADER = "X-Trace-ID"
REQUEST_ID_HEADER = "X-Request-ID"
def get_trace_id() -> str:
"""获取当前请求的 TraceID"""
trace_id = _trace_id_var.get()
return trace_id if trace_id else "no-trace"
def set_trace_id(trace_id: str) -> None:
"""设置当前请求的 TraceID"""
_trace_id_var.set(trace_id)
def generate_trace_id() -> str:
"""生成新的 TraceID格式: 时间戳-随机8位"""
timestamp = int(time.time())
random_part = uuid.uuid4().hex[:8]
return f"{timestamp}-{random_part}"
class TraceMiddleware(BaseHTTPMiddleware):
"""TraceID 追踪中间件"""
def __init__(self, app, log_requests: bool = True):
super().__init__(app)
self.log_requests = log_requests
async def dispatch(self, request: Request, call_next) -> Response:
# 从请求头获取 TraceID或生成新的
trace_id = (
request.headers.get(TRACE_ID_HEADER) or
request.headers.get(REQUEST_ID_HEADER) or
generate_trace_id()
)
set_trace_id(trace_id)
start_time = time.time()
method = request.method
path = request.url.path
if self.log_requests:
logger.info(f"[{trace_id}] --> {method} {path}")
try:
response = await call_next(request)
duration_ms = int((time.time() - start_time) * 1000)
response.headers[TRACE_ID_HEADER] = trace_id
response.headers["X-Response-Time"] = f"{duration_ms}ms"
if self.log_requests:
logger.info(f"[{trace_id}] <-- {response.status_code} ({duration_ms}ms)")
return response
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(f"[{trace_id}] !!! 请求异常: {e} ({duration_ms}ms)")
raise
class TraceLogFilter(logging.Filter):
"""日志过滤器:自动添加 TraceID"""
def filter(self, record):
record.trace_id = get_trace_id()
return True
def setup_logging(level: int = logging.INFO, include_trace: bool = True):
"""配置日志格式"""
if include_trace:
format_str = "%(asctime)s [%(trace_id)s] %(levelname)s %(name)s: %(message)s"
else:
format_str = "%(asctime)s %(levelname)s %(name)s: %(message)s"
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(format_str, datefmt="%Y-%m-%d %H:%M:%S"))
if include_trace:
handler.addFilter(TraceLogFilter())
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.handlers = [handler]