- 添加 RateLimitMiddleware 限流中间件 (200请求/分钟) - 优化 Content-Type 错误返回 400 而非 500 - 添加 JSON 解析错误处理 - 统一 HTTP 异常处理格式
This commit is contained in:
@@ -3,14 +3,96 @@
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
from typing import Callable
|
||||
from typing import Callable, Dict
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.core.logger import logger
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
API 限流中间件
|
||||
|
||||
基于IP地址进行限流,防止恶意请求攻击
|
||||
"""
|
||||
|
||||
def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100):
|
||||
super().__init__(app)
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.burst_limit = burst_limit # 突发请求限制
|
||||
self.request_counts: Dict[str, list] = defaultdict(list)
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""获取客户端真实IP"""
|
||||
# 优先从代理头获取
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _clean_old_requests(self, ip: str, window_start: datetime):
|
||||
"""清理窗口外的请求记录"""
|
||||
self.request_counts[ip] = [
|
||||
t for t in self.request_counts[ip]
|
||||
if t > window_start
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 跳过健康检查和静态文件
|
||||
if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]:
|
||||
return await call_next(request)
|
||||
if request.url.path.startswith("/static/"):
|
||||
return await call_next(request)
|
||||
|
||||
client_ip = self._get_client_ip(request)
|
||||
now = datetime.now()
|
||||
window_start = now - timedelta(minutes=1)
|
||||
|
||||
# 清理过期记录
|
||||
self._clean_old_requests(client_ip, window_start)
|
||||
|
||||
# 检查请求数
|
||||
request_count = len(self.request_counts[client_ip])
|
||||
|
||||
if request_count >= self.burst_limit:
|
||||
logger.warning(
|
||||
"请求被限流",
|
||||
client_ip=client_ip,
|
||||
request_count=request_count,
|
||||
path=request.url.path,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"code": 429,
|
||||
"message": "请求过于频繁,请稍后再试",
|
||||
"retry_after": 60,
|
||||
},
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
|
||||
# 记录本次请求
|
||||
self.request_counts[client_ip].append(now)
|
||||
|
||||
# 如果接近限制,添加警告头
|
||||
response = await call_next(request)
|
||||
|
||||
remaining = self.burst_limit - len(self.request_counts[client_ip])
|
||||
response.headers["X-RateLimit-Limit"] = str(self.burst_limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
|
||||
response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp()))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""请求ID中间件"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user