Files
012-kaopeilian/backend/app/core/middleware.py
yuliang_guo 79b55cfd12
All checks were successful
continuous-integration/drone/push Build is passing
fix: 修复权限提升漏洞和添加安全头
安全修复:
- 创建 UserSelfUpdate schema,禁止用户修改自己的 role 和 is_active
- /users/me 端点现在使用 UserSelfUpdate 而非 UserUpdate

安全增强:
- 添加 SecurityHeadersMiddleware 中间件
- X-Content-Type-Options: nosniff
- X-Frame-Options: DENY
- X-XSS-Protection: 1; mode=block
- Referrer-Policy: strict-origin-when-cross-origin
- Permissions-Policy: 禁用敏感功能
- Cache-Control: API响应不缓存
2026-01-31 10:57:41 +08:00

180 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
中间件定义
"""
import time
import uuid
from typing import Callable, Dict
from collections import defaultdict
from datetime import datetime, timedelta
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.logger import logger
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
API 限流中间件
基于IP地址进行限流防止恶意请求攻击
"""
def __init__(self, app, requests_per_minute: int = 60, burst_limit: int = 100):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.burst_limit = burst_limit # 突发请求限制
self.request_counts: Dict[str, list] = defaultdict(list)
def _get_client_ip(self, request: Request) -> str:
"""获取客户端真实IP"""
# 优先从代理头获取
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
return request.client.host if request.client else "unknown"
def _clean_old_requests(self, ip: str, window_start: datetime):
"""清理窗口外的请求记录"""
self.request_counts[ip] = [
t for t in self.request_counts[ip]
if t > window_start
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 跳过健康检查和静态文件
if request.url.path in ["/health", "/docs", "/openapi.json", "/redoc"]:
return await call_next(request)
if request.url.path.startswith("/static/"):
return await call_next(request)
client_ip = self._get_client_ip(request)
now = datetime.now()
window_start = now - timedelta(minutes=1)
# 清理过期记录
self._clean_old_requests(client_ip, window_start)
# 检查请求数
request_count = len(self.request_counts[client_ip])
if request_count >= self.burst_limit:
logger.warning(
"请求被限流",
client_ip=client_ip,
request_count=request_count,
path=request.url.path,
)
return JSONResponse(
status_code=429,
content={
"code": 429,
"message": "请求过于频繁,请稍后再试",
"retry_after": 60,
},
headers={"Retry-After": "60"}
)
# 记录本次请求
self.request_counts[client_ip].append(now)
# 如果接近限制,添加警告头
response = await call_next(request)
remaining = self.burst_limit - len(self.request_counts[client_ip])
response.headers["X-RateLimit-Limit"] = str(self.burst_limit)
response.headers["X-RateLimit-Remaining"] = str(max(0, remaining))
response.headers["X-RateLimit-Reset"] = str(int((window_start + timedelta(minutes=1)).timestamp()))
return response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
安全响应头中间件
添加各种安全相关的 HTTP 响应头
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
response = await call_next(request)
# 防止 MIME 类型嗅探
response.headers["X-Content-Type-Options"] = "nosniff"
# 防止点击劫持
response.headers["X-Frame-Options"] = "DENY"
# XSS 过滤器(现代浏览器已弃用,但仍有一些旧浏览器支持)
response.headers["X-XSS-Protection"] = "1; mode=block"
# 引用策略
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# 权限策略(禁用一些敏感功能)
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
# 缓存控制API 响应不应被缓存)
if request.url.path.startswith("/api/"):
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
return response
class RequestIDMiddleware(BaseHTTPMiddleware):
"""请求ID中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 生成请求ID
request_id = str(uuid.uuid4())
# 将请求ID添加到request状态
request.state.request_id = request_id
# 记录请求开始
start_time = time.time()
# 处理请求
response = await call_next(request)
# 计算处理时间
process_time = time.time() - start_time
# 添加响应头
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = str(process_time)
# 记录请求日志
logger.info(
"HTTP请求",
method=request.method,
url=str(request.url),
status_code=response.status_code,
process_time=process_time,
request_id=request_id,
)
return response
class GlobalContextMiddleware(BaseHTTPMiddleware):
"""全局上下文中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 设置追踪ID用于分布式追踪
trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4()))
request.state.trace_id = trace_id
# 处理请求
response = await call_next(request)
# 添加追踪ID到响应头
response.headers["X-Trace-ID"] = trace_id
return response