""" Coze 服务层实现 处理会话管理、消息发送、流式响应等核心功能 """ import asyncio import json import logging import uuid from typing import AsyncIterator, Dict, Any, List, Optional from datetime import datetime from cozepy import ChatEventType, Message, MessageContentType from .client import get_coze_client, get_bot_config, get_workspace_id from .models import ( CozeSession, CozeMessage, StreamEvent, SessionType, MessageRole, ContentType, StreamEventType, CreateSessionRequest, CreateSessionResponse, SendMessageRequest, EndSessionRequest, EndSessionResponse, ) from .exceptions import ( CozeAPIError, CozeStreamError, CozeTimeoutError, map_coze_error_to_exception, ) logger = logging.getLogger(__name__) class CozeService: """Coze 服务类""" def __init__(self): self.client = get_coze_client() self.workspace_id = get_workspace_id() # 内存中的会话存储(生产环境应使用 Redis) self._sessions: Dict[str, CozeSession] = {} self._messages: Dict[str, List[CozeMessage]] = {} def _get_bot_config(self, session_type: str) -> Dict[str, Any]: """根据会话类型获取 Bot 配置""" return get_bot_config(session_type) async def create_session( self, request: CreateSessionRequest ) -> CreateSessionResponse: """ 创建新会话 Args: request: 创建会话请求 Returns: CreateSessionResponse: 会话信息 """ try: # 根据会话类型选择 Bot bot_id = self._get_bot_id_by_type(request.session_type) # 创建 Coze 对话 conversation = await asyncio.to_thread( self.client.conversations.create, bot_id=bot_id ) # 创建本地会话记录 session = CozeSession( session_id=str(uuid.uuid4()), conversation_id=conversation.id, session_type=request.session_type, user_id=request.user_id, bot_id=bot_id, metadata=request.metadata, ) # 保存会话 self._sessions[session.session_id] = session self._messages[session.session_id] = [] logger.info( f"创建会话成功", extra={ "session_id": session.session_id, "conversation_id": conversation.id, "session_type": request.session_type.value, "user_id": request.user_id, }, ) return CreateSessionResponse( session_id=session.session_id, conversation_id=session.conversation_id, bot_id=session.bot_id, created_at=session.created_at, ) except Exception as e: logger.error(f"创建会话失败: {e}", exc_info=True) raise map_coze_error_to_exception(e) async def send_message( self, request: SendMessageRequest ) -> AsyncIterator[StreamEvent]: """ 发送消息并处理流式响应 Args: request: 发送消息请求 Yields: StreamEvent: 流式事件 """ session = self._get_session(request.session_id) if not session: raise CozeAPIError(f"会话不存在: {request.session_id}") # 记录用户消息 user_message = CozeMessage( message_id=str(uuid.uuid4()), session_id=session.session_id, role=MessageRole.USER, content=request.content, ) self._messages[session.session_id].append(user_message) try: # 构建消息历史 messages = self._build_message_history(session.session_id) # 调用 Coze API stream = await asyncio.to_thread( self.client.chat.stream, bot_id=session.bot_id, conversation_id=session.conversation_id, additional_messages=messages, auto_save_history=True, ) # 处理流式响应 async for event in self._process_stream(stream, session.session_id): yield event except asyncio.TimeoutError: logger.error(f"消息发送超时: session_id={request.session_id}") raise CozeTimeoutError("消息处理超时") except Exception as e: logger.error(f"发送消息失败: {e}", exc_info=True) raise map_coze_error_to_exception(e) async def end_session( self, session_id: str, request: EndSessionRequest ) -> EndSessionResponse: """ 结束会话 Args: session_id: 会话ID request: 结束会话请求 Returns: EndSessionResponse: 结束会话响应 """ session = self._get_session(session_id) if not session: raise CozeAPIError(f"会话不存在: {session_id}") # 更新会话状态 session.ended_at = datetime.now() # 计算会话统计 duration_seconds = int((session.ended_at - session.created_at).total_seconds()) message_count = len(self._messages.get(session_id, [])) # 记录结束原因和反馈 if request.reason: session.metadata["end_reason"] = request.reason if request.feedback: session.metadata["feedback"] = request.feedback logger.info( f"会话结束", extra={ "session_id": session_id, "duration_seconds": duration_seconds, "message_count": message_count, "reason": request.reason, }, ) return EndSessionResponse( session_id=session_id, ended_at=session.ended_at, duration_seconds=duration_seconds, message_count=message_count, ) async def get_session_messages( self, session_id: str, limit: int = 50, offset: int = 0 ) -> List[CozeMessage]: """获取会话消息历史""" messages = self._messages.get(session_id, []) return messages[offset : offset + limit] def _get_bot_id_by_type(self, session_type: SessionType) -> str: """根据会话类型获取 Bot ID""" # 将 SessionType 枚举映射到配置字符串 type_mapping = { SessionType.COURSE_CHAT: "course_chat", SessionType.TRAINING: "training", SessionType.EXAM: "training", # 考试类型使用训练 bot } config_type = type_mapping.get(session_type, "training") bot_config = get_bot_config(config_type) return bot_config["bot_id"] def _get_session(self, session_id: str) -> Optional[CozeSession]: """获取会话""" return self._sessions.get(session_id) def _build_message_history(self, session_id: str) -> List[Message]: """构建消息历史""" messages = self._messages.get(session_id, []) history = [] for msg in messages[-10:]: # 只发送最近10条消息作为上下文 history.append( Message( role=msg.role.value, content=msg.content, content_type=MessageContentType.TEXT, ) ) return history async def _process_stream( self, stream, session_id: str ) -> AsyncIterator[StreamEvent]: """处理流式响应""" assistant_message_id = str(uuid.uuid4()) accumulated_content = [] content_type = ContentType.TEXT try: for event in stream: if event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA: # 消息片段 content = event.message.content accumulated_content.append(content) # 检测卡片类型 if ( hasattr(event.message, "content_type") and event.message.content_type == "card" ): content_type = ContentType.CARD yield StreamEvent( event=StreamEventType.MESSAGE_DELTA, data={ "conversation_id": event.conversation_id, "message_id": assistant_message_id, "content": content, "content_type": content_type.value, }, message_id=assistant_message_id, content=content, content_type=content_type, role=MessageRole.ASSISTANT, ) elif event.event == ChatEventType.CONVERSATION_MESSAGE_COMPLETED: # 消息完成 full_content = "".join(accumulated_content) # 保存助手消息 assistant_message = CozeMessage( message_id=assistant_message_id, session_id=session_id, role=MessageRole.ASSISTANT, content=full_content, content_type=content_type, ) self._messages[session_id].append(assistant_message) yield StreamEvent( event=StreamEventType.MESSAGE_COMPLETED, data={ "conversation_id": event.conversation_id, "message_id": assistant_message_id, "content": full_content, "content_type": content_type.value, "usage": getattr(event, "usage", {}), }, message_id=assistant_message_id, content=full_content, content_type=content_type, role=MessageRole.ASSISTANT, ) elif event.event == ChatEventType.ERROR: # 错误事件 yield StreamEvent( event=StreamEventType.ERROR, data={"error": str(event)}, error=str(event), ) except Exception as e: logger.error(f"流式处理错误: {e}", exc_info=True) yield StreamEvent( event=StreamEventType.ERROR, data={"error": str(e)}, error=str(e) ) finally: # 发送结束事件 yield StreamEvent( event=StreamEventType.DONE, data={"session_id": session_id} ) # 全局服务实例 _service: Optional[CozeService] = None def get_coze_service() -> CozeService: """获取 Coze 服务单例""" global _service if _service is None: _service = CozeService() return _service