122 lines
3.7 KiB
Python
122 lines
3.7 KiB
Python
"""响应缓存中间件
|
|
|
|
对 GET 请求的响应进行缓存,减少数据库查询
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from typing import Callable, Optional, Set
|
|
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from app.services.cache_service import get_cache, CacheNamespace
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ResponseCacheMiddleware(BaseHTTPMiddleware):
|
|
"""响应缓存中间件
|
|
|
|
对符合条件的 GET 请求进行响应缓存
|
|
"""
|
|
|
|
# 需要缓存的路径前缀和 TTL 配置
|
|
CACHE_CONFIG = {
|
|
"/api/v1/categories": {"ttl": 300, "namespace": CacheNamespace.CATEGORIES},
|
|
"/api/v1/materials": {"ttl": 300, "namespace": CacheNamespace.MATERIALS},
|
|
"/api/v1/equipments": {"ttl": 300, "namespace": CacheNamespace.EQUIPMENTS},
|
|
"/api/v1/staff-levels": {"ttl": 300, "namespace": CacheNamespace.STAFF_LEVELS},
|
|
}
|
|
|
|
# 不缓存的路径(精确匹配)
|
|
NO_CACHE_PATHS: Set[str] = {
|
|
"/health",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
}
|
|
|
|
def _should_cache(self, request: Request) -> Optional[dict]:
|
|
"""判断是否应该缓存"""
|
|
# 只缓存 GET 请求
|
|
if request.method != "GET":
|
|
return None
|
|
|
|
path = request.url.path
|
|
|
|
# 排除不缓存的路径
|
|
if path in self.NO_CACHE_PATHS:
|
|
return None
|
|
|
|
# 检查是否在缓存配置中
|
|
for prefix, config in self.CACHE_CONFIG.items():
|
|
if path.startswith(prefix):
|
|
return config
|
|
|
|
return None
|
|
|
|
def _generate_cache_key(self, request: Request) -> str:
|
|
"""生成缓存键"""
|
|
# 包含路径和查询参数
|
|
key_parts = [
|
|
request.method,
|
|
request.url.path,
|
|
str(sorted(request.query_params.items())),
|
|
]
|
|
key_str = "|".join(key_parts)
|
|
return hashlib.md5(key_str.encode()).hexdigest()
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
cache_config = self._should_cache(request)
|
|
|
|
if not cache_config:
|
|
return await call_next(request)
|
|
|
|
# 生成缓存键
|
|
cache_key = self._generate_cache_key(request)
|
|
cache = get_cache(cache_config["namespace"])
|
|
|
|
# 尝试从缓存获取
|
|
cached_data = cache.get(cache_key)
|
|
if cached_data is not None:
|
|
logger.debug(f"Cache hit: {request.url.path}")
|
|
response = Response(
|
|
content=cached_data["content"],
|
|
status_code=cached_data["status_code"],
|
|
headers=dict(cached_data["headers"]),
|
|
media_type="application/json",
|
|
)
|
|
response.headers["X-Cache"] = "HIT"
|
|
return response
|
|
|
|
# 执行请求
|
|
response = await call_next(request)
|
|
|
|
# 只缓存成功的响应
|
|
if response.status_code == 200:
|
|
# 读取响应体
|
|
body = b""
|
|
async for chunk in response.body_iterator:
|
|
body += chunk
|
|
|
|
# 保存到缓存
|
|
cache_data = {
|
|
"content": body,
|
|
"status_code": response.status_code,
|
|
"headers": dict(response.headers),
|
|
}
|
|
cache.set(cache_key, cache_data, cache_config["ttl"])
|
|
|
|
# 重新构建响应
|
|
response = Response(
|
|
content=body,
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
media_type="application/json",
|
|
)
|
|
response.headers["X-Cache"] = "MISS"
|
|
|
|
return response
|