258 lines
8.1 KiB
Python
258 lines
8.1 KiB
Python
"""AI 服务封装
|
||
|
||
遵循瑞小美 AI 接入规范:
|
||
- 通过 shared_backend.AIService 调用
|
||
- 初始化时传入 db_session(用于日志记录)
|
||
- 调用时传入 prompt_name(用于统计)
|
||
|
||
性能优化:
|
||
- 相同输入的响应缓存(减少 API 调用成本)
|
||
- 缓存键基于消息内容哈希
|
||
"""
|
||
|
||
import hashlib
|
||
import json
|
||
from typing import Optional, List, Dict, Any
|
||
from dataclasses import dataclass
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.config import settings
|
||
from app.services.cache_service import get_cache, CacheNamespace
|
||
|
||
|
||
class AIServiceWrapper:
|
||
"""AI 服务封装类
|
||
|
||
封装 shared_backend.AIService 的调用
|
||
提供统一的接口供业务层使用
|
||
支持响应缓存以减少 API 调用
|
||
"""
|
||
|
||
# 默认缓存 TTL(1 小时)- AI 响应通常不会频繁变化
|
||
DEFAULT_CACHE_TTL = 3600
|
||
|
||
def __init__(self, db_session: AsyncSession, enable_cache: bool = True):
|
||
"""初始化 AI 服务
|
||
|
||
Args:
|
||
db_session: 数据库会话,用于记录 AI 调用日志
|
||
enable_cache: 是否启用响应缓存
|
||
"""
|
||
self.db_session = db_session
|
||
self.module_code = settings.AI_MODULE_CODE
|
||
self.enable_cache = enable_cache
|
||
self._ai_service = None
|
||
self._cache = get_cache(CacheNamespace.AI_RESPONSES, maxsize=100, ttl=self.DEFAULT_CACHE_TTL)
|
||
|
||
def _generate_cache_key(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
prompt_name: str,
|
||
model: Optional[str] = None,
|
||
) -> str:
|
||
"""生成缓存键
|
||
|
||
基于消息内容和参数生成唯一的缓存键
|
||
"""
|
||
key_data = {
|
||
"messages": messages,
|
||
"prompt_name": prompt_name,
|
||
"model": model or "default",
|
||
}
|
||
key_str = json.dumps(key_data, sort_keys=True, ensure_ascii=False)
|
||
return f"ai:{hashlib.sha256(key_str.encode()).hexdigest()[:16]}"
|
||
|
||
async def _get_service(self):
|
||
"""获取 AIService 实例(延迟加载)"""
|
||
if self._ai_service is None:
|
||
try:
|
||
from shared_backend.services.ai_service import AIService
|
||
self._ai_service = AIService(
|
||
module_code=self.module_code,
|
||
db_session=self.db_session,
|
||
)
|
||
except ImportError:
|
||
# 开发环境可能没有 shared_backend
|
||
# 使用 Mock 实现
|
||
self._ai_service = MockAIService(self.module_code)
|
||
return self._ai_service
|
||
|
||
async def chat(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
prompt_name: str,
|
||
model: Optional[str] = None,
|
||
use_cache: bool = True,
|
||
cache_ttl: Optional[int] = None,
|
||
**kwargs,
|
||
):
|
||
"""调用 AI 聊天接口(带缓存支持)
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
prompt_name: 提示词名称(必填,用于统计)
|
||
model: 模型名称,默认使用配置的模型
|
||
use_cache: 是否使用缓存
|
||
cache_ttl: 缓存 TTL(秒),默认 3600
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
AIResponse 对象
|
||
"""
|
||
# 检查缓存
|
||
if self.enable_cache and use_cache:
|
||
cache_key = self._generate_cache_key(messages, prompt_name, model)
|
||
cached_response = self._cache.get(cache_key)
|
||
if cached_response is not None:
|
||
# 返回缓存的响应(添加标记)
|
||
cached_response.from_cache = True
|
||
return cached_response
|
||
|
||
# 调用 AI 服务
|
||
service = await self._get_service()
|
||
response = await service.chat(
|
||
messages=messages,
|
||
prompt_name=prompt_name,
|
||
model=model,
|
||
**kwargs,
|
||
)
|
||
|
||
# 存入缓存
|
||
if self.enable_cache and use_cache and response is not None:
|
||
cache_key = self._generate_cache_key(messages, prompt_name, model)
|
||
response.from_cache = False
|
||
self._cache.set(cache_key, response, cache_ttl or self.DEFAULT_CACHE_TTL)
|
||
|
||
return response
|
||
|
||
async def chat_stream(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
prompt_name: str,
|
||
model: Optional[str] = None,
|
||
**kwargs,
|
||
):
|
||
"""调用 AI 聊天流式接口
|
||
|
||
注意:流式接口不使用缓存
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
prompt_name: 提示词名称(必填,用于统计)
|
||
model: 模型名称
|
||
**kwargs: 其他参数
|
||
|
||
Yields:
|
||
响应片段
|
||
"""
|
||
service = await self._get_service()
|
||
async for chunk in service.chat_stream(
|
||
messages=messages,
|
||
prompt_name=prompt_name,
|
||
model=model,
|
||
**kwargs,
|
||
):
|
||
yield chunk
|
||
|
||
def clear_cache(self, prompt_name: Optional[str] = None):
|
||
"""清除 AI 响应缓存
|
||
|
||
Args:
|
||
prompt_name: 指定提示词名称清除,None 则清除全部
|
||
"""
|
||
# 简单实现:清除整个缓存
|
||
# 更精细的实现可以按 prompt_name 过滤
|
||
self._cache.clear()
|
||
|
||
def get_cache_stats(self) -> Dict[str, Any]:
|
||
"""获取缓存统计信息"""
|
||
return self._cache.stats()
|
||
|
||
|
||
@dataclass
|
||
class MockAIResponse:
|
||
"""Mock AI 响应"""
|
||
content: str = "这是一个 Mock 响应,用于开发测试。实际部署时会使用真实的 AI 服务。"
|
||
model: str = "mock-model"
|
||
provider: str = "mock"
|
||
input_tokens: int = 100
|
||
output_tokens: int = 50
|
||
total_tokens: int = 150
|
||
cost: float = 0.0
|
||
latency_ms: int = 100
|
||
raw_response: dict = None
|
||
images: list = None
|
||
annotations: dict = None
|
||
from_cache: bool = False
|
||
|
||
|
||
class MockAIService:
|
||
"""Mock AI 服务(开发环境使用)"""
|
||
|
||
def __init__(self, module_code: str):
|
||
self.module_code = module_code
|
||
|
||
async def chat(self, messages, prompt_name, **kwargs):
|
||
"""Mock 聊天接口"""
|
||
# 生成一个基于输入的简单响应
|
||
user_message = ""
|
||
for msg in messages:
|
||
if msg.get("role") == "user":
|
||
user_message = msg.get("content", "")[:100]
|
||
break
|
||
|
||
response = MockAIResponse(
|
||
content=f"""## Mock AI 分析报告
|
||
|
||
根据您提供的数据,以下是分析结果:
|
||
|
||
### 定价建议
|
||
- **推荐价格**: 根据成本和市场分析,建议定价在合理区间内
|
||
- **引流款策略**: 适合新客引流,建议价格较低
|
||
- **利润款策略**: 适合日常经营,建议价格适中
|
||
- **高端款策略**: 适合高端客群,可考虑较高定价
|
||
|
||
### 风险提示
|
||
- 请密切关注市场动态
|
||
- 建议定期复核定价策略
|
||
|
||
*注意:这是开发测试环境的 Mock 响应,实际部署时会使用真实的 AI 服务。*
|
||
""",
|
||
model="mock-model",
|
||
provider="mock",
|
||
input_tokens=len(str(messages)),
|
||
output_tokens=200,
|
||
total_tokens=len(str(messages)) + 200,
|
||
cost=0.0,
|
||
latency_ms=50,
|
||
)
|
||
return response
|
||
|
||
async def chat_stream(self, messages, prompt_name, **kwargs):
|
||
"""Mock 流式接口"""
|
||
chunks = [
|
||
"## Mock AI 分析报告\n\n",
|
||
"根据您提供的数据,以下是分析结果:\n\n",
|
||
"### 定价建议\n",
|
||
"- 推荐价格:合理区间内\n",
|
||
"- 引流款策略:适合新客引流\n",
|
||
"- 利润款策略:适合日常经营\n\n",
|
||
"*这是 Mock 响应,实际部署时会使用真实的 AI 服务。*",
|
||
]
|
||
for chunk in chunks:
|
||
yield chunk
|
||
|
||
|
||
async def get_ai_service(db_session: AsyncSession, enable_cache: bool = True) -> AIServiceWrapper:
|
||
"""获取 AI 服务实例(依赖注入)
|
||
|
||
Args:
|
||
db_session: 数据库会话
|
||
enable_cache: 是否启用缓存
|
||
|
||
Returns:
|
||
AIServiceWrapper 实例
|
||
"""
|
||
return AIServiceWrapper(db_session, enable_cache=enable_cache)
|