Files
012-kaopeilian/backend/app/services/ai/coze/client.py
111 998211c483 feat: 初始化考培练系统项目
- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
2026-01-24 19:33:28 +08:00

204 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Coze AI 客户端管理
负责管理 Coze API 的认证和客户端实例
"""
from functools import lru_cache
from typing import Optional, Dict, Any
import logging
from pathlib import Path
from cozepy import Coze, TokenAuth, JWTAuth, COZE_CN_BASE_URL
from app.core.config import get_settings
logger = logging.getLogger(__name__)
class CozeAuthManager:
"""Coze 认证管理器"""
def __init__(self):
self.settings = get_settings()
self._client: Optional[Coze] = None
def _create_pat_auth(self) -> TokenAuth:
"""创建个人访问令牌认证"""
if not self.settings.COZE_API_TOKEN:
raise ValueError("COZE_API_TOKEN 未配置")
return TokenAuth(token=self.settings.COZE_API_TOKEN)
def _create_oauth_auth(self) -> JWTAuth:
"""创建 OAuth 认证"""
if not all(
[
self.settings.COZE_OAUTH_CLIENT_ID,
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
]
):
raise ValueError("OAuth 配置不完整")
# 读取私钥
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
if not private_key_path.exists():
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
with open(private_key_path, "r") as f:
private_key = f.read()
try:
return JWTAuth(
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
private_key=private_key,
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
)
except Exception as e:
logger.error(f"创建 OAuth 认证失败: {e}")
raise
def get_client(self, force_new: bool = False) -> Coze:
"""
获取 Coze 客户端实例
Args:
force_new: 是否强制创建新客户端用于长时间运行的请求避免token过期
认证优先级:
1. OAuth推荐配置完整时使用自动刷新token
2. PAT仅当OAuth未配置时使用注意PAT会过期
"""
if self._client is not None and not force_new:
return self._client
auth = None
auth_type = None
# 检查 OAuth 配置是否完整
oauth_configured = all([
self.settings.COZE_OAUTH_CLIENT_ID,
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
])
if oauth_configured:
# OAuth 配置完整,必须使用 OAuth不fallback到PAT
try:
auth = self._create_oauth_auth()
auth_type = "OAuth"
logger.info("使用 OAuth 认证")
except Exception as e:
# OAuth 配置完整但创建失败直接抛出异常不fallback到可能过期的PAT
logger.error(f"OAuth 认证创建失败: {e}")
raise ValueError(f"OAuth 认证失败,请检查私钥文件和配置: {e}")
else:
# OAuth 未配置,使用 PAT
if self.settings.COZE_API_TOKEN:
auth = self._create_pat_auth()
auth_type = "PAT"
logger.warning("使用 PAT 认证注意PAT会过期建议配置OAuth")
else:
raise ValueError("Coze 认证未配置:需要配置 OAuth 或 PAT Token")
# 创建客户端
client = Coze(
auth=auth, base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL
)
logger.debug(f"Coze客户端创建成功认证方式: {auth_type}, force_new: {force_new}")
# 只有非强制创建时才缓存
if not force_new:
self._client = client
return client
def reset(self):
"""重置客户端实例"""
self._client = None
def get_oauth_token(self) -> str:
"""
获取OAuth JWT Token用于前端直连
Returns:
JWT token字符串
"""
if not all([
self.settings.COZE_OAUTH_CLIENT_ID,
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
]):
raise ValueError("OAuth 配置不完整")
# 读取私钥
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
if not private_key_path.exists():
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
with open(private_key_path, "r") as f:
private_key = f.read()
# 创建JWTAuth实例必须指定中国区base_url
jwt_auth = JWTAuth(
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
private_key=private_key,
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
)
# 获取tokenJWTAuth内部会自动生成
# JWTAuth.token属性返回已签名的JWT
return jwt_auth.token
@lru_cache()
def get_auth_manager() -> CozeAuthManager:
"""获取认证管理器单例"""
return CozeAuthManager()
def get_coze_client(force_new: bool = False) -> Coze:
"""
获取 Coze 客户端
Args:
force_new: 是否强制创建新客户端(用于工作流等长时间运行的请求)
"""
return get_auth_manager().get_client(force_new=force_new)
def get_workspace_id() -> str:
"""获取工作空间 ID"""
settings = get_settings()
if not settings.COZE_WORKSPACE_ID:
raise ValueError("COZE_WORKSPACE_ID 未配置")
return settings.COZE_WORKSPACE_ID
def get_bot_config(session_type: str) -> Dict[str, Any]:
"""
根据会话类型获取 Bot 配置
Args:
session_type: 会话类型 (course_chat 或 training)
Returns:
包含 bot_id 等配置的字典
"""
settings = get_settings()
if session_type == "course_chat":
bot_id = settings.COZE_CHAT_BOT_ID
if not bot_id:
raise ValueError("COZE_CHAT_BOT_ID 未配置")
elif session_type == "training":
bot_id = settings.COZE_TRAINING_BOT_ID
if not bot_id:
raise ValueError("COZE_TRAINING_BOT_ID 未配置")
else:
raise ValueError(f"不支持的会话类型: {session_type}")
return {"bot_id": bot_id, "workspace_id": settings.COZE_WORKSPACE_ID}