From 52dccaab796582619631fe6dbbe09baa4116bf24 Mon Sep 17 00:00:00 2001 From: yuliang_guo Date: Sat, 31 Jan 2026 10:50:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0API=E9=99=90=E6=B5=81?= =?UTF-8?q?=E5=92=8C=E4=BC=98=E5=8C=96=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 RateLimitMiddleware 限流中间件 (200请求/分钟) - 优化 Content-Type 错误返回 400 而非 500 - 添加 JSON 解析错误处理 - 统一 HTTP 异常处理格式 --- backend/app/core/middleware.py | 84 +++++++++++++++++++++++++++++++++- backend/app/main.py | 56 ++++++++++++++++++++++- 2 files changed, 137 insertions(+), 3 deletions(-) diff --git a/backend/app/core/middleware.py b/backend/app/core/middleware.py index 9a4232b..3b5ac0d 100644 --- a/backend/app/core/middleware.py +++ b/backend/app/core/middleware.py @@ -3,14 +3,96 @@ """ import time import uuid -from typing import Callable +from typing import Callable, Dict +from collections import defaultdict +from datetime import datetime, timedelta from fastapi import Request, Response +from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from app.core.logger import logger +class RateLimitMiddleware(BaseHTTPMiddleware): + """ + API 限流中间件 + + 基于IP地址进行限流,防止恶意请求攻击 + """ + + def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100): + super().__init__(app) + self.requests_per_minute = requests_per_minute + self.burst_limit = burst_limit # 突发请求限制 + self.request_counts: Dict[str, list] = defaultdict(list) + + def _get_client_ip(self, request: Request) -> str: + """获取客户端真实IP""" + # 优先从代理头获取 + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + return request.client.host if request.client else "unknown" + + def _clean_old_requests(self, ip: str, window_start: datetime): + """清理窗口外的请求记录""" + self.request_counts[ip] = [ + t for t in self.request_counts[ip] + if t > window_start + ] + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + # 跳过健康检查和静态文件 + if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]: + return await call_next(request) + if request.url.path.startswith("/static/"): + return await call_next(request) + + client_ip = self._get_client_ip(request) + now = datetime.now() + window_start = now - timedelta(minutes=1) + + # 清理过期记录 + self._clean_old_requests(client_ip, window_start) + + # 检查请求数 + request_count = len(self.request_counts[client_ip]) + + if request_count >= self.burst_limit: + logger.warning( + "请求被限流", + client_ip=client_ip, + request_count=request_count, + path=request.url.path, + ) + return JSONResponse( + status_code=429, + content={ + "code": 429, + "message": "请求过于频繁,请稍后再试", + "retry_after": 60, + }, + headers={"Retry-After": "60"} + ) + + # 记录本次请求 + self.request_counts[client_ip].append(now) + + # 如果接近限制,添加警告头 + response = await call_next(request) + + remaining = self.burst_limit - len(self.request_counts[client_ip]) + response.headers["X-RateLimit-Limit"] = str(self.burst_limit) + response.headers["X-RateLimit-Remaining"] = str(max(0, remaining)) + response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp())) + + return response + + class RequestIDMiddleware(BaseHTTPMiddleware): """请求ID中间件""" diff --git a/backend/app/main.py b/backend/app/main.py index 3cd56f8..7f95c5f 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -97,6 +97,14 @@ app.add_middleware( allow_headers=["*"], ) +# 添加限流中间件 +from app.core.middleware import RateLimitMiddleware +app.add_middleware( + RateLimitMiddleware, + requests_per_minute=120, # 每分钟最大请求数 + burst_limit=200, # 突发请求限制 +) + # 健康检查端点 @app.get("/health") @@ -140,16 +148,60 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE return JSONResponse( status_code=422, content={ + "code": 422, + "message": "请求参数验证失败", "detail": exc.errors(), - "body": exc.body if hasattr(exc, 'body') else None, + }, + ) + + +# JSON 解析错误处理 +from json import JSONDecodeError +@app.exception_handler(JSONDecodeError) +async def json_decode_exception_handler(request: Request, exc: JSONDecodeError): + """处理 JSON 解析错误""" + logger.warning(f"JSON解析错误 [{request.method} {request.url.path}]: {exc}") + return JSONResponse( + status_code=400, + content={ + "code": 400, + "message": "请求体格式错误,需要有效的 JSON", + "detail": str(exc), + }, + ) + + +# HTTP 异常处理 +from fastapi import HTTPException +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + """处理 HTTP 异常""" + return JSONResponse( + status_code=exc.status_code, + content={ + "code": exc.status_code, + "message": exc.detail, }, ) # 全局异常处理 @app.exception_handler(Exception) -async def global_exception_handler(request, exc): +async def global_exception_handler(request: Request, exc: Exception): """全局异常处理""" + error_msg = str(exc) + + # 检查是否是 Content-Type 相关错误 + if "Expecting value" in error_msg or "JSON" in error_msg.upper(): + logger.warning(f"请求体解析错误 [{request.method} {request.url.path}]: {error_msg}") + return JSONResponse( + status_code=400, + content={ + "code": 400, + "message": "请求体格式错误,请使用 application/json", + }, + ) + logger.error(f"未处理的异常: {exc}", exc_info=True) return JSONResponse( status_code=500,