"""AI 服务封装 遵循瑞小美 AI 接入规范: - 通过 shared_backend.AIService 调用 - 初始化时传入 db_session(用于日志记录) - 调用时传入 prompt_name(用于统计) 性能优化: - 相同输入的响应缓存(减少 API 调用成本) - 缓存键基于消息内容哈希 """ import hashlib import json from typing import Optional, List, Dict, Any from dataclasses import dataclass from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.services.cache_service import get_cache, CacheNamespace class AIServiceWrapper: """AI 服务封装类 封装 shared_backend.AIService 的调用 提供统一的接口供业务层使用 支持响应缓存以减少 API 调用 """ # 默认缓存 TTL(1 小时)- AI 响应通常不会频繁变化 DEFAULT_CACHE_TTL = 3600 def __init__(self, db_session: AsyncSession, enable_cache: bool = True): """初始化 AI 服务 Args: db_session: 数据库会话,用于记录 AI 调用日志 enable_cache: 是否启用响应缓存 """ self.db_session = db_session self.module_code = settings.AI_MODULE_CODE self.enable_cache = enable_cache self._ai_service = None self._cache = get_cache(CacheNamespace.AI_RESPONSES, maxsize=100, ttl=self.DEFAULT_CACHE_TTL) def _generate_cache_key( self, messages: List[Dict[str, str]], prompt_name: str, model: Optional[str] = None, ) -> str: """生成缓存键 基于消息内容和参数生成唯一的缓存键 """ key_data = { "messages": messages, "prompt_name": prompt_name, "model": model or "default", } key_str = json.dumps(key_data, sort_keys=True, ensure_ascii=False) return f"ai:{hashlib.sha256(key_str.encode()).hexdigest()[:16]}" async def _get_service(self): """获取 AIService 实例(延迟加载)""" if self._ai_service is None: try: from shared_backend.services.ai_service import AIService self._ai_service = AIService( module_code=self.module_code, db_session=self.db_session, ) except ImportError: # 开发环境可能没有 shared_backend # 使用 Mock 实现 self._ai_service = MockAIService(self.module_code) return self._ai_service async def chat( self, messages: List[Dict[str, str]], prompt_name: str, model: Optional[str] = None, use_cache: bool = True, cache_ttl: Optional[int] = None, **kwargs, ): """调用 AI 聊天接口(带缓存支持) Args: messages: 消息列表 prompt_name: 提示词名称(必填,用于统计) model: 模型名称,默认使用配置的模型 use_cache: 是否使用缓存 cache_ttl: 缓存 TTL(秒),默认 3600 **kwargs: 其他参数 Returns: AIResponse 对象 """ # 检查缓存 if self.enable_cache and use_cache: cache_key = self._generate_cache_key(messages, prompt_name, model) cached_response = self._cache.get(cache_key) if cached_response is not None: # 返回缓存的响应(添加标记) cached_response.from_cache = True return cached_response # 调用 AI 服务 service = await self._get_service() response = await service.chat( messages=messages, prompt_name=prompt_name, model=model, **kwargs, ) # 存入缓存 if self.enable_cache and use_cache and response is not None: cache_key = self._generate_cache_key(messages, prompt_name, model) response.from_cache = False self._cache.set(cache_key, response, cache_ttl or self.DEFAULT_CACHE_TTL) return response async def chat_stream( self, messages: List[Dict[str, str]], prompt_name: str, model: Optional[str] = None, **kwargs, ): """调用 AI 聊天流式接口 注意:流式接口不使用缓存 Args: messages: 消息列表 prompt_name: 提示词名称(必填,用于统计) model: 模型名称 **kwargs: 其他参数 Yields: 响应片段 """ service = await self._get_service() async for chunk in service.chat_stream( messages=messages, prompt_name=prompt_name, model=model, **kwargs, ): yield chunk def clear_cache(self, prompt_name: Optional[str] = None): """清除 AI 响应缓存 Args: prompt_name: 指定提示词名称清除,None 则清除全部 """ # 简单实现:清除整个缓存 # 更精细的实现可以按 prompt_name 过滤 self._cache.clear() def get_cache_stats(self) -> Dict[str, Any]: """获取缓存统计信息""" return self._cache.stats() @dataclass class MockAIResponse: """Mock AI 响应""" content: str = "这是一个 Mock 响应,用于开发测试。实际部署时会使用真实的 AI 服务。" model: str = "mock-model" provider: str = "mock" input_tokens: int = 100 output_tokens: int = 50 total_tokens: int = 150 cost: float = 0.0 latency_ms: int = 100 raw_response: dict = None images: list = None annotations: dict = None from_cache: bool = False class MockAIService: """Mock AI 服务(开发环境使用)""" def __init__(self, module_code: str): self.module_code = module_code async def chat(self, messages, prompt_name, **kwargs): """Mock 聊天接口""" # 生成一个基于输入的简单响应 user_message = "" for msg in messages: if msg.get("role") == "user": user_message = msg.get("content", "")[:100] break response = MockAIResponse( content=f"""## Mock AI 分析报告 根据您提供的数据,以下是分析结果: ### 定价建议 - **推荐价格**: 根据成本和市场分析,建议定价在合理区间内 - **引流款策略**: 适合新客引流,建议价格较低 - **利润款策略**: 适合日常经营,建议价格适中 - **高端款策略**: 适合高端客群,可考虑较高定价 ### 风险提示 - 请密切关注市场动态 - 建议定期复核定价策略 *注意:这是开发测试环境的 Mock 响应,实际部署时会使用真实的 AI 服务。* """, model="mock-model", provider="mock", input_tokens=len(str(messages)), output_tokens=200, total_tokens=len(str(messages)) + 200, cost=0.0, latency_ms=50, ) return response async def chat_stream(self, messages, prompt_name, **kwargs): """Mock 流式接口""" chunks = [ "## Mock AI 分析报告\n\n", "根据您提供的数据,以下是分析结果:\n\n", "### 定价建议\n", "- 推荐价格:合理区间内\n", "- 引流款策略:适合新客引流\n", "- 利润款策略:适合日常经营\n\n", "*这是 Mock 响应,实际部署时会使用真实的 AI 服务。*", ] for chunk in chunks: yield chunk async def get_ai_service(db_session: AsyncSession, enable_cache: bool = True) -> AIServiceWrapper: """获取 AI 服务实例(依赖注入) Args: db_session: 数据库会话 enable_cache: 是否启用缓存 Returns: AIServiceWrapper 实例 """ return AIServiceWrapper(db_session, enable_cache=enable_cache)