"""安全中间件 实现安全相关功能: - 请求验证 - 速率限制 - 安全头设置 - 敏感数据保护 遵循瑞小美系统技术栈标准 """ import logging import time from collections import defaultdict from typing import Callable, Dict, Optional import re from fastapi import Request, Response, HTTPException from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware logger = logging.getLogger(__name__) class RateLimitMiddleware(BaseHTTPMiddleware): """速率限制中间件 防止 API 滥用,保护服务稳定性 """ # 速率限制配置 RATE_LIMITS = { "default": {"requests": 100, "window": 60}, # 默认:100 次/分钟 "/api/v1/projects/*/generate-pricing": {"requests": 10, "window": 60}, # AI 接口:10 次/分钟 "/api/v1/projects/*/market-analysis": {"requests": 20, "window": 60}, # 分析接口:20 次/分钟 } def __init__(self, app, enabled: bool = True): super().__init__(app) self.enabled = enabled self._requests: Dict[str, list] = defaultdict(list) def _get_client_id(self, request: Request) -> str: """获取客户端标识""" # 优先使用 X-Forwarded-For(反向代理场景) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: return forwarded_for.split(",")[0].strip() # 使用客户端 IP return request.client.host if request.client else "unknown" def _get_rate_limit(self, path: str) -> Dict: """获取路径的速率限制配置""" for pattern, limit in self.RATE_LIMITS.items(): if pattern == "default": continue # 简单的路径匹配(* 匹配任意字符) regex = pattern.replace("*", "[^/]+") if re.match(regex, path): return limit return self.RATE_LIMITS["default"] def _is_rate_limited(self, client_id: str, path: str) -> bool: """检查是否超过速率限制""" limit_config = self._get_rate_limit(path) requests = limit_config["requests"] window = limit_config["window"] now = time.time() key = f"{client_id}:{path}" # 清理过期记录 self._requests[key] = [t for t in self._requests[key] if now - t < window] # 检查是否超限 if len(self._requests[key]) >= requests: return True # 记录请求 self._requests[key].append(now) return False async def dispatch(self, request: Request, call_next: Callable) -> Response: if not self.enabled: return await call_next(request) client_id = self._get_client_id(request) path = request.url.path if self._is_rate_limited(client_id, path): logger.warning(f"Rate limit exceeded: {client_id} -> {path}") return JSONResponse( status_code=429, content={ "code": 40001, "message": "请求过于频繁,请稍后再试", "data": None } ) return await call_next(request) class SecurityHeadersMiddleware(BaseHTTPMiddleware): """安全响应头中间件 添加安全相关的 HTTP 响应头 """ async def dispatch(self, request: Request, call_next: Callable) -> Response: response = await call_next(request) # 防止 XSS 攻击 response.headers["X-XSS-Protection"] = "1; mode=block" # 防止 MIME 类型嗅探 response.headers["X-Content-Type-Options"] = "nosniff" # 点击劫持保护 response.headers["X-Frame-Options"] = "DENY" # 内容安全策略(基础版) response.headers["Content-Security-Policy"] = "default-src 'self'" # 严格传输安全(仅在 HTTPS 环境) # response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" return response class InputSanitizer: """输入清理工具 防止 SQL 注入、XSS 等攻击 """ # SQL 注入关键字 SQL_KEYWORDS = [ "SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "UNION", "OR", "AND", "--", "/*", "*/", "EXEC", "EXECUTE" ] # XSS 危险模式 XSS_PATTERNS = [ r"", r"javascript:", r"on\w+\s*=", r"eval\s*\(", ] @classmethod def check_sql_injection(cls, value: str) -> bool: """检查是否包含 SQL 注入特征""" upper_value = value.upper() for keyword in cls.SQL_KEYWORDS: if keyword in upper_value: return True return False @classmethod def check_xss(cls, value: str) -> bool: """检查是否包含 XSS 特征""" for pattern in cls.XSS_PATTERNS: if re.search(pattern, value, re.IGNORECASE): return True return False @classmethod def sanitize(cls, value: str) -> str: """清理输入值 移除潜在危险字符 """ # 移除 HTML 标签 value = re.sub(r'<[^>]+>', '', value) # 转义特殊字符 value = value.replace("&", "&") value = value.replace("<", "<") value = value.replace(">", ">") value = value.replace('"', """) value = value.replace("'", "'") return value def validate_request_body(data: dict, max_depth: int = 10) -> None: """验证请求体 检查数据深度和内容安全 Args: data: 请求数据 max_depth: 最大嵌套深度 Raises: HTTPException: 验证失败 """ def check_depth(obj, depth=0): if depth > max_depth: raise HTTPException( status_code=400, detail={"code": 10001, "message": "请求数据嵌套层级过深"} ) if isinstance(obj, dict): for value in obj.values(): check_depth(value, depth + 1) elif isinstance(obj, list): for item in obj: check_depth(item, depth + 1) elif isinstance(obj, str): if InputSanitizer.check_sql_injection(obj): logger.warning(f"Potential SQL injection detected: {obj[:100]}") if InputSanitizer.check_xss(obj): logger.warning(f"Potential XSS detected: {obj[:100]}") check_depth(data) class AuditLogger: """审计日志记录器 记录敏感操作的审计日志 """ # 需要审计的操作 AUDIT_OPERATIONS = { ("POST", "/api/v1/pricing-plans"): "创建定价方案", ("PUT", "/api/v1/pricing-plans/*"): "更新定价方案", ("DELETE", "/api/v1/pricing-plans/*"): "删除定价方案", ("POST", "/api/v1/projects/*/generate-pricing"): "生成 AI 定价建议", } @classmethod def should_audit(cls, method: str, path: str) -> Optional[str]: """检查是否需要审计""" for (m, p), desc in cls.AUDIT_OPERATIONS.items(): if m != method: continue regex = p.replace("*", "[^/]+") if re.match(regex, path): return desc return None @classmethod def log( cls, operation: str, user_id: Optional[int], request: Request, response_code: int, details: Optional[Dict] = None ): """记录审计日志""" log_data = { "operation": operation, "user_id": user_id, "method": request.method, "path": str(request.url.path), "client_ip": request.client.host if request.client else "unknown", "response_code": response_code, "details": details or {}, } logger.info(f"AUDIT: {log_data}")