Files
012-kaopeilian/backend/app/services/ai/coze/service.py
yuliang_guo 82f8e6596c
Some checks failed
continuous-integration/drone/push Build is failing
fix: 修复CozeService初始化时get_bot_config缺少参数的问题
- 移除 __init__ 中对 get_bot_config() 的无参数调用
- 改为在需要时根据 session_type 动态获取 bot_config
- 修复 _get_bot_id_by_type 方法使用正确的配置获取方式
2026-02-02 13:21:00 +08:00

342 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Coze 服务层实现
处理会话管理、消息发送、流式响应等核心功能
"""
import asyncio
import json
import logging
import uuid
from typing import AsyncIterator, Dict, Any, List, Optional
from datetime import datetime
from cozepy import ChatEventType, Message, MessageContentType
from .client import get_coze_client, get_bot_config, get_workspace_id
from .models import (
CozeSession,
CozeMessage,
StreamEvent,
SessionType,
MessageRole,
ContentType,
StreamEventType,
CreateSessionRequest,
CreateSessionResponse,
SendMessageRequest,
EndSessionRequest,
EndSessionResponse,
)
from .exceptions import (
CozeAPIError,
CozeStreamError,
CozeTimeoutError,
map_coze_error_to_exception,
)
logger = logging.getLogger(__name__)
class CozeService:
"""Coze 服务类"""
def __init__(self):
self.client = get_coze_client()
self.workspace_id = get_workspace_id()
# 内存中的会话存储(生产环境应使用 Redis
self._sessions: Dict[str, CozeSession] = {}
self._messages: Dict[str, List[CozeMessage]] = {}
def _get_bot_config(self, session_type: str) -> Dict[str, Any]:
"""根据会话类型获取 Bot 配置"""
return get_bot_config(session_type)
async def create_session(
self, request: CreateSessionRequest
) -> CreateSessionResponse:
"""
创建新会话
Args:
request: 创建会话请求
Returns:
CreateSessionResponse: 会话信息
"""
try:
# 根据会话类型选择 Bot
bot_id = self._get_bot_id_by_type(request.session_type)
# 创建 Coze 对话
conversation = await asyncio.to_thread(
self.client.conversations.create, bot_id=bot_id
)
# 创建本地会话记录
session = CozeSession(
session_id=str(uuid.uuid4()),
conversation_id=conversation.id,
session_type=request.session_type,
user_id=request.user_id,
bot_id=bot_id,
metadata=request.metadata,
)
# 保存会话
self._sessions[session.session_id] = session
self._messages[session.session_id] = []
logger.info(
f"创建会话成功",
extra={
"session_id": session.session_id,
"conversation_id": conversation.id,
"session_type": request.session_type.value,
"user_id": request.user_id,
},
)
return CreateSessionResponse(
session_id=session.session_id,
conversation_id=session.conversation_id,
bot_id=session.bot_id,
created_at=session.created_at,
)
except Exception as e:
logger.error(f"创建会话失败: {e}", exc_info=True)
raise map_coze_error_to_exception(e)
async def send_message(
self, request: SendMessageRequest
) -> AsyncIterator[StreamEvent]:
"""
发送消息并处理流式响应
Args:
request: 发送消息请求
Yields:
StreamEvent: 流式事件
"""
session = self._get_session(request.session_id)
if not session:
raise CozeAPIError(f"会话不存在: {request.session_id}")
# 记录用户消息
user_message = CozeMessage(
message_id=str(uuid.uuid4()),
session_id=session.session_id,
role=MessageRole.USER,
content=request.content,
)
self._messages[session.session_id].append(user_message)
try:
# 构建消息历史
messages = self._build_message_history(session.session_id)
# 调用 Coze API
stream = await asyncio.to_thread(
self.client.chat.stream,
bot_id=session.bot_id,
conversation_id=session.conversation_id,
additional_messages=messages,
auto_save_history=True,
)
# 处理流式响应
async for event in self._process_stream(stream, session.session_id):
yield event
except asyncio.TimeoutError:
logger.error(f"消息发送超时: session_id={request.session_id}")
raise CozeTimeoutError("消息处理超时")
except Exception as e:
logger.error(f"发送消息失败: {e}", exc_info=True)
raise map_coze_error_to_exception(e)
async def end_session(
self, session_id: str, request: EndSessionRequest
) -> EndSessionResponse:
"""
结束会话
Args:
session_id: 会话ID
request: 结束会话请求
Returns:
EndSessionResponse: 结束会话响应
"""
session = self._get_session(session_id)
if not session:
raise CozeAPIError(f"会话不存在: {session_id}")
# 更新会话状态
session.ended_at = datetime.now()
# 计算会话统计
duration_seconds = int((session.ended_at - session.created_at).total_seconds())
message_count = len(self._messages.get(session_id, []))
# 记录结束原因和反馈
if request.reason:
session.metadata["end_reason"] = request.reason
if request.feedback:
session.metadata["feedback"] = request.feedback
logger.info(
f"会话结束",
extra={
"session_id": session_id,
"duration_seconds": duration_seconds,
"message_count": message_count,
"reason": request.reason,
},
)
return EndSessionResponse(
session_id=session_id,
ended_at=session.ended_at,
duration_seconds=duration_seconds,
message_count=message_count,
)
async def get_session_messages(
self, session_id: str, limit: int = 50, offset: int = 0
) -> List[CozeMessage]:
"""获取会话消息历史"""
messages = self._messages.get(session_id, [])
return messages[offset : offset + limit]
def _get_bot_id_by_type(self, session_type: SessionType) -> str:
"""根据会话类型获取 Bot ID"""
# 将 SessionType 枚举映射到配置字符串
type_mapping = {
SessionType.COURSE_CHAT: "course_chat",
SessionType.TRAINING: "training",
SessionType.EXAM: "training", # 考试类型使用训练 bot
}
config_type = type_mapping.get(session_type, "training")
bot_config = get_bot_config(config_type)
return bot_config["bot_id"]
def _get_session(self, session_id: str) -> Optional[CozeSession]:
"""获取会话"""
return self._sessions.get(session_id)
def _build_message_history(self, session_id: str) -> List[Message]:
"""构建消息历史"""
messages = self._messages.get(session_id, [])
history = []
for msg in messages[-10:]: # 只发送最近10条消息作为上下文
history.append(
Message(
role=msg.role.value,
content=msg.content,
content_type=MessageContentType.TEXT,
)
)
return history
async def _process_stream(
self, stream, session_id: str
) -> AsyncIterator[StreamEvent]:
"""处理流式响应"""
assistant_message_id = str(uuid.uuid4())
accumulated_content = []
content_type = ContentType.TEXT
try:
for event in stream:
if event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA:
# 消息片段
content = event.message.content
accumulated_content.append(content)
# 检测卡片类型
if (
hasattr(event.message, "content_type")
and event.message.content_type == "card"
):
content_type = ContentType.CARD
yield StreamEvent(
event=StreamEventType.MESSAGE_DELTA,
data={
"conversation_id": event.conversation_id,
"message_id": assistant_message_id,
"content": content,
"content_type": content_type.value,
},
message_id=assistant_message_id,
content=content,
content_type=content_type,
role=MessageRole.ASSISTANT,
)
elif event.event == ChatEventType.CONVERSATION_MESSAGE_COMPLETED:
# 消息完成
full_content = "".join(accumulated_content)
# 保存助手消息
assistant_message = CozeMessage(
message_id=assistant_message_id,
session_id=session_id,
role=MessageRole.ASSISTANT,
content=full_content,
content_type=content_type,
)
self._messages[session_id].append(assistant_message)
yield StreamEvent(
event=StreamEventType.MESSAGE_COMPLETED,
data={
"conversation_id": event.conversation_id,
"message_id": assistant_message_id,
"content": full_content,
"content_type": content_type.value,
"usage": getattr(event, "usage", {}),
},
message_id=assistant_message_id,
content=full_content,
content_type=content_type,
role=MessageRole.ASSISTANT,
)
elif event.event == ChatEventType.ERROR:
# 错误事件
yield StreamEvent(
event=StreamEventType.ERROR,
data={"error": str(event)},
error=str(event),
)
except Exception as e:
logger.error(f"流式处理错误: {e}", exc_info=True)
yield StreamEvent(
event=StreamEventType.ERROR, data={"error": str(e)}, error=str(e)
)
finally:
# 发送结束事件
yield StreamEvent(
event=StreamEventType.DONE, data={"session_id": session_id}
)
# 全局服务实例
_service: Optional[CozeService] = None
def get_coze_service() -> CozeService:
"""获取 Coze 服务单例"""
global _service
if _service is None:
_service = CozeService()
return _service