feat: 初始化考培练系统项目
- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
This commit is contained in:
61
backend/app/services/ai/coze/__init__.py
Normal file
61
backend/app/services/ai/coze/__init__.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Coze AI 服务模块
|
||||
"""
|
||||
|
||||
from .client import get_coze_client, get_auth_manager, get_bot_config, get_workspace_id
|
||||
from .service import get_coze_service, CozeService
|
||||
from .models import (
|
||||
SessionType,
|
||||
MessageRole,
|
||||
ContentType,
|
||||
StreamEventType,
|
||||
CozeSession,
|
||||
CozeMessage,
|
||||
StreamEvent,
|
||||
CreateSessionRequest,
|
||||
CreateSessionResponse,
|
||||
SendMessageRequest,
|
||||
EndSessionRequest,
|
||||
EndSessionResponse,
|
||||
)
|
||||
from .exceptions import (
|
||||
CozeException,
|
||||
CozeAuthError,
|
||||
CozeAPIError,
|
||||
CozeRateLimitError,
|
||||
CozeTimeoutError,
|
||||
CozeStreamError,
|
||||
map_coze_error_to_exception,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Client
|
||||
"get_coze_client",
|
||||
"get_auth_manager",
|
||||
"get_bot_config",
|
||||
"get_workspace_id",
|
||||
# Service
|
||||
"get_coze_service",
|
||||
"CozeService",
|
||||
# Models
|
||||
"SessionType",
|
||||
"MessageRole",
|
||||
"ContentType",
|
||||
"StreamEventType",
|
||||
"CozeSession",
|
||||
"CozeMessage",
|
||||
"StreamEvent",
|
||||
"CreateSessionRequest",
|
||||
"CreateSessionResponse",
|
||||
"SendMessageRequest",
|
||||
"EndSessionRequest",
|
||||
"EndSessionResponse",
|
||||
# Exceptions
|
||||
"CozeException",
|
||||
"CozeAuthError",
|
||||
"CozeAPIError",
|
||||
"CozeRateLimitError",
|
||||
"CozeTimeoutError",
|
||||
"CozeStreamError",
|
||||
"map_coze_error_to_exception",
|
||||
]
|
||||
203
backend/app/services/ai/coze/client.py
Normal file
203
backend/app/services/ai/coze/client.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Coze AI 客户端管理
|
||||
负责管理 Coze API 的认证和客户端实例
|
||||
"""
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from cozepy import Coze, TokenAuth, JWTAuth, COZE_CN_BASE_URL
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CozeAuthManager:
|
||||
"""Coze 认证管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self._client: Optional[Coze] = None
|
||||
|
||||
def _create_pat_auth(self) -> TokenAuth:
|
||||
"""创建个人访问令牌认证"""
|
||||
if not self.settings.COZE_API_TOKEN:
|
||||
raise ValueError("COZE_API_TOKEN 未配置")
|
||||
|
||||
return TokenAuth(token=self.settings.COZE_API_TOKEN)
|
||||
|
||||
def _create_oauth_auth(self) -> JWTAuth:
|
||||
"""创建 OAuth 认证"""
|
||||
if not all(
|
||||
[
|
||||
self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
|
||||
]
|
||||
):
|
||||
raise ValueError("OAuth 配置不完整")
|
||||
|
||||
# 读取私钥
|
||||
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
|
||||
if not private_key_path.exists():
|
||||
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
|
||||
|
||||
with open(private_key_path, "r") as f:
|
||||
private_key = f.read()
|
||||
|
||||
try:
|
||||
return JWTAuth(
|
||||
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
private_key=private_key,
|
||||
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建 OAuth 认证失败: {e}")
|
||||
raise
|
||||
|
||||
def get_client(self, force_new: bool = False) -> Coze:
|
||||
"""
|
||||
获取 Coze 客户端实例
|
||||
|
||||
Args:
|
||||
force_new: 是否强制创建新客户端(用于长时间运行的请求,避免token过期)
|
||||
|
||||
认证优先级:
|
||||
1. OAuth(推荐):配置完整时使用,自动刷新token
|
||||
2. PAT:仅当OAuth未配置时使用(注意:PAT会过期)
|
||||
"""
|
||||
if self._client is not None and not force_new:
|
||||
return self._client
|
||||
|
||||
auth = None
|
||||
auth_type = None
|
||||
|
||||
# 检查 OAuth 配置是否完整
|
||||
oauth_configured = all([
|
||||
self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
|
||||
])
|
||||
|
||||
if oauth_configured:
|
||||
# OAuth 配置完整,必须使用 OAuth(不fallback到PAT)
|
||||
try:
|
||||
auth = self._create_oauth_auth()
|
||||
auth_type = "OAuth"
|
||||
logger.info("使用 OAuth 认证")
|
||||
except Exception as e:
|
||||
# OAuth 配置完整但创建失败,直接抛出异常(不fallback到可能过期的PAT)
|
||||
logger.error(f"OAuth 认证创建失败: {e}")
|
||||
raise ValueError(f"OAuth 认证失败,请检查私钥文件和配置: {e}")
|
||||
else:
|
||||
# OAuth 未配置,使用 PAT
|
||||
if self.settings.COZE_API_TOKEN:
|
||||
auth = self._create_pat_auth()
|
||||
auth_type = "PAT"
|
||||
logger.warning("使用 PAT 认证(注意:PAT会过期,建议配置OAuth)")
|
||||
else:
|
||||
raise ValueError("Coze 认证未配置:需要配置 OAuth 或 PAT Token")
|
||||
|
||||
# 创建客户端
|
||||
client = Coze(
|
||||
auth=auth, base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL
|
||||
)
|
||||
|
||||
logger.debug(f"Coze客户端创建成功,认证方式: {auth_type}, force_new: {force_new}")
|
||||
|
||||
# 只有非强制创建时才缓存
|
||||
if not force_new:
|
||||
self._client = client
|
||||
|
||||
return client
|
||||
|
||||
def reset(self):
|
||||
"""重置客户端实例"""
|
||||
self._client = None
|
||||
|
||||
def get_oauth_token(self) -> str:
|
||||
"""
|
||||
获取OAuth JWT Token用于前端直连
|
||||
|
||||
Returns:
|
||||
JWT token字符串
|
||||
"""
|
||||
if not all([
|
||||
self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
|
||||
]):
|
||||
raise ValueError("OAuth 配置不完整")
|
||||
|
||||
# 读取私钥
|
||||
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
|
||||
if not private_key_path.exists():
|
||||
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
|
||||
|
||||
with open(private_key_path, "r") as f:
|
||||
private_key = f.read()
|
||||
|
||||
# 创建JWTAuth实例(必须指定中国区base_url)
|
||||
jwt_auth = JWTAuth(
|
||||
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
private_key=private_key,
|
||||
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
|
||||
)
|
||||
|
||||
# 获取token(JWTAuth内部会自动生成)
|
||||
# JWTAuth.token属性返回已签名的JWT
|
||||
return jwt_auth.token
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_auth_manager() -> CozeAuthManager:
|
||||
"""获取认证管理器单例"""
|
||||
return CozeAuthManager()
|
||||
|
||||
|
||||
def get_coze_client(force_new: bool = False) -> Coze:
|
||||
"""
|
||||
获取 Coze 客户端
|
||||
|
||||
Args:
|
||||
force_new: 是否强制创建新客户端(用于工作流等长时间运行的请求)
|
||||
"""
|
||||
return get_auth_manager().get_client(force_new=force_new)
|
||||
|
||||
|
||||
def get_workspace_id() -> str:
|
||||
"""获取工作空间 ID"""
|
||||
settings = get_settings()
|
||||
if not settings.COZE_WORKSPACE_ID:
|
||||
raise ValueError("COZE_WORKSPACE_ID 未配置")
|
||||
return settings.COZE_WORKSPACE_ID
|
||||
|
||||
|
||||
def get_bot_config(session_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
根据会话类型获取 Bot 配置
|
||||
|
||||
Args:
|
||||
session_type: 会话类型 (course_chat 或 training)
|
||||
|
||||
Returns:
|
||||
包含 bot_id 等配置的字典
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
if session_type == "course_chat":
|
||||
bot_id = settings.COZE_CHAT_BOT_ID
|
||||
if not bot_id:
|
||||
raise ValueError("COZE_CHAT_BOT_ID 未配置")
|
||||
elif session_type == "training":
|
||||
bot_id = settings.COZE_TRAINING_BOT_ID
|
||||
if not bot_id:
|
||||
raise ValueError("COZE_TRAINING_BOT_ID 未配置")
|
||||
else:
|
||||
raise ValueError(f"不支持的会话类型: {session_type}")
|
||||
|
||||
return {"bot_id": bot_id, "workspace_id": settings.COZE_WORKSPACE_ID}
|
||||
44
backend/app/services/ai/coze/client_backup.py
Normal file
44
backend/app/services/ai/coze/client_backup.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Coze客户端(临时模拟,等Agent-Coze实现后替换)"""
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CozeClient:
|
||||
"""
|
||||
Coze客户端模拟类
|
||||
TODO: 等Agent-Coze模块实现后,这个类将被真实的Coze网关客户端替换
|
||||
"""
|
||||
|
||||
async def create_conversation(
|
||||
self, bot_id: str, user_id: str, meta_data: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""创建会话(模拟)"""
|
||||
logger.info(f"模拟创建Coze会话: bot_id={bot_id}, user_id={user_id}")
|
||||
|
||||
# 返回模拟的会话信息
|
||||
return {
|
||||
"conversation_id": f"mock_conversation_{user_id}_{bot_id[:8]}",
|
||||
"bot_id": bot_id,
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
async def send_message(
|
||||
self, conversation_id: str, content: str, message_type: str = "text"
|
||||
) -> Dict[str, Any]:
|
||||
"""发送消息(模拟)"""
|
||||
logger.info(f"模拟发送消息到会话 {conversation_id}: {content[:50]}...")
|
||||
|
||||
# 返回模拟的消息响应
|
||||
return {
|
||||
"message_id": f"mock_msg_{conversation_id[:8]}",
|
||||
"content": f"这是对'{content[:30]}...'的模拟回复",
|
||||
"role": "assistant",
|
||||
}
|
||||
|
||||
async def end_conversation(self, conversation_id: str) -> Dict[str, Any]:
|
||||
"""结束会话(模拟)"""
|
||||
logger.info(f"模拟结束会话: {conversation_id}")
|
||||
|
||||
return {"status": "completed", "conversation_id": conversation_id}
|
||||
101
backend/app/services/ai/coze/exceptions.py
Normal file
101
backend/app/services/ai/coze/exceptions.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Coze 服务异常定义
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class CozeException(Exception):
|
||||
"""Coze 服务基础异常"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: Optional[str] = None,
|
||||
status_code: Optional[int] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class CozeAuthError(CozeException):
|
||||
"""认证异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CozeAPIError(CozeException):
|
||||
"""API 调用异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CozeRateLimitError(CozeException):
|
||||
"""速率限制异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CozeTimeoutError(CozeException):
|
||||
"""超时异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CozeStreamError(CozeException):
|
||||
"""流式响应异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def map_coze_error_to_exception(error: Exception) -> CozeException:
|
||||
"""
|
||||
将 Coze SDK 错误映射为统一异常
|
||||
|
||||
Args:
|
||||
error: 原始异常
|
||||
|
||||
Returns:
|
||||
CozeException: 映射后的异常
|
||||
"""
|
||||
error_message = str(error)
|
||||
|
||||
# 根据错误消息判断错误类型
|
||||
if (
|
||||
"authentication" in error_message.lower()
|
||||
or "unauthorized" in error_message.lower()
|
||||
):
|
||||
return CozeAuthError(
|
||||
message="Coze 认证失败",
|
||||
code="COZE_AUTH_ERROR",
|
||||
status_code=401,
|
||||
details={"original_error": error_message},
|
||||
)
|
||||
|
||||
if "rate limit" in error_message.lower():
|
||||
return CozeRateLimitError(
|
||||
message="Coze API 速率限制",
|
||||
code="COZE_RATE_LIMIT",
|
||||
status_code=429,
|
||||
details={"original_error": error_message},
|
||||
)
|
||||
|
||||
if "timeout" in error_message.lower():
|
||||
return CozeTimeoutError(
|
||||
message="Coze API 调用超时",
|
||||
code="COZE_TIMEOUT",
|
||||
status_code=504,
|
||||
details={"original_error": error_message},
|
||||
)
|
||||
|
||||
# 默认映射为 API 错误
|
||||
return CozeAPIError(
|
||||
message="Coze API 调用失败",
|
||||
code="COZE_API_ERROR",
|
||||
status_code=500,
|
||||
details={"original_error": error_message},
|
||||
)
|
||||
136
backend/app/services/ai/coze/models.py
Normal file
136
backend/app/services/ai/coze/models.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Coze 服务数据模型
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
"""会话类型"""
|
||||
|
||||
COURSE_CHAT = "course_chat" # 课程对话
|
||||
TRAINING = "training" # 陪练会话
|
||||
EXAM = "exam" # 考试会话
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""消息角色"""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class ContentType(str, Enum):
|
||||
"""内容类型"""
|
||||
|
||||
TEXT = "text"
|
||||
CARD = "card"
|
||||
IMAGE = "image"
|
||||
FILE = "file"
|
||||
|
||||
|
||||
class StreamEventType(str, Enum):
|
||||
"""流式事件类型"""
|
||||
|
||||
MESSAGE_START = "conversation.message.start"
|
||||
MESSAGE_DELTA = "conversation.message.delta"
|
||||
MESSAGE_COMPLETED = "conversation.message.completed"
|
||||
ERROR = "error"
|
||||
DONE = "done"
|
||||
|
||||
|
||||
class CozeSession(BaseModel):
|
||||
"""Coze 会话模型"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
conversation_id: str = Field(..., description="Coze对话ID")
|
||||
session_type: SessionType = Field(..., description="会话类型")
|
||||
user_id: str = Field(..., description="用户ID")
|
||||
bot_id: str = Field(..., description="Bot ID")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
|
||||
ended_at: Optional[datetime] = Field(None, description="结束时间")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class CozeMessage(BaseModel):
|
||||
"""Coze 消息模型"""
|
||||
|
||||
message_id: str = Field(..., description="消息ID")
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
role: MessageRole = Field(..., description="消息角色")
|
||||
content: str = Field(..., description="消息内容")
|
||||
content_type: ContentType = Field(ContentType.TEXT, description="内容类型")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class StreamEvent(BaseModel):
|
||||
"""流式事件模型"""
|
||||
|
||||
event: StreamEventType = Field(..., description="事件类型")
|
||||
data: Dict[str, Any] = Field(..., description="事件数据")
|
||||
message_id: Optional[str] = Field(None, description="消息ID")
|
||||
content: Optional[str] = Field(None, description="内容")
|
||||
content_type: Optional[ContentType] = Field(None, description="内容类型")
|
||||
role: Optional[MessageRole] = Field(None, description="角色")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""创建会话请求"""
|
||||
|
||||
session_type: SessionType = Field(..., description="会话类型")
|
||||
user_id: str = Field(..., description="用户ID")
|
||||
course_id: Optional[str] = Field(None, description="课程ID (课程对话时必需)")
|
||||
training_topic: Optional[str] = Field(None, description="陪练主题 (陪练时可选)")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""创建会话响应"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
conversation_id: str = Field(..., description="Coze对话ID")
|
||||
bot_id: str = Field(..., description="Bot ID")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
"""发送消息请求"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
content: str = Field(..., description="消息内容")
|
||||
file_ids: List[str] = Field(default_factory=list, description="附件ID列表")
|
||||
stream: bool = Field(True, description="是否流式响应")
|
||||
|
||||
|
||||
class EndSessionRequest(BaseModel):
|
||||
"""结束会话请求"""
|
||||
|
||||
reason: Optional[str] = Field(None, description="结束原因")
|
||||
feedback: Optional[Dict[str, Any]] = Field(None, description="用户反馈")
|
||||
|
||||
|
||||
class EndSessionResponse(BaseModel):
|
||||
"""结束会话响应"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
ended_at: datetime = Field(..., description="结束时间")
|
||||
duration_seconds: int = Field(..., description="会话时长(秒)")
|
||||
message_count: int = Field(..., description="消息数量")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
335
backend/app/services/ai/coze/service.py
Normal file
335
backend/app/services/ai/coze/service.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Coze 服务层实现
|
||||
处理会话管理、消息发送、流式响应等核心功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import AsyncIterator, Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from cozepy import ChatEventType, Message, MessageContentType
|
||||
|
||||
from .client import get_coze_client, get_bot_config, get_workspace_id
|
||||
from .models import (
|
||||
CozeSession,
|
||||
CozeMessage,
|
||||
StreamEvent,
|
||||
SessionType,
|
||||
MessageRole,
|
||||
ContentType,
|
||||
StreamEventType,
|
||||
CreateSessionRequest,
|
||||
CreateSessionResponse,
|
||||
SendMessageRequest,
|
||||
EndSessionRequest,
|
||||
EndSessionResponse,
|
||||
)
|
||||
from .exceptions import (
|
||||
CozeAPIError,
|
||||
CozeStreamError,
|
||||
CozeTimeoutError,
|
||||
map_coze_error_to_exception,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CozeService:
|
||||
"""Coze 服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = get_coze_client()
|
||||
self.bot_config = get_bot_config()
|
||||
self.workspace_id = get_workspace_id()
|
||||
|
||||
# 内存中的会话存储(生产环境应使用 Redis)
|
||||
self._sessions: Dict[str, CozeSession] = {}
|
||||
self._messages: Dict[str, List[CozeMessage]] = {}
|
||||
|
||||
async def create_session(
|
||||
self, request: CreateSessionRequest
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
创建新会话
|
||||
|
||||
Args:
|
||||
request: 创建会话请求
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: 会话信息
|
||||
"""
|
||||
try:
|
||||
# 根据会话类型选择 Bot
|
||||
bot_id = self._get_bot_id_by_type(request.session_type)
|
||||
|
||||
# 创建 Coze 对话
|
||||
conversation = await asyncio.to_thread(
|
||||
self.client.conversations.create, bot_id=bot_id
|
||||
)
|
||||
|
||||
# 创建本地会话记录
|
||||
session = CozeSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
conversation_id=conversation.id,
|
||||
session_type=request.session_type,
|
||||
user_id=request.user_id,
|
||||
bot_id=bot_id,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
# 保存会话
|
||||
self._sessions[session.session_id] = session
|
||||
self._messages[session.session_id] = []
|
||||
|
||||
logger.info(
|
||||
f"创建会话成功",
|
||||
extra={
|
||||
"session_id": session.session_id,
|
||||
"conversation_id": conversation.id,
|
||||
"session_type": request.session_type.value,
|
||||
"user_id": request.user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return CreateSessionResponse(
|
||||
session_id=session.session_id,
|
||||
conversation_id=session.conversation_id,
|
||||
bot_id=session.bot_id,
|
||||
created_at=session.created_at,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建会话失败: {e}", exc_info=True)
|
||||
raise map_coze_error_to_exception(e)
|
||||
|
||||
async def send_message(
|
||||
self, request: SendMessageRequest
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""
|
||||
发送消息并处理流式响应
|
||||
|
||||
Args:
|
||||
request: 发送消息请求
|
||||
|
||||
Yields:
|
||||
StreamEvent: 流式事件
|
||||
"""
|
||||
session = self._get_session(request.session_id)
|
||||
if not session:
|
||||
raise CozeAPIError(f"会话不存在: {request.session_id}")
|
||||
|
||||
# 记录用户消息
|
||||
user_message = CozeMessage(
|
||||
message_id=str(uuid.uuid4()),
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content=request.content,
|
||||
)
|
||||
self._messages[session.session_id].append(user_message)
|
||||
|
||||
try:
|
||||
# 构建消息历史
|
||||
messages = self._build_message_history(session.session_id)
|
||||
|
||||
# 调用 Coze API
|
||||
stream = await asyncio.to_thread(
|
||||
self.client.chat.stream,
|
||||
bot_id=session.bot_id,
|
||||
conversation_id=session.conversation_id,
|
||||
additional_messages=messages,
|
||||
auto_save_history=True,
|
||||
)
|
||||
|
||||
# 处理流式响应
|
||||
async for event in self._process_stream(stream, session.session_id):
|
||||
yield event
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"消息发送超时: session_id={request.session_id}")
|
||||
raise CozeTimeoutError("消息处理超时")
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}", exc_info=True)
|
||||
raise map_coze_error_to_exception(e)
|
||||
|
||||
async def end_session(
|
||||
self, session_id: str, request: EndSessionRequest
|
||||
) -> EndSessionResponse:
|
||||
"""
|
||||
结束会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
request: 结束会话请求
|
||||
|
||||
Returns:
|
||||
EndSessionResponse: 结束会话响应
|
||||
"""
|
||||
session = self._get_session(session_id)
|
||||
if not session:
|
||||
raise CozeAPIError(f"会话不存在: {session_id}")
|
||||
|
||||
# 更新会话状态
|
||||
session.ended_at = datetime.now()
|
||||
|
||||
# 计算会话统计
|
||||
duration_seconds = int((session.ended_at - session.created_at).total_seconds())
|
||||
message_count = len(self._messages.get(session_id, []))
|
||||
|
||||
# 记录结束原因和反馈
|
||||
if request.reason:
|
||||
session.metadata["end_reason"] = request.reason
|
||||
if request.feedback:
|
||||
session.metadata["feedback"] = request.feedback
|
||||
|
||||
logger.info(
|
||||
f"会话结束",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"duration_seconds": duration_seconds,
|
||||
"message_count": message_count,
|
||||
"reason": request.reason,
|
||||
},
|
||||
)
|
||||
|
||||
return EndSessionResponse(
|
||||
session_id=session_id,
|
||||
ended_at=session.ended_at,
|
||||
duration_seconds=duration_seconds,
|
||||
message_count=message_count,
|
||||
)
|
||||
|
||||
async def get_session_messages(
|
||||
self, session_id: str, limit: int = 50, offset: int = 0
|
||||
) -> List[CozeMessage]:
|
||||
"""获取会话消息历史"""
|
||||
messages = self._messages.get(session_id, [])
|
||||
return messages[offset : offset + limit]
|
||||
|
||||
def _get_bot_id_by_type(self, session_type: SessionType) -> str:
|
||||
"""根据会话类型获取 Bot ID"""
|
||||
mapping = {
|
||||
SessionType.COURSE_CHAT: self.bot_config["course_chat"],
|
||||
SessionType.TRAINING: self.bot_config["training"],
|
||||
SessionType.EXAM: self.bot_config["exam"],
|
||||
}
|
||||
return mapping.get(session_type, self.bot_config["training"])
|
||||
|
||||
def _get_session(self, session_id: str) -> Optional[CozeSession]:
|
||||
"""获取会话"""
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def _build_message_history(self, session_id: str) -> List[Message]:
|
||||
"""构建消息历史"""
|
||||
messages = self._messages.get(session_id, [])
|
||||
history = []
|
||||
|
||||
for msg in messages[-10:]: # 只发送最近10条消息作为上下文
|
||||
history.append(
|
||||
Message(
|
||||
role=msg.role.value,
|
||||
content=msg.content,
|
||||
content_type=MessageContentType.TEXT,
|
||||
)
|
||||
)
|
||||
|
||||
return history
|
||||
|
||||
async def _process_stream(
|
||||
self, stream, session_id: str
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""处理流式响应"""
|
||||
assistant_message_id = str(uuid.uuid4())
|
||||
accumulated_content = []
|
||||
content_type = ContentType.TEXT
|
||||
|
||||
try:
|
||||
for event in stream:
|
||||
if event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA:
|
||||
# 消息片段
|
||||
content = event.message.content
|
||||
accumulated_content.append(content)
|
||||
|
||||
# 检测卡片类型
|
||||
if (
|
||||
hasattr(event.message, "content_type")
|
||||
and event.message.content_type == "card"
|
||||
):
|
||||
content_type = ContentType.CARD
|
||||
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.MESSAGE_DELTA,
|
||||
data={
|
||||
"conversation_id": event.conversation_id,
|
||||
"message_id": assistant_message_id,
|
||||
"content": content,
|
||||
"content_type": content_type.value,
|
||||
},
|
||||
message_id=assistant_message_id,
|
||||
content=content,
|
||||
content_type=content_type,
|
||||
role=MessageRole.ASSISTANT,
|
||||
)
|
||||
|
||||
elif event.event == ChatEventType.CONVERSATION_MESSAGE_COMPLETED:
|
||||
# 消息完成
|
||||
full_content = "".join(accumulated_content)
|
||||
|
||||
# 保存助手消息
|
||||
assistant_message = CozeMessage(
|
||||
message_id=assistant_message_id,
|
||||
session_id=session_id,
|
||||
role=MessageRole.ASSISTANT,
|
||||
content=full_content,
|
||||
content_type=content_type,
|
||||
)
|
||||
self._messages[session_id].append(assistant_message)
|
||||
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.MESSAGE_COMPLETED,
|
||||
data={
|
||||
"conversation_id": event.conversation_id,
|
||||
"message_id": assistant_message_id,
|
||||
"content": full_content,
|
||||
"content_type": content_type.value,
|
||||
"usage": getattr(event, "usage", {}),
|
||||
},
|
||||
message_id=assistant_message_id,
|
||||
content=full_content,
|
||||
content_type=content_type,
|
||||
role=MessageRole.ASSISTANT,
|
||||
)
|
||||
|
||||
elif event.event == ChatEventType.ERROR:
|
||||
# 错误事件
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.ERROR,
|
||||
data={"error": str(event)},
|
||||
error=str(event),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式处理错误: {e}", exc_info=True)
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.ERROR, data={"error": str(e)}, error=str(e)
|
||||
)
|
||||
finally:
|
||||
# 发送结束事件
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.DONE, data={"session_id": session_id}
|
||||
)
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
_service: Optional[CozeService] = None
|
||||
|
||||
|
||||
def get_coze_service() -> CozeService:
|
||||
"""获取 Coze 服务单例"""
|
||||
global _service
|
||||
if _service is None:
|
||||
_service = CozeService()
|
||||
return _service
|
||||
Reference in New Issue
Block a user