feat: 初始化考培练系统项目
- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
This commit is contained in:
335
backend/app/services/ai/coze/service.py
Normal file
335
backend/app/services/ai/coze/service.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Coze 服务层实现
|
||||
处理会话管理、消息发送、流式响应等核心功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import AsyncIterator, Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from cozepy import ChatEventType, Message, MessageContentType
|
||||
|
||||
from .client import get_coze_client, get_bot_config, get_workspace_id
|
||||
from .models import (
|
||||
CozeSession,
|
||||
CozeMessage,
|
||||
StreamEvent,
|
||||
SessionType,
|
||||
MessageRole,
|
||||
ContentType,
|
||||
StreamEventType,
|
||||
CreateSessionRequest,
|
||||
CreateSessionResponse,
|
||||
SendMessageRequest,
|
||||
EndSessionRequest,
|
||||
EndSessionResponse,
|
||||
)
|
||||
from .exceptions import (
|
||||
CozeAPIError,
|
||||
CozeStreamError,
|
||||
CozeTimeoutError,
|
||||
map_coze_error_to_exception,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CozeService:
|
||||
"""Coze 服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = get_coze_client()
|
||||
self.bot_config = get_bot_config()
|
||||
self.workspace_id = get_workspace_id()
|
||||
|
||||
# 内存中的会话存储(生产环境应使用 Redis)
|
||||
self._sessions: Dict[str, CozeSession] = {}
|
||||
self._messages: Dict[str, List[CozeMessage]] = {}
|
||||
|
||||
async def create_session(
|
||||
self, request: CreateSessionRequest
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
创建新会话
|
||||
|
||||
Args:
|
||||
request: 创建会话请求
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: 会话信息
|
||||
"""
|
||||
try:
|
||||
# 根据会话类型选择 Bot
|
||||
bot_id = self._get_bot_id_by_type(request.session_type)
|
||||
|
||||
# 创建 Coze 对话
|
||||
conversation = await asyncio.to_thread(
|
||||
self.client.conversations.create, bot_id=bot_id
|
||||
)
|
||||
|
||||
# 创建本地会话记录
|
||||
session = CozeSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
conversation_id=conversation.id,
|
||||
session_type=request.session_type,
|
||||
user_id=request.user_id,
|
||||
bot_id=bot_id,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
# 保存会话
|
||||
self._sessions[session.session_id] = session
|
||||
self._messages[session.session_id] = []
|
||||
|
||||
logger.info(
|
||||
f"创建会话成功",
|
||||
extra={
|
||||
"session_id": session.session_id,
|
||||
"conversation_id": conversation.id,
|
||||
"session_type": request.session_type.value,
|
||||
"user_id": request.user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return CreateSessionResponse(
|
||||
session_id=session.session_id,
|
||||
conversation_id=session.conversation_id,
|
||||
bot_id=session.bot_id,
|
||||
created_at=session.created_at,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建会话失败: {e}", exc_info=True)
|
||||
raise map_coze_error_to_exception(e)
|
||||
|
||||
async def send_message(
|
||||
self, request: SendMessageRequest
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""
|
||||
发送消息并处理流式响应
|
||||
|
||||
Args:
|
||||
request: 发送消息请求
|
||||
|
||||
Yields:
|
||||
StreamEvent: 流式事件
|
||||
"""
|
||||
session = self._get_session(request.session_id)
|
||||
if not session:
|
||||
raise CozeAPIError(f"会话不存在: {request.session_id}")
|
||||
|
||||
# 记录用户消息
|
||||
user_message = CozeMessage(
|
||||
message_id=str(uuid.uuid4()),
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content=request.content,
|
||||
)
|
||||
self._messages[session.session_id].append(user_message)
|
||||
|
||||
try:
|
||||
# 构建消息历史
|
||||
messages = self._build_message_history(session.session_id)
|
||||
|
||||
# 调用 Coze API
|
||||
stream = await asyncio.to_thread(
|
||||
self.client.chat.stream,
|
||||
bot_id=session.bot_id,
|
||||
conversation_id=session.conversation_id,
|
||||
additional_messages=messages,
|
||||
auto_save_history=True,
|
||||
)
|
||||
|
||||
# 处理流式响应
|
||||
async for event in self._process_stream(stream, session.session_id):
|
||||
yield event
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"消息发送超时: session_id={request.session_id}")
|
||||
raise CozeTimeoutError("消息处理超时")
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}", exc_info=True)
|
||||
raise map_coze_error_to_exception(e)
|
||||
|
||||
async def end_session(
|
||||
self, session_id: str, request: EndSessionRequest
|
||||
) -> EndSessionResponse:
|
||||
"""
|
||||
结束会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
request: 结束会话请求
|
||||
|
||||
Returns:
|
||||
EndSessionResponse: 结束会话响应
|
||||
"""
|
||||
session = self._get_session(session_id)
|
||||
if not session:
|
||||
raise CozeAPIError(f"会话不存在: {session_id}")
|
||||
|
||||
# 更新会话状态
|
||||
session.ended_at = datetime.now()
|
||||
|
||||
# 计算会话统计
|
||||
duration_seconds = int((session.ended_at - session.created_at).total_seconds())
|
||||
message_count = len(self._messages.get(session_id, []))
|
||||
|
||||
# 记录结束原因和反馈
|
||||
if request.reason:
|
||||
session.metadata["end_reason"] = request.reason
|
||||
if request.feedback:
|
||||
session.metadata["feedback"] = request.feedback
|
||||
|
||||
logger.info(
|
||||
f"会话结束",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"duration_seconds": duration_seconds,
|
||||
"message_count": message_count,
|
||||
"reason": request.reason,
|
||||
},
|
||||
)
|
||||
|
||||
return EndSessionResponse(
|
||||
session_id=session_id,
|
||||
ended_at=session.ended_at,
|
||||
duration_seconds=duration_seconds,
|
||||
message_count=message_count,
|
||||
)
|
||||
|
||||
async def get_session_messages(
|
||||
self, session_id: str, limit: int = 50, offset: int = 0
|
||||
) -> List[CozeMessage]:
|
||||
"""获取会话消息历史"""
|
||||
messages = self._messages.get(session_id, [])
|
||||
return messages[offset : offset + limit]
|
||||
|
||||
def _get_bot_id_by_type(self, session_type: SessionType) -> str:
|
||||
"""根据会话类型获取 Bot ID"""
|
||||
mapping = {
|
||||
SessionType.COURSE_CHAT: self.bot_config["course_chat"],
|
||||
SessionType.TRAINING: self.bot_config["training"],
|
||||
SessionType.EXAM: self.bot_config["exam"],
|
||||
}
|
||||
return mapping.get(session_type, self.bot_config["training"])
|
||||
|
||||
def _get_session(self, session_id: str) -> Optional[CozeSession]:
|
||||
"""获取会话"""
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def _build_message_history(self, session_id: str) -> List[Message]:
|
||||
"""构建消息历史"""
|
||||
messages = self._messages.get(session_id, [])
|
||||
history = []
|
||||
|
||||
for msg in messages[-10:]: # 只发送最近10条消息作为上下文
|
||||
history.append(
|
||||
Message(
|
||||
role=msg.role.value,
|
||||
content=msg.content,
|
||||
content_type=MessageContentType.TEXT,
|
||||
)
|
||||
)
|
||||
|
||||
return history
|
||||
|
||||
async def _process_stream(
|
||||
self, stream, session_id: str
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""处理流式响应"""
|
||||
assistant_message_id = str(uuid.uuid4())
|
||||
accumulated_content = []
|
||||
content_type = ContentType.TEXT
|
||||
|
||||
try:
|
||||
for event in stream:
|
||||
if event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA:
|
||||
# 消息片段
|
||||
content = event.message.content
|
||||
accumulated_content.append(content)
|
||||
|
||||
# 检测卡片类型
|
||||
if (
|
||||
hasattr(event.message, "content_type")
|
||||
and event.message.content_type == "card"
|
||||
):
|
||||
content_type = ContentType.CARD
|
||||
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.MESSAGE_DELTA,
|
||||
data={
|
||||
"conversation_id": event.conversation_id,
|
||||
"message_id": assistant_message_id,
|
||||
"content": content,
|
||||
"content_type": content_type.value,
|
||||
},
|
||||
message_id=assistant_message_id,
|
||||
content=content,
|
||||
content_type=content_type,
|
||||
role=MessageRole.ASSISTANT,
|
||||
)
|
||||
|
||||
elif event.event == ChatEventType.CONVERSATION_MESSAGE_COMPLETED:
|
||||
# 消息完成
|
||||
full_content = "".join(accumulated_content)
|
||||
|
||||
# 保存助手消息
|
||||
assistant_message = CozeMessage(
|
||||
message_id=assistant_message_id,
|
||||
session_id=session_id,
|
||||
role=MessageRole.ASSISTANT,
|
||||
content=full_content,
|
||||
content_type=content_type,
|
||||
)
|
||||
self._messages[session_id].append(assistant_message)
|
||||
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.MESSAGE_COMPLETED,
|
||||
data={
|
||||
"conversation_id": event.conversation_id,
|
||||
"message_id": assistant_message_id,
|
||||
"content": full_content,
|
||||
"content_type": content_type.value,
|
||||
"usage": getattr(event, "usage", {}),
|
||||
},
|
||||
message_id=assistant_message_id,
|
||||
content=full_content,
|
||||
content_type=content_type,
|
||||
role=MessageRole.ASSISTANT,
|
||||
)
|
||||
|
||||
elif event.event == ChatEventType.ERROR:
|
||||
# 错误事件
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.ERROR,
|
||||
data={"error": str(event)},
|
||||
error=str(event),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式处理错误: {e}", exc_info=True)
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.ERROR, data={"error": str(e)}, error=str(e)
|
||||
)
|
||||
finally:
|
||||
# 发送结束事件
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.DONE, data={"session_id": session_id}
|
||||
)
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
_service: Optional[CozeService] = None
|
||||
|
||||
|
||||
def get_coze_service() -> CozeService:
|
||||
"""获取 Coze 服务单例"""
|
||||
global _service
|
||||
if _service is None:
|
||||
_service = CozeService()
|
||||
return _service
|
||||
Reference in New Issue
Block a user