Initial commit: 智能项目定价模型
This commit is contained in:
20
后端服务/app/middleware/__init__.py
Normal file
20
后端服务/app/middleware/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""中间件模块"""
|
||||
from .performance import PerformanceMiddleware
|
||||
from .cache import ResponseCacheMiddleware
|
||||
from .security import (
|
||||
RateLimitMiddleware,
|
||||
SecurityHeadersMiddleware,
|
||||
InputSanitizer,
|
||||
validate_request_body,
|
||||
AuditLogger,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PerformanceMiddleware",
|
||||
"ResponseCacheMiddleware",
|
||||
"RateLimitMiddleware",
|
||||
"SecurityHeadersMiddleware",
|
||||
"InputSanitizer",
|
||||
"validate_request_body",
|
||||
"AuditLogger",
|
||||
]
|
||||
121
后端服务/app/middleware/cache.py
Normal file
121
后端服务/app/middleware/cache.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""响应缓存中间件
|
||||
|
||||
对 GET 请求的响应进行缓存,减少数据库查询
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Callable, Optional, Set
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.services.cache_service import get_cache, CacheNamespace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseCacheMiddleware(BaseHTTPMiddleware):
|
||||
"""响应缓存中间件
|
||||
|
||||
对符合条件的 GET 请求进行响应缓存
|
||||
"""
|
||||
|
||||
# 需要缓存的路径前缀和 TTL 配置
|
||||
CACHE_CONFIG = {
|
||||
"/api/v1/categories": {"ttl": 300, "namespace": CacheNamespace.CATEGORIES},
|
||||
"/api/v1/materials": {"ttl": 300, "namespace": CacheNamespace.MATERIALS},
|
||||
"/api/v1/equipments": {"ttl": 300, "namespace": CacheNamespace.EQUIPMENTS},
|
||||
"/api/v1/staff-levels": {"ttl": 300, "namespace": CacheNamespace.STAFF_LEVELS},
|
||||
}
|
||||
|
||||
# 不缓存的路径(精确匹配)
|
||||
NO_CACHE_PATHS: Set[str] = {
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
}
|
||||
|
||||
def _should_cache(self, request: Request) -> Optional[dict]:
|
||||
"""判断是否应该缓存"""
|
||||
# 只缓存 GET 请求
|
||||
if request.method != "GET":
|
||||
return None
|
||||
|
||||
path = request.url.path
|
||||
|
||||
# 排除不缓存的路径
|
||||
if path in self.NO_CACHE_PATHS:
|
||||
return None
|
||||
|
||||
# 检查是否在缓存配置中
|
||||
for prefix, config in self.CACHE_CONFIG.items():
|
||||
if path.startswith(prefix):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
def _generate_cache_key(self, request: Request) -> str:
|
||||
"""生成缓存键"""
|
||||
# 包含路径和查询参数
|
||||
key_parts = [
|
||||
request.method,
|
||||
request.url.path,
|
||||
str(sorted(request.query_params.items())),
|
||||
]
|
||||
key_str = "|".join(key_parts)
|
||||
return hashlib.md5(key_str.encode()).hexdigest()
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
cache_config = self._should_cache(request)
|
||||
|
||||
if not cache_config:
|
||||
return await call_next(request)
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = self._generate_cache_key(request)
|
||||
cache = get_cache(cache_config["namespace"])
|
||||
|
||||
# 尝试从缓存获取
|
||||
cached_data = cache.get(cache_key)
|
||||
if cached_data is not None:
|
||||
logger.debug(f"Cache hit: {request.url.path}")
|
||||
response = Response(
|
||||
content=cached_data["content"],
|
||||
status_code=cached_data["status_code"],
|
||||
headers=dict(cached_data["headers"]),
|
||||
media_type="application/json",
|
||||
)
|
||||
response.headers["X-Cache"] = "HIT"
|
||||
return response
|
||||
|
||||
# 执行请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 只缓存成功的响应
|
||||
if response.status_code == 200:
|
||||
# 读取响应体
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
|
||||
# 保存到缓存
|
||||
cache_data = {
|
||||
"content": body,
|
||||
"status_code": response.status_code,
|
||||
"headers": dict(response.headers),
|
||||
}
|
||||
cache.set(cache_key, cache_data, cache_config["ttl"])
|
||||
|
||||
# 重新构建响应
|
||||
response = Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type="application/json",
|
||||
)
|
||||
response.headers["X-Cache"] = "MISS"
|
||||
|
||||
return response
|
||||
50
后端服务/app/middleware/performance.py
Normal file
50
后端服务/app/middleware/performance.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""性能监控中间件
|
||||
|
||||
记录请求响应时间,用于性能分析和优化
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerformanceMiddleware(BaseHTTPMiddleware):
|
||||
"""性能监控中间件
|
||||
|
||||
记录每个请求的响应时间,并在响应头中添加 X-Response-Time
|
||||
"""
|
||||
|
||||
# 慢请求阈值(毫秒)
|
||||
SLOW_REQUEST_THRESHOLD = 1000
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
start_time = time.time()
|
||||
|
||||
# 执行请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 计算响应时间
|
||||
process_time = (time.time() - start_time) * 1000
|
||||
|
||||
# 添加响应头
|
||||
response.headers["X-Response-Time"] = f"{process_time:.2f}ms"
|
||||
|
||||
# 记录慢请求
|
||||
if process_time > self.SLOW_REQUEST_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Slow request: {request.method} {request.url.path} "
|
||||
f"took {process_time:.2f}ms"
|
||||
)
|
||||
|
||||
# 记录请求日志(开发环境)
|
||||
logger.debug(
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"{response.status_code} - {process_time:.2f}ms"
|
||||
)
|
||||
|
||||
return response
|
||||
267
后端服务/app/middleware/security.py
Normal file
267
后端服务/app/middleware/security.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""安全中间件
|
||||
|
||||
实现安全相关功能:
|
||||
- 请求验证
|
||||
- 速率限制
|
||||
- 安全头设置
|
||||
- 敏感数据保护
|
||||
|
||||
遵循瑞小美系统技术栈标准
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Optional
|
||||
import re
|
||||
|
||||
from fastapi import Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""速率限制中间件
|
||||
|
||||
防止 API 滥用,保护服务稳定性
|
||||
"""
|
||||
|
||||
# 速率限制配置
|
||||
RATE_LIMITS = {
|
||||
"default": {"requests": 100, "window": 60}, # 默认:100 次/分钟
|
||||
"/api/v1/projects/*/generate-pricing": {"requests": 10, "window": 60}, # AI 接口:10 次/分钟
|
||||
"/api/v1/projects/*/market-analysis": {"requests": 20, "window": 60}, # 分析接口:20 次/分钟
|
||||
}
|
||||
|
||||
def __init__(self, app, enabled: bool = True):
|
||||
super().__init__(app)
|
||||
self.enabled = enabled
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""获取客户端标识"""
|
||||
# 优先使用 X-Forwarded-For(反向代理场景)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
# 使用客户端 IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _get_rate_limit(self, path: str) -> Dict:
|
||||
"""获取路径的速率限制配置"""
|
||||
for pattern, limit in self.RATE_LIMITS.items():
|
||||
if pattern == "default":
|
||||
continue
|
||||
# 简单的路径匹配(* 匹配任意字符)
|
||||
regex = pattern.replace("*", "[^/]+")
|
||||
if re.match(regex, path):
|
||||
return limit
|
||||
|
||||
return self.RATE_LIMITS["default"]
|
||||
|
||||
def _is_rate_limited(self, client_id: str, path: str) -> bool:
|
||||
"""检查是否超过速率限制"""
|
||||
limit_config = self._get_rate_limit(path)
|
||||
requests = limit_config["requests"]
|
||||
window = limit_config["window"]
|
||||
|
||||
now = time.time()
|
||||
key = f"{client_id}:{path}"
|
||||
|
||||
# 清理过期记录
|
||||
self._requests[key] = [t for t in self._requests[key] if now - t < window]
|
||||
|
||||
# 检查是否超限
|
||||
if len(self._requests[key]) >= requests:
|
||||
return True
|
||||
|
||||
# 记录请求
|
||||
self._requests[key].append(now)
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
client_id = self._get_client_id(request)
|
||||
path = request.url.path
|
||||
|
||||
if self._is_rate_limited(client_id, path):
|
||||
logger.warning(f"Rate limit exceeded: {client_id} -> {path}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"code": 40001,
|
||||
"message": "请求过于频繁,请稍后再试",
|
||||
"data": None
|
||||
}
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""安全响应头中间件
|
||||
|
||||
添加安全相关的 HTTP 响应头
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# 防止 XSS 攻击
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
|
||||
# 防止 MIME 类型嗅探
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
# 点击劫持保护
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
|
||||
# 内容安全策略(基础版)
|
||||
response.headers["Content-Security-Policy"] = "default-src 'self'"
|
||||
|
||||
# 严格传输安全(仅在 HTTPS 环境)
|
||||
# response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class InputSanitizer:
|
||||
"""输入清理工具
|
||||
|
||||
防止 SQL 注入、XSS 等攻击
|
||||
"""
|
||||
|
||||
# SQL 注入关键字
|
||||
SQL_KEYWORDS = [
|
||||
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "UNION",
|
||||
"OR", "AND", "--", "/*", "*/", "EXEC", "EXECUTE"
|
||||
]
|
||||
|
||||
# XSS 危险模式
|
||||
XSS_PATTERNS = [
|
||||
r"<script.*?>",
|
||||
r"javascript:",
|
||||
r"on\w+\s*=",
|
||||
r"eval\s*\(",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def check_sql_injection(cls, value: str) -> bool:
|
||||
"""检查是否包含 SQL 注入特征"""
|
||||
upper_value = value.upper()
|
||||
for keyword in cls.SQL_KEYWORDS:
|
||||
if keyword in upper_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_xss(cls, value: str) -> bool:
|
||||
"""检查是否包含 XSS 特征"""
|
||||
for pattern in cls.XSS_PATTERNS:
|
||||
if re.search(pattern, value, re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def sanitize(cls, value: str) -> str:
|
||||
"""清理输入值
|
||||
|
||||
移除潜在危险字符
|
||||
"""
|
||||
# 移除 HTML 标签
|
||||
value = re.sub(r'<[^>]+>', '', value)
|
||||
|
||||
# 转义特殊字符
|
||||
value = value.replace("&", "&")
|
||||
value = value.replace("<", "<")
|
||||
value = value.replace(">", ">")
|
||||
value = value.replace('"', """)
|
||||
value = value.replace("'", "'")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_request_body(data: dict, max_depth: int = 10) -> None:
|
||||
"""验证请求体
|
||||
|
||||
检查数据深度和内容安全
|
||||
|
||||
Args:
|
||||
data: 请求数据
|
||||
max_depth: 最大嵌套深度
|
||||
|
||||
Raises:
|
||||
HTTPException: 验证失败
|
||||
"""
|
||||
def check_depth(obj, depth=0):
|
||||
if depth > max_depth:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": 10001, "message": "请求数据嵌套层级过深"}
|
||||
)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for value in obj.values():
|
||||
check_depth(value, depth + 1)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
check_depth(item, depth + 1)
|
||||
elif isinstance(obj, str):
|
||||
if InputSanitizer.check_sql_injection(obj):
|
||||
logger.warning(f"Potential SQL injection detected: {obj[:100]}")
|
||||
if InputSanitizer.check_xss(obj):
|
||||
logger.warning(f"Potential XSS detected: {obj[:100]}")
|
||||
|
||||
check_depth(data)
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""审计日志记录器
|
||||
|
||||
记录敏感操作的审计日志
|
||||
"""
|
||||
|
||||
# 需要审计的操作
|
||||
AUDIT_OPERATIONS = {
|
||||
("POST", "/api/v1/pricing-plans"): "创建定价方案",
|
||||
("PUT", "/api/v1/pricing-plans/*"): "更新定价方案",
|
||||
("DELETE", "/api/v1/pricing-plans/*"): "删除定价方案",
|
||||
("POST", "/api/v1/projects/*/generate-pricing"): "生成 AI 定价建议",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def should_audit(cls, method: str, path: str) -> Optional[str]:
|
||||
"""检查是否需要审计"""
|
||||
for (m, p), desc in cls.AUDIT_OPERATIONS.items():
|
||||
if m != method:
|
||||
continue
|
||||
regex = p.replace("*", "[^/]+")
|
||||
if re.match(regex, path):
|
||||
return desc
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def log(
|
||||
cls,
|
||||
operation: str,
|
||||
user_id: Optional[int],
|
||||
request: Request,
|
||||
response_code: int,
|
||||
details: Optional[Dict] = None
|
||||
):
|
||||
"""记录审计日志"""
|
||||
log_data = {
|
||||
"operation": operation,
|
||||
"user_id": user_id,
|
||||
"method": request.method,
|
||||
"path": str(request.url.path),
|
||||
"client_ip": request.client.host if request.client else "unknown",
|
||||
"response_code": response_code,
|
||||
"details": details or {},
|
||||
}
|
||||
logger.info(f"AUDIT: {log_data}")
|
||||
Reference in New Issue
Block a user