"""陪练模块测试""" import pytest from httpx import AsyncClient from sqlalchemy.ext.asyncio import AsyncSession from app.models.training import TrainingScene, TrainingSession, TrainingSceneStatus from app.services.training_service import TrainingSceneService, TrainingSessionService class TestTrainingSceneAPI: """陪练场景API测试""" @pytest.mark.asyncio async def test_get_training_scenes(self, client: AsyncClient, auth_headers: dict): """测试获取陪练场景列表""" response = await client.get( "/api/v1/training/scenes", headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert "data" in data assert "items" in data["data"] assert "total" in data["data"] assert "page" in data["data"] assert "page_size" in data["data"] @pytest.mark.asyncio async def test_create_training_scene_admin_only( self, client: AsyncClient, auth_headers: dict, admin_auth_headers: dict ): """测试创建陪练场景(需要管理员权限)""" scene_data = { "name": "面试训练", "description": "模拟面试场景,提升面试技巧", "category": "面试", "ai_config": { "bot_id": "test_bot_id", "prompt": "你是一位专业的面试官" }, "is_public": True } # 普通用户无权限 response = await client.post( "/api/v1/training/scenes", json=scene_data, headers=auth_headers ) assert response.status_code == 403 # 管理员可以创建 # 注意:这里需要mock管理员权限检查 # 在实际测试中,需要正确设置依赖覆盖 @pytest.mark.asyncio async def test_get_training_scene_detail( self, client: AsyncClient, auth_headers: dict, db_session: AsyncSession ): """测试获取陪练场景详情""" # 创建测试场景 scene_service = TrainingSceneService() scene = await scene_service.create( db_session, obj_in={ "name": "测试场景", "category": "测试", "status": TrainingSceneStatus.ACTIVE, "is_public": True }, created_by=1, updated_by=1 ) # 获取场景详情 response = await client.get( f"/api/v1/training/scenes/{scene.id}", headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert data["data"]["id"] == scene.id assert data["data"]["name"] == "测试场景" @pytest.mark.asyncio async def test_get_nonexistent_scene(self, client: AsyncClient, auth_headers: dict): """测试获取不存在的场景""" response = await client.get( "/api/v1/training/scenes/99999", headers=auth_headers ) assert response.status_code == 404 class TestTrainingSessionAPI: """陪练会话API测试""" @pytest.mark.asyncio async def test_start_training( self, client: AsyncClient, auth_headers: dict, db_session: AsyncSession ): """测试开始陪练""" # 创建测试场景 scene_service = TrainingSceneService() scene = await scene_service.create( db_session, obj_in={ "name": "测试陪练场景", "category": "测试", "status": TrainingSceneStatus.ACTIVE, "is_public": True }, created_by=1, updated_by=1 ) # 开始陪练 response = await client.post( "/api/v1/training/sessions", json={ "scene_id": scene.id, "config": {"key": "value"} }, headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert "session_id" in data["data"] assert data["data"]["scene"]["id"] == scene.id @pytest.mark.asyncio async def test_end_training( self, client: AsyncClient, auth_headers: dict, db_session: AsyncSession ): """测试结束陪练""" # 创建测试场景和会话 scene_service = TrainingSceneService() scene = await scene_service.create( db_session, obj_in={ "name": "测试场景", "category": "测试", "status": TrainingSceneStatus.ACTIVE, "is_public": True }, created_by=1, updated_by=1 ) session_service = TrainingSessionService() session = await session_service.create( db_session, obj_in={ "scene_id": scene.id, "session_config": {} }, user_id=1, created_by=1 ) # 结束陪练 response = await client.post( f"/api/v1/training/sessions/{session.id}/end", json={"generate_report": True}, headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert data["data"]["session"]["status"] == "completed" @pytest.mark.asyncio async def test_get_user_sessions(self, client: AsyncClient, auth_headers: dict): """测试获取用户会话列表""" response = await client.get( "/api/v1/training/sessions", headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert "items" in data["data"] assert isinstance(data["data"]["items"], list) @pytest.mark.asyncio async def test_get_session_messages( self, client: AsyncClient, auth_headers: dict, db_session: AsyncSession ): """测试获取会话消息""" # 创建测试数据 scene_service = TrainingSceneService() scene = await scene_service.create( db_session, obj_in={ "name": "测试场景", "category": "测试", "status": TrainingSceneStatus.ACTIVE, "is_public": True }, created_by=1, updated_by=1 ) session_service = TrainingSessionService() session = await session_service.create( db_session, obj_in={ "scene_id": scene.id, "session_config": {} }, user_id=1, created_by=1 ) # 获取消息 response = await client.get( f"/api/v1/training/sessions/{session.id}/messages", headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert isinstance(data["data"], list) class TestTrainingReportAPI: """陪练报告API测试""" @pytest.mark.asyncio async def test_get_user_reports(self, client: AsyncClient, auth_headers: dict): """测试获取用户报告列表""" response = await client.get( "/api/v1/training/reports", headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert "items" in data["data"] assert isinstance(data["data"]["items"], list) @pytest.mark.asyncio async def test_get_report_by_session( self, client: AsyncClient, auth_headers: dict, db_session: AsyncSession ): """测试根据会话ID获取报告""" # 创建测试数据 scene_service = TrainingSceneService() scene = await scene_service.create( db_session, obj_in={ "name": "测试场景", "category": "测试", "status": TrainingSceneStatus.ACTIVE, "is_public": True }, created_by=1, updated_by=1 ) session_service = TrainingSessionService() session = await session_service.create( db_session, obj_in={ "scene_id": scene.id, "session_config": {} }, user_id=1, created_by=1 ) # 获取报告(会话还没有报告) response = await client.get( f"/api/v1/training/sessions/{session.id}/report", headers=auth_headers ) assert response.status_code == 404 class TestTrainingService: """陪练服务层测试""" @pytest.mark.asyncio async def test_scene_service_crud(self, db_session: AsyncSession): """测试场景服务的CRUD操作""" scene_service = TrainingSceneService() # 创建 scene = await scene_service.create_scene( db_session, scene_in={ "name": "演讲训练", "description": "提升演讲能力", "category": "演讲", "status": TrainingSceneStatus.ACTIVE }, created_by=1 ) assert scene.id is not None assert scene.name == "演讲训练" # 读取 retrieved = await scene_service.get(db_session, scene.id) assert retrieved is not None assert retrieved.id == scene.id # 更新 updated = await scene_service.update_scene( db_session, scene_id=scene.id, scene_in={"description": "提升公众演讲能力"}, updated_by=1 ) assert updated is not None assert updated.description == "提升公众演讲能力" # 软删除 success = await scene_service.soft_delete(db_session, id=scene.id) assert success is True # 验证软删除 deleted = await scene_service.get(db_session, scene.id) assert deleted.is_deleted is True @pytest.mark.asyncio async def test_session_lifecycle(self, db_session: AsyncSession): """测试会话生命周期""" # 创建场景 scene_service = TrainingSceneService() scene = await scene_service.create( db_session, obj_in={ "name": "测试场景", "category": "测试", "status": TrainingSceneStatus.ACTIVE, "is_public": True }, created_by=1, updated_by=1 ) # 开始会话 session_service = TrainingSessionService() start_response = await session_service.start_training( db_session, request={"scene_id": scene.id}, user_id=1 ) assert start_response.session_id is not None # 结束会话 end_response = await session_service.end_training( db_session, session_id=start_response.session_id, request={"generate_report": True}, user_id=1 ) assert end_response.session.status == "completed" assert end_response.session.duration_seconds is not None # 报告应该被生成 if end_response.report: assert end_response.report.overall_score > 0 assert len(end_response.report.strengths) > 0 assert len(end_response.report.suggestions) > 0