- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
508 lines
16 KiB
Python
508 lines
16 KiB
Python
"""陪练模块API路由"""
|
||
import logging
|
||
from typing import List, Optional
|
||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.deps import get_db, get_current_user, require_admin
|
||
from app.schemas.base import ResponseModel
|
||
from app.schemas.training import (
|
||
TrainingSceneCreate,
|
||
TrainingSceneUpdate,
|
||
TrainingSceneResponse,
|
||
TrainingSessionResponse,
|
||
TrainingMessageResponse,
|
||
TrainingReportResponse,
|
||
StartTrainingRequest,
|
||
StartTrainingResponse,
|
||
EndTrainingRequest,
|
||
EndTrainingResponse,
|
||
TrainingSceneListQuery,
|
||
TrainingSessionListQuery,
|
||
PaginatedResponse,
|
||
)
|
||
from app.services.training_service import (
|
||
TrainingSceneService,
|
||
TrainingSessionService,
|
||
TrainingMessageService,
|
||
TrainingReportService,
|
||
)
|
||
from app.models.training import TrainingSceneStatus, TrainingSessionStatus
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/training", tags=["陪练模块"])
|
||
|
||
# 服务实例
|
||
scene_service = TrainingSceneService()
|
||
session_service = TrainingSessionService()
|
||
message_service = TrainingMessageService()
|
||
report_service = TrainingReportService()
|
||
|
||
|
||
# ========== 陪练场景管理 ==========
|
||
|
||
|
||
@router.get(
|
||
"/scenes", response_model=ResponseModel[PaginatedResponse[TrainingSceneResponse]]
|
||
)
|
||
async def get_training_scenes(
|
||
category: Optional[str] = Query(None, description="场景分类"),
|
||
status: Optional[TrainingSceneStatus] = Query(None, description="场景状态"),
|
||
is_public: Optional[bool] = Query(None, description="是否公开"),
|
||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
获取陪练场景列表
|
||
|
||
- 支持按分类、状态、是否公开筛选
|
||
- 支持关键词搜索
|
||
- 支持分页
|
||
"""
|
||
try:
|
||
# 计算分页参数
|
||
skip = (page - 1) * page_size
|
||
|
||
# 获取用户等级(TODO: 从User服务获取)
|
||
user_level = 1
|
||
|
||
# 获取场景列表
|
||
scenes = await scene_service.get_active_scenes(
|
||
db,
|
||
category=category,
|
||
is_public=is_public,
|
||
user_level=user_level,
|
||
skip=skip,
|
||
limit=page_size,
|
||
)
|
||
|
||
# 获取总数
|
||
from sqlalchemy import select, func, and_
|
||
from app.models.training import TrainingScene
|
||
|
||
count_query = (
|
||
select(func.count())
|
||
.select_from(TrainingScene)
|
||
.where(
|
||
and_(
|
||
TrainingScene.status == TrainingSceneStatus.ACTIVE,
|
||
TrainingScene.is_deleted == False,
|
||
)
|
||
)
|
||
)
|
||
|
||
if category:
|
||
count_query = count_query.where(TrainingScene.category == category)
|
||
if is_public is not None:
|
||
count_query = count_query.where(TrainingScene.is_public == is_public)
|
||
|
||
result = await db.execute(count_query)
|
||
total = result.scalar_one()
|
||
|
||
# 计算总页数
|
||
pages = (total + page_size - 1) // page_size
|
||
|
||
return ResponseModel(
|
||
data=PaginatedResponse(
|
||
items=scenes, total=total, page=page, page_size=page_size, pages=pages
|
||
),
|
||
message="获取陪练场景列表成功",
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取陪练场景列表失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取陪练场景列表失败"
|
||
)
|
||
|
||
|
||
@router.get("/scenes/{scene_id}", response_model=ResponseModel[TrainingSceneResponse])
|
||
async def get_training_scene(
|
||
scene_id: int,
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取陪练场景详情"""
|
||
scene = await scene_service.get(db, scene_id)
|
||
|
||
if not scene or scene.is_deleted:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在")
|
||
|
||
# 检查访问权限
|
||
if not scene.is_public and current_user.get("role") != "admin":
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此场景")
|
||
|
||
return ResponseModel(data=scene, message="获取陪练场景成功")
|
||
|
||
|
||
@router.post("/scenes", response_model=ResponseModel[TrainingSceneResponse])
|
||
async def create_training_scene(
|
||
scene_in: TrainingSceneCreate,
|
||
current_user: dict = Depends(require_admin),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
创建陪练场景(管理员)
|
||
|
||
- 需要管理员权限
|
||
- 场景默认为草稿状态
|
||
"""
|
||
try:
|
||
scene = await scene_service.create_scene(
|
||
db, scene_in=scene_in, created_by=current_user["id"]
|
||
)
|
||
|
||
logger.info(f"管理员 {current_user['id']} 创建了陪练场景: {scene.id}")
|
||
|
||
return ResponseModel(data=scene, message="创建陪练场景成功")
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建陪练场景失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建陪练场景失败"
|
||
)
|
||
|
||
|
||
@router.put("/scenes/{scene_id}", response_model=ResponseModel[TrainingSceneResponse])
|
||
async def update_training_scene(
|
||
scene_id: int,
|
||
scene_in: TrainingSceneUpdate,
|
||
current_user: dict = Depends(require_admin),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""更新陪练场景(管理员)"""
|
||
scene = await scene_service.update_scene(
|
||
db, scene_id=scene_id, scene_in=scene_in, updated_by=current_user["id"]
|
||
)
|
||
|
||
if not scene:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在")
|
||
|
||
logger.info(f"管理员 {current_user['id']} 更新了陪练场景: {scene_id}")
|
||
|
||
return ResponseModel(data=scene, message="更新陪练场景成功")
|
||
|
||
|
||
@router.delete("/scenes/{scene_id}", response_model=ResponseModel[bool])
|
||
async def delete_training_scene(
|
||
scene_id: int,
|
||
current_user: dict = Depends(require_admin),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""删除陪练场景(管理员)"""
|
||
success = await scene_service.soft_delete(db, id=scene_id)
|
||
|
||
if not success:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在")
|
||
|
||
logger.info(f"管理员 {current_user['id']} 删除了陪练场景: {scene_id}")
|
||
|
||
return ResponseModel(data=True, message="删除陪练场景成功")
|
||
|
||
|
||
# ========== 陪练会话管理 ==========
|
||
|
||
|
||
@router.post("/sessions", response_model=ResponseModel[StartTrainingResponse])
|
||
async def start_training(
|
||
request: StartTrainingRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
开始陪练会话
|
||
|
||
- 需要登录
|
||
- 创建会话记录
|
||
- 初始化Coze对话(如果配置了Bot)
|
||
- 返回会话信息和WebSocket连接地址(如果支持)
|
||
"""
|
||
try:
|
||
response = await session_service.start_training(
|
||
db, request=request, user_id=current_user["id"]
|
||
)
|
||
|
||
logger.info(f"用户 {current_user['id']} 开始陪练会话: {response.session_id}")
|
||
|
||
return ResponseModel(data=response, message="开始陪练成功")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"开始陪练失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="开始陪练失败"
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/sessions/{session_id}/end", response_model=ResponseModel[EndTrainingResponse]
|
||
)
|
||
async def end_training(
|
||
session_id: int,
|
||
request: EndTrainingRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
结束陪练会话
|
||
|
||
- 需要登录且是会话创建者
|
||
- 更新会话状态
|
||
- 可选生成陪练报告
|
||
"""
|
||
try:
|
||
response = await session_service.end_training(
|
||
db, session_id=session_id, request=request, user_id=current_user["id"]
|
||
)
|
||
|
||
logger.info(f"用户 {current_user['id']} 结束陪练会话: {session_id}")
|
||
|
||
return ResponseModel(data=response, message="结束陪练成功")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"结束陪练失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="结束陪练失败"
|
||
)
|
||
|
||
|
||
@router.get(
|
||
"/sessions",
|
||
response_model=ResponseModel[PaginatedResponse[TrainingSessionResponse]],
|
||
)
|
||
async def get_training_sessions(
|
||
scene_id: Optional[int] = Query(None, description="场景ID"),
|
||
status: Optional[TrainingSessionStatus] = Query(None, description="会话状态"),
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取用户的陪练会话列表"""
|
||
try:
|
||
skip = (page - 1) * page_size
|
||
|
||
sessions = await session_service.get_user_sessions(
|
||
db,
|
||
user_id=current_user["id"],
|
||
scene_id=scene_id,
|
||
status=status,
|
||
skip=skip,
|
||
limit=page_size,
|
||
)
|
||
|
||
# 获取总数
|
||
from sqlalchemy import select, func
|
||
from app.models.training import TrainingSession
|
||
|
||
count_query = (
|
||
select(func.count())
|
||
.select_from(TrainingSession)
|
||
.where(TrainingSession.user_id == current_user["id"])
|
||
)
|
||
|
||
if scene_id:
|
||
count_query = count_query.where(TrainingSession.scene_id == scene_id)
|
||
if status:
|
||
count_query = count_query.where(TrainingSession.status == status)
|
||
|
||
result = await db.execute(count_query)
|
||
total = result.scalar_one()
|
||
|
||
pages = (total + page_size - 1) // page_size
|
||
|
||
# 加载关联的场景信息
|
||
for session in sessions:
|
||
await db.refresh(session, ["scene"])
|
||
|
||
return ResponseModel(
|
||
data=PaginatedResponse(
|
||
items=sessions, total=total, page=page, page_size=page_size, pages=pages
|
||
),
|
||
message="获取陪练会话列表成功",
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取陪练会话列表失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取陪练会话列表失败"
|
||
)
|
||
|
||
|
||
@router.get(
|
||
"/sessions/{session_id}", response_model=ResponseModel[TrainingSessionResponse]
|
||
)
|
||
async def get_training_session(
|
||
session_id: int,
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取陪练会话详情"""
|
||
session = await session_service.get(db, session_id)
|
||
|
||
if not session:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练会话不存在")
|
||
|
||
# 检查访问权限
|
||
if session.user_id != current_user["id"] and current_user.get("role") != "admin":
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此会话")
|
||
|
||
# 加载关联数据
|
||
await db.refresh(session, ["scene"])
|
||
|
||
# 获取消息数量
|
||
messages = await message_service.get_session_messages(db, session_id=session_id)
|
||
session.message_count = len(messages)
|
||
|
||
return ResponseModel(data=session, message="获取陪练会话成功")
|
||
|
||
|
||
# ========== 消息管理 ==========
|
||
|
||
|
||
@router.get(
|
||
"/sessions/{session_id}/messages",
|
||
response_model=ResponseModel[List[TrainingMessageResponse]],
|
||
)
|
||
async def get_training_messages(
|
||
session_id: int,
|
||
skip: int = Query(0, ge=0, description="跳过数量"),
|
||
limit: int = Query(100, ge=1, le=500, description="返回数量"),
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取陪练会话的消息列表"""
|
||
# 验证会话访问权限
|
||
session = await session_service.get(db, session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练会话不存在")
|
||
|
||
if session.user_id != current_user["id"] and current_user.get("role") != "admin":
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此会话消息")
|
||
|
||
messages = await message_service.get_session_messages(
|
||
db, session_id=session_id, skip=skip, limit=limit
|
||
)
|
||
|
||
return ResponseModel(data=messages, message="获取消息列表成功")
|
||
|
||
|
||
# ========== 报告管理 ==========
|
||
|
||
|
||
@router.get(
|
||
"/reports", response_model=ResponseModel[PaginatedResponse[TrainingReportResponse]]
|
||
)
|
||
async def get_training_reports(
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取用户的陪练报告列表"""
|
||
try:
|
||
skip = (page - 1) * page_size
|
||
|
||
reports = await report_service.get_user_reports(
|
||
db, user_id=current_user["id"], skip=skip, limit=page_size
|
||
)
|
||
|
||
# 获取总数
|
||
from sqlalchemy import select, func
|
||
from app.models.training import TrainingReport
|
||
|
||
count_query = (
|
||
select(func.count())
|
||
.select_from(TrainingReport)
|
||
.where(TrainingReport.user_id == current_user["id"])
|
||
)
|
||
|
||
result = await db.execute(count_query)
|
||
total = result.scalar_one()
|
||
|
||
pages = (total + page_size - 1) // page_size
|
||
|
||
# 加载关联的会话信息
|
||
for report in reports:
|
||
await db.refresh(report, ["session"])
|
||
if report.session:
|
||
await db.refresh(report.session, ["scene"])
|
||
|
||
return ResponseModel(
|
||
data=PaginatedResponse(
|
||
items=reports, total=total, page=page, page_size=page_size, pages=pages
|
||
),
|
||
message="获取陪练报告列表成功",
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取陪练报告列表失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取陪练报告列表失败"
|
||
)
|
||
|
||
|
||
@router.get(
|
||
"/reports/{report_id}", response_model=ResponseModel[TrainingReportResponse]
|
||
)
|
||
async def get_training_report(
|
||
report_id: int,
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取陪练报告详情"""
|
||
report = await report_service.get(db, report_id)
|
||
|
||
if not report:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练报告不存在")
|
||
|
||
# 检查访问权限
|
||
if report.user_id != current_user["id"] and current_user.get("role") != "admin":
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此报告")
|
||
|
||
# 加载关联数据
|
||
await db.refresh(report, ["session"])
|
||
if report.session:
|
||
await db.refresh(report.session, ["scene"])
|
||
|
||
return ResponseModel(data=report, message="获取陪练报告成功")
|
||
|
||
|
||
@router.get(
|
||
"/sessions/{session_id}/report",
|
||
response_model=ResponseModel[TrainingReportResponse],
|
||
)
|
||
async def get_session_report(
|
||
session_id: int,
|
||
current_user: dict = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""根据会话ID获取陪练报告"""
|
||
# 验证会话访问权限
|
||
session = await session_service.get(db, session_id)
|
||
if not session:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练会话不存在")
|
||
|
||
if session.user_id != current_user["id"] and current_user.get("role") != "admin":
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此会话报告")
|
||
|
||
# 获取报告
|
||
report = await report_service.get_by_session(db, session_id=session_id)
|
||
|
||
if not report:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="该会话暂无报告")
|
||
|
||
# 加载关联数据
|
||
await db.refresh(report, ["session"])
|
||
if report.session:
|
||
await db.refresh(report.session, ["scene"])
|
||
|
||
return ResponseModel(data=report, message="获取会话报告成功")
|