"""陪练服务层""" import logging from typing import List, Optional, Dict, Any from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, or_, func from fastapi import HTTPException, status from app.models.training import ( TrainingScene, TrainingSession, TrainingMessage, TrainingReport, TrainingSceneStatus, TrainingSessionStatus, MessageRole, MessageType, ) from app.schemas.training import ( TrainingSceneCreate, TrainingSceneUpdate, TrainingSessionCreate, TrainingSessionUpdate, TrainingMessageCreate, TrainingReportCreate, StartTrainingRequest, StartTrainingResponse, EndTrainingRequest, EndTrainingResponse, ) from app.services.base_service import BaseService # from app.services.ai.coze.client import CozeClient from app.core.config import get_settings logger = logging.getLogger(__name__) settings = get_settings() class TrainingSceneService(BaseService[TrainingScene]): """陪练场景服务""" def __init__(self): super().__init__(TrainingScene) async def get_active_scenes( self, db: AsyncSession, *, category: Optional[str] = None, is_public: Optional[bool] = None, user_level: Optional[int] = None, skip: int = 0, limit: int = 20, ) -> List[TrainingScene]: """获取激活的陪练场景列表""" query = select(self.model).where( and_( self.model.status == TrainingSceneStatus.ACTIVE, self.model.is_deleted == False, ) ) if category: query = query.where(self.model.category == category) if is_public is not None: query = query.where(self.model.is_public == is_public) if user_level is not None: query = query.where( or_( self.model.required_level == None, self.model.required_level <= user_level, ) ) return await self.get_multi(db, skip=skip, limit=limit, query=query) async def create_scene( self, db: AsyncSession, *, scene_in: TrainingSceneCreate, created_by: int ) -> TrainingScene: """创建陪练场景""" return await self.create( db, obj_in=scene_in, created_by=created_by, updated_by=created_by ) async def update_scene( self, db: AsyncSession, *, scene_id: int, scene_in: TrainingSceneUpdate, updated_by: int, ) -> Optional[TrainingScene]: """更新陪练场景""" scene = await self.get(db, scene_id) if not scene or scene.is_deleted: return None scene.updated_by = updated_by return await self.update(db, db_obj=scene, obj_in=scene_in) class TrainingSessionService(BaseService[TrainingSession]): """陪练会话服务""" def __init__(self): super().__init__(TrainingSession) self.scene_service = TrainingSceneService() self.message_service = TrainingMessageService() self.report_service = TrainingReportService() # TODO: 等Coze网关模块实现后替换 self._coze_client = None @property def coze_client(self): """延迟初始化Coze客户端""" if self._coze_client is None: try: # from app.services.ai.coze.client import CozeClient # self._coze_client = CozeClient() logger.warning("Coze客户端暂未实现,使用模拟模式") self._coze_client = None except ImportError: logger.warning("Coze客户端未实现,使用模拟模式") return self._coze_client async def start_training( self, db: AsyncSession, *, request: StartTrainingRequest, user_id: int ) -> StartTrainingResponse: """开始陪练会话""" # 验证场景 scene = await self.scene_service.get(db, request.scene_id) if not scene or scene.is_deleted or scene.status != TrainingSceneStatus.ACTIVE: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在或未激活" ) # 检查用户等级 # TODO: 从User服务获取用户等级 user_level = 1 # 临时模拟 if scene.required_level and user_level < scene.required_level: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户等级不足") # 创建会话 session_data = TrainingSessionCreate( scene_id=request.scene_id, session_config=request.config ) session = await self.create( db, obj_in=session_data, user_id=user_id, created_by=user_id ) # 初始化Coze会话 coze_conversation_id = None if self.coze_client and scene.ai_config: try: bot_id = scene.ai_config.get("bot_id", settings.coze_training_bot_id) if bot_id: # 创建Coze会话 coze_result = await self.coze_client.create_conversation( bot_id=bot_id, user_id=str(user_id), meta_data={ "scene_id": scene.id, "scene_name": scene.name, "session_id": session.id, }, ) coze_conversation_id = coze_result.get("conversation_id") # 更新会话的Coze ID session.coze_conversation_id = coze_conversation_id await db.commit() except Exception as e: logger.error(f"创建Coze会话失败: {e}") # 加载场景信息 await db.refresh(session, ["scene"]) # 构造WebSocket URL(如果需要) websocket_url = None if coze_conversation_id: websocket_url = f"ws://localhost:8000/ws/v1/training/{session.id}" return StartTrainingResponse( session_id=session.id, coze_conversation_id=coze_conversation_id, scene=scene, websocket_url=websocket_url, ) async def end_training( self, db: AsyncSession, *, session_id: int, request: EndTrainingRequest, user_id: int, ) -> EndTrainingResponse: """结束陪练会话""" # 获取会话 session = await self.get(db, session_id) if not session or session.user_id != user_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="会话不存在") if session.status in [ TrainingSessionStatus.COMPLETED, TrainingSessionStatus.CANCELLED, ]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="会话已结束") # 计算持续时间 end_time = datetime.now() duration_seconds = int((end_time - session.start_time).total_seconds()) # 更新会话状态 update_data = TrainingSessionUpdate( status=TrainingSessionStatus.COMPLETED, end_time=end_time, duration_seconds=duration_seconds, ) session = await self.update(db, db_obj=session, obj_in=update_data) # 生成报告 report = None if request.generate_report: report = await self._generate_report( db, session_id=session_id, user_id=user_id ) # 加载关联数据 await db.refresh(session, ["scene"]) if report: await db.refresh(report, ["session"]) return EndTrainingResponse(session=session, report=report) async def get_user_sessions( self, db: AsyncSession, *, user_id: int, scene_id: Optional[int] = None, status: Optional[TrainingSessionStatus] = None, skip: int = 0, limit: int = 20, ) -> List[TrainingSession]: """获取用户的陪练会话列表""" query = select(self.model).where(self.model.user_id == user_id) if scene_id: query = query.where(self.model.scene_id == scene_id) if status: query = query.where(self.model.status == status) query = query.order_by(self.model.created_at.desc()) return await self.get_multi(db, skip=skip, limit=limit, query=query) async def _generate_report( self, db: AsyncSession, *, session_id: int, user_id: int ) -> Optional[TrainingReport]: """生成陪练报告(内部方法)""" # 获取会话消息 messages = await self.message_service.get_session_messages( db, session_id=session_id ) # TODO: 调用AI分析服务生成报告 # 这里先生成模拟报告 report_data = TrainingReportCreate( session_id=session_id, user_id=user_id, overall_score=85.5, dimension_scores={"表达能力": 88.0, "逻辑思维": 85.0, "专业知识": 82.0, "应变能力": 87.0}, strengths=["表达清晰,语言流畅", "能够快速理解问题并作出回应", "展现了良好的专业素养"], weaknesses=["部分专业术语使用不够准确", "回答有时过于冗长,需要更加精炼"], suggestions=["加强专业知识的学习,特别是术语的准确使用", "练习更加简洁有力的表达方式", "增加实际案例的积累,丰富回答内容"], detailed_analysis="整体表现良好,展现了扎实的基础知识和良好的沟通能力...", statistics={ "total_messages": len(messages), "user_messages": len( [m for m in messages if m.role == MessageRole.USER] ), "avg_response_time": 2.5, "total_words": 1500, }, ) return await self.report_service.create( db, obj_in=report_data, created_by=user_id ) class TrainingMessageService(BaseService[TrainingMessage]): """陪练消息服务""" def __init__(self): super().__init__(TrainingMessage) async def create_message( self, db: AsyncSession, *, message_in: TrainingMessageCreate ) -> TrainingMessage: """创建消息""" return await self.create(db, obj_in=message_in) async def get_session_messages( self, db: AsyncSession, *, session_id: int, skip: int = 0, limit: int = 100 ) -> List[TrainingMessage]: """获取会话的所有消息""" query = ( select(self.model) .where(self.model.session_id == session_id) .order_by(self.model.created_at) ) return await self.get_multi(db, skip=skip, limit=limit, query=query) async def save_voice_message( self, db: AsyncSession, *, session_id: int, role: MessageRole, content: str, voice_url: str, voice_duration: float, metadata: Optional[Dict[str, Any]] = None, ) -> TrainingMessage: """保存语音消息""" message_data = TrainingMessageCreate( session_id=session_id, role=role, type=MessageType.VOICE, content=content, voice_url=voice_url, voice_duration=voice_duration, metadata=metadata, ) return await self.create(db, obj_in=message_data) class TrainingReportService(BaseService[TrainingReport]): """陪练报告服务""" def __init__(self): super().__init__(TrainingReport) async def get_by_session( self, db: AsyncSession, *, session_id: int ) -> Optional[TrainingReport]: """根据会话ID获取报告""" result = await db.execute( select(self.model).where(self.model.session_id == session_id) ) return result.scalar_one_or_none() async def get_user_reports( self, db: AsyncSession, *, user_id: int, skip: int = 0, limit: int = 20 ) -> List[TrainingReport]: """获取用户的所有报告""" query = ( select(self.model) .where(self.model.user_id == user_id) .order_by(self.model.created_at.desc()) ) return await self.get_multi(db, skip=skip, limit=limit, query=query)