""" 中间件定义 """ import time import uuid from typing import Callable, Dict from collections import defaultdict from datetime import datetime, timedelta from fastapi import Request, Response from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from app.core.logger import logger class RateLimitMiddleware(BaseHTTPMiddleware): """ API 限流中间件 基于IP地址进行限流,防止恶意请求攻击 """ def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100): super().__init__(app) self.requests_per_minute = requests_per_minute self.burst_limit = burst_limit # 突发请求限制 self.request_counts: Dict[str, list] = defaultdict(list) def _get_client_ip(self, request: Request) -> str: """获取客户端真实IP""" # 优先从代理头获取 forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip return request.client.host if request.client else "unknown" def _clean_old_requests(self, ip: str, window_start: datetime): """清理窗口外的请求记录""" self.request_counts[ip] = [ t for t in self.request_counts[ip] if t > window_start ] async def dispatch(self, request: Request, call_next: Callable) -> Response: # 跳过健康检查和静态文件 if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]: return await call_next(request) if request.url.path.startswith("/static/"): return await call_next(request) client_ip = self._get_client_ip(request) now = datetime.now() window_start = now - timedelta(minutes=1) # 清理过期记录 self._clean_old_requests(client_ip, window_start) # 检查请求数 request_count = len(self.request_counts[client_ip]) if request_count >= self.burst_limit: logger.warning( "请求被限流", client_ip=client_ip, request_count=request_count, path=request.url.path, ) return JSONResponse( status_code=429, content={ "code": 429, "message": "请求过于频繁,请稍后再试", "retry_after": 60, }, headers={"Retry-After": "60"} ) # 记录本次请求 self.request_counts[client_ip].append(now) # 如果接近限制,添加警告头 response = await call_next(request) remaining = self.burst_limit - len(self.request_counts[client_ip]) response.headers["X-RateLimit-Limit"] = str(self.burst_limit) response.headers["X-RateLimit-Remaining"] = str(max(0, remaining)) response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp())) return response class SecurityHeadersMiddleware(BaseHTTPMiddleware): """ 安全响应头中间件 添加各种安全相关的 HTTP 响应头 """ async def dispatch(self, request: Request, call_next: Callable) -> Response: response = await call_next(request) # 防止 MIME 类型嗅探 response.headers["X-Content-Type-Options"] = "nosniff" # 防止点击劫持 response.headers["X-Frame-Options"] = "DENY" # XSS 过滤器(现代浏览器已弃用,但仍有一些旧浏览器支持) response.headers["X-XSS-Protection"] = "1; mode=block" # 引用策略 response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" # 权限策略(禁用一些敏感功能) response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" # 缓存控制(API 响应不应被缓存) if request.url.path.startswith("/api/"): response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate" response.headers["Pragma"] = "no-cache" return response class RequestIDMiddleware(BaseHTTPMiddleware): """请求ID中间件""" async def dispatch(self, request: Request, call_next: Callable) -> Response: # 生成请求ID request_id = str(uuid.uuid4()) # 将请求ID添加到request状态 request.state.request_id = request_id # 记录请求开始 start_time = time.time() # 处理请求 response = await call_next(request) # 计算处理时间 process_time = time.time() - start_time # 添加响应头 response.headers["X-Request-ID"] = request_id response.headers["X-Process-Time"] = str(process_time) # 记录请求日志 logger.info( "HTTP请求", method=request.method, url=str(request.url), status_code=response.status_code, process_time=process_time, request_id=request_id, ) return response class GlobalContextMiddleware(BaseHTTPMiddleware): """全局上下文中间件""" async def dispatch(self, request: Request, call_next: Callable) -> Response: # 设置追踪ID(用于分布式追踪) trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4())) request.state.trace_id = trace_id # 处理请求 response = await call_next(request) # 添加追踪ID到响应头 response.headers["X-Trace-ID"] = trace_id return response