Files
smart-project-pricing/后端服务/app/services/ai_service_wrapper.py
2026-01-31 21:33:06 +08:00

258 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI 服务封装
遵循瑞小美 AI 接入规范:
- 通过 shared_backend.AIService 调用
- 初始化时传入 db_session用于日志记录
- 调用时传入 prompt_name用于统计
性能优化:
- 相同输入的响应缓存(减少 API 调用成本)
- 缓存键基于消息内容哈希
"""
import hashlib
import json
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.services.cache_service import get_cache, CacheNamespace
class AIServiceWrapper:
"""AI 服务封装类
封装 shared_backend.AIService 的调用
提供统一的接口供业务层使用
支持响应缓存以减少 API 调用
"""
# 默认缓存 TTL1 小时)- AI 响应通常不会频繁变化
DEFAULT_CACHE_TTL = 3600
def __init__(self, db_session: AsyncSession, enable_cache: bool = True):
"""初始化 AI 服务
Args:
db_session: 数据库会话,用于记录 AI 调用日志
enable_cache: 是否启用响应缓存
"""
self.db_session = db_session
self.module_code = settings.AI_MODULE_CODE
self.enable_cache = enable_cache
self._ai_service = None
self._cache = get_cache(CacheNamespace.AI_RESPONSES, maxsize=100, ttl=self.DEFAULT_CACHE_TTL)
def _generate_cache_key(
self,
messages: List[Dict[str, str]],
prompt_name: str,
model: Optional[str] = None,
) -> str:
"""生成缓存键
基于消息内容和参数生成唯一的缓存键
"""
key_data = {
"messages": messages,
"prompt_name": prompt_name,
"model": model or "default",
}
key_str = json.dumps(key_data, sort_keys=True, ensure_ascii=False)
return f"ai:{hashlib.sha256(key_str.encode()).hexdigest()[:16]}"
async def _get_service(self):
"""获取 AIService 实例(延迟加载)"""
if self._ai_service is None:
try:
from shared_backend.services.ai_service import AIService
self._ai_service = AIService(
module_code=self.module_code,
db_session=self.db_session,
)
except ImportError:
# 开发环境可能没有 shared_backend
# 使用 Mock 实现
self._ai_service = MockAIService(self.module_code)
return self._ai_service
async def chat(
self,
messages: List[Dict[str, str]],
prompt_name: str,
model: Optional[str] = None,
use_cache: bool = True,
cache_ttl: Optional[int] = None,
**kwargs,
):
"""调用 AI 聊天接口(带缓存支持)
Args:
messages: 消息列表
prompt_name: 提示词名称(必填,用于统计)
model: 模型名称,默认使用配置的模型
use_cache: 是否使用缓存
cache_ttl: 缓存 TTL默认 3600
**kwargs: 其他参数
Returns:
AIResponse 对象
"""
# 检查缓存
if self.enable_cache and use_cache:
cache_key = self._generate_cache_key(messages, prompt_name, model)
cached_response = self._cache.get(cache_key)
if cached_response is not None:
# 返回缓存的响应(添加标记)
cached_response.from_cache = True
return cached_response
# 调用 AI 服务
service = await self._get_service()
response = await service.chat(
messages=messages,
prompt_name=prompt_name,
model=model,
**kwargs,
)
# 存入缓存
if self.enable_cache and use_cache and response is not None:
cache_key = self._generate_cache_key(messages, prompt_name, model)
response.from_cache = False
self._cache.set(cache_key, response, cache_ttl or self.DEFAULT_CACHE_TTL)
return response
async def chat_stream(
self,
messages: List[Dict[str, str]],
prompt_name: str,
model: Optional[str] = None,
**kwargs,
):
"""调用 AI 聊天流式接口
注意:流式接口不使用缓存
Args:
messages: 消息列表
prompt_name: 提示词名称(必填,用于统计)
model: 模型名称
**kwargs: 其他参数
Yields:
响应片段
"""
service = await self._get_service()
async for chunk in service.chat_stream(
messages=messages,
prompt_name=prompt_name,
model=model,
**kwargs,
):
yield chunk
def clear_cache(self, prompt_name: Optional[str] = None):
"""清除 AI 响应缓存
Args:
prompt_name: 指定提示词名称清除None 则清除全部
"""
# 简单实现:清除整个缓存
# 更精细的实现可以按 prompt_name 过滤
self._cache.clear()
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
return self._cache.stats()
@dataclass
class MockAIResponse:
"""Mock AI 响应"""
content: str = "这是一个 Mock 响应,用于开发测试。实际部署时会使用真实的 AI 服务。"
model: str = "mock-model"
provider: str = "mock"
input_tokens: int = 100
output_tokens: int = 50
total_tokens: int = 150
cost: float = 0.0
latency_ms: int = 100
raw_response: dict = None
images: list = None
annotations: dict = None
from_cache: bool = False
class MockAIService:
"""Mock AI 服务(开发环境使用)"""
def __init__(self, module_code: str):
self.module_code = module_code
async def chat(self, messages, prompt_name, **kwargs):
"""Mock 聊天接口"""
# 生成一个基于输入的简单响应
user_message = ""
for msg in messages:
if msg.get("role") == "user":
user_message = msg.get("content", "")[:100]
break
response = MockAIResponse(
content=f"""## Mock AI 分析报告
根据您提供的数据,以下是分析结果:
### 定价建议
- **推荐价格**: 根据成本和市场分析,建议定价在合理区间内
- **引流款策略**: 适合新客引流,建议价格较低
- **利润款策略**: 适合日常经营,建议价格适中
- **高端款策略**: 适合高端客群,可考虑较高定价
### 风险提示
- 请密切关注市场动态
- 建议定期复核定价策略
*注意:这是开发测试环境的 Mock 响应,实际部署时会使用真实的 AI 服务。*
""",
model="mock-model",
provider="mock",
input_tokens=len(str(messages)),
output_tokens=200,
total_tokens=len(str(messages)) + 200,
cost=0.0,
latency_ms=50,
)
return response
async def chat_stream(self, messages, prompt_name, **kwargs):
"""Mock 流式接口"""
chunks = [
"## Mock AI 分析报告\n\n",
"根据您提供的数据,以下是分析结果:\n\n",
"### 定价建议\n",
"- 推荐价格:合理区间内\n",
"- 引流款策略:适合新客引流\n",
"- 利润款策略:适合日常经营\n\n",
"*这是 Mock 响应,实际部署时会使用真实的 AI 服务。*",
]
for chunk in chunks:
yield chunk
async def get_ai_service(db_session: AsyncSession, enable_cache: bool = True) -> AIServiceWrapper:
"""获取 AI 服务实例(依赖注入)
Args:
db_session: 数据库会话
enable_cache: 是否启用缓存
Returns:
AIServiceWrapper 实例
"""
return AIServiceWrapper(db_session, enable_cache=enable_cache)