234 lines
6.5 KiB
Python
234 lines
6.5 KiB
Python
"""缓存服务
|
||
|
||
实现简单的内存缓存和 LRU 缓存策略
|
||
用于优化频繁查询的数据
|
||
|
||
遵循瑞小美系统技术栈标准
|
||
"""
|
||
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
from datetime import datetime, timedelta
|
||
from functools import lru_cache, wraps
|
||
from typing import Any, Callable, Optional, Dict
|
||
from collections import OrderedDict
|
||
import threading
|
||
|
||
|
||
class TTLCache:
|
||
"""带过期时间的缓存
|
||
|
||
线程安全的 TTL 缓存实现
|
||
"""
|
||
|
||
def __init__(self, maxsize: int = 1000, ttl: int = 300):
|
||
"""
|
||
Args:
|
||
maxsize: 最大缓存条目数
|
||
ttl: 默认过期时间(秒)
|
||
"""
|
||
self.maxsize = maxsize
|
||
self.ttl = ttl
|
||
self._cache: OrderedDict = OrderedDict()
|
||
self._lock = threading.Lock()
|
||
|
||
def _is_expired(self, expire_at: datetime) -> bool:
|
||
"""检查是否过期"""
|
||
return datetime.now() > expire_at
|
||
|
||
def get(self, key: str) -> Optional[Any]:
|
||
"""获取缓存值"""
|
||
with self._lock:
|
||
if key not in self._cache:
|
||
return None
|
||
|
||
value, expire_at = self._cache[key]
|
||
|
||
if self._is_expired(expire_at):
|
||
del self._cache[key]
|
||
return None
|
||
|
||
# 移动到末尾(LRU)
|
||
self._cache.move_to_end(key)
|
||
return value
|
||
|
||
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
||
"""设置缓存值"""
|
||
with self._lock:
|
||
expire_at = datetime.now() + timedelta(seconds=ttl or self.ttl)
|
||
|
||
if key in self._cache:
|
||
self._cache.move_to_end(key)
|
||
|
||
self._cache[key] = (value, expire_at)
|
||
|
||
# 超过最大容量时删除最旧的
|
||
while len(self._cache) > self.maxsize:
|
||
self._cache.popitem(last=False)
|
||
|
||
def delete(self, key: str) -> bool:
|
||
"""删除缓存"""
|
||
with self._lock:
|
||
if key in self._cache:
|
||
del self._cache[key]
|
||
return True
|
||
return False
|
||
|
||
def clear(self) -> None:
|
||
"""清空缓存"""
|
||
with self._lock:
|
||
self._cache.clear()
|
||
|
||
def cleanup(self) -> int:
|
||
"""清理过期缓存,返回清理数量"""
|
||
with self._lock:
|
||
expired_keys = [
|
||
key for key, (_, expire_at) in self._cache.items()
|
||
if self._is_expired(expire_at)
|
||
]
|
||
for key in expired_keys:
|
||
del self._cache[key]
|
||
return len(expired_keys)
|
||
|
||
def stats(self) -> Dict[str, Any]:
|
||
"""获取缓存统计"""
|
||
with self._lock:
|
||
return {
|
||
"size": len(self._cache),
|
||
"maxsize": self.maxsize,
|
||
"ttl": self.ttl,
|
||
}
|
||
|
||
|
||
# 全局缓存实例
|
||
_cache_instances: Dict[str, TTLCache] = {}
|
||
|
||
|
||
def get_cache(namespace: str = "default", maxsize: int = 1000, ttl: int = 300) -> TTLCache:
|
||
"""获取缓存实例
|
||
|
||
Args:
|
||
namespace: 缓存命名空间
|
||
maxsize: 最大缓存条目数
|
||
ttl: 默认过期时间(秒)
|
||
|
||
Returns:
|
||
缓存实例
|
||
"""
|
||
if namespace not in _cache_instances:
|
||
_cache_instances[namespace] = TTLCache(maxsize=maxsize, ttl=ttl)
|
||
return _cache_instances[namespace]
|
||
|
||
|
||
def cache_key(*args, **kwargs) -> str:
|
||
"""生成缓存键"""
|
||
key_parts = [str(arg) for arg in args]
|
||
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
|
||
key_str = "|".join(key_parts)
|
||
return hashlib.md5(key_str.encode()).hexdigest()
|
||
|
||
|
||
def cached(
|
||
namespace: str = "default",
|
||
ttl: int = 300,
|
||
key_prefix: str = "",
|
||
):
|
||
"""缓存装饰器
|
||
|
||
Args:
|
||
namespace: 缓存命名空间
|
||
ttl: 过期时间(秒)
|
||
key_prefix: 键前缀
|
||
|
||
Example:
|
||
@cached(namespace="projects", ttl=60, key_prefix="project_detail")
|
||
async def get_project(project_id: int):
|
||
...
|
||
"""
|
||
def decorator(func: Callable):
|
||
@wraps(func)
|
||
async def async_wrapper(*args, **kwargs):
|
||
cache = get_cache(namespace, ttl=ttl)
|
||
|
||
# 生成缓存键
|
||
key = f"{key_prefix}:{cache_key(*args, **kwargs)}"
|
||
|
||
# 尝试获取缓存
|
||
cached_value = cache.get(key)
|
||
if cached_value is not None:
|
||
return cached_value
|
||
|
||
# 执行函数
|
||
result = await func(*args, **kwargs)
|
||
|
||
# 设置缓存
|
||
if result is not None:
|
||
cache.set(key, result, ttl)
|
||
|
||
return result
|
||
|
||
@wraps(func)
|
||
def sync_wrapper(*args, **kwargs):
|
||
cache = get_cache(namespace, ttl=ttl)
|
||
key = f"{key_prefix}:{cache_key(*args, **kwargs)}"
|
||
|
||
cached_value = cache.get(key)
|
||
if cached_value is not None:
|
||
return cached_value
|
||
|
||
result = func(*args, **kwargs)
|
||
|
||
if result is not None:
|
||
cache.set(key, result, ttl)
|
||
|
||
return result
|
||
|
||
if asyncio.iscoroutinefunction(func):
|
||
return async_wrapper
|
||
return sync_wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
def invalidate_cache(namespace: str, key_prefix: str = "") -> None:
|
||
"""使缓存失效
|
||
|
||
注意:这会清除整个命名空间的缓存
|
||
"""
|
||
cache = get_cache(namespace)
|
||
cache.clear()
|
||
|
||
|
||
# 预定义缓存命名空间
|
||
class CacheNamespace:
|
||
"""缓存命名空间常量"""
|
||
CATEGORIES = "categories"
|
||
MATERIALS = "materials"
|
||
EQUIPMENTS = "equipments"
|
||
STAFF_LEVELS = "staff_levels"
|
||
PROJECTS = "projects"
|
||
MARKET_ANALYSIS = "market_analysis"
|
||
AI_RESPONSES = "ai_responses"
|
||
|
||
|
||
# 预初始化常用缓存
|
||
def init_caches():
|
||
"""初始化缓存实例"""
|
||
# 基础数据缓存(较长 TTL)
|
||
get_cache(CacheNamespace.CATEGORIES, maxsize=100, ttl=600)
|
||
get_cache(CacheNamespace.MATERIALS, maxsize=500, ttl=600)
|
||
get_cache(CacheNamespace.EQUIPMENTS, maxsize=200, ttl=600)
|
||
get_cache(CacheNamespace.STAFF_LEVELS, maxsize=50, ttl=600)
|
||
|
||
# 业务数据缓存(较短 TTL)
|
||
get_cache(CacheNamespace.PROJECTS, maxsize=500, ttl=300)
|
||
get_cache(CacheNamespace.MARKET_ANALYSIS, maxsize=200, ttl=180)
|
||
|
||
# AI 响应缓存(较长 TTL,减少 API 调用)
|
||
get_cache(CacheNamespace.AI_RESPONSES, maxsize=100, ttl=3600)
|
||
|
||
|
||
# 应用启动时初始化
|
||
init_caches()
|