Files
012-kaopeilian/backend/tests/test_training.py
111 998211c483 feat: 初始化考培练系统项目
- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
2026-01-24 19:33:28 +08:00

400 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""陪练模块测试"""
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.training import TrainingScene, TrainingSession, TrainingSceneStatus
from app.services.training_service import TrainingSceneService, TrainingSessionService
class TestTrainingSceneAPI:
"""陪练场景API测试"""
@pytest.mark.asyncio
async def test_get_training_scenes(self, client: AsyncClient, auth_headers: dict):
"""测试获取陪练场景列表"""
response = await client.get(
"/api/v1/training/scenes",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert "data" in data
assert "items" in data["data"]
assert "total" in data["data"]
assert "page" in data["data"]
assert "page_size" in data["data"]
@pytest.mark.asyncio
async def test_create_training_scene_admin_only(
self,
client: AsyncClient,
auth_headers: dict,
admin_auth_headers: dict
):
"""测试创建陪练场景(需要管理员权限)"""
scene_data = {
"name": "面试训练",
"description": "模拟面试场景,提升面试技巧",
"category": "面试",
"ai_config": {
"bot_id": "test_bot_id",
"prompt": "你是一位专业的面试官"
},
"is_public": True
}
# 普通用户无权限
response = await client.post(
"/api/v1/training/scenes",
json=scene_data,
headers=auth_headers
)
assert response.status_code == 403
# 管理员可以创建
# 注意这里需要mock管理员权限检查
# 在实际测试中,需要正确设置依赖覆盖
@pytest.mark.asyncio
async def test_get_training_scene_detail(
self,
client: AsyncClient,
auth_headers: dict,
db_session: AsyncSession
):
"""测试获取陪练场景详情"""
# 创建测试场景
scene_service = TrainingSceneService()
scene = await scene_service.create(
db_session,
obj_in={
"name": "测试场景",
"category": "测试",
"status": TrainingSceneStatus.ACTIVE,
"is_public": True
},
created_by=1,
updated_by=1
)
# 获取场景详情
response = await client.get(
f"/api/v1/training/scenes/{scene.id}",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert data["data"]["id"] == scene.id
assert data["data"]["name"] == "测试场景"
@pytest.mark.asyncio
async def test_get_nonexistent_scene(self, client: AsyncClient, auth_headers: dict):
"""测试获取不存在的场景"""
response = await client.get(
"/api/v1/training/scenes/99999",
headers=auth_headers
)
assert response.status_code == 404
class TestTrainingSessionAPI:
"""陪练会话API测试"""
@pytest.mark.asyncio
async def test_start_training(
self,
client: AsyncClient,
auth_headers: dict,
db_session: AsyncSession
):
"""测试开始陪练"""
# 创建测试场景
scene_service = TrainingSceneService()
scene = await scene_service.create(
db_session,
obj_in={
"name": "测试陪练场景",
"category": "测试",
"status": TrainingSceneStatus.ACTIVE,
"is_public": True
},
created_by=1,
updated_by=1
)
# 开始陪练
response = await client.post(
"/api/v1/training/sessions",
json={
"scene_id": scene.id,
"config": {"key": "value"}
},
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert "session_id" in data["data"]
assert data["data"]["scene"]["id"] == scene.id
@pytest.mark.asyncio
async def test_end_training(
self,
client: AsyncClient,
auth_headers: dict,
db_session: AsyncSession
):
"""测试结束陪练"""
# 创建测试场景和会话
scene_service = TrainingSceneService()
scene = await scene_service.create(
db_session,
obj_in={
"name": "测试场景",
"category": "测试",
"status": TrainingSceneStatus.ACTIVE,
"is_public": True
},
created_by=1,
updated_by=1
)
session_service = TrainingSessionService()
session = await session_service.create(
db_session,
obj_in={
"scene_id": scene.id,
"session_config": {}
},
user_id=1,
created_by=1
)
# 结束陪练
response = await client.post(
f"/api/v1/training/sessions/{session.id}/end",
json={"generate_report": True},
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert data["data"]["session"]["status"] == "completed"
@pytest.mark.asyncio
async def test_get_user_sessions(self, client: AsyncClient, auth_headers: dict):
"""测试获取用户会话列表"""
response = await client.get(
"/api/v1/training/sessions",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert "items" in data["data"]
assert isinstance(data["data"]["items"], list)
@pytest.mark.asyncio
async def test_get_session_messages(
self,
client: AsyncClient,
auth_headers: dict,
db_session: AsyncSession
):
"""测试获取会话消息"""
# 创建测试数据
scene_service = TrainingSceneService()
scene = await scene_service.create(
db_session,
obj_in={
"name": "测试场景",
"category": "测试",
"status": TrainingSceneStatus.ACTIVE,
"is_public": True
},
created_by=1,
updated_by=1
)
session_service = TrainingSessionService()
session = await session_service.create(
db_session,
obj_in={
"scene_id": scene.id,
"session_config": {}
},
user_id=1,
created_by=1
)
# 获取消息
response = await client.get(
f"/api/v1/training/sessions/{session.id}/messages",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert isinstance(data["data"], list)
class TestTrainingReportAPI:
"""陪练报告API测试"""
@pytest.mark.asyncio
async def test_get_user_reports(self, client: AsyncClient, auth_headers: dict):
"""测试获取用户报告列表"""
response = await client.get(
"/api/v1/training/reports",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert "items" in data["data"]
assert isinstance(data["data"]["items"], list)
@pytest.mark.asyncio
async def test_get_report_by_session(
self,
client: AsyncClient,
auth_headers: dict,
db_session: AsyncSession
):
"""测试根据会话ID获取报告"""
# 创建测试数据
scene_service = TrainingSceneService()
scene = await scene_service.create(
db_session,
obj_in={
"name": "测试场景",
"category": "测试",
"status": TrainingSceneStatus.ACTIVE,
"is_public": True
},
created_by=1,
updated_by=1
)
session_service = TrainingSessionService()
session = await session_service.create(
db_session,
obj_in={
"scene_id": scene.id,
"session_config": {}
},
user_id=1,
created_by=1
)
# 获取报告(会话还没有报告)
response = await client.get(
f"/api/v1/training/sessions/{session.id}/report",
headers=auth_headers
)
assert response.status_code == 404
class TestTrainingService:
"""陪练服务层测试"""
@pytest.mark.asyncio
async def test_scene_service_crud(self, db_session: AsyncSession):
"""测试场景服务的CRUD操作"""
scene_service = TrainingSceneService()
# 创建
scene = await scene_service.create_scene(
db_session,
scene_in={
"name": "演讲训练",
"description": "提升演讲能力",
"category": "演讲",
"status": TrainingSceneStatus.ACTIVE
},
created_by=1
)
assert scene.id is not None
assert scene.name == "演讲训练"
# 读取
retrieved = await scene_service.get(db_session, scene.id)
assert retrieved is not None
assert retrieved.id == scene.id
# 更新
updated = await scene_service.update_scene(
db_session,
scene_id=scene.id,
scene_in={"description": "提升公众演讲能力"},
updated_by=1
)
assert updated is not None
assert updated.description == "提升公众演讲能力"
# 软删除
success = await scene_service.soft_delete(db_session, id=scene.id)
assert success is True
# 验证软删除
deleted = await scene_service.get(db_session, scene.id)
assert deleted.is_deleted is True
@pytest.mark.asyncio
async def test_session_lifecycle(self, db_session: AsyncSession):
"""测试会话生命周期"""
# 创建场景
scene_service = TrainingSceneService()
scene = await scene_service.create(
db_session,
obj_in={
"name": "测试场景",
"category": "测试",
"status": TrainingSceneStatus.ACTIVE,
"is_public": True
},
created_by=1,
updated_by=1
)
# 开始会话
session_service = TrainingSessionService()
start_response = await session_service.start_training(
db_session,
request={"scene_id": scene.id},
user_id=1
)
assert start_response.session_id is not None
# 结束会话
end_response = await session_service.end_training(
db_session,
session_id=start_response.session_id,
request={"generate_report": True},
user_id=1
)
assert end_response.session.status == "completed"
assert end_response.session.duration_seconds is not None
# 报告应该被生成
if end_response.report:
assert end_response.report.overall_score > 0
assert len(end_response.report.strengths) > 0
assert len(end_response.report.suggestions) > 0