268 lines
8.0 KiB
Python
268 lines
8.0 KiB
Python
"""安全中间件
|
||
|
||
实现安全相关功能:
|
||
- 请求验证
|
||
- 速率限制
|
||
- 安全头设置
|
||
- 敏感数据保护
|
||
|
||
遵循瑞小美系统技术栈标准
|
||
"""
|
||
|
||
import logging
|
||
import time
|
||
from collections import defaultdict
|
||
from typing import Callable, Dict, Optional
|
||
import re
|
||
|
||
from fastapi import Request, Response, HTTPException
|
||
from fastapi.responses import JSONResponse
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||
"""速率限制中间件
|
||
|
||
防止 API 滥用,保护服务稳定性
|
||
"""
|
||
|
||
# 速率限制配置
|
||
RATE_LIMITS = {
|
||
"default": {"requests": 100, "window": 60}, # 默认:100 次/分钟
|
||
"/api/v1/projects/*/generate-pricing": {"requests": 10, "window": 60}, # AI 接口:10 次/分钟
|
||
"/api/v1/projects/*/market-analysis": {"requests": 20, "window": 60}, # 分析接口:20 次/分钟
|
||
}
|
||
|
||
def __init__(self, app, enabled: bool = True):
|
||
super().__init__(app)
|
||
self.enabled = enabled
|
||
self._requests: Dict[str, list] = defaultdict(list)
|
||
|
||
def _get_client_id(self, request: Request) -> str:
|
||
"""获取客户端标识"""
|
||
# 优先使用 X-Forwarded-For(反向代理场景)
|
||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||
if forwarded_for:
|
||
return forwarded_for.split(",")[0].strip()
|
||
|
||
# 使用客户端 IP
|
||
return request.client.host if request.client else "unknown"
|
||
|
||
def _get_rate_limit(self, path: str) -> Dict:
|
||
"""获取路径的速率限制配置"""
|
||
for pattern, limit in self.RATE_LIMITS.items():
|
||
if pattern == "default":
|
||
continue
|
||
# 简单的路径匹配(* 匹配任意字符)
|
||
regex = pattern.replace("*", "[^/]+")
|
||
if re.match(regex, path):
|
||
return limit
|
||
|
||
return self.RATE_LIMITS["default"]
|
||
|
||
def _is_rate_limited(self, client_id: str, path: str) -> bool:
|
||
"""检查是否超过速率限制"""
|
||
limit_config = self._get_rate_limit(path)
|
||
requests = limit_config["requests"]
|
||
window = limit_config["window"]
|
||
|
||
now = time.time()
|
||
key = f"{client_id}:{path}"
|
||
|
||
# 清理过期记录
|
||
self._requests[key] = [t for t in self._requests[key] if now - t < window]
|
||
|
||
# 检查是否超限
|
||
if len(self._requests[key]) >= requests:
|
||
return True
|
||
|
||
# 记录请求
|
||
self._requests[key].append(now)
|
||
return False
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
if not self.enabled:
|
||
return await call_next(request)
|
||
|
||
client_id = self._get_client_id(request)
|
||
path = request.url.path
|
||
|
||
if self._is_rate_limited(client_id, path):
|
||
logger.warning(f"Rate limit exceeded: {client_id} -> {path}")
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={
|
||
"code": 40001,
|
||
"message": "请求过于频繁,请稍后再试",
|
||
"data": None
|
||
}
|
||
)
|
||
|
||
return await call_next(request)
|
||
|
||
|
||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||
"""安全响应头中间件
|
||
|
||
添加安全相关的 HTTP 响应头
|
||
"""
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
response = await call_next(request)
|
||
|
||
# 防止 XSS 攻击
|
||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||
|
||
# 防止 MIME 类型嗅探
|
||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||
|
||
# 点击劫持保护
|
||
response.headers["X-Frame-Options"] = "DENY"
|
||
|
||
# 内容安全策略(基础版)
|
||
response.headers["Content-Security-Policy"] = "default-src 'self'"
|
||
|
||
# 严格传输安全(仅在 HTTPS 环境)
|
||
# response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||
|
||
return response
|
||
|
||
|
||
class InputSanitizer:
|
||
"""输入清理工具
|
||
|
||
防止 SQL 注入、XSS 等攻击
|
||
"""
|
||
|
||
# SQL 注入关键字
|
||
SQL_KEYWORDS = [
|
||
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "UNION",
|
||
"OR", "AND", "--", "/*", "*/", "EXEC", "EXECUTE"
|
||
]
|
||
|
||
# XSS 危险模式
|
||
XSS_PATTERNS = [
|
||
r"<script.*?>",
|
||
r"javascript:",
|
||
r"on\w+\s*=",
|
||
r"eval\s*\(",
|
||
]
|
||
|
||
@classmethod
|
||
def check_sql_injection(cls, value: str) -> bool:
|
||
"""检查是否包含 SQL 注入特征"""
|
||
upper_value = value.upper()
|
||
for keyword in cls.SQL_KEYWORDS:
|
||
if keyword in upper_value:
|
||
return True
|
||
return False
|
||
|
||
@classmethod
|
||
def check_xss(cls, value: str) -> bool:
|
||
"""检查是否包含 XSS 特征"""
|
||
for pattern in cls.XSS_PATTERNS:
|
||
if re.search(pattern, value, re.IGNORECASE):
|
||
return True
|
||
return False
|
||
|
||
@classmethod
|
||
def sanitize(cls, value: str) -> str:
|
||
"""清理输入值
|
||
|
||
移除潜在危险字符
|
||
"""
|
||
# 移除 HTML 标签
|
||
value = re.sub(r'<[^>]+>', '', value)
|
||
|
||
# 转义特殊字符
|
||
value = value.replace("&", "&")
|
||
value = value.replace("<", "<")
|
||
value = value.replace(">", ">")
|
||
value = value.replace('"', """)
|
||
value = value.replace("'", "'")
|
||
|
||
return value
|
||
|
||
|
||
def validate_request_body(data: dict, max_depth: int = 10) -> None:
|
||
"""验证请求体
|
||
|
||
检查数据深度和内容安全
|
||
|
||
Args:
|
||
data: 请求数据
|
||
max_depth: 最大嵌套深度
|
||
|
||
Raises:
|
||
HTTPException: 验证失败
|
||
"""
|
||
def check_depth(obj, depth=0):
|
||
if depth > max_depth:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail={"code": 10001, "message": "请求数据嵌套层级过深"}
|
||
)
|
||
|
||
if isinstance(obj, dict):
|
||
for value in obj.values():
|
||
check_depth(value, depth + 1)
|
||
elif isinstance(obj, list):
|
||
for item in obj:
|
||
check_depth(item, depth + 1)
|
||
elif isinstance(obj, str):
|
||
if InputSanitizer.check_sql_injection(obj):
|
||
logger.warning(f"Potential SQL injection detected: {obj[:100]}")
|
||
if InputSanitizer.check_xss(obj):
|
||
logger.warning(f"Potential XSS detected: {obj[:100]}")
|
||
|
||
check_depth(data)
|
||
|
||
|
||
class AuditLogger:
|
||
"""审计日志记录器
|
||
|
||
记录敏感操作的审计日志
|
||
"""
|
||
|
||
# 需要审计的操作
|
||
AUDIT_OPERATIONS = {
|
||
("POST", "/api/v1/pricing-plans"): "创建定价方案",
|
||
("PUT", "/api/v1/pricing-plans/*"): "更新定价方案",
|
||
("DELETE", "/api/v1/pricing-plans/*"): "删除定价方案",
|
||
("POST", "/api/v1/projects/*/generate-pricing"): "生成 AI 定价建议",
|
||
}
|
||
|
||
@classmethod
|
||
def should_audit(cls, method: str, path: str) -> Optional[str]:
|
||
"""检查是否需要审计"""
|
||
for (m, p), desc in cls.AUDIT_OPERATIONS.items():
|
||
if m != method:
|
||
continue
|
||
regex = p.replace("*", "[^/]+")
|
||
if re.match(regex, path):
|
||
return desc
|
||
return None
|
||
|
||
@classmethod
|
||
def log(
|
||
cls,
|
||
operation: str,
|
||
user_id: Optional[int],
|
||
request: Request,
|
||
response_code: int,
|
||
details: Optional[Dict] = None
|
||
):
|
||
"""记录审计日志"""
|
||
log_data = {
|
||
"operation": operation,
|
||
"user_id": user_id,
|
||
"method": request.method,
|
||
"path": str(request.url.path),
|
||
"client_ip": request.client.host if request.client else "unknown",
|
||
"response_code": response_code,
|
||
"details": details or {},
|
||
}
|
||
logger.info(f"AUDIT: {log_data}")
|