- 添加 RateLimitMiddleware 限流中间件 (200请求/分钟) - 优化 Content-Type 错误返回 400 而非 500 - 添加 JSON 解析错误处理 - 统一 HTTP 异常处理格式
This commit is contained in:
@@ -3,14 +3,96 @@
|
|||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Callable
|
from typing import Callable, Dict
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
from app.core.logger import logger
|
from app.core.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
API 限流中间件
|
||||||
|
|
||||||
|
基于IP地址进行限流,防止恶意请求攻击
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100):
|
||||||
|
super().__init__(app)
|
||||||
|
self.requests_per_minute = requests_per_minute
|
||||||
|
self.burst_limit = burst_limit # 突发请求限制
|
||||||
|
self.request_counts: Dict[str, list] = defaultdict(list)
|
||||||
|
|
||||||
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
|
"""获取客户端真实IP"""
|
||||||
|
# 优先从代理头获取
|
||||||
|
forwarded = request.headers.get("X-Forwarded-For")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
|
if real_ip:
|
||||||
|
return real_ip
|
||||||
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
def _clean_old_requests(self, ip: str, window_start: datetime):
|
||||||
|
"""清理窗口外的请求记录"""
|
||||||
|
self.request_counts[ip] = [
|
||||||
|
t for t in self.request_counts[ip]
|
||||||
|
if t > window_start
|
||||||
|
]
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
# 跳过健康检查和静态文件
|
||||||
|
if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]:
|
||||||
|
return await call_next(request)
|
||||||
|
if request.url.path.startswith("/static/"):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
client_ip = self._get_client_ip(request)
|
||||||
|
now = datetime.now()
|
||||||
|
window_start = now - timedelta(minutes=1)
|
||||||
|
|
||||||
|
# 清理过期记录
|
||||||
|
self._clean_old_requests(client_ip, window_start)
|
||||||
|
|
||||||
|
# 检查请求数
|
||||||
|
request_count = len(self.request_counts[client_ip])
|
||||||
|
|
||||||
|
if request_count >= self.burst_limit:
|
||||||
|
logger.warning(
|
||||||
|
"请求被限流",
|
||||||
|
client_ip=client_ip,
|
||||||
|
request_count=request_count,
|
||||||
|
path=request.url.path,
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"code": 429,
|
||||||
|
"message": "请求过于频繁,请稍后再试",
|
||||||
|
"retry_after": 60,
|
||||||
|
},
|
||||||
|
headers={"Retry-After": "60"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录本次请求
|
||||||
|
self.request_counts[client_ip].append(now)
|
||||||
|
|
||||||
|
# 如果接近限制,添加警告头
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
remaining = self.burst_limit - len(self.request_counts[client_ip])
|
||||||
|
response.headers["X-RateLimit-Limit"] = str(self.burst_limit)
|
||||||
|
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
|
||||||
|
response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp()))
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||||
"""请求ID中间件"""
|
"""请求ID中间件"""
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,14 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 添加限流中间件
|
||||||
|
from app.core.middleware import RateLimitMiddleware
|
||||||
|
app.add_middleware(
|
||||||
|
RateLimitMiddleware,
|
||||||
|
requests_per_minute=120, # 每分钟最大请求数
|
||||||
|
burst_limit=200, # 突发请求限制
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 健康检查端点
|
# 健康检查端点
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
@@ -140,16 +148,60 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
|||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=422,
|
status_code=422,
|
||||||
content={
|
content={
|
||||||
|
"code": 422,
|
||||||
|
"message": "请求参数验证失败",
|
||||||
"detail": exc.errors(),
|
"detail": exc.errors(),
|
||||||
"body": exc.body if hasattr(exc, 'body') else None,
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# JSON 解析错误处理
|
||||||
|
from json import JSONDecodeError
|
||||||
|
@app.exception_handler(JSONDecodeError)
|
||||||
|
async def json_decode_exception_handler(request: Request, exc: JSONDecodeError):
|
||||||
|
"""处理 JSON 解析错误"""
|
||||||
|
logger.warning(f"JSON解析错误 [{request.method} {request.url.path}]: {exc}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"code": 400,
|
||||||
|
"message": "请求体格式错误,需要有效的 JSON",
|
||||||
|
"detail": str(exc),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# HTTP 异常处理
|
||||||
|
from fastapi import HTTPException
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||||
|
"""处理 HTTP 异常"""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={
|
||||||
|
"code": exc.status_code,
|
||||||
|
"message": exc.detail,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 全局异常处理
|
# 全局异常处理
|
||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def global_exception_handler(request, exc):
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
"""全局异常处理"""
|
"""全局异常处理"""
|
||||||
|
error_msg = str(exc)
|
||||||
|
|
||||||
|
# 检查是否是 Content-Type 相关错误
|
||||||
|
if "Expecting value" in error_msg or "JSON" in error_msg.upper():
|
||||||
|
logger.warning(f"请求体解析错误 [{request.method} {request.url.path}]: {error_msg}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"code": 400,
|
||||||
|
"message": "请求体格式错误,请使用 application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
logger.error(f"未处理的异常: {exc}", exc_info=True)
|
logger.error(f"未处理的异常: {exc}", exc_info=True)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
|||||||
Reference in New Issue
Block a user