"""陪练模块API路由""" import logging from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status, Query from sqlalchemy.ext.asyncio import AsyncSession from app.core.deps import get_db, get_current_user, require_admin from app.schemas.base import ResponseModel from app.schemas.training import ( TrainingSceneCreate, TrainingSceneUpdate, TrainingSceneResponse, TrainingSessionResponse, TrainingMessageResponse, TrainingReportResponse, StartTrainingRequest, StartTrainingResponse, EndTrainingRequest, EndTrainingResponse, TrainingSceneListQuery, TrainingSessionListQuery, PaginatedResponse, ) from app.services.training_service import ( TrainingSceneService, TrainingSessionService, TrainingMessageService, TrainingReportService, ) from app.models.training import TrainingSceneStatus, TrainingSessionStatus logger = logging.getLogger(__name__) router = APIRouter(prefix="/training", tags=["陪练模块"]) # 服务实例 scene_service = TrainingSceneService() session_service = TrainingSessionService() message_service = TrainingMessageService() report_service = TrainingReportService() # ========== 陪练场景管理 ========== @router.get( "/scenes", response_model=ResponseModel[PaginatedResponse[TrainingSceneResponse]] ) async def get_training_scenes( category: Optional[str] = Query(None, description="场景分类"), status: Optional[TrainingSceneStatus] = Query(None, description="场景状态"), is_public: Optional[bool] = Query(None, description="是否公开"), search: Optional[str] = Query(None, description="搜索关键词"), page: int = Query(1, ge=1, description="页码"), page_size: int = Query(20, ge=1, le=100, description="每页数量"), current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 获取陪练场景列表 - 支持按分类、状态、是否公开筛选 - 支持关键词搜索 - 支持分页 """ try: # 计算分页参数 skip = (page - 1) * page_size # 获取用户等级(TODO: 从User服务获取) user_level = 1 # 获取场景列表 scenes = await scene_service.get_active_scenes( db, category=category, is_public=is_public, user_level=user_level, skip=skip, limit=page_size, ) # 获取总数 from sqlalchemy import select, func, and_ from app.models.training import TrainingScene count_query = ( select(func.count()) .select_from(TrainingScene) .where( and_( TrainingScene.status == TrainingSceneStatus.ACTIVE, TrainingScene.is_deleted == False, ) ) ) if category: count_query = count_query.where(TrainingScene.category == category) if is_public is not None: count_query = count_query.where(TrainingScene.is_public == is_public) result = await db.execute(count_query) total = result.scalar_one() # 计算总页数 pages = (total + page_size - 1) // page_size return ResponseModel( data=PaginatedResponse( items=scenes, total=total, page=page, page_size=page_size, pages=pages ), message="获取陪练场景列表成功", ) except Exception as e: logger.error(f"获取陪练场景列表失败: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取陪练场景列表失败" ) @router.get("/scenes/{scene_id}", response_model=ResponseModel[TrainingSceneResponse]) async def get_training_scene( scene_id: int, current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取陪练场景详情""" scene = await scene_service.get(db, scene_id) if not scene or scene.is_deleted: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在") # 检查访问权限 if not scene.is_public and current_user.get("role") != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此场景") return ResponseModel(data=scene, message="获取陪练场景成功") @router.post("/scenes", response_model=ResponseModel[TrainingSceneResponse]) async def create_training_scene( scene_in: TrainingSceneCreate, current_user: dict = Depends(require_admin), db: AsyncSession = Depends(get_db), ): """ 创建陪练场景(管理员) - 需要管理员权限 - 场景默认为草稿状态 """ try: scene = await scene_service.create_scene( db, scene_in=scene_in, created_by=current_user["id"] ) logger.info(f"管理员 {current_user['id']} 创建了陪练场景: {scene.id}") return ResponseModel(data=scene, message="创建陪练场景成功") except Exception as e: logger.error(f"创建陪练场景失败: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建陪练场景失败" ) @router.put("/scenes/{scene_id}", response_model=ResponseModel[TrainingSceneResponse]) async def update_training_scene( scene_id: int, scene_in: TrainingSceneUpdate, current_user: dict = Depends(require_admin), db: AsyncSession = Depends(get_db), ): """更新陪练场景(管理员)""" scene = await scene_service.update_scene( db, scene_id=scene_id, scene_in=scene_in, updated_by=current_user["id"] ) if not scene: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在") logger.info(f"管理员 {current_user['id']} 更新了陪练场景: {scene_id}") return ResponseModel(data=scene, message="更新陪练场景成功") @router.delete("/scenes/{scene_id}", response_model=ResponseModel[bool]) async def delete_training_scene( scene_id: int, current_user: dict = Depends(require_admin), db: AsyncSession = Depends(get_db), ): """删除陪练场景(管理员)""" success = await scene_service.soft_delete(db, id=scene_id) if not success: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在") logger.info(f"管理员 {current_user['id']} 删除了陪练场景: {scene_id}") return ResponseModel(data=True, message="删除陪练场景成功") # ========== 陪练会话管理 ========== @router.post("/sessions", response_model=ResponseModel[StartTrainingResponse]) async def start_training( request: StartTrainingRequest, current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 开始陪练会话 - 需要登录 - 创建会话记录 - 初始化Coze对话(如果配置了Bot) - 返回会话信息和WebSocket连接地址(如果支持) """ try: response = await session_service.start_training( db, request=request, user_id=current_user["id"] ) logger.info(f"用户 {current_user['id']} 开始陪练会话: {response.session_id}") return ResponseModel(data=response, message="开始陪练成功") except HTTPException: raise except Exception as e: logger.error(f"开始陪练失败: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="开始陪练失败" ) @router.post( "/sessions/{session_id}/end", response_model=ResponseModel[EndTrainingResponse] ) async def end_training( session_id: int, request: EndTrainingRequest, current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 结束陪练会话 - 需要登录且是会话创建者 - 更新会话状态 - 可选生成陪练报告 """ try: response = await session_service.end_training( db, session_id=session_id, request=request, user_id=current_user["id"] ) logger.info(f"用户 {current_user['id']} 结束陪练会话: {session_id}") # 陪练完成时触发经验值和奖章检查 exp_result = None new_badges = [] try: from app.services.level_service import LevelService from app.services.badge_service import BadgeService level_service = LevelService(db) badge_service = BadgeService(db) # 获取陪练得分(如果有报告的话) score = response.get("total_score") if isinstance(response, dict) else None # 添加陪练经验值 exp_result = await level_service.add_training_exp( user_id=current_user["id"], session_id=session_id, score=score ) # 检查是否解锁新奖章 new_badges = await badge_service.check_and_award_badges(current_user["id"]) await db.commit() except Exception as e: logger.warning(f"陪练经验值/奖章处理失败: {str(e)}") # 将经验值结果添加到返回数据 result_data = response if isinstance(response, dict) else {"session_id": session_id} result_data["exp_result"] = exp_result result_data["new_badges"] = new_badges return ResponseModel(data=result_data, message="结束陪练成功") except HTTPException: raise except Exception as e: logger.error(f"结束陪练失败: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="结束陪练失败" ) @router.get( "/sessions", response_model=ResponseModel[PaginatedResponse[TrainingSessionResponse]], ) async def get_training_sessions( scene_id: Optional[int] = Query(None, description="场景ID"), status: Optional[TrainingSessionStatus] = Query(None, description="会话状态"), page: int = Query(1, ge=1, description="页码"), page_size: int = Query(20, ge=1, le=100, description="每页数量"), current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取用户的陪练会话列表""" try: skip = (page - 1) * page_size sessions = await session_service.get_user_sessions( db, user_id=current_user["id"], scene_id=scene_id, status=status, skip=skip, limit=page_size, ) # 获取总数 from sqlalchemy import select, func from app.models.training import TrainingSession count_query = ( select(func.count()) .select_from(TrainingSession) .where(TrainingSession.user_id == current_user["id"]) ) if scene_id: count_query = count_query.where(TrainingSession.scene_id == scene_id) if status: count_query = count_query.where(TrainingSession.status == status) result = await db.execute(count_query) total = result.scalar_one() pages = (total + page_size - 1) // page_size # 加载关联的场景信息 for session in sessions: await db.refresh(session, ["scene"]) return ResponseModel( data=PaginatedResponse( items=sessions, total=total, page=page, page_size=page_size, pages=pages ), message="获取陪练会话列表成功", ) except Exception as e: logger.error(f"获取陪练会话列表失败: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取陪练会话列表失败" ) @router.get( "/sessions/{session_id}", response_model=ResponseModel[TrainingSessionResponse] ) async def get_training_session( session_id: int, current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取陪练会话详情""" session = await session_service.get(db, session_id) if not session: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练会话不存在") # 检查访问权限 if session.user_id != current_user["id"] and current_user.get("role") != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此会话") # 加载关联数据 await db.refresh(session, ["scene"]) # 获取消息数量 messages = await message_service.get_session_messages(db, session_id=session_id) session.message_count = len(messages) return ResponseModel(data=session, message="获取陪练会话成功") # ========== 消息管理 ========== @router.get( "/sessions/{session_id}/messages", response_model=ResponseModel[List[TrainingMessageResponse]], ) async def get_training_messages( session_id: int, skip: int = Query(0, ge=0, description="跳过数量"), limit: int = Query(100, ge=1, le=500, description="返回数量"), current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取陪练会话的消息列表""" # 验证会话访问权限 session = await session_service.get(db, session_id) if not session: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练会话不存在") if session.user_id != current_user["id"] and current_user.get("role") != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此会话消息") messages = await message_service.get_session_messages( db, session_id=session_id, skip=skip, limit=limit ) return ResponseModel(data=messages, message="获取消息列表成功") # ========== 报告管理 ========== @router.get( "/reports", response_model=ResponseModel[PaginatedResponse[TrainingReportResponse]] ) async def get_training_reports( page: int = Query(1, ge=1, description="页码"), page_size: int = Query(20, ge=1, le=100, description="每页数量"), current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取用户的陪练报告列表""" try: skip = (page - 1) * page_size reports = await report_service.get_user_reports( db, user_id=current_user["id"], skip=skip, limit=page_size ) # 获取总数 from sqlalchemy import select, func from app.models.training import TrainingReport count_query = ( select(func.count()) .select_from(TrainingReport) .where(TrainingReport.user_id == current_user["id"]) ) result = await db.execute(count_query) total = result.scalar_one() pages = (total + page_size - 1) // page_size # 加载关联的会话信息 for report in reports: await db.refresh(report, ["session"]) if report.session: await db.refresh(report.session, ["scene"]) return ResponseModel( data=PaginatedResponse( items=reports, total=total, page=page, page_size=page_size, pages=pages ), message="获取陪练报告列表成功", ) except Exception as e: logger.error(f"获取陪练报告列表失败: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取陪练报告列表失败" ) @router.get( "/reports/{report_id}", response_model=ResponseModel[TrainingReportResponse] ) async def get_training_report( report_id: int, current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """获取陪练报告详情""" report = await report_service.get(db, report_id) if not report: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练报告不存在") # 检查访问权限 if report.user_id != current_user["id"] and current_user.get("role") != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此报告") # 加载关联数据 await db.refresh(report, ["session"]) if report.session: await db.refresh(report.session, ["scene"]) return ResponseModel(data=report, message="获取陪练报告成功") @router.get( "/sessions/{session_id}/report", response_model=ResponseModel[TrainingReportResponse], ) async def get_session_report( session_id: int, current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """根据会话ID获取陪练报告""" # 验证会话访问权限 session = await session_service.get(db, session_id) if not session: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练会话不存在") if session.user_id != current_user["id"] and current_user.get("role") != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此会话报告") # 获取报告 report = await report_service.get_by_session(db, session_id=session_id) if not report: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="该会话暂无报告") # 加载关联数据 await db.refresh(report, ["session"]) if report.session: await db.refresh(report.session, ["scene"]) return ResponseModel(data=report, message="获取会话报告成功")