- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
204 lines
6.6 KiB
Python
204 lines
6.6 KiB
Python
"""
|
||
Coze AI 客户端管理
|
||
负责管理 Coze API 的认证和客户端实例
|
||
"""
|
||
from functools import lru_cache
|
||
from typing import Optional, Dict, Any
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
from cozepy import Coze, TokenAuth, JWTAuth, COZE_CN_BASE_URL
|
||
|
||
from app.core.config import get_settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class CozeAuthManager:
|
||
"""Coze 认证管理器"""
|
||
|
||
def __init__(self):
|
||
self.settings = get_settings()
|
||
self._client: Optional[Coze] = None
|
||
|
||
def _create_pat_auth(self) -> TokenAuth:
|
||
"""创建个人访问令牌认证"""
|
||
if not self.settings.COZE_API_TOKEN:
|
||
raise ValueError("COZE_API_TOKEN 未配置")
|
||
|
||
return TokenAuth(token=self.settings.COZE_API_TOKEN)
|
||
|
||
def _create_oauth_auth(self) -> JWTAuth:
|
||
"""创建 OAuth 认证"""
|
||
if not all(
|
||
[
|
||
self.settings.COZE_OAUTH_CLIENT_ID,
|
||
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
|
||
]
|
||
):
|
||
raise ValueError("OAuth 配置不完整")
|
||
|
||
# 读取私钥
|
||
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
|
||
if not private_key_path.exists():
|
||
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
|
||
|
||
with open(private_key_path, "r") as f:
|
||
private_key = f.read()
|
||
|
||
try:
|
||
return JWTAuth(
|
||
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
|
||
private_key=private_key,
|
||
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"创建 OAuth 认证失败: {e}")
|
||
raise
|
||
|
||
def get_client(self, force_new: bool = False) -> Coze:
|
||
"""
|
||
获取 Coze 客户端实例
|
||
|
||
Args:
|
||
force_new: 是否强制创建新客户端(用于长时间运行的请求,避免token过期)
|
||
|
||
认证优先级:
|
||
1. OAuth(推荐):配置完整时使用,自动刷新token
|
||
2. PAT:仅当OAuth未配置时使用(注意:PAT会过期)
|
||
"""
|
||
if self._client is not None and not force_new:
|
||
return self._client
|
||
|
||
auth = None
|
||
auth_type = None
|
||
|
||
# 检查 OAuth 配置是否完整
|
||
oauth_configured = all([
|
||
self.settings.COZE_OAUTH_CLIENT_ID,
|
||
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
|
||
])
|
||
|
||
if oauth_configured:
|
||
# OAuth 配置完整,必须使用 OAuth(不fallback到PAT)
|
||
try:
|
||
auth = self._create_oauth_auth()
|
||
auth_type = "OAuth"
|
||
logger.info("使用 OAuth 认证")
|
||
except Exception as e:
|
||
# OAuth 配置完整但创建失败,直接抛出异常(不fallback到可能过期的PAT)
|
||
logger.error(f"OAuth 认证创建失败: {e}")
|
||
raise ValueError(f"OAuth 认证失败,请检查私钥文件和配置: {e}")
|
||
else:
|
||
# OAuth 未配置,使用 PAT
|
||
if self.settings.COZE_API_TOKEN:
|
||
auth = self._create_pat_auth()
|
||
auth_type = "PAT"
|
||
logger.warning("使用 PAT 认证(注意:PAT会过期,建议配置OAuth)")
|
||
else:
|
||
raise ValueError("Coze 认证未配置:需要配置 OAuth 或 PAT Token")
|
||
|
||
# 创建客户端
|
||
client = Coze(
|
||
auth=auth, base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL
|
||
)
|
||
|
||
logger.debug(f"Coze客户端创建成功,认证方式: {auth_type}, force_new: {force_new}")
|
||
|
||
# 只有非强制创建时才缓存
|
||
if not force_new:
|
||
self._client = client
|
||
|
||
return client
|
||
|
||
def reset(self):
|
||
"""重置客户端实例"""
|
||
self._client = None
|
||
|
||
def get_oauth_token(self) -> str:
|
||
"""
|
||
获取OAuth JWT Token用于前端直连
|
||
|
||
Returns:
|
||
JWT token字符串
|
||
"""
|
||
if not all([
|
||
self.settings.COZE_OAUTH_CLIENT_ID,
|
||
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
|
||
]):
|
||
raise ValueError("OAuth 配置不完整")
|
||
|
||
# 读取私钥
|
||
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
|
||
if not private_key_path.exists():
|
||
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
|
||
|
||
with open(private_key_path, "r") as f:
|
||
private_key = f.read()
|
||
|
||
# 创建JWTAuth实例(必须指定中国区base_url)
|
||
jwt_auth = JWTAuth(
|
||
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
|
||
private_key=private_key,
|
||
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
|
||
)
|
||
|
||
# 获取token(JWTAuth内部会自动生成)
|
||
# JWTAuth.token属性返回已签名的JWT
|
||
return jwt_auth.token
|
||
|
||
|
||
@lru_cache()
|
||
def get_auth_manager() -> CozeAuthManager:
|
||
"""获取认证管理器单例"""
|
||
return CozeAuthManager()
|
||
|
||
|
||
def get_coze_client(force_new: bool = False) -> Coze:
|
||
"""
|
||
获取 Coze 客户端
|
||
|
||
Args:
|
||
force_new: 是否强制创建新客户端(用于工作流等长时间运行的请求)
|
||
"""
|
||
return get_auth_manager().get_client(force_new=force_new)
|
||
|
||
|
||
def get_workspace_id() -> str:
|
||
"""获取工作空间 ID"""
|
||
settings = get_settings()
|
||
if not settings.COZE_WORKSPACE_ID:
|
||
raise ValueError("COZE_WORKSPACE_ID 未配置")
|
||
return settings.COZE_WORKSPACE_ID
|
||
|
||
|
||
def get_bot_config(session_type: str) -> Dict[str, Any]:
|
||
"""
|
||
根据会话类型获取 Bot 配置
|
||
|
||
Args:
|
||
session_type: 会话类型 (course_chat 或 training)
|
||
|
||
Returns:
|
||
包含 bot_id 等配置的字典
|
||
"""
|
||
settings = get_settings()
|
||
|
||
if session_type == "course_chat":
|
||
bot_id = settings.COZE_CHAT_BOT_ID
|
||
if not bot_id:
|
||
raise ValueError("COZE_CHAT_BOT_ID 未配置")
|
||
elif session_type == "training":
|
||
bot_id = settings.COZE_TRAINING_BOT_ID
|
||
if not bot_id:
|
||
raise ValueError("COZE_TRAINING_BOT_ID 未配置")
|
||
else:
|
||
raise ValueError(f"不支持的会话类型: {session_type}")
|
||
|
||
return {"bot_id": bot_id, "workspace_id": settings.COZE_WORKSPACE_ID}
|