Some checks failed
continuous-integration/drone/push Build is failing
- 移除 __init__ 中对 get_bot_config() 的无参数调用 - 改为在需要时根据 session_type 动态获取 bot_config - 修复 _get_bot_id_by_type 方法使用正确的配置获取方式
342 lines
11 KiB
Python
342 lines
11 KiB
Python
"""
|
||
Coze 服务层实现
|
||
处理会话管理、消息发送、流式响应等核心功能
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from typing import AsyncIterator, Dict, Any, List, Optional
|
||
from datetime import datetime
|
||
|
||
from cozepy import ChatEventType, Message, MessageContentType
|
||
|
||
from .client import get_coze_client, get_bot_config, get_workspace_id
|
||
from .models import (
|
||
CozeSession,
|
||
CozeMessage,
|
||
StreamEvent,
|
||
SessionType,
|
||
MessageRole,
|
||
ContentType,
|
||
StreamEventType,
|
||
CreateSessionRequest,
|
||
CreateSessionResponse,
|
||
SendMessageRequest,
|
||
EndSessionRequest,
|
||
EndSessionResponse,
|
||
)
|
||
from .exceptions import (
|
||
CozeAPIError,
|
||
CozeStreamError,
|
||
CozeTimeoutError,
|
||
map_coze_error_to_exception,
|
||
)
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class CozeService:
|
||
"""Coze 服务类"""
|
||
|
||
def __init__(self):
|
||
self.client = get_coze_client()
|
||
self.workspace_id = get_workspace_id()
|
||
|
||
# 内存中的会话存储(生产环境应使用 Redis)
|
||
self._sessions: Dict[str, CozeSession] = {}
|
||
self._messages: Dict[str, List[CozeMessage]] = {}
|
||
|
||
def _get_bot_config(self, session_type: str) -> Dict[str, Any]:
|
||
"""根据会话类型获取 Bot 配置"""
|
||
return get_bot_config(session_type)
|
||
|
||
async def create_session(
|
||
self, request: CreateSessionRequest
|
||
) -> CreateSessionResponse:
|
||
"""
|
||
创建新会话
|
||
|
||
Args:
|
||
request: 创建会话请求
|
||
|
||
Returns:
|
||
CreateSessionResponse: 会话信息
|
||
"""
|
||
try:
|
||
# 根据会话类型选择 Bot
|
||
bot_id = self._get_bot_id_by_type(request.session_type)
|
||
|
||
# 创建 Coze 对话
|
||
conversation = await asyncio.to_thread(
|
||
self.client.conversations.create, bot_id=bot_id
|
||
)
|
||
|
||
# 创建本地会话记录
|
||
session = CozeSession(
|
||
session_id=str(uuid.uuid4()),
|
||
conversation_id=conversation.id,
|
||
session_type=request.session_type,
|
||
user_id=request.user_id,
|
||
bot_id=bot_id,
|
||
metadata=request.metadata,
|
||
)
|
||
|
||
# 保存会话
|
||
self._sessions[session.session_id] = session
|
||
self._messages[session.session_id] = []
|
||
|
||
logger.info(
|
||
f"创建会话成功",
|
||
extra={
|
||
"session_id": session.session_id,
|
||
"conversation_id": conversation.id,
|
||
"session_type": request.session_type.value,
|
||
"user_id": request.user_id,
|
||
},
|
||
)
|
||
|
||
return CreateSessionResponse(
|
||
session_id=session.session_id,
|
||
conversation_id=session.conversation_id,
|
||
bot_id=session.bot_id,
|
||
created_at=session.created_at,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建会话失败: {e}", exc_info=True)
|
||
raise map_coze_error_to_exception(e)
|
||
|
||
async def send_message(
|
||
self, request: SendMessageRequest
|
||
) -> AsyncIterator[StreamEvent]:
|
||
"""
|
||
发送消息并处理流式响应
|
||
|
||
Args:
|
||
request: 发送消息请求
|
||
|
||
Yields:
|
||
StreamEvent: 流式事件
|
||
"""
|
||
session = self._get_session(request.session_id)
|
||
if not session:
|
||
raise CozeAPIError(f"会话不存在: {request.session_id}")
|
||
|
||
# 记录用户消息
|
||
user_message = CozeMessage(
|
||
message_id=str(uuid.uuid4()),
|
||
session_id=session.session_id,
|
||
role=MessageRole.USER,
|
||
content=request.content,
|
||
)
|
||
self._messages[session.session_id].append(user_message)
|
||
|
||
try:
|
||
# 构建消息历史
|
||
messages = self._build_message_history(session.session_id)
|
||
|
||
# 调用 Coze API
|
||
stream = await asyncio.to_thread(
|
||
self.client.chat.stream,
|
||
bot_id=session.bot_id,
|
||
conversation_id=session.conversation_id,
|
||
additional_messages=messages,
|
||
auto_save_history=True,
|
||
)
|
||
|
||
# 处理流式响应
|
||
async for event in self._process_stream(stream, session.session_id):
|
||
yield event
|
||
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"消息发送超时: session_id={request.session_id}")
|
||
raise CozeTimeoutError("消息处理超时")
|
||
except Exception as e:
|
||
logger.error(f"发送消息失败: {e}", exc_info=True)
|
||
raise map_coze_error_to_exception(e)
|
||
|
||
async def end_session(
|
||
self, session_id: str, request: EndSessionRequest
|
||
) -> EndSessionResponse:
|
||
"""
|
||
结束会话
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
request: 结束会话请求
|
||
|
||
Returns:
|
||
EndSessionResponse: 结束会话响应
|
||
"""
|
||
session = self._get_session(session_id)
|
||
if not session:
|
||
raise CozeAPIError(f"会话不存在: {session_id}")
|
||
|
||
# 更新会话状态
|
||
session.ended_at = datetime.now()
|
||
|
||
# 计算会话统计
|
||
duration_seconds = int((session.ended_at - session.created_at).total_seconds())
|
||
message_count = len(self._messages.get(session_id, []))
|
||
|
||
# 记录结束原因和反馈
|
||
if request.reason:
|
||
session.metadata["end_reason"] = request.reason
|
||
if request.feedback:
|
||
session.metadata["feedback"] = request.feedback
|
||
|
||
logger.info(
|
||
f"会话结束",
|
||
extra={
|
||
"session_id": session_id,
|
||
"duration_seconds": duration_seconds,
|
||
"message_count": message_count,
|
||
"reason": request.reason,
|
||
},
|
||
)
|
||
|
||
return EndSessionResponse(
|
||
session_id=session_id,
|
||
ended_at=session.ended_at,
|
||
duration_seconds=duration_seconds,
|
||
message_count=message_count,
|
||
)
|
||
|
||
async def get_session_messages(
|
||
self, session_id: str, limit: int = 50, offset: int = 0
|
||
) -> List[CozeMessage]:
|
||
"""获取会话消息历史"""
|
||
messages = self._messages.get(session_id, [])
|
||
return messages[offset : offset + limit]
|
||
|
||
def _get_bot_id_by_type(self, session_type: SessionType) -> str:
|
||
"""根据会话类型获取 Bot ID"""
|
||
# 将 SessionType 枚举映射到配置字符串
|
||
type_mapping = {
|
||
SessionType.COURSE_CHAT: "course_chat",
|
||
SessionType.TRAINING: "training",
|
||
SessionType.EXAM: "training", # 考试类型使用训练 bot
|
||
}
|
||
config_type = type_mapping.get(session_type, "training")
|
||
bot_config = get_bot_config(config_type)
|
||
return bot_config["bot_id"]
|
||
|
||
def _get_session(self, session_id: str) -> Optional[CozeSession]:
|
||
"""获取会话"""
|
||
return self._sessions.get(session_id)
|
||
|
||
def _build_message_history(self, session_id: str) -> List[Message]:
|
||
"""构建消息历史"""
|
||
messages = self._messages.get(session_id, [])
|
||
history = []
|
||
|
||
for msg in messages[-10:]: # 只发送最近10条消息作为上下文
|
||
history.append(
|
||
Message(
|
||
role=msg.role.value,
|
||
content=msg.content,
|
||
content_type=MessageContentType.TEXT,
|
||
)
|
||
)
|
||
|
||
return history
|
||
|
||
async def _process_stream(
|
||
self, stream, session_id: str
|
||
) -> AsyncIterator[StreamEvent]:
|
||
"""处理流式响应"""
|
||
assistant_message_id = str(uuid.uuid4())
|
||
accumulated_content = []
|
||
content_type = ContentType.TEXT
|
||
|
||
try:
|
||
for event in stream:
|
||
if event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA:
|
||
# 消息片段
|
||
content = event.message.content
|
||
accumulated_content.append(content)
|
||
|
||
# 检测卡片类型
|
||
if (
|
||
hasattr(event.message, "content_type")
|
||
and event.message.content_type == "card"
|
||
):
|
||
content_type = ContentType.CARD
|
||
|
||
yield StreamEvent(
|
||
event=StreamEventType.MESSAGE_DELTA,
|
||
data={
|
||
"conversation_id": event.conversation_id,
|
||
"message_id": assistant_message_id,
|
||
"content": content,
|
||
"content_type": content_type.value,
|
||
},
|
||
message_id=assistant_message_id,
|
||
content=content,
|
||
content_type=content_type,
|
||
role=MessageRole.ASSISTANT,
|
||
)
|
||
|
||
elif event.event == ChatEventType.CONVERSATION_MESSAGE_COMPLETED:
|
||
# 消息完成
|
||
full_content = "".join(accumulated_content)
|
||
|
||
# 保存助手消息
|
||
assistant_message = CozeMessage(
|
||
message_id=assistant_message_id,
|
||
session_id=session_id,
|
||
role=MessageRole.ASSISTANT,
|
||
content=full_content,
|
||
content_type=content_type,
|
||
)
|
||
self._messages[session_id].append(assistant_message)
|
||
|
||
yield StreamEvent(
|
||
event=StreamEventType.MESSAGE_COMPLETED,
|
||
data={
|
||
"conversation_id": event.conversation_id,
|
||
"message_id": assistant_message_id,
|
||
"content": full_content,
|
||
"content_type": content_type.value,
|
||
"usage": getattr(event, "usage", {}),
|
||
},
|
||
message_id=assistant_message_id,
|
||
content=full_content,
|
||
content_type=content_type,
|
||
role=MessageRole.ASSISTANT,
|
||
)
|
||
|
||
elif event.event == ChatEventType.ERROR:
|
||
# 错误事件
|
||
yield StreamEvent(
|
||
event=StreamEventType.ERROR,
|
||
data={"error": str(event)},
|
||
error=str(event),
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"流式处理错误: {e}", exc_info=True)
|
||
yield StreamEvent(
|
||
event=StreamEventType.ERROR, data={"error": str(e)}, error=str(e)
|
||
)
|
||
finally:
|
||
# 发送结束事件
|
||
yield StreamEvent(
|
||
event=StreamEventType.DONE, data={"session_id": session_id}
|
||
)
|
||
|
||
|
||
# 全局服务实例
|
||
_service: Optional[CozeService] = None
|
||
|
||
|
||
def get_coze_service() -> CozeService:
|
||
"""获取 Coze 服务单例"""
|
||
global _service
|
||
if _service is None:
|
||
_service = CozeService()
|
||
return _service
|