- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
276 lines
9.1 KiB
Python
276 lines
9.1 KiB
Python
"""
|
||
Coze 网关 API 路由
|
||
提供课程对话和陪练功能的统一接口
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from fastapi.responses import StreamingResponse
|
||
from sse_starlette.sse import EventSourceResponse
|
||
|
||
from app.services.ai.coze import (
|
||
get_coze_service,
|
||
CreateSessionRequest,
|
||
SendMessageRequest,
|
||
EndSessionRequest,
|
||
SessionType,
|
||
CozeException,
|
||
StreamEventType,
|
||
)
|
||
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(tags=["coze-gateway"])
|
||
|
||
|
||
# TODO: 依赖注入获取当前用户
|
||
async def get_current_user():
|
||
"""获取当前登录用户(临时实现)"""
|
||
# 实际应该从 Auth 模块获取
|
||
return {"user_id": "test-user-123", "username": "test_user"}
|
||
|
||
|
||
@router.post("/course-chat/sessions")
|
||
async def create_course_chat_session(course_id: str, user=Depends(get_current_user)):
|
||
"""
|
||
创建课程对话会话
|
||
|
||
- **course_id**: 课程ID
|
||
"""
|
||
try:
|
||
service = get_coze_service()
|
||
request = CreateSessionRequest(
|
||
session_type=SessionType.COURSE_CHAT,
|
||
user_id=user["user_id"],
|
||
course_id=course_id,
|
||
metadata={"username": user["username"], "course_id": course_id},
|
||
)
|
||
|
||
response = await service.create_session(request)
|
||
|
||
return {"code": 200, "message": "success", "data": response.dict()}
|
||
|
||
except CozeException as e:
|
||
logger.error(f"创建课程对话会话失败: {e}")
|
||
raise HTTPException(
|
||
status_code=e.status_code or 500,
|
||
detail={"code": e.code, "message": e.message, "details": e.details},
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"未知错误: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
|
||
)
|
||
|
||
|
||
@router.post("/training/sessions")
|
||
async def create_training_session(
|
||
training_topic: str = None, user=Depends(get_current_user)
|
||
):
|
||
"""
|
||
创建陪练会话
|
||
|
||
- **training_topic**: 陪练主题(可选)
|
||
"""
|
||
try:
|
||
service = get_coze_service()
|
||
request = CreateSessionRequest(
|
||
session_type=SessionType.TRAINING,
|
||
user_id=user["user_id"],
|
||
training_topic=training_topic,
|
||
metadata={"username": user["username"], "training_topic": training_topic},
|
||
)
|
||
|
||
response = await service.create_session(request)
|
||
|
||
return {"code": 200, "message": "success", "data": response.dict()}
|
||
|
||
except CozeException as e:
|
||
logger.error(f"创建陪练会话失败: {e}")
|
||
raise HTTPException(
|
||
status_code=e.status_code or 500,
|
||
detail={"code": e.code, "message": e.message, "details": e.details},
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"未知错误: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
|
||
)
|
||
|
||
|
||
@router.post("/training/sessions/{session_id}/end")
|
||
async def end_training_session(
|
||
session_id: str, request: EndSessionRequest, user=Depends(get_current_user)
|
||
):
|
||
"""
|
||
结束陪练会话
|
||
|
||
- **session_id**: 会话ID
|
||
"""
|
||
try:
|
||
service = get_coze_service()
|
||
response = await service.end_session(session_id, request)
|
||
|
||
return {"code": 200, "message": "success", "data": response.dict()}
|
||
|
||
except CozeException as e:
|
||
logger.error(f"结束会话失败: {e}")
|
||
raise HTTPException(
|
||
status_code=e.status_code or 500,
|
||
detail={"code": e.code, "message": e.message, "details": e.details},
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"未知错误: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
|
||
)
|
||
|
||
|
||
@router.post("/chat/messages")
|
||
async def send_message(request: SendMessageRequest, user=Depends(get_current_user)):
|
||
"""
|
||
发送消息(支持流式响应)
|
||
|
||
- **session_id**: 会话ID
|
||
- **content**: 消息内容
|
||
- **stream**: 是否流式响应(默认True)
|
||
"""
|
||
try:
|
||
service = get_coze_service()
|
||
|
||
if request.stream:
|
||
# 流式响应
|
||
async def event_generator():
|
||
async for event in service.send_message(request):
|
||
# 转换为 SSE 格式
|
||
if event.event == StreamEventType.MESSAGE_DELTA:
|
||
yield {
|
||
"event": "message",
|
||
"data": {
|
||
"type": "delta",
|
||
"content": event.content,
|
||
"content_type": event.content_type.value,
|
||
"message_id": event.message_id,
|
||
},
|
||
}
|
||
elif event.event == StreamEventType.MESSAGE_COMPLETED:
|
||
yield {
|
||
"event": "message",
|
||
"data": {
|
||
"type": "completed",
|
||
"content": event.content,
|
||
"content_type": event.content_type.value,
|
||
"message_id": event.message_id,
|
||
"usage": event.data.get("usage", {}),
|
||
},
|
||
}
|
||
elif event.event == StreamEventType.ERROR:
|
||
yield {"event": "error", "data": {"error": event.error}}
|
||
elif event.event == StreamEventType.DONE:
|
||
yield {
|
||
"event": "done",
|
||
"data": {"session_id": event.data.get("session_id")},
|
||
}
|
||
|
||
return EventSourceResponse(event_generator())
|
||
|
||
else:
|
||
# 非流式响应(收集完整响应)
|
||
full_content = ""
|
||
content_type = None
|
||
message_id = None
|
||
|
||
async for event in service.send_message(request):
|
||
if event.event == StreamEventType.MESSAGE_COMPLETED:
|
||
full_content = event.content
|
||
content_type = event.content_type
|
||
message_id = event.message_id
|
||
break
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": "success",
|
||
"data": {
|
||
"message_id": message_id,
|
||
"content": full_content,
|
||
"content_type": content_type.value if content_type else "text",
|
||
"role": "assistant",
|
||
},
|
||
}
|
||
|
||
except CozeException as e:
|
||
logger.error(f"发送消息失败: {e}")
|
||
if request.stream:
|
||
# 流式响应的错误处理
|
||
async def error_generator():
|
||
yield {
|
||
"event": "error",
|
||
"data": {
|
||
"code": e.code,
|
||
"message": e.message,
|
||
"details": e.details,
|
||
},
|
||
}
|
||
|
||
return EventSourceResponse(error_generator())
|
||
else:
|
||
raise HTTPException(
|
||
status_code=e.status_code or 500,
|
||
detail={"code": e.code, "message": e.message, "details": e.details},
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"未知错误: {e}", exc_info=True)
|
||
if request.stream:
|
||
|
||
async def error_generator():
|
||
yield {
|
||
"event": "error",
|
||
"data": {"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
|
||
}
|
||
|
||
return EventSourceResponse(error_generator())
|
||
else:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
|
||
)
|
||
|
||
|
||
@router.get("/sessions/{session_id}/messages")
|
||
async def get_session_messages(
|
||
session_id: str, limit: int = 50, offset: int = 0, user=Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取会话消息历史
|
||
|
||
- **session_id**: 会话ID
|
||
- **limit**: 返回消息数量限制
|
||
- **offset**: 偏移量
|
||
"""
|
||
try:
|
||
service = get_coze_service()
|
||
messages = await service.get_session_messages(session_id, limit, offset)
|
||
|
||
return {
|
||
"code": 200,
|
||
"message": "success",
|
||
"data": {
|
||
"messages": [msg.dict() for msg in messages],
|
||
"total": len(messages),
|
||
"limit": limit,
|
||
"offset": offset,
|
||
},
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取消息历史失败: {e}", exc_info=True)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
|
||
)
|