All checks were successful
continuous-integration/drone/push Build is passing
Co-authored-by: Cursor <cursoragent@cursor.com>
1249 lines
44 KiB
Python
1249 lines
44 KiB
Python
"""
|
||
陪练功能API
|
||
"""
|
||
from typing import Optional
|
||
import json
|
||
from datetime import datetime, timedelta
|
||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from sqlalchemy import select, func, or_
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from cozepy import ChatEventType
|
||
from cozepy.exception import CozeError, CozeAPIError
|
||
|
||
from app.core.deps import get_db, get_current_user
|
||
from app.models.user import User
|
||
from app.models.practice import PracticeScene, PracticeSession, PracticeDialogue, PracticeReport
|
||
from app.schemas.practice import (
|
||
PracticeSceneResponse,
|
||
PracticeSceneCreate,
|
||
PracticeSceneUpdate,
|
||
StartPracticeRequest,
|
||
InterruptPracticeRequest,
|
||
ConversationsResponse,
|
||
ExtractSceneRequest,
|
||
ExtractSceneResponse,
|
||
ExtractedSceneData,
|
||
PracticeSessionCreate,
|
||
PracticeSessionResponse,
|
||
SaveDialogueRequest,
|
||
PracticeDialogueResponse,
|
||
PracticeReportResponse,
|
||
PracticeAnalysisResult
|
||
)
|
||
from app.schemas.base import ResponseModel, PaginatedResponse
|
||
from app.services.coze_service import get_coze_service, CozeService
|
||
from app.services.ai.coze.client import get_auth_manager
|
||
from app.services.ai.practice_analysis_service import practice_analysis_service
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.get("/coze-token")
|
||
async def get_coze_token(
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取Coze OAuth Token用于前端直连WebSocket
|
||
|
||
前端语音对话需要直连Coze WebSocket,但不能暴露私钥,
|
||
因此通过此接口从后端获取临时Token
|
||
"""
|
||
try:
|
||
auth_manager = get_auth_manager()
|
||
token = auth_manager.get_oauth_token()
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="Token获取成功",
|
||
data={"token": token}
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"获取Coze Token失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"获取Token失败: {str(e)}")
|
||
|
||
|
||
@router.get("/scenes", response_model=ResponseModel[PaginatedResponse[PracticeSceneResponse]])
|
||
async def get_practice_scenes(
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
type: Optional[str] = Query(None, description="场景类型筛选"),
|
||
difficulty: Optional[str] = Query(None, description="难度筛选"),
|
||
search: Optional[str] = Query(None, description="关键词搜索(名称、描述)"),
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取可用陪练场景列表
|
||
|
||
仅返回status=active且未删除的场景
|
||
支持分页、筛选和搜索
|
||
"""
|
||
# 构建查询
|
||
query = select(PracticeScene).where(
|
||
PracticeScene.is_deleted == False,
|
||
PracticeScene.status == "active"
|
||
)
|
||
|
||
# 类型筛选
|
||
if type:
|
||
query = query.where(PracticeScene.type == type)
|
||
|
||
# 难度筛选
|
||
if difficulty:
|
||
query = query.where(PracticeScene.difficulty == difficulty)
|
||
|
||
# 关键词搜索(搜索名称和描述)
|
||
if search:
|
||
search_pattern = f"%{search}%"
|
||
query = query.where(
|
||
or_(
|
||
PracticeScene.name.like(search_pattern),
|
||
PracticeScene.description.like(search_pattern)
|
||
)
|
||
)
|
||
|
||
# 查询总数
|
||
count_query = select(func.count()).select_from(query.subquery())
|
||
total = await db.scalar(count_query)
|
||
|
||
# 分页查询
|
||
query = query.offset((page - 1) * size).limit(size).order_by(PracticeScene.created_at.desc())
|
||
result = await db.scalars(query)
|
||
scenes = list(result.all())
|
||
|
||
logger.info(
|
||
f"用户{current_user.id}查询陪练场景列表,"
|
||
f"类型={type}, 难度={difficulty}, 搜索={search}, "
|
||
f"返回{len(scenes)}条记录"
|
||
)
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="success",
|
||
data=PaginatedResponse(
|
||
items=scenes,
|
||
total=total or 0,
|
||
page=page,
|
||
page_size=size,
|
||
pages=(total + size - 1) // size if total else 0
|
||
)
|
||
)
|
||
|
||
|
||
@router.get("/scenes/{scene_id}", response_model=ResponseModel[PracticeSceneResponse])
|
||
async def get_practice_scene_detail(
|
||
scene_id: int,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取陪练场景详情
|
||
|
||
返回指定ID的场景完整信息
|
||
"""
|
||
# 查询场景
|
||
result = await db.execute(
|
||
select(PracticeScene).where(
|
||
PracticeScene.id == scene_id,
|
||
PracticeScene.is_deleted == False,
|
||
PracticeScene.status == "active"
|
||
)
|
||
)
|
||
scene = result.scalar_one_or_none()
|
||
|
||
if not scene:
|
||
logger.warning(f"用户{current_user.id}查询场景{scene_id}不存在或已禁用")
|
||
raise HTTPException(status_code=404, detail="场景不存在或已禁用")
|
||
|
||
logger.info(f"用户{current_user.id}查询场景{scene_id}详情")
|
||
|
||
return ResponseModel(code=200, message="success", data=scene)
|
||
|
||
|
||
@router.post("/scenes", response_model=ResponseModel[PracticeSceneResponse])
|
||
async def create_practice_scene(
|
||
scene_data: PracticeSceneCreate,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
创建陪练场景
|
||
|
||
仅管理员和经理可以创建场景
|
||
"""
|
||
# 权限检查
|
||
if current_user.role not in ["admin", "manager"]:
|
||
raise HTTPException(status_code=403, detail="无权限创建陪练场景")
|
||
|
||
# 创建场景
|
||
scene = PracticeScene(
|
||
**scene_data.model_dump(),
|
||
created_by=current_user.id,
|
||
updated_by=current_user.id
|
||
)
|
||
|
||
db.add(scene)
|
||
await db.commit()
|
||
await db.refresh(scene)
|
||
|
||
logger.info(f"用户{current_user.id}创建陪练场景: {scene.name} (ID: {scene.id})")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="场景创建成功",
|
||
data=scene
|
||
)
|
||
|
||
|
||
@router.put("/scenes/{scene_id}", response_model=ResponseModel[PracticeSceneResponse])
|
||
async def update_practice_scene(
|
||
scene_id: int,
|
||
scene_data: PracticeSceneUpdate,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
更新陪练场景
|
||
|
||
仅管理员和经理可以更新场景
|
||
"""
|
||
# 权限检查
|
||
if current_user.role not in ["admin", "manager"]:
|
||
raise HTTPException(status_code=403, detail="无权限更新陪练场景")
|
||
|
||
# 查询场景
|
||
result = await db.execute(
|
||
select(PracticeScene).where(
|
||
PracticeScene.id == scene_id,
|
||
PracticeScene.is_deleted == False
|
||
)
|
||
)
|
||
scene = result.scalar_one_or_none()
|
||
|
||
if not scene:
|
||
raise HTTPException(status_code=404, detail="场景不存在")
|
||
|
||
# 更新字段
|
||
update_data = scene_data.model_dump(exclude_unset=True)
|
||
for field, value in update_data.items():
|
||
setattr(scene, field, value)
|
||
|
||
scene.updated_by = current_user.id
|
||
|
||
await db.commit()
|
||
await db.refresh(scene)
|
||
|
||
logger.info(f"用户{current_user.id}更新陪练场景: {scene.name} (ID: {scene.id})")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="场景更新成功",
|
||
data=scene
|
||
)
|
||
|
||
|
||
@router.delete("/scenes/{scene_id}", response_model=ResponseModel)
|
||
async def delete_practice_scene(
|
||
scene_id: int,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
删除陪练场景(软删除)
|
||
|
||
仅管理员和经理可以删除场景
|
||
"""
|
||
# 权限检查
|
||
if current_user.role not in ["admin", "manager"]:
|
||
raise HTTPException(status_code=403, detail="无权限删除陪练场景")
|
||
|
||
# 查询场景
|
||
result = await db.execute(
|
||
select(PracticeScene).where(
|
||
PracticeScene.id == scene_id,
|
||
PracticeScene.is_deleted == False
|
||
)
|
||
)
|
||
scene = result.scalar_one_or_none()
|
||
|
||
if not scene:
|
||
raise HTTPException(status_code=404, detail="场景不存在")
|
||
|
||
# 软删除
|
||
scene.is_deleted = True
|
||
scene.updated_by = current_user.id
|
||
|
||
await db.commit()
|
||
|
||
logger.info(f"用户{current_user.id}删除陪练场景: {scene.name} (ID: {scene.id})")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="场景删除成功",
|
||
data={"scene_id": scene_id}
|
||
)
|
||
|
||
|
||
@router.post("/start")
|
||
async def start_practice(
|
||
request: StartPracticeRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
coze_service: CozeService = Depends(get_coze_service)
|
||
):
|
||
"""
|
||
开始陪练对话(SSE流式返回)
|
||
|
||
⚠️ 核心功能:
|
||
- 首次消息(is_first=true):构建完整场景提示词发送给Coze
|
||
- 后续消息(is_first=false):仅发送用户消息
|
||
- 使用conversation_id保持对话上下文
|
||
"""
|
||
logger.info(
|
||
f"用户{current_user.id}开始陪练对话,"
|
||
f"场景={request.scene_name}, "
|
||
f"is_first={request.is_first}, "
|
||
f"conversation_id={request.conversation_id}"
|
||
)
|
||
|
||
# 构建发送给Coze的消息
|
||
if request.is_first:
|
||
# 首次消息:构建完整场景提示词
|
||
message = coze_service.build_scene_prompt(
|
||
scene_name=request.scene_name,
|
||
scene_background=request.scene_background,
|
||
scene_ai_role=request.scene_ai_role,
|
||
scene_objectives=request.scene_objectives,
|
||
scene_keywords=request.scene_keywords,
|
||
scene_description=request.scene_description,
|
||
user_message=request.user_message
|
||
)
|
||
logger.debug(f"场景提示词已构建,长度={len(message)}字符")
|
||
else:
|
||
# 后续消息:仅发送用户输入
|
||
message = request.user_message
|
||
logger.debug(f"用户消息: {message}")
|
||
|
||
def generate_stream():
|
||
"""SSE流式生成器"""
|
||
try:
|
||
# 创建Coze流式对话
|
||
stream = coze_service.create_stream_chat(
|
||
user_id=str(current_user.id),
|
||
message=message,
|
||
conversation_id=request.conversation_id
|
||
)
|
||
|
||
# 处理Coze事件流
|
||
for event in stream:
|
||
# 对话创建事件
|
||
if event.event == ChatEventType.CONVERSATION_CHAT_CREATED:
|
||
# 优先使用请求中的conversation_id(续接对话)
|
||
# 如果没有,使用Coze返回的新对话ID(首次对话)
|
||
final_conversation_id = request.conversation_id or event.chat.conversation_id
|
||
event_data = {
|
||
"conversation_id": final_conversation_id,
|
||
"chat_id": event.chat.id
|
||
}
|
||
yield f"event: conversation.chat.created\ndata: {json.dumps(event_data)}\n\n"
|
||
logger.debug(f"对话已创建/续接: conversation_id={final_conversation_id}, 来源={'请求参数' if request.conversation_id else 'Coze创建'}")
|
||
|
||
# 消息增量事件(实时打字效果)
|
||
elif event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA:
|
||
event_data = {"content": event.message.content}
|
||
yield f"event: message.delta\ndata: {json.dumps(event_data)}\n\n"
|
||
|
||
# 消息完成事件
|
||
elif event.event == ChatEventType.CONVERSATION_MESSAGE_COMPLETED:
|
||
event_data = {} # 不需要返回完整内容,前端已通过delta累积
|
||
yield f"event: message.completed\ndata: {json.dumps(event_data)}\n\n"
|
||
logger.info(f"消息已完成")
|
||
|
||
# 对话完成事件
|
||
elif event.event == ChatEventType.CONVERSATION_CHAT_COMPLETED:
|
||
# 安全地获取token用量
|
||
token_count = 0
|
||
input_count = 0
|
||
output_count = 0
|
||
if hasattr(event.chat, 'usage') and event.chat.usage:
|
||
token_count = getattr(event.chat.usage, 'token_count', 0)
|
||
input_count = getattr(event.chat.usage, 'input_count', 0)
|
||
output_count = getattr(event.chat.usage, 'output_count', 0)
|
||
|
||
event_data = {
|
||
"token_count": token_count,
|
||
"input_count": input_count,
|
||
"output_count": output_count
|
||
}
|
||
yield f"event: conversation.completed\ndata: {json.dumps(event_data)}\n\n"
|
||
logger.info(f"对话已完成,Token用量={event_data['token_count']}")
|
||
break
|
||
|
||
# 对话失败事件
|
||
elif event.event == ChatEventType.CONVERSATION_CHAT_FAILED:
|
||
error_msg = str(event.chat.last_error) if event.chat.last_error else "对话失败"
|
||
event_data = {"error": error_msg}
|
||
yield f"event: error\ndata: {json.dumps(event_data)}\n\n"
|
||
logger.error(f"对话失败: {error_msg}")
|
||
break
|
||
|
||
# 发送结束标记
|
||
yield f"event: done\ndata: [DONE]\n\n"
|
||
logger.info(f"SSE流结束")
|
||
|
||
except (CozeError, CozeAPIError) as e:
|
||
logger.error(f"Coze API错误: {e}", exc_info=True)
|
||
error_data = {"error": f"对话失败: {str(e)}"}
|
||
yield f"event: error\ndata: {json.dumps(error_data)}\n\n"
|
||
except Exception as e:
|
||
logger.error(f"陪练对话异常: {e}", exc_info=True)
|
||
error_data = {"error": f"系统错误: {str(e)}"}
|
||
yield f"event: error\ndata: {json.dumps(error_data)}\n\n"
|
||
|
||
# 返回SSE流式响应
|
||
return StreamingResponse(
|
||
generate_stream(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no" # 禁用Nginx缓冲
|
||
}
|
||
)
|
||
|
||
|
||
@router.post("/interrupt", response_model=ResponseModel)
|
||
async def interrupt_practice(
|
||
request: InterruptPracticeRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
coze_service: CozeService = Depends(get_coze_service)
|
||
):
|
||
"""
|
||
中断陪练对话
|
||
|
||
调用Coze API中断当前进行中的对话
|
||
"""
|
||
logger.info(
|
||
f"用户{current_user.id}中断对话,"
|
||
f"conversation_id={request.conversation_id}, "
|
||
f"chat_id={request.chat_id}"
|
||
)
|
||
|
||
try:
|
||
result = coze_service.cancel_chat(
|
||
conversation_id=request.conversation_id,
|
||
chat_id=request.chat_id
|
||
)
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="对话已中断",
|
||
data={
|
||
"conversation_id": request.conversation_id,
|
||
"chat_id": request.chat_id
|
||
}
|
||
)
|
||
except (CozeError, CozeAPIError) as e:
|
||
logger.error(f"中断对话失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"中断对话失败: {str(e)}")
|
||
except Exception as e:
|
||
logger.error(f"中断对话异常: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"系统错误: {str(e)}")
|
||
|
||
|
||
@router.post("/conversation/create", response_model=ResponseModel)
|
||
async def create_practice_conversation(
|
||
current_user: User = Depends(get_current_user),
|
||
coze_service: CozeService = Depends(get_coze_service)
|
||
):
|
||
"""
|
||
创建新的陪练对话
|
||
|
||
⚠️ 关键:必须先创建conversation,然后才能续接对话
|
||
返回conversation_id供后续对话使用
|
||
"""
|
||
try:
|
||
# 调用Coze API创建对话
|
||
conversation = coze_service.client.conversations.create()
|
||
|
||
conversation_id = conversation.id
|
||
|
||
logger.info(f"用户{current_user.id}创建陪练对话,conversation_id={conversation_id}")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="对话创建成功",
|
||
data={"conversation_id": conversation_id}
|
||
)
|
||
except (CozeError, CozeAPIError) as e:
|
||
logger.error(f"创建对话失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"创建对话失败: {str(e)}")
|
||
except Exception as e:
|
||
logger.error(f"创建对话异常: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"系统错误: {str(e)}")
|
||
|
||
|
||
@router.get("/conversations", response_model=ResponseModel[ConversationsResponse])
|
||
async def get_conversations(
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取对话列表
|
||
|
||
查询用户在Coze平台上的对话历史
|
||
|
||
注意:语音陪练使用前端直连Coze WebSocket,不经过后端中转
|
||
"""
|
||
# TODO: 实现对话列表查询
|
||
# 将在阶段四实现
|
||
logger.info(f"用户{current_user.id}查询对话列表")
|
||
raise HTTPException(status_code=501, detail="对话列表功能正在开发中")
|
||
|
||
|
||
@router.post("/extract-scene", response_model=ResponseModel[ExtractSceneResponse])
|
||
async def extract_scene(
|
||
request: ExtractSceneRequest,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
从课程提取陪练场景
|
||
|
||
使用 Python 原生 AI 服务实现,直接调用 AI API 生成场景。
|
||
|
||
流程:
|
||
1. 验证课程是否存在
|
||
2. 获取课程知识点
|
||
3. 调用 AI 生成陪练场景
|
||
4. 解析并返回场景数据
|
||
"""
|
||
from app.models.course import Course
|
||
from app.services.ai import practice_scene_service
|
||
|
||
# 验证课程存在
|
||
course = await db.get(Course, request.course_id)
|
||
if not course:
|
||
logger.warning(f"课程不存在: course_id={request.course_id}")
|
||
raise HTTPException(status_code=404, detail="课程不存在")
|
||
|
||
logger.info(f"用户{current_user.id}开始提取课程{request.course_id}的陪练场景")
|
||
|
||
# 调用 Python 原生服务
|
||
result = await practice_scene_service.prepare_practice_knowledge(
|
||
db=db,
|
||
course_id=request.course_id
|
||
)
|
||
|
||
if not result.success:
|
||
# 根据错误类型返回适当的 HTTP 状态码
|
||
if "没有可用的知识点" in result.error or "没有知识点" in result.error:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="该课程尚未添加知识点,无法生成陪练场景。请先在课程管理中上传资料并分析知识点。"
|
||
)
|
||
raise HTTPException(status_code=500, detail=f"场景提取失败: {result.error}")
|
||
|
||
# 将 PracticeScene 转换为 ExtractedSceneData
|
||
scene = result.scene
|
||
scene_data = ExtractedSceneData(
|
||
name=scene.name,
|
||
description=scene.description,
|
||
type=scene.type,
|
||
difficulty=scene.difficulty,
|
||
background=scene.background,
|
||
ai_role=scene.ai_role,
|
||
objectives=scene.objectives,
|
||
keywords=scene.keywords
|
||
)
|
||
|
||
logger.info(
|
||
f"场景提取成功: {scene.name}, course_id={request.course_id}, "
|
||
f"provider={result.ai_provider}, tokens={result.ai_tokens}"
|
||
)
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="场景提取成功",
|
||
data=ExtractSceneResponse(
|
||
scene=scene_data,
|
||
workflow_run_id=f"{result.ai_provider}_{result.ai_latency_ms}ms",
|
||
task_id=f"native_{request.course_id}"
|
||
)
|
||
)
|
||
|
||
|
||
# ==================== 陪练会话管理API ====================
|
||
|
||
@router.post("/sessions/create", response_model=ResponseModel[PracticeSessionResponse])
|
||
async def create_practice_session(
|
||
request: PracticeSessionCreate,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
创建陪练会话
|
||
|
||
用户开始陪练时调用,创建session记录
|
||
"""
|
||
try:
|
||
# 生成session_id(格式:PS + 时间戳后6位)
|
||
session_id = f"PS{str(int(datetime.now().timestamp() * 1000))[-6:]}"
|
||
|
||
# 创建session记录
|
||
session = PracticeSession(
|
||
session_id=session_id,
|
||
user_id=current_user.id,
|
||
scene_id=request.scene_id,
|
||
scene_name=request.scene_name,
|
||
scene_type=request.scene_type,
|
||
conversation_id=request.conversation_id,
|
||
start_time=datetime.now(),
|
||
status="in_progress"
|
||
)
|
||
|
||
db.add(session)
|
||
await db.commit()
|
||
await db.refresh(session)
|
||
|
||
logger.info(f"创建陪练会话: session_id={session_id}, user_id={current_user.id}, scene={request.scene_name}")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="会话创建成功",
|
||
data=session
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建会话失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"创建会话失败: {str(e)}")
|
||
|
||
|
||
@router.post("/dialogues/save", response_model=ResponseModel)
|
||
async def save_dialogue(
|
||
request: SaveDialogueRequest,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
保存对话记录
|
||
|
||
每一条对话(用户或AI)都实时保存
|
||
"""
|
||
try:
|
||
# 创建对话记录
|
||
dialogue = PracticeDialogue(
|
||
session_id=request.session_id,
|
||
speaker=request.speaker,
|
||
content=request.content,
|
||
timestamp=datetime.now(),
|
||
sequence=request.sequence
|
||
)
|
||
|
||
db.add(dialogue)
|
||
await db.commit()
|
||
|
||
logger.debug(f"保存对话: session_id={request.session_id}, speaker={request.speaker}, seq={request.sequence}")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="对话保存成功",
|
||
data={"session_id": request.session_id, "sequence": request.sequence}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"保存对话失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"保存对话失败: {str(e)}")
|
||
|
||
|
||
@router.post("/sessions/{session_id}/end", response_model=ResponseModel[PracticeSessionResponse])
|
||
async def end_practice_session(
|
||
session_id: str,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
结束陪练会话
|
||
|
||
用户结束陪练时调用,更新会话状态和时长
|
||
"""
|
||
try:
|
||
# 查询会话
|
||
result = await db.execute(
|
||
select(PracticeSession).where(
|
||
PracticeSession.session_id == session_id,
|
||
PracticeSession.user_id == current_user.id,
|
||
PracticeSession.is_deleted == False
|
||
)
|
||
)
|
||
session = result.scalar_one_or_none()
|
||
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
# 查询对话数量
|
||
result = await db.execute(
|
||
select(func.count(PracticeDialogue.id)).where(
|
||
PracticeDialogue.session_id == session_id
|
||
)
|
||
)
|
||
dialogue_count = result.scalar() or 0
|
||
|
||
# 更新会话状态
|
||
session.end_time = datetime.now()
|
||
session.duration_seconds = int((session.end_time - session.start_time).total_seconds())
|
||
session.turns = dialogue_count
|
||
session.status = "completed"
|
||
|
||
await db.commit()
|
||
await db.refresh(session)
|
||
|
||
logger.info(f"结束陪练会话: session_id={session_id}, 时长={session.duration_seconds}秒, 轮次={session.turns}")
|
||
|
||
# 练习完成时触发经验值和奖章检查
|
||
exp_result = None
|
||
new_badges = []
|
||
try:
|
||
from app.services.level_service import LevelService
|
||
from app.services.badge_service import BadgeService
|
||
|
||
level_service = LevelService(db)
|
||
badge_service = BadgeService(db)
|
||
|
||
# 添加练习经验值
|
||
exp_result = await level_service.add_practice_exp(
|
||
user_id=current_user.id,
|
||
session_id=session.id
|
||
)
|
||
|
||
# 检查是否解锁新奖章
|
||
new_badges = await badge_service.check_and_award_badges(current_user.id)
|
||
|
||
await db.commit()
|
||
# 第二次commit后需要refresh,避免DetachedInstanceError
|
||
await db.refresh(session)
|
||
except Exception as e:
|
||
logger.warning(f"练习经验值/奖章处理失败: {str(e)}")
|
||
# 确保 session 仍然可用
|
||
try:
|
||
await db.refresh(session)
|
||
except Exception:
|
||
pass
|
||
|
||
# 将 ORM 对象转换为响应格式,避免 DetachedInstanceError
|
||
session_data = PracticeSessionResponse(
|
||
id=session.id,
|
||
session_id=session.session_id,
|
||
user_id=session.user_id,
|
||
scene_id=session.scene_id,
|
||
scene_name=session.scene_name or "",
|
||
scene_type=session.scene_type,
|
||
conversation_id=session.conversation_id,
|
||
start_time=session.start_time,
|
||
end_time=session.end_time,
|
||
duration_seconds=session.duration_seconds or 0,
|
||
turns=session.turns or 0,
|
||
status=session.status,
|
||
created_at=session.created_at
|
||
)
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="会话已结束",
|
||
data=session_data
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"结束会话失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"结束会话失败: {str(e)}")
|
||
|
||
|
||
@router.post("/sessions/{session_id}/analyze", response_model=ResponseModel)
|
||
async def analyze_practice_session(
|
||
session_id: str,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
生成陪练分析报告
|
||
|
||
使用 Python 原生 AI 服务实现。
|
||
"""
|
||
try:
|
||
# 1. 查询会话信息
|
||
result = await db.execute(
|
||
select(PracticeSession).where(
|
||
PracticeSession.session_id == session_id,
|
||
PracticeSession.user_id == current_user.id,
|
||
PracticeSession.is_deleted == False
|
||
)
|
||
)
|
||
session = result.scalar_one_or_none()
|
||
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
# 2. 查询对话历史
|
||
result = await db.execute(
|
||
select(PracticeDialogue).where(
|
||
PracticeDialogue.session_id == session_id
|
||
).order_by(PracticeDialogue.sequence)
|
||
)
|
||
dialogues = result.scalars().all()
|
||
|
||
if not dialogues or len(dialogues) < 2:
|
||
raise HTTPException(status_code=400, detail="对话数量太少,无法生成分析报告")
|
||
|
||
# 3. 构建对话历史数据
|
||
dialogue_history = [
|
||
{
|
||
"speaker": d.speaker,
|
||
"content": d.content,
|
||
"timestamp": d.timestamp.isoformat()
|
||
}
|
||
for d in dialogues
|
||
]
|
||
|
||
logger.info(f"开始分析陪练会话: session_id={session_id}, 对话数={len(dialogue_history)}")
|
||
|
||
# 调用 Python 原生陪练分析服务
|
||
v2_result = await practice_analysis_service.analyze(dialogue_history, db=db)
|
||
|
||
if not v2_result.success:
|
||
raise HTTPException(status_code=500, detail=f"分析失败: {v2_result.error}")
|
||
|
||
analysis_data = v2_result.to_dict()
|
||
|
||
logger.info(
|
||
f"陪练分析完成 - total_score: {v2_result.total_score}, "
|
||
f"provider: {v2_result.ai_provider}, latency: {v2_result.ai_latency_ms}ms"
|
||
)
|
||
|
||
# 解析分析结果
|
||
analysis_result = analysis_data.get("analysis", {})
|
||
|
||
# 检查报告是否已存在
|
||
existing_report = await db.execute(
|
||
select(PracticeReport).where(PracticeReport.session_id == session_id)
|
||
)
|
||
report = existing_report.scalar_one_or_none()
|
||
|
||
if report:
|
||
# 更新现有报告
|
||
report.total_score = analysis_result.get("total_score")
|
||
report.score_breakdown = analysis_result.get("score_breakdown")
|
||
report.ability_dimensions = analysis_result.get("ability_dimensions")
|
||
report.dialogue_review = analysis_result.get("dialogue_annotations")
|
||
report.suggestions = analysis_result.get("suggestions")
|
||
report.workflow_run_id = f"{v2_result.ai_provider}_{v2_result.ai_latency_ms}ms"
|
||
logger.info(f"更新现有分析报告: session_id={session_id}")
|
||
else:
|
||
# 创建新报告
|
||
report = PracticeReport(
|
||
session_id=session_id,
|
||
total_score=analysis_result.get("total_score"),
|
||
score_breakdown=analysis_result.get("score_breakdown"),
|
||
ability_dimensions=analysis_result.get("ability_dimensions"),
|
||
dialogue_review=analysis_result.get("dialogue_annotations"),
|
||
suggestions=analysis_result.get("suggestions"),
|
||
workflow_run_id=f"{v2_result.ai_provider}_{v2_result.ai_latency_ms}ms",
|
||
task_id=None
|
||
)
|
||
db.add(report)
|
||
|
||
await db.commit()
|
||
|
||
logger.info(f"分析报告已保存: session_id={session_id}, total_score={report.total_score}")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="分析报告生成成功",
|
||
data={
|
||
"session_id": session_id,
|
||
"total_score": report.total_score,
|
||
"workflow_run_id": report.workflow_run_id
|
||
}
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"生成分析报告失败: {e}, session_id={session_id}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"生成分析报告失败: {str(e)}")
|
||
|
||
|
||
@router.get("/reports/{session_id}", response_model=ResponseModel[PracticeReportResponse])
|
||
async def get_practice_report(
|
||
session_id: str,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取陪练分析报告详情
|
||
|
||
合并数据库对话记录和AI标注,生成完整的对话复盘
|
||
"""
|
||
try:
|
||
# 1. 查询会话信息
|
||
result = await db.execute(
|
||
select(PracticeSession).where(
|
||
PracticeSession.session_id == session_id,
|
||
PracticeSession.user_id == current_user.id,
|
||
PracticeSession.is_deleted == False
|
||
)
|
||
)
|
||
session = result.scalar_one_or_none()
|
||
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
# 2. 查询分析报告
|
||
result = await db.execute(
|
||
select(PracticeReport).where(
|
||
PracticeReport.session_id == session_id
|
||
)
|
||
)
|
||
report = result.scalar_one_or_none()
|
||
|
||
if not report:
|
||
# 报告不存在,自动生成
|
||
logger.info(f"报告不存在,自动生成: session_id={session_id}")
|
||
|
||
# 查询对话历史
|
||
result = await db.execute(
|
||
select(PracticeDialogue).where(
|
||
PracticeDialogue.session_id == session_id
|
||
).order_by(PracticeDialogue.sequence)
|
||
)
|
||
dialogue_list = result.scalars().all()
|
||
|
||
if not dialogue_list:
|
||
raise HTTPException(status_code=404, detail="没有对话记录,无法生成报告")
|
||
|
||
# 构建对话历史
|
||
dialogue_history = [
|
||
{"role": "user" if d.speaker == "user" else "assistant", "content": d.content}
|
||
for d in dialogue_list
|
||
]
|
||
|
||
# 调用分析服务
|
||
from app.services.ai.practice_analysis_service import PracticeAnalysisService
|
||
import json
|
||
|
||
practice_analysis_service = PracticeAnalysisService()
|
||
analysis_result = await practice_analysis_service.analyze(dialogue_history, db=db)
|
||
|
||
if not analysis_result.success:
|
||
raise HTTPException(status_code=500, detail=f"分析失败: {analysis_result.error}")
|
||
|
||
analysis_data = analysis_result.to_dict()
|
||
|
||
# 保存报告
|
||
report = PracticeReport(
|
||
session_id=session_id,
|
||
total_score=analysis_data.get("overall_score", 0),
|
||
score_breakdown=analysis_data.get("score_breakdown", []),
|
||
ability_dimensions=analysis_data.get("ability_dimensions", []),
|
||
dialogue_review=analysis_data.get("dialogue_review", []),
|
||
suggestions=analysis_data.get("suggestions", []),
|
||
summary=analysis_data.get("summary", ""),
|
||
raw_response=json.dumps(analysis_data, ensure_ascii=False)
|
||
)
|
||
db.add(report)
|
||
await db.commit()
|
||
await db.refresh(report)
|
||
logger.info(f"报告自动生成成功: session_id={session_id}, 总分={report.total_score}")
|
||
|
||
# 3. 查询完整对话记录(从数据库)
|
||
result = await db.execute(
|
||
select(PracticeDialogue).where(
|
||
PracticeDialogue.session_id == session_id
|
||
).order_by(PracticeDialogue.sequence)
|
||
)
|
||
dialogues = result.scalars().all()
|
||
|
||
# 4. 合并对话记录和AI标注
|
||
# dialogue_review字段存储的是标注信息(包含sequence, tags, comment)
|
||
ai_annotations = report.dialogue_review or []
|
||
|
||
# 创建标注映射(sequence -> {tags, comment})
|
||
annotations_map = {}
|
||
for annotation in ai_annotations:
|
||
seq = annotation.get('sequence')
|
||
if seq:
|
||
annotations_map[seq] = {
|
||
'tags': annotation.get('tags', []),
|
||
'comment': annotation.get('comment', '')
|
||
}
|
||
|
||
# 构建完整对话复盘(数据库对话 + AI标注)
|
||
dialogue_review = []
|
||
for dialogue in dialogues:
|
||
# 计算时间(从会话开始时间算起)
|
||
time_offset = int((dialogue.timestamp - session.start_time).total_seconds())
|
||
time_str = f"{time_offset // 60:02d}:{time_offset % 60:02d}"
|
||
|
||
# 获取标注
|
||
annotation = annotations_map.get(dialogue.sequence, {})
|
||
|
||
dialogue_review.append({
|
||
"speaker": "顾问" if dialogue.speaker == "user" else "客户",
|
||
"time": time_str,
|
||
"content": dialogue.content,
|
||
"tags": annotation.get('tags', []),
|
||
"comment": annotation.get('comment', '')
|
||
})
|
||
|
||
# 5. 构建响应数据
|
||
# 5.1 处理score_breakdown字段(兼容字典和列表格式)
|
||
score_breakdown_data = report.score_breakdown or []
|
||
if isinstance(score_breakdown_data, str):
|
||
try:
|
||
score_breakdown_data = json.loads(score_breakdown_data)
|
||
except json.JSONDecodeError:
|
||
logger.warning(f"无法解析score_breakdown JSON: {score_breakdown_data}")
|
||
score_breakdown_data = []
|
||
|
||
# 如果是字典格式,转换为列表格式
|
||
if isinstance(score_breakdown_data, dict):
|
||
score_breakdown_data = [
|
||
{"name": k, "score": int(v), "description": ""}
|
||
for k, v in score_breakdown_data.items()
|
||
]
|
||
|
||
# 5.2 处理ability_dimensions字段(兼容字典和列表格式)
|
||
ability_dimensions_data = report.ability_dimensions or []
|
||
if isinstance(ability_dimensions_data, str):
|
||
try:
|
||
ability_dimensions_data = json.loads(ability_dimensions_data)
|
||
except json.JSONDecodeError:
|
||
logger.warning(f"无法解析ability_dimensions JSON: {ability_dimensions_data}")
|
||
ability_dimensions_data = []
|
||
|
||
# 如果是字典格式,转换为列表格式
|
||
if isinstance(ability_dimensions_data, dict):
|
||
ability_dimensions_data = [
|
||
{"name": k, "score": int(v), "feedback": ""}
|
||
for k, v in ability_dimensions_data.items()
|
||
]
|
||
|
||
# 5.3 处理suggestions字段
|
||
suggestions_data = report.suggestions or []
|
||
if isinstance(suggestions_data, str):
|
||
try:
|
||
suggestions_data = json.loads(suggestions_data)
|
||
except json.JSONDecodeError:
|
||
logger.warning(f"无法解析suggestions JSON: {suggestions_data}")
|
||
suggestions_data = []
|
||
|
||
analysis = PracticeAnalysisResult(
|
||
total_score=report.total_score,
|
||
score_breakdown=score_breakdown_data,
|
||
ability_dimensions=ability_dimensions_data,
|
||
dialogue_review=dialogue_review, # 使用合并后的对话
|
||
suggestions=suggestions_data
|
||
)
|
||
|
||
response_data = PracticeReportResponse(
|
||
session_info=session,
|
||
analysis=analysis
|
||
)
|
||
|
||
logger.info(f"获取分析报告: session_id={session_id}, total_score={report.total_score}, 对话数={len(dialogue_review)}")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="success",
|
||
data=response_data
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取报告失败: {e}, session_id={session_id}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"获取报告失败: {str(e)}")
|
||
|
||
|
||
# ==================== 陪练记录查询API ====================
|
||
|
||
@router.get("/sessions/list", response_model=ResponseModel[PaginatedResponse])
|
||
async def get_practice_sessions_list(
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
keyword: Optional[str] = Query(None, description="关键词搜索"),
|
||
scene_type: Optional[str] = Query(None, description="场景类型"),
|
||
start_date: Optional[str] = Query(None, description="开始日期"),
|
||
end_date: Optional[str] = Query(None, description="结束日期"),
|
||
min_score: Optional[int] = Query(None, ge=0, le=100, description="最低分数"),
|
||
max_score: Optional[int] = Query(None, ge=0, le=100, description="最高分数"),
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取陪练记录列表
|
||
|
||
支持关键词搜索、场景筛选、时间范围筛选、分数筛选
|
||
"""
|
||
try:
|
||
# 构建查询(关联practice_reports表获取分数)
|
||
query = select(
|
||
PracticeSession,
|
||
PracticeReport.total_score
|
||
).outerjoin(
|
||
PracticeReport,
|
||
PracticeSession.session_id == PracticeReport.session_id
|
||
).where(
|
||
PracticeSession.user_id == current_user.id,
|
||
PracticeSession.is_deleted == False,
|
||
PracticeSession.status == "completed" # 只查询已完成的会话
|
||
)
|
||
|
||
# 关键词搜索
|
||
if keyword:
|
||
query = query.where(
|
||
or_(
|
||
PracticeSession.scene_name.contains(keyword),
|
||
PracticeSession.session_id.contains(keyword)
|
||
)
|
||
)
|
||
|
||
# 场景类型筛选
|
||
if scene_type:
|
||
query = query.where(PracticeSession.scene_type == scene_type)
|
||
|
||
# 时间范围筛选
|
||
if start_date:
|
||
query = query.where(PracticeSession.start_time >= start_date)
|
||
if end_date:
|
||
query = query.where(PracticeSession.start_time <= end_date)
|
||
|
||
# 分数筛选
|
||
if min_score is not None:
|
||
query = query.where(PracticeReport.total_score >= min_score)
|
||
if max_score is not None:
|
||
query = query.where(PracticeReport.total_score <= max_score)
|
||
|
||
# 按开始时间倒序排列
|
||
query = query.order_by(PracticeSession.start_time.desc())
|
||
|
||
# 计算总数
|
||
count_query = select(func.count()).select_from(query.subquery())
|
||
total = await db.scalar(count_query) or 0
|
||
|
||
# 分页查询
|
||
results = await db.execute(
|
||
query.offset((page - 1) * size).limit(size)
|
||
)
|
||
|
||
# 构建响应数据
|
||
items = []
|
||
for session, total_score in results:
|
||
# 计算result等级
|
||
result_level = "needs_improvement"
|
||
if total_score:
|
||
if total_score >= 90:
|
||
result_level = "excellent"
|
||
elif total_score >= 80:
|
||
result_level = "good"
|
||
elif total_score >= 70:
|
||
result_level = "average"
|
||
|
||
items.append({
|
||
"session_id": session.session_id,
|
||
"scene_name": session.scene_name,
|
||
"scene_type": session.scene_type,
|
||
"start_time": session.start_time,
|
||
"duration_seconds": session.duration_seconds,
|
||
"turns": session.turns,
|
||
"total_score": total_score,
|
||
"result": result_level
|
||
})
|
||
|
||
logger.info(f"查询陪练记录: user_id={current_user.id}, 返回{len(items)}条记录")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="success",
|
||
data=PaginatedResponse(
|
||
items=items,
|
||
total=total,
|
||
page=page,
|
||
page_size=size,
|
||
pages=(total + size - 1) // size
|
||
)
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询陪练记录失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}")
|
||
|
||
|
||
@router.get("/stats", response_model=ResponseModel)
|
||
async def get_practice_stats(
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
获取陪练统计数据
|
||
|
||
返回:总次数、平均分、总时长、本月进步
|
||
"""
|
||
try:
|
||
# 查询总次数和总时长
|
||
result = await db.execute(
|
||
select(
|
||
func.count(PracticeSession.id).label('total_count'),
|
||
func.sum(PracticeSession.duration_seconds).label('total_duration')
|
||
).where(
|
||
PracticeSession.user_id == current_user.id,
|
||
PracticeSession.is_deleted == False,
|
||
PracticeSession.status == "completed"
|
||
)
|
||
)
|
||
stats = result.first()
|
||
|
||
total_count = stats.total_count or 0
|
||
total_duration = stats.total_duration or 0
|
||
total_duration_hours = round(total_duration / 3600, 1)
|
||
|
||
# 查询平均分
|
||
result = await db.execute(
|
||
select(func.avg(PracticeReport.total_score)).where(
|
||
PracticeReport.session_id.in_(
|
||
select(PracticeSession.session_id).where(
|
||
PracticeSession.user_id == current_user.id,
|
||
PracticeSession.is_deleted == False
|
||
)
|
||
)
|
||
)
|
||
)
|
||
avg_score = result.scalar() or 0
|
||
avg_score = round(float(avg_score), 1) if avg_score else 0
|
||
|
||
# 计算本月进步(简化:与上月平均分对比)
|
||
# TODO: 实现真实的月度对比逻辑
|
||
month_improvement = 15 # 暂时使用固定值
|
||
|
||
logger.info(f"查询陪练统计: user_id={current_user.id}, total={total_count}, avg={avg_score}")
|
||
|
||
return ResponseModel(
|
||
code=200,
|
||
message="success",
|
||
data={
|
||
"total_count": total_count,
|
||
"avg_score": avg_score,
|
||
"total_duration_hours": total_duration_hours,
|
||
"month_improvement": month_improvement
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"查询统计数据失败: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=f"查询失败: {str(e)}")
|