All checks were successful
continuous-integration/drone/push Build is passing
安全修复: - 创建 UserSelfUpdate schema,禁止用户修改自己的 role 和 is_active - /users/me 端点现在使用 UserSelfUpdate 而非 UserUpdate 安全增强: - 添加 SecurityHeadersMiddleware 中间件 - X-Content-Type-Options: nosniff - X-Frame-Options: DENY - X-XSS-Protection: 1; mode=block - Referrer-Policy: strict-origin-when-cross-origin - Permissions-Policy: 禁用敏感功能 - Cache-Control: API响应不缓存
180 lines
5.8 KiB
Python
180 lines
5.8 KiB
Python
"""
|
||
中间件定义
|
||
"""
|
||
import time
|
||
import uuid
|
||
from typing import Callable, Dict
|
||
from collections import defaultdict
|
||
from datetime import datetime, timedelta
|
||
|
||
from fastapi import Request, Response
|
||
from fastapi.responses import JSONResponse
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
|
||
from app.core.logger import logger
|
||
|
||
|
||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||
"""
|
||
API 限流中间件
|
||
|
||
基于IP地址进行限流,防止恶意请求攻击
|
||
"""
|
||
|
||
def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100):
|
||
super().__init__(app)
|
||
self.requests_per_minute = requests_per_minute
|
||
self.burst_limit = burst_limit # 突发请求限制
|
||
self.request_counts: Dict[str, list] = defaultdict(list)
|
||
|
||
def _get_client_ip(self, request: Request) -> str:
|
||
"""获取客户端真实IP"""
|
||
# 优先从代理头获取
|
||
forwarded = request.headers.get("X-Forwarded-For")
|
||
if forwarded:
|
||
return forwarded.split(",")[0].strip()
|
||
real_ip = request.headers.get("X-Real-IP")
|
||
if real_ip:
|
||
return real_ip
|
||
return request.client.host if request.client else "unknown"
|
||
|
||
def _clean_old_requests(self, ip: str, window_start: datetime):
|
||
"""清理窗口外的请求记录"""
|
||
self.request_counts[ip] = [
|
||
t for t in self.request_counts[ip]
|
||
if t > window_start
|
||
]
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
# 跳过健康检查和静态文件
|
||
if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]:
|
||
return await call_next(request)
|
||
if request.url.path.startswith("/static/"):
|
||
return await call_next(request)
|
||
|
||
client_ip = self._get_client_ip(request)
|
||
now = datetime.now()
|
||
window_start = now - timedelta(minutes=1)
|
||
|
||
# 清理过期记录
|
||
self._clean_old_requests(client_ip, window_start)
|
||
|
||
# 检查请求数
|
||
request_count = len(self.request_counts[client_ip])
|
||
|
||
if request_count >= self.burst_limit:
|
||
logger.warning(
|
||
"请求被限流",
|
||
client_ip=client_ip,
|
||
request_count=request_count,
|
||
path=request.url.path,
|
||
)
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={
|
||
"code": 429,
|
||
"message": "请求过于频繁,请稍后再试",
|
||
"retry_after": 60,
|
||
},
|
||
headers={"Retry-After": "60"}
|
||
)
|
||
|
||
# 记录本次请求
|
||
self.request_counts[client_ip].append(now)
|
||
|
||
# 如果接近限制,添加警告头
|
||
response = await call_next(request)
|
||
|
||
remaining = self.burst_limit - len(self.request_counts[client_ip])
|
||
response.headers["X-RateLimit-Limit"] = str(self.burst_limit)
|
||
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
|
||
response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp()))
|
||
|
||
return response
|
||
|
||
|
||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||
"""
|
||
安全响应头中间件
|
||
|
||
添加各种安全相关的 HTTP 响应头
|
||
"""
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
response = await call_next(request)
|
||
|
||
# 防止 MIME 类型嗅探
|
||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||
|
||
# 防止点击劫持
|
||
response.headers["X-Frame-Options"] = "DENY"
|
||
|
||
# XSS 过滤器(现代浏览器已弃用,但仍有一些旧浏览器支持)
|
||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||
|
||
# 引用策略
|
||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||
|
||
# 权限策略(禁用一些敏感功能)
|
||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||
|
||
# 缓存控制(API 响应不应被缓存)
|
||
if request.url.path.startswith("/api/"):
|
||
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
||
response.headers["Pragma"] = "no-cache"
|
||
|
||
return response
|
||
|
||
|
||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||
"""请求ID中间件"""
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
# 生成请求ID
|
||
request_id = str(uuid.uuid4())
|
||
|
||
# 将请求ID添加到request状态
|
||
request.state.request_id = request_id
|
||
|
||
# 记录请求开始
|
||
start_time = time.time()
|
||
|
||
# 处理请求
|
||
response = await call_next(request)
|
||
|
||
# 计算处理时间
|
||
process_time = time.time() - start_time
|
||
|
||
# 添加响应头
|
||
response.headers["X-Request-ID"] = request_id
|
||
response.headers["X-Process-Time"] = str(process_time)
|
||
|
||
# 记录请求日志
|
||
logger.info(
|
||
"HTTP请求",
|
||
method=request.method,
|
||
url=str(request.url),
|
||
status_code=response.status_code,
|
||
process_time=process_time,
|
||
request_id=request_id,
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
class GlobalContextMiddleware(BaseHTTPMiddleware):
|
||
"""全局上下文中间件"""
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
# 设置追踪ID(用于分布式追踪)
|
||
trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4()))
|
||
request.state.trace_id = trace_id
|
||
|
||
# 处理请求
|
||
response = await call_next(request)
|
||
|
||
# 添加追踪ID到响应头
|
||
response.headers["X-Trace-ID"] = trace_id
|
||
|
||
return response
|