Files
012-kaopeilian/backend/app/core/middleware.py
yuliang_guo 52dccaab79
Some checks failed
continuous-integration/drone/push Build is failing
feat: 添加API限流和优化错误处理
- 添加 RateLimitMiddleware 限流中间件 (200请求/分钟)
- 优化 Content-Type 错误返回 400 而非 500
- 添加 JSON 解析错误处理
- 统一 HTTP 异常处理格式
2026-01-31 10:50:27 +08:00

147 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
中间件定义
"""
import time
import uuid
from typing import Callable, Dict
from collections import defaultdict
from datetime import datetime, timedelta
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.logger import logger
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
API 限流中间件
基于IP地址进行限流防止恶意请求攻击
"""
def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.burst_limit = burst_limit # 突发请求限制
self.request_counts: Dict[str, list] = defaultdict(list)
def _get_client_ip(self, request: Request) -> str:
"""获取客户端真实IP"""
# 优先从代理头获取
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
def _clean_old_requests(self, ip: str, window_start: datetime):
"""清理窗口外的请求记录"""
self.request_counts[ip] = [
t for t in self.request_counts[ip]
if t > window_start
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 跳过健康检查和静态文件
if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]:
return await call_next(request)
if request.url.path.startswith("/static/"):
return await call_next(request)
client_ip = self._get_client_ip(request)
now = datetime.now()
window_start = now - timedelta(minutes=1)
# 清理过期记录
self._clean_old_requests(client_ip, window_start)
# 检查请求数
request_count = len(self.request_counts[client_ip])
if request_count >= self.burst_limit:
logger.warning(
"请求被限流",
client_ip=client_ip,
request_count=request_count,
path=request.url.path,
)
return JSONResponse(
status_code=429,
content={
"code": 429,
"message": "请求过于频繁,请稍后再试",
"retry_after": 60,
},
headers={"Retry-After": "60"}
)
# 记录本次请求
self.request_counts[client_ip].append(now)
# 如果接近限制,添加警告头
response = await call_next(request)
remaining = self.burst_limit - len(self.request_counts[client_ip])
response.headers["X-RateLimit-Limit"] = str(self.burst_limit)
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp()))
return response
class RequestIDMiddleware(BaseHTTPMiddleware):
"""请求ID中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 生成请求ID
request_id = str(uuid.uuid4())
# 将请求ID添加到request状态
request.state.request_id = request_id
# 记录请求开始
start_time = time.time()
# 处理请求
response = await call_next(request)
# 计算处理时间
process_time = time.time() - start_time
# 添加响应头
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = str(process_time)
# 记录请求日志
logger.info(
"HTTP请求",
method=request.method,
url=str(request.url),
status_code=response.status_code,
process_time=process_time,
request_id=request_id,
)
return response
class GlobalContextMiddleware(BaseHTTPMiddleware):
"""全局上下文中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 设置追踪ID用于分布式追踪
trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4()))
request.state.trace_id = trace_id
# 处理请求
response = await call_next(request)
# 添加追踪ID到响应头
response.headers["X-Trace-ID"] = trace_id
return response