feat: 添加API限流和优化错误处理
Some checks failed
continuous-integration/drone/push Build is failing

- 添加 RateLimitMiddleware 限流中间件 (200请求/分钟)
- 优化 Content-Type 错误返回 400 而非 500
- 添加 JSON 解析错误处理
- 统一 HTTP 异常处理格式
This commit is contained in:
yuliang_guo
2026-01-31 10:50:27 +08:00
parent d59a4355a5
commit 52dccaab79
2 changed files with 137 additions and 3 deletions

View File

@@ -3,14 +3,96 @@
"""
import time
import uuid
from typing import Callable
from typing import Callable, Dict
from collections import defaultdict
from datetime import datetime, timedelta
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.logger import logger
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
API 限流中间件
基于IP地址进行限流防止恶意请求攻击
"""
def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.burst_limit = burst_limit # 突发请求限制
self.request_counts: Dict[str, list] = defaultdict(list)
def _get_client_ip(self, request: Request) -> str:
"""获取客户端真实IP"""
# 优先从代理头获取
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
def _clean_old_requests(self, ip: str, window_start: datetime):
"""清理窗口外的请求记录"""
self.request_counts[ip] = [
t for t in self.request_counts[ip]
if t > window_start
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 跳过健康检查和静态文件
if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]:
return await call_next(request)
if request.url.path.startswith("/static/"):
return await call_next(request)
client_ip = self._get_client_ip(request)
now = datetime.now()
window_start = now - timedelta(minutes=1)
# 清理过期记录
self._clean_old_requests(client_ip, window_start)
# 检查请求数
request_count = len(self.request_counts[client_ip])
if request_count >= self.burst_limit:
logger.warning(
"请求被限流",
client_ip=client_ip,
request_count=request_count,
path=request.url.path,
)
return JSONResponse(
status_code=429,
content={
"code": 429,
"message": "请求过于频繁,请稍后再试",
"retry_after": 60,
},
headers={"Retry-After": "60"}
)
# 记录本次请求
self.request_counts[client_ip].append(now)
# 如果接近限制,添加警告头
response = await call_next(request)
remaining = self.burst_limit - len(self.request_counts[client_ip])
response.headers["X-RateLimit-Limit"] = str(self.burst_limit)
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp()))
return response
class RequestIDMiddleware(BaseHTTPMiddleware):
"""请求ID中间件"""