Files
smart-project-pricing/后端服务/app/middleware/security.py
2026-01-31 21:33:06 +08:00

268 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""安全中间件
实现安全相关功能:
- 请求验证
- 速率限制
- 安全头设置
- 敏感数据保护
遵循瑞小美系统技术栈标准
"""
import logging
import time
from collections import defaultdict
from typing import Callable, Dict, Optional
import re
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""速率限制中间件
防止 API 滥用,保护服务稳定性
"""
# 速率限制配置
RATE_LIMITS = {
"default": {"requests": 100, "window": 60}, # 默认100 次/分钟
"/api/v1/projects/*/generate-pricing": {"requests": 10, "window": 60}, # AI 接口10 次/分钟
"/api/v1/projects/*/market-analysis": {"requests": 20, "window": 60}, # 分析接口20 次/分钟
}
def __init__(self, app, enabled: bool = True):
super().__init__(app)
self.enabled = enabled
self._requests: Dict[str, list] = defaultdict(list)
def _get_client_id(self, request: Request) -> str:
"""获取客户端标识"""
# 优先使用 X-Forwarded-For反向代理场景
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
# 使用客户端 IP
return request.client.host if request.client else "unknown"
def _get_rate_limit(self, path: str) -> Dict:
"""获取路径的速率限制配置"""
for pattern, limit in self.RATE_LIMITS.items():
if pattern == "default":
continue
# 简单的路径匹配(* 匹配任意字符)
regex = pattern.replace("*", "[^/]+")
if re.match(regex, path):
return limit
return self.RATE_LIMITS["default"]
def _is_rate_limited(self, client_id: str, path: str) -> bool:
"""检查是否超过速率限制"""
limit_config = self._get_rate_limit(path)
requests = limit_config["requests"]
window = limit_config["window"]
now = time.time()
key = f"{client_id}:{path}"
# 清理过期记录
self._requests[key] = [t for t in self._requests[key] if now - t < window]
# 检查是否超限
if len(self._requests[key]) >= requests:
return True
# 记录请求
self._requests[key].append(now)
return False
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if not self.enabled:
return await call_next(request)
client_id = self._get_client_id(request)
path = request.url.path
if self._is_rate_limited(client_id, path):
logger.warning(f"Rate limit exceeded: {client_id} -> {path}")
return JSONResponse(
status_code=429,
content={
"code": 40001,
"message": "请求过于频繁,请稍后再试",
"data": None
}
)
return await call_next(request)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""安全响应头中间件
添加安全相关的 HTTP 响应头
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
response = await call_next(request)
# 防止 XSS 攻击
response.headers["X-XSS-Protection"] = "1; mode=block"
# 防止 MIME 类型嗅探
response.headers["X-Content-Type-Options"] = "nosniff"
# 点击劫持保护
response.headers["X-Frame-Options"] = "DENY"
# 内容安全策略(基础版)
response.headers["Content-Security-Policy"] = "default-src 'self'"
# 严格传输安全(仅在 HTTPS 环境)
# response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
class InputSanitizer:
"""输入清理工具
防止 SQL 注入、XSS 等攻击
"""
# SQL 注入关键字
SQL_KEYWORDS = [
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "UNION",
"OR", "AND", "--", "/*", "*/", "EXEC", "EXECUTE"
]
# XSS 危险模式
XSS_PATTERNS = [
r"<script.*?>",
r"javascript:",
r"on\w+\s*=",
r"eval\s*\(",
]
@classmethod
def check_sql_injection(cls, value: str) -> bool:
"""检查是否包含 SQL 注入特征"""
upper_value = value.upper()
for keyword in cls.SQL_KEYWORDS:
if keyword in upper_value:
return True
return False
@classmethod
def check_xss(cls, value: str) -> bool:
"""检查是否包含 XSS 特征"""
for pattern in cls.XSS_PATTERNS:
if re.search(pattern, value, re.IGNORECASE):
return True
return False
@classmethod
def sanitize(cls, value: str) -> str:
"""清理输入值
移除潜在危险字符
"""
# 移除 HTML 标签
value = re.sub(r'<[^>]+>', '', value)
# 转义特殊字符
value = value.replace("&", "&amp;")
value = value.replace("<", "&lt;")
value = value.replace(">", "&gt;")
value = value.replace('"', "&quot;")
value = value.replace("'", "&#x27;")
return value
def validate_request_body(data: dict, max_depth: int = 10) -> None:
"""验证请求体
检查数据深度和内容安全
Args:
data: 请求数据
max_depth: 最大嵌套深度
Raises:
HTTPException: 验证失败
"""
def check_depth(obj, depth=0):
if depth > max_depth:
raise HTTPException(
status_code=400,
detail={"code": 10001, "message": "请求数据嵌套层级过深"}
)
if isinstance(obj, dict):
for value in obj.values():
check_depth(value, depth + 1)
elif isinstance(obj, list):
for item in obj:
check_depth(item, depth + 1)
elif isinstance(obj, str):
if InputSanitizer.check_sql_injection(obj):
logger.warning(f"Potential SQL injection detected: {obj[:100]}")
if InputSanitizer.check_xss(obj):
logger.warning(f"Potential XSS detected: {obj[:100]}")
check_depth(data)
class AuditLogger:
"""审计日志记录器
记录敏感操作的审计日志
"""
# 需要审计的操作
AUDIT_OPERATIONS = {
("POST", "/api/v1/pricing-plans"): "创建定价方案",
("PUT", "/api/v1/pricing-plans/*"): "更新定价方案",
("DELETE", "/api/v1/pricing-plans/*"): "删除定价方案",
("POST", "/api/v1/projects/*/generate-pricing"): "生成 AI 定价建议",
}
@classmethod
def should_audit(cls, method: str, path: str) -> Optional[str]:
"""检查是否需要审计"""
for (m, p), desc in cls.AUDIT_OPERATIONS.items():
if m != method:
continue
regex = p.replace("*", "[^/]+")
if re.match(regex, path):
return desc
return None
@classmethod
def log(
cls,
operation: str,
user_id: Optional[int],
request: Request,
response_code: int,
details: Optional[Dict] = None
):
"""记录审计日志"""
log_data = {
"operation": operation,
"user_id": user_id,
"method": request.method,
"path": str(request.url.path),
"client_ip": request.client.host if request.client else "unknown",
"response_code": response_code,
"details": details or {},
}
logger.info(f"AUDIT: {log_data}")