Some checks failed
continuous-integration/drone/push Build is failing
- 添加 RateLimitMiddleware 限流中间件 (200请求/分钟) - 优化 Content-Type 错误返回 400 而非 500 - 添加 JSON 解析错误处理 - 统一 HTTP 异常处理格式
147 lines
4.6 KiB
Python
147 lines
4.6 KiB
Python
"""
|
||
中间件定义
|
||
"""
|
||
import time
|
||
import uuid
|
||
from typing import Callable, Dict
|
||
from collections import defaultdict
|
||
from datetime import datetime, timedelta
|
||
|
||
from fastapi import Request, Response
|
||
from fastapi.responses import JSONResponse
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
|
||
from app.core.logger import logger
|
||
|
||
|
||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||
"""
|
||
API 限流中间件
|
||
|
||
基于IP地址进行限流,防止恶意请求攻击
|
||
"""
|
||
|
||
def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100):
|
||
super().__init__(app)
|
||
self.requests_per_minute = requests_per_minute
|
||
self.burst_limit = burst_limit # 突发请求限制
|
||
self.request_counts: Dict[str, list] = defaultdict(list)
|
||
|
||
def _get_client_ip(self, request: Request) -> str:
|
||
"""获取客户端真实IP"""
|
||
# 优先从代理头获取
|
||
forwarded = request.headers.get("X-Forwarded-For")
|
||
if forwarded:
|
||
return forwarded.split(",")[0].strip()
|
||
real_ip = request.headers.get("X-Real-IP")
|
||
if real_ip:
|
||
return real_ip
|
||
return request.client.host if request.client else "unknown"
|
||
|
||
def _clean_old_requests(self, ip: str, window_start: datetime):
|
||
"""清理窗口外的请求记录"""
|
||
self.request_counts[ip] = [
|
||
t for t in self.request_counts[ip]
|
||
if t > window_start
|
||
]
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
# 跳过健康检查和静态文件
|
||
if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]:
|
||
return await call_next(request)
|
||
if request.url.path.startswith("/static/"):
|
||
return await call_next(request)
|
||
|
||
client_ip = self._get_client_ip(request)
|
||
now = datetime.now()
|
||
window_start = now - timedelta(minutes=1)
|
||
|
||
# 清理过期记录
|
||
self._clean_old_requests(client_ip, window_start)
|
||
|
||
# 检查请求数
|
||
request_count = len(self.request_counts[client_ip])
|
||
|
||
if request_count >= self.burst_limit:
|
||
logger.warning(
|
||
"请求被限流",
|
||
client_ip=client_ip,
|
||
request_count=request_count,
|
||
path=request.url.path,
|
||
)
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={
|
||
"code": 429,
|
||
"message": "请求过于频繁,请稍后再试",
|
||
"retry_after": 60,
|
||
},
|
||
headers={"Retry-After": "60"}
|
||
)
|
||
|
||
# 记录本次请求
|
||
self.request_counts[client_ip].append(now)
|
||
|
||
# 如果接近限制,添加警告头
|
||
response = await call_next(request)
|
||
|
||
remaining = self.burst_limit - len(self.request_counts[client_ip])
|
||
response.headers["X-RateLimit-Limit"] = str(self.burst_limit)
|
||
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
|
||
response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp()))
|
||
|
||
return response
|
||
|
||
|
||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||
"""请求ID中间件"""
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
# 生成请求ID
|
||
request_id = str(uuid.uuid4())
|
||
|
||
# 将请求ID添加到request状态
|
||
request.state.request_id = request_id
|
||
|
||
# 记录请求开始
|
||
start_time = time.time()
|
||
|
||
# 处理请求
|
||
response = await call_next(request)
|
||
|
||
# 计算处理时间
|
||
process_time = time.time() - start_time
|
||
|
||
# 添加响应头
|
||
response.headers["X-Request-ID"] = request_id
|
||
response.headers["X-Process-Time"] = str(process_time)
|
||
|
||
# 记录请求日志
|
||
logger.info(
|
||
"HTTP请求",
|
||
method=request.method,
|
||
url=str(request.url),
|
||
status_code=response.status_code,
|
||
process_time=process_time,
|
||
request_id=request_id,
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
class GlobalContextMiddleware(BaseHTTPMiddleware):
|
||
"""全局上下文中间件"""
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
# 设置追踪ID(用于分布式追踪)
|
||
trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4()))
|
||
request.state.trace_id = trace_id
|
||
|
||
# 处理请求
|
||
response = await call_next(request)
|
||
|
||
# 添加追踪ID到响应头
|
||
response.headers["X-Trace-ID"] = trace_id
|
||
|
||
return response
|