"""响应缓存中间件 对 GET 请求的响应进行缓存,减少数据库查询 """ import hashlib import json import logging from typing import Callable, Optional, Set from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from app.services.cache_service import get_cache, CacheNamespace logger = logging.getLogger(__name__) class ResponseCacheMiddleware(BaseHTTPMiddleware): """响应缓存中间件 对符合条件的 GET 请求进行响应缓存 """ # 需要缓存的路径前缀和 TTL 配置 CACHE_CONFIG = { "/api/v1/categories": {"ttl": 300, "namespace": CacheNamespace.CATEGORIES}, "/api/v1/materials": {"ttl": 300, "namespace": CacheNamespace.MATERIALS}, "/api/v1/equipments": {"ttl": 300, "namespace": CacheNamespace.EQUIPMENTS}, "/api/v1/staff-levels": {"ttl": 300, "namespace": CacheNamespace.STAFF_LEVELS}, } # 不缓存的路径(精确匹配) NO_CACHE_PATHS: Set[str] = { "/health", "/docs", "/redoc", "/openapi.json", } def _should_cache(self, request: Request) -> Optional[dict]: """判断是否应该缓存""" # 只缓存 GET 请求 if request.method != "GET": return None path = request.url.path # 排除不缓存的路径 if path in self.NO_CACHE_PATHS: return None # 检查是否在缓存配置中 for prefix, config in self.CACHE_CONFIG.items(): if path.startswith(prefix): return config return None def _generate_cache_key(self, request: Request) -> str: """生成缓存键""" # 包含路径和查询参数 key_parts = [ request.method, request.url.path, str(sorted(request.query_params.items())), ] key_str = "|".join(key_parts) return hashlib.md5(key_str.encode()).hexdigest() async def dispatch(self, request: Request, call_next: Callable) -> Response: cache_config = self._should_cache(request) if not cache_config: return await call_next(request) # 生成缓存键 cache_key = self._generate_cache_key(request) cache = get_cache(cache_config["namespace"]) # 尝试从缓存获取 cached_data = cache.get(cache_key) if cached_data is not None: logger.debug(f"Cache hit: {request.url.path}") response = Response( content=cached_data["content"], status_code=cached_data["status_code"], headers=dict(cached_data["headers"]), media_type="application/json", ) response.headers["X-Cache"] = "HIT" return response # 执行请求 response = await call_next(request) # 只缓存成功的响应 if response.status_code == 200: # 读取响应体 body = b"" async for chunk in response.body_iterator: body += chunk # 保存到缓存 cache_data = { "content": body, "status_code": response.status_code, "headers": dict(response.headers), } cache.set(cache_key, cache_data, cache_config["ttl"]) # 重新构建响应 response = Response( content=body, status_code=response.status_code, headers=dict(response.headers), media_type="application/json", ) response.headers["X-Cache"] = "MISS" return response