Files
012-kaopeilian/backend/app/api/v1/coze_gateway.py
yuliang_guo 41a2f7944a
All checks were successful
continuous-integration/drone/push Build is passing
fix: 修复flake8 lint检查错误
- 删除废弃的 admin_positions_backup.py 备份文件
- 修复 courses.py 缺失的 select 导入
- 修复 coze_gateway.py 异常变量作用域问题
- 修复 scheduler_service.py 无用的 global 声明
- 添加 TYPE_CHECKING 导入解决模型前向引用警告
2026-01-31 17:43:39 +08:00

280 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Coze 网关 API 路由
提供课程对话和陪练功能的统一接口
"""
import logging
from typing import Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse
from app.services.ai.coze import (
get_coze_service,
CreateSessionRequest,
SendMessageRequest,
EndSessionRequest,
SessionType,
CozeException,
StreamEventType,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["coze-gateway"])
# TODO: 依赖注入获取当前用户
async def get_current_user():
"""获取当前登录用户(临时实现)"""
# 实际应该从 Auth 模块获取
return {"user_id": "test-user-123", "username": "test_user"}
@router.post("/course-chat/sessions")
async def create_course_chat_session(course_id: str, user=Depends(get_current_user)):
"""
创建课程对话会话
- **course_id**: 课程ID
"""
try:
service = get_coze_service()
request = CreateSessionRequest(
session_type=SessionType.COURSE_CHAT,
user_id=user["user_id"],
course_id=course_id,
metadata={"username": user["username"], "course_id": course_id},
)
response = await service.create_session(request)
return {"code": 200, "message": "success", "data": response.dict()}
except CozeException as e:
logger.error(f"创建课程对话会话失败: {e}")
raise HTTPException(
status_code=e.status_code or 500,
detail={"code": e.code, "message": e.message, "details": e.details},
)
except Exception as e:
logger.error(f"未知错误: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)
@router.post("/training/sessions")
async def create_training_session(
training_topic: str = None, user=Depends(get_current_user)
):
"""
创建陪练会话
- **training_topic**: 陪练主题(可选)
"""
try:
service = get_coze_service()
request = CreateSessionRequest(
session_type=SessionType.TRAINING,
user_id=user["user_id"],
training_topic=training_topic,
metadata={"username": user["username"], "training_topic": training_topic},
)
response = await service.create_session(request)
return {"code": 200, "message": "success", "data": response.dict()}
except CozeException as e:
logger.error(f"创建陪练会话失败: {e}")
raise HTTPException(
status_code=e.status_code or 500,
detail={"code": e.code, "message": e.message, "details": e.details},
)
except Exception as e:
logger.error(f"未知错误: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)
@router.post("/training/sessions/{session_id}/end")
async def end_training_session(
session_id: str, request: EndSessionRequest, user=Depends(get_current_user)
):
"""
结束陪练会话
- **session_id**: 会话ID
"""
try:
service = get_coze_service()
response = await service.end_session(session_id, request)
return {"code": 200, "message": "success", "data": response.dict()}
except CozeException as e:
logger.error(f"结束会话失败: {e}")
raise HTTPException(
status_code=e.status_code or 500,
detail={"code": e.code, "message": e.message, "details": e.details},
)
except Exception as e:
logger.error(f"未知错误: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)
@router.post("/chat/messages")
async def send_message(request: SendMessageRequest, user=Depends(get_current_user)):
"""
发送消息(支持流式响应)
- **session_id**: 会话ID
- **content**: 消息内容
- **stream**: 是否流式响应默认True
"""
try:
service = get_coze_service()
if request.stream:
# 流式响应
async def event_generator():
async for event in service.send_message(request):
# 转换为 SSE 格式
if event.event == StreamEventType.MESSAGE_DELTA:
yield {
"event": "message",
"data": {
"type": "delta",
"content": event.content,
"content_type": event.content_type.value,
"message_id": event.message_id,
},
}
elif event.event == StreamEventType.MESSAGE_COMPLETED:
yield {
"event": "message",
"data": {
"type": "completed",
"content": event.content,
"content_type": event.content_type.value,
"message_id": event.message_id,
"usage": event.data.get("usage", {}),
},
}
elif event.event == StreamEventType.ERROR:
yield {"event": "error", "data": {"error": event.error}}
elif event.event == StreamEventType.DONE:
yield {
"event": "done",
"data": {"session_id": event.data.get("session_id")},
}
return EventSourceResponse(event_generator())
else:
# 非流式响应(收集完整响应)
full_content = ""
content_type = None
message_id = None
async for event in service.send_message(request):
if event.event == StreamEventType.MESSAGE_COMPLETED:
full_content = event.content
content_type = event.content_type
message_id = event.message_id
break
return {
"code": 200,
"message": "success",
"data": {
"message_id": message_id,
"content": full_content,
"content_type": content_type.value if content_type else "text",
"role": "assistant",
},
}
except CozeException as coze_err:
logger.error(f"发送消息失败: {coze_err}")
if request.stream:
# 流式响应的错误处理 - 捕获异常信息避免闭包问题
err_code = coze_err.code
err_message = coze_err.message
err_details = coze_err.details
async def error_generator():
yield {
"event": "error",
"data": {
"code": err_code,
"message": err_message,
"details": err_details,
},
}
return EventSourceResponse(error_generator())
else:
raise HTTPException(
status_code=coze_err.status_code or 500,
detail={"code": err_code, "message": err_message, "details": err_details},
)
except Exception as e:
logger.error(f"未知错误: {e}", exc_info=True)
if request.stream:
async def error_generator():
yield {
"event": "error",
"data": {"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
}
return EventSourceResponse(error_generator())
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)
@router.get("/sessions/{session_id}/messages")
async def get_session_messages(
session_id: str, limit: int = 50, offset: int = 0, user=Depends(get_current_user)
):
"""
获取会话消息历史
- **session_id**: 会话ID
- **limit**: 返回消息数量限制
- **offset**: 偏移量
"""
try:
service = get_coze_service()
messages = await service.get_session_messages(session_id, limit, offset)
return {
"code": 200,
"message": "success",
"data": {
"messages": [msg.dict() for msg in messages],
"total": len(messages),
"limit": limit,
"offset": offset,
},
}
except Exception as e:
logger.error(f"获取消息历史失败: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)