""" Coze 网关 API 路由 提供课程对话和陪练功能的统一接口 """ import logging from typing import Dict, Any from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse from sse_starlette.sse import EventSourceResponse from app.services.ai.coze import ( get_coze_service, CreateSessionRequest, SendMessageRequest, EndSessionRequest, SessionType, CozeException, StreamEventType, ) logger = logging.getLogger(__name__) router = APIRouter(tags=["coze-gateway"]) # TODO: 依赖注入获取当前用户 async def get_current_user(): """获取当前登录用户(临时实现)""" # 实际应该从 Auth 模块获取 return {"user_id": "test-user-123", "username": "test_user"} @router.post("/course-chat/sessions") async def create_course_chat_session(course_id: str, user=Depends(get_current_user)): """ 创建课程对话会话 - **course_id**: 课程ID """ try: service = get_coze_service() request = CreateSessionRequest( session_type=SessionType.COURSE_CHAT, user_id=user["user_id"], course_id=course_id, metadata={"username": user["username"], "course_id": course_id}, ) response = await service.create_session(request) return {"code": 200, "message": "success", "data": response.dict()} except CozeException as e: logger.error(f"创建课程对话会话失败: {e}") raise HTTPException( status_code=e.status_code or 500, detail={"code": e.code, "message": e.message, "details": e.details}, ) except Exception as e: logger.error(f"未知错误: {e}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"}, ) @router.post("/training/sessions") async def create_training_session( training_topic: str = None, user=Depends(get_current_user) ): """ 创建陪练会话 - **training_topic**: 陪练主题(可选) """ try: service = get_coze_service() request = CreateSessionRequest( session_type=SessionType.TRAINING, user_id=user["user_id"], training_topic=training_topic, metadata={"username": user["username"], "training_topic": training_topic}, ) response = await service.create_session(request) return {"code": 200, "message": "success", "data": response.dict()} except CozeException as e: logger.error(f"创建陪练会话失败: {e}") raise HTTPException( status_code=e.status_code or 500, detail={"code": e.code, "message": e.message, "details": e.details}, ) except Exception as e: logger.error(f"未知错误: {e}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"}, ) @router.post("/training/sessions/{session_id}/end") async def end_training_session( session_id: str, request: EndSessionRequest, user=Depends(get_current_user) ): """ 结束陪练会话 - **session_id**: 会话ID """ try: service = get_coze_service() response = await service.end_session(session_id, request) return {"code": 200, "message": "success", "data": response.dict()} except CozeException as e: logger.error(f"结束会话失败: {e}") raise HTTPException( status_code=e.status_code or 500, detail={"code": e.code, "message": e.message, "details": e.details}, ) except Exception as e: logger.error(f"未知错误: {e}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"}, ) @router.post("/chat/messages") async def send_message(request: SendMessageRequest, user=Depends(get_current_user)): """ 发送消息(支持流式响应) - **session_id**: 会话ID - **content**: 消息内容 - **stream**: 是否流式响应(默认True) """ try: service = get_coze_service() if request.stream: # 流式响应 async def event_generator(): async for event in service.send_message(request): # 转换为 SSE 格式 if event.event == StreamEventType.MESSAGE_DELTA: yield { "event": "message", "data": { "type": "delta", "content": event.content, "content_type": event.content_type.value, "message_id": event.message_id, }, } elif event.event == StreamEventType.MESSAGE_COMPLETED: yield { "event": "message", "data": { "type": "completed", "content": event.content, "content_type": event.content_type.value, "message_id": event.message_id, "usage": event.data.get("usage", {}), }, } elif event.event == StreamEventType.ERROR: yield {"event": "error", "data": {"error": event.error}} elif event.event == StreamEventType.DONE: yield { "event": "done", "data": {"session_id": event.data.get("session_id")}, } return EventSourceResponse(event_generator()) else: # 非流式响应(收集完整响应) full_content = "" content_type = None message_id = None async for event in service.send_message(request): if event.event == StreamEventType.MESSAGE_COMPLETED: full_content = event.content content_type = event.content_type message_id = event.message_id break return { "code": 200, "message": "success", "data": { "message_id": message_id, "content": full_content, "content_type": content_type.value if content_type else "text", "role": "assistant", }, } except CozeException as e: logger.error(f"发送消息失败: {e}") if request.stream: # 流式响应的错误处理 async def error_generator(): yield { "event": "error", "data": { "code": e.code, "message": e.message, "details": e.details, }, } return EventSourceResponse(error_generator()) else: raise HTTPException( status_code=e.status_code or 500, detail={"code": e.code, "message": e.message, "details": e.details}, ) except Exception as e: logger.error(f"未知错误: {e}", exc_info=True) if request.stream: async def error_generator(): yield { "event": "error", "data": {"code": "INTERNAL_ERROR", "message": "服务器内部错误"}, } return EventSourceResponse(error_generator()) else: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"}, ) @router.get("/sessions/{session_id}/messages") async def get_session_messages( session_id: str, limit: int = 50, offset: int = 0, user=Depends(get_current_user) ): """ 获取会话消息历史 - **session_id**: 会话ID - **limit**: 返回消息数量限制 - **offset**: 偏移量 """ try: service = get_coze_service() messages = await service.get_session_messages(session_id, limit, offset) return { "code": 200, "message": "success", "data": { "messages": [msg.dict() for msg in messages], "total": len(messages), "limit": limit, "offset": offset, }, } except Exception as e: logger.error(f"获取消息历史失败: {e}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"}, )