feat: 添加API限流和优化错误处理
Some checks failed
continuous-integration/drone/push Build is failing

- 添加 RateLimitMiddleware 限流中间件 (200请求/分钟)
- 优化 Content-Type 错误返回 400 而非 500
- 添加 JSON 解析错误处理
- 统一 HTTP 异常处理格式
This commit is contained in:
yuliang_guo
2026-01-31 10:50:27 +08:00
parent d59a4355a5
commit 52dccaab79
2 changed files with 137 additions and 3 deletions

View File

@@ -3,14 +3,96 @@
"""
import time
import uuid
from typing import Callable
from typing import Callable, Dict
from collections import defaultdict
from datetime import datetime, timedelta
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.logger import logger
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
API 限流中间件
基于IP地址进行限流防止恶意请求攻击
"""
def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.burst_limit = burst_limit # 突发请求限制
self.request_counts: Dict[str, list] = defaultdict(list)
def _get_client_ip(self, request: Request) -> str:
"""获取客户端真实IP"""
# 优先从代理头获取
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
def _clean_old_requests(self, ip: str, window_start: datetime):
"""清理窗口外的请求记录"""
self.request_counts[ip] = [
t for t in self.request_counts[ip]
if t > window_start
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 跳过健康检查和静态文件
if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]:
return await call_next(request)
if request.url.path.startswith("/static/"):
return await call_next(request)
client_ip = self._get_client_ip(request)
now = datetime.now()
window_start = now - timedelta(minutes=1)
# 清理过期记录
self._clean_old_requests(client_ip, window_start)
# 检查请求数
request_count = len(self.request_counts[client_ip])
if request_count >= self.burst_limit:
logger.warning(
"请求被限流",
client_ip=client_ip,
request_count=request_count,
path=request.url.path,
)
return JSONResponse(
status_code=429,
content={
"code": 429,
"message": "请求过于频繁,请稍后再试",
"retry_after": 60,
},
headers={"Retry-After": "60"}
)
# 记录本次请求
self.request_counts[client_ip].append(now)
# 如果接近限制,添加警告头
response = await call_next(request)
remaining = self.burst_limit - len(self.request_counts[client_ip])
response.headers["X-RateLimit-Limit"] = str(self.burst_limit)
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp()))
return response
class RequestIDMiddleware(BaseHTTPMiddleware):
"""请求ID中间件"""

View File

@@ -97,6 +97,14 @@ app.add_middleware(
allow_headers=["*"],
)
# 添加限流中间件
from app.core.middleware import RateLimitMiddleware
app.add_middleware(
RateLimitMiddleware,
requests_per_minute=120, # 每分钟最大请求数
burst_limit=200, # 突发请求限制
)
# 健康检查端点
@app.get("/health")
@@ -140,16 +148,60 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
return JSONResponse(
status_code=422,
content={
"code": 422,
"message": "请求参数验证失败",
"detail": exc.errors(),
"body": exc.body if hasattr(exc, 'body') else None,
},
)
# JSON 解析错误处理
from json import JSONDecodeError
@app.exception_handler(JSONDecodeError)
async def json_decode_exception_handler(request: Request, exc: JSONDecodeError):
"""处理 JSON 解析错误"""
logger.warning(f"JSON解析错误 [{request.method} {request.url.path}]: {exc}")
return JSONResponse(
status_code=400,
content={
"code": 400,
"message": "请求体格式错误,需要有效的 JSON",
"detail": str(exc),
},
)
# HTTP 异常处理
from fastapi import HTTPException
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""处理 HTTP 异常"""
return JSONResponse(
status_code=exc.status_code,
content={
"code": exc.status_code,
"message": exc.detail,
},
)
# 全局异常处理
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
async def global_exception_handler(request: Request, exc: Exception):
"""全局异常处理"""
error_msg = str(exc)
# 检查是否是 Content-Type 相关错误
if "Expecting value" in error_msg or "JSON" in error_msg.upper():
logger.warning(f"请求体解析错误 [{request.method} {request.url.path}]: {error_msg}")
return JSONResponse(
status_code=400,
content={
"code": 400,
"message": "请求体格式错误,请使用 application/json",
},
)
logger.error(f"未处理的异常: {exc}", exc_info=True)
return JSONResponse(
status_code=500,