Initial commit: 智能项目定价模型

This commit is contained in:
kuzma
2026-01-31 21:33:06 +08:00
commit ef0824303f
174 changed files with 31705 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
"""中间件模块"""
from .performance import PerformanceMiddleware
from .cache import ResponseCacheMiddleware
from .security import (
RateLimitMiddleware,
SecurityHeadersMiddleware,
InputSanitizer,
validate_request_body,
AuditLogger,
)
__all__ = [
"PerformanceMiddleware",
"ResponseCacheMiddleware",
"RateLimitMiddleware",
"SecurityHeadersMiddleware",
"InputSanitizer",
"validate_request_body",
"AuditLogger",
]

View File

@@ -0,0 +1,121 @@
"""响应缓存中间件
对 GET 请求的响应进行缓存,减少数据库查询
"""
import hashlib
import json
import logging
from typing import Callable, Optional, Set
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.services.cache_service import get_cache, CacheNamespace
logger = logging.getLogger(__name__)
class ResponseCacheMiddleware(BaseHTTPMiddleware):
"""响应缓存中间件
对符合条件的 GET 请求进行响应缓存
"""
# 需要缓存的路径前缀和 TTL 配置
CACHE_CONFIG = {
"/api/v1/categories": {"ttl": 300, "namespace": CacheNamespace.CATEGORIES},
"/api/v1/materials": {"ttl": 300, "namespace": CacheNamespace.MATERIALS},
"/api/v1/equipments": {"ttl": 300, "namespace": CacheNamespace.EQUIPMENTS},
"/api/v1/staff-levels": {"ttl": 300, "namespace": CacheNamespace.STAFF_LEVELS},
}
# 不缓存的路径(精确匹配)
NO_CACHE_PATHS: Set[str] = {
"/health",
"/docs",
"/redoc",
"/openapi.json",
}
def _should_cache(self, request: Request) -> Optional[dict]:
"""判断是否应该缓存"""
# 只缓存 GET 请求
if request.method != "GET":
return None
path = request.url.path
# 排除不缓存的路径
if path in self.NO_CACHE_PATHS:
return None
# 检查是否在缓存配置中
for prefix, config in self.CACHE_CONFIG.items():
if path.startswith(prefix):
return config
return None
def _generate_cache_key(self, request: Request) -> str:
"""生成缓存键"""
# 包含路径和查询参数
key_parts = [
request.method,
request.url.path,
str(sorted(request.query_params.items())),
]
key_str = "|".join(key_parts)
return hashlib.md5(key_str.encode()).hexdigest()
async def dispatch(self, request: Request, call_next: Callable) -> Response:
cache_config = self._should_cache(request)
if not cache_config:
return await call_next(request)
# 生成缓存键
cache_key = self._generate_cache_key(request)
cache = get_cache(cache_config["namespace"])
# 尝试从缓存获取
cached_data = cache.get(cache_key)
if cached_data is not None:
logger.debug(f"Cache hit: {request.url.path}")
response = Response(
content=cached_data["content"],
status_code=cached_data["status_code"],
headers=dict(cached_data["headers"]),
media_type="application/json",
)
response.headers["X-Cache"] = "HIT"
return response
# 执行请求
response = await call_next(request)
# 只缓存成功的响应
if response.status_code == 200:
# 读取响应体
body = b""
async for chunk in response.body_iterator:
body += chunk
# 保存到缓存
cache_data = {
"content": body,
"status_code": response.status_code,
"headers": dict(response.headers),
}
cache.set(cache_key, cache_data, cache_config["ttl"])
# 重新构建响应
response = Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type="application/json",
)
response.headers["X-Cache"] = "MISS"
return response

View File

@@ -0,0 +1,50 @@
"""性能监控中间件
记录请求响应时间,用于性能分析和优化
"""
import time
import logging
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class PerformanceMiddleware(BaseHTTPMiddleware):
"""性能监控中间件
记录每个请求的响应时间,并在响应头中添加 X-Response-Time
"""
# 慢请求阈值(毫秒)
SLOW_REQUEST_THRESHOLD = 1000
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
# 执行请求
response = await call_next(request)
# 计算响应时间
process_time = (time.time() - start_time) * 1000
# 添加响应头
response.headers["X-Response-Time"] = f"{process_time:.2f}ms"
# 记录慢请求
if process_time > self.SLOW_REQUEST_THRESHOLD:
logger.warning(
f"Slow request: {request.method} {request.url.path} "
f"took {process_time:.2f}ms"
)
# 记录请求日志(开发环境)
logger.debug(
f"{request.method} {request.url.path} - "
f"{response.status_code} - {process_time:.2f}ms"
)
return response

View File

@@ -0,0 +1,267 @@
"""安全中间件
实现安全相关功能:
- 请求验证
- 速率限制
- 安全头设置
- 敏感数据保护
遵循瑞小美系统技术栈标准
"""
import logging
import time
from collections import defaultdict
from typing import Callable, Dict, Optional
import re
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""速率限制中间件
防止 API 滥用,保护服务稳定性
"""
# 速率限制配置
RATE_LIMITS = {
"default": {"requests": 100, "window": 60}, # 默认100 次/分钟
"/api/v1/projects/*/generate-pricing": {"requests": 10, "window": 60}, # AI 接口10 次/分钟
"/api/v1/projects/*/market-analysis": {"requests": 20, "window": 60}, # 分析接口20 次/分钟
}
def __init__(self, app, enabled: bool = True):
super().__init__(app)
self.enabled = enabled
self._requests: Dict[str, list] = defaultdict(list)
def _get_client_id(self, request: Request) -> str:
"""获取客户端标识"""
# 优先使用 X-Forwarded-For反向代理场景
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
# 使用客户端 IP
return request.client.host if request.client else "unknown"
def _get_rate_limit(self, path: str) -> Dict:
"""获取路径的速率限制配置"""
for pattern, limit in self.RATE_LIMITS.items():
if pattern == "default":
continue
# 简单的路径匹配(* 匹配任意字符)
regex = pattern.replace("*", "[^/]+")
if re.match(regex, path):
return limit
return self.RATE_LIMITS["default"]
def _is_rate_limited(self, client_id: str, path: str) -> bool:
"""检查是否超过速率限制"""
limit_config = self._get_rate_limit(path)
requests = limit_config["requests"]
window = limit_config["window"]
now = time.time()
key = f"{client_id}:{path}"
# 清理过期记录
self._requests[key] = [t for t in self._requests[key] if now - t < window]
# 检查是否超限
if len(self._requests[key]) >= requests:
return True
# 记录请求
self._requests[key].append(now)
return False
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if not self.enabled:
return await call_next(request)
client_id = self._get_client_id(request)
path = request.url.path
if self._is_rate_limited(client_id, path):
logger.warning(f"Rate limit exceeded: {client_id} -> {path}")
return JSONResponse(
status_code=429,
content={
"code": 40001,
"message": "请求过于频繁,请稍后再试",
"data": None
}
)
return await call_next(request)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""安全响应头中间件
添加安全相关的 HTTP 响应头
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
response = await call_next(request)
# 防止 XSS 攻击
response.headers["X-XSS-Protection"] = "1; mode=block"
# 防止 MIME 类型嗅探
response.headers["X-Content-Type-Options"] = "nosniff"
# 点击劫持保护
response.headers["X-Frame-Options"] = "DENY"
# 内容安全策略(基础版)
response.headers["Content-Security-Policy"] = "default-src 'self'"
# 严格传输安全(仅在 HTTPS 环境)
# response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
class InputSanitizer:
"""输入清理工具
防止 SQL 注入、XSS 等攻击
"""
# SQL 注入关键字
SQL_KEYWORDS = [
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "UNION",
"OR", "AND", "--", "/*", "*/", "EXEC", "EXECUTE"
]
# XSS 危险模式
XSS_PATTERNS = [
r"<script.*?>",
r"javascript:",
r"on\w+\s*=",
r"eval\s*\(",
]
@classmethod
def check_sql_injection(cls, value: str) -> bool:
"""检查是否包含 SQL 注入特征"""
upper_value = value.upper()
for keyword in cls.SQL_KEYWORDS:
if keyword in upper_value:
return True
return False
@classmethod
def check_xss(cls, value: str) -> bool:
"""检查是否包含 XSS 特征"""
for pattern in cls.XSS_PATTERNS:
if re.search(pattern, value, re.IGNORECASE):
return True
return False
@classmethod
def sanitize(cls, value: str) -> str:
"""清理输入值
移除潜在危险字符
"""
# 移除 HTML 标签
value = re.sub(r'<[^>]+>', '', value)
# 转义特殊字符
value = value.replace("&", "&amp;")
value = value.replace("<", "&lt;")
value = value.replace(">", "&gt;")
value = value.replace('"', "&quot;")
value = value.replace("'", "&#x27;")
return value
def validate_request_body(data: dict, max_depth: int = 10) -> None:
"""验证请求体
检查数据深度和内容安全
Args:
data: 请求数据
max_depth: 最大嵌套深度
Raises:
HTTPException: 验证失败
"""
def check_depth(obj, depth=0):
if depth > max_depth:
raise HTTPException(
status_code=400,
detail={"code": 10001, "message": "请求数据嵌套层级过深"}
)
if isinstance(obj, dict):
for value in obj.values():
check_depth(value, depth + 1)
elif isinstance(obj, list):
for item in obj:
check_depth(item, depth + 1)
elif isinstance(obj, str):
if InputSanitizer.check_sql_injection(obj):
logger.warning(f"Potential SQL injection detected: {obj[:100]}")
if InputSanitizer.check_xss(obj):
logger.warning(f"Potential XSS detected: {obj[:100]}")
check_depth(data)
class AuditLogger:
"""审计日志记录器
记录敏感操作的审计日志
"""
# 需要审计的操作
AUDIT_OPERATIONS = {
("POST", "/api/v1/pricing-plans"): "创建定价方案",
("PUT", "/api/v1/pricing-plans/*"): "更新定价方案",
("DELETE", "/api/v1/pricing-plans/*"): "删除定价方案",
("POST", "/api/v1/projects/*/generate-pricing"): "生成 AI 定价建议",
}
@classmethod
def should_audit(cls, method: str, path: str) -> Optional[str]:
"""检查是否需要审计"""
for (m, p), desc in cls.AUDIT_OPERATIONS.items():
if m != method:
continue
regex = p.replace("*", "[^/]+")
if re.match(regex, path):
return desc
return None
@classmethod
def log(
cls,
operation: str,
user_id: Optional[int],
request: Request,
response_code: int,
details: Optional[Dict] = None
):
"""记录审计日志"""
log_data = {
"operation": operation,
"user_id": user_id,
"method": request.method,
"path": str(request.url.path),
"client_ip": request.client.host if request.client else "unknown",
"response_code": response_code,
"details": details or {},
}
logger.info(f"AUDIT: {log_data}")