Initial commit: 智能项目定价模型
This commit is contained in:
121
后端服务/app/middleware/cache.py
Normal file
121
后端服务/app/middleware/cache.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""响应缓存中间件
|
||||
|
||||
对 GET 请求的响应进行缓存,减少数据库查询
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Callable, Optional, Set
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.services.cache_service import get_cache, CacheNamespace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseCacheMiddleware(BaseHTTPMiddleware):
|
||||
"""响应缓存中间件
|
||||
|
||||
对符合条件的 GET 请求进行响应缓存
|
||||
"""
|
||||
|
||||
# 需要缓存的路径前缀和 TTL 配置
|
||||
CACHE_CONFIG = {
|
||||
"/api/v1/categories": {"ttl": 300, "namespace": CacheNamespace.CATEGORIES},
|
||||
"/api/v1/materials": {"ttl": 300, "namespace": CacheNamespace.MATERIALS},
|
||||
"/api/v1/equipments": {"ttl": 300, "namespace": CacheNamespace.EQUIPMENTS},
|
||||
"/api/v1/staff-levels": {"ttl": 300, "namespace": CacheNamespace.STAFF_LEVELS},
|
||||
}
|
||||
|
||||
# 不缓存的路径(精确匹配)
|
||||
NO_CACHE_PATHS: Set[str] = {
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
}
|
||||
|
||||
def _should_cache(self, request: Request) -> Optional[dict]:
|
||||
"""判断是否应该缓存"""
|
||||
# 只缓存 GET 请求
|
||||
if request.method != "GET":
|
||||
return None
|
||||
|
||||
path = request.url.path
|
||||
|
||||
# 排除不缓存的路径
|
||||
if path in self.NO_CACHE_PATHS:
|
||||
return None
|
||||
|
||||
# 检查是否在缓存配置中
|
||||
for prefix, config in self.CACHE_CONFIG.items():
|
||||
if path.startswith(prefix):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
def _generate_cache_key(self, request: Request) -> str:
|
||||
"""生成缓存键"""
|
||||
# 包含路径和查询参数
|
||||
key_parts = [
|
||||
request.method,
|
||||
request.url.path,
|
||||
str(sorted(request.query_params.items())),
|
||||
]
|
||||
key_str = "|".join(key_parts)
|
||||
return hashlib.md5(key_str.encode()).hexdigest()
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
cache_config = self._should_cache(request)
|
||||
|
||||
if not cache_config:
|
||||
return await call_next(request)
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = self._generate_cache_key(request)
|
||||
cache = get_cache(cache_config["namespace"])
|
||||
|
||||
# 尝试从缓存获取
|
||||
cached_data = cache.get(cache_key)
|
||||
if cached_data is not None:
|
||||
logger.debug(f"Cache hit: {request.url.path}")
|
||||
response = Response(
|
||||
content=cached_data["content"],
|
||||
status_code=cached_data["status_code"],
|
||||
headers=dict(cached_data["headers"]),
|
||||
media_type="application/json",
|
||||
)
|
||||
response.headers["X-Cache"] = "HIT"
|
||||
return response
|
||||
|
||||
# 执行请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 只缓存成功的响应
|
||||
if response.status_code == 200:
|
||||
# 读取响应体
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
|
||||
# 保存到缓存
|
||||
cache_data = {
|
||||
"content": body,
|
||||
"status_code": response.status_code,
|
||||
"headers": dict(response.headers),
|
||||
}
|
||||
cache.set(cache_key, cache_data, cache_config["ttl"])
|
||||
|
||||
# 重新构建响应
|
||||
response = Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type="application/json",
|
||||
)
|
||||
response.headers["X-Cache"] = "MISS"
|
||||
|
||||
return response
|
||||
Reference in New Issue
Block a user