Files
012-kaopeilian/backend/app/services/training_service.py
111 998211c483 feat: 初始化考培练系统项目
- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
2026-01-24 19:33:28 +08:00

373 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""陪练服务层"""
import logging
from typing import List, Optional, Dict, Any
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func
from fastapi import HTTPException, status
from app.models.training import (
TrainingScene,
TrainingSession,
TrainingMessage,
TrainingReport,
TrainingSceneStatus,
TrainingSessionStatus,
MessageRole,
MessageType,
)
from app.schemas.training import (
TrainingSceneCreate,
TrainingSceneUpdate,
TrainingSessionCreate,
TrainingSessionUpdate,
TrainingMessageCreate,
TrainingReportCreate,
StartTrainingRequest,
StartTrainingResponse,
EndTrainingRequest,
EndTrainingResponse,
)
from app.services.base_service import BaseService
# from app.services.ai.coze.client import CozeClient
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class TrainingSceneService(BaseService[TrainingScene]):
"""陪练场景服务"""
def __init__(self):
super().__init__(TrainingScene)
async def get_active_scenes(
self,
db: AsyncSession,
*,
category: Optional[str] = None,
is_public: Optional[bool] = None,
user_level: Optional[int] = None,
skip: int = 0,
limit: int = 20,
) -> List[TrainingScene]:
"""获取激活的陪练场景列表"""
query = select(self.model).where(
and_(
self.model.status == TrainingSceneStatus.ACTIVE,
self.model.is_deleted == False,
)
)
if category:
query = query.where(self.model.category == category)
if is_public is not None:
query = query.where(self.model.is_public == is_public)
if user_level is not None:
query = query.where(
or_(
self.model.required_level == None,
self.model.required_level <= user_level,
)
)
return await self.get_multi(db, skip=skip, limit=limit, query=query)
async def create_scene(
self, db: AsyncSession, *, scene_in: TrainingSceneCreate, created_by: int
) -> TrainingScene:
"""创建陪练场景"""
return await self.create(
db, obj_in=scene_in, created_by=created_by, updated_by=created_by
)
async def update_scene(
self,
db: AsyncSession,
*,
scene_id: int,
scene_in: TrainingSceneUpdate,
updated_by: int,
) -> Optional[TrainingScene]:
"""更新陪练场景"""
scene = await self.get(db, scene_id)
if not scene or scene.is_deleted:
return None
scene.updated_by = updated_by
return await self.update(db, db_obj=scene, obj_in=scene_in)
class TrainingSessionService(BaseService[TrainingSession]):
"""陪练会话服务"""
def __init__(self):
super().__init__(TrainingSession)
self.scene_service = TrainingSceneService()
self.message_service = TrainingMessageService()
self.report_service = TrainingReportService()
# TODO: 等Coze网关模块实现后替换
self._coze_client = None
@property
def coze_client(self):
"""延迟初始化Coze客户端"""
if self._coze_client is None:
try:
# from app.services.ai.coze.client import CozeClient
# self._coze_client = CozeClient()
logger.warning("Coze客户端暂未实现使用模拟模式")
self._coze_client = None
except ImportError:
logger.warning("Coze客户端未实现使用模拟模式")
return self._coze_client
async def start_training(
self, db: AsyncSession, *, request: StartTrainingRequest, user_id: int
) -> StartTrainingResponse:
"""开始陪练会话"""
# 验证场景
scene = await self.scene_service.get(db, request.scene_id)
if not scene or scene.is_deleted or scene.status != TrainingSceneStatus.ACTIVE:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在或未激活"
)
# 检查用户等级
# TODO: 从User服务获取用户等级
user_level = 1 # 临时模拟
if scene.required_level and user_level < scene.required_level:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户等级不足")
# 创建会话
session_data = TrainingSessionCreate(
scene_id=request.scene_id, session_config=request.config
)
session = await self.create(
db, obj_in=session_data, user_id=user_id, created_by=user_id
)
# 初始化Coze会话
coze_conversation_id = None
if self.coze_client and scene.ai_config:
try:
bot_id = scene.ai_config.get("bot_id", settings.coze_training_bot_id)
if bot_id:
# 创建Coze会话
coze_result = await self.coze_client.create_conversation(
bot_id=bot_id,
user_id=str(user_id),
meta_data={
"scene_id": scene.id,
"scene_name": scene.name,
"session_id": session.id,
},
)
coze_conversation_id = coze_result.get("conversation_id")
# 更新会话的Coze ID
session.coze_conversation_id = coze_conversation_id
await db.commit()
except Exception as e:
logger.error(f"创建Coze会话失败: {e}")
# 加载场景信息
await db.refresh(session, ["scene"])
# 构造WebSocket URL如果需要
websocket_url = None
if coze_conversation_id:
websocket_url = f"ws://localhost:8000/ws/v1/training/{session.id}"
return StartTrainingResponse(
session_id=session.id,
coze_conversation_id=coze_conversation_id,
scene=scene,
websocket_url=websocket_url,
)
async def end_training(
self,
db: AsyncSession,
*,
session_id: int,
request: EndTrainingRequest,
user_id: int,
) -> EndTrainingResponse:
"""结束陪练会话"""
# 获取会话
session = await self.get(db, session_id)
if not session or session.user_id != user_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="会话不存在")
if session.status in [
TrainingSessionStatus.COMPLETED,
TrainingSessionStatus.CANCELLED,
]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="会话已结束")
# 计算持续时间
end_time = datetime.now()
duration_seconds = int((end_time - session.start_time).total_seconds())
# 更新会话状态
update_data = TrainingSessionUpdate(
status=TrainingSessionStatus.COMPLETED,
end_time=end_time,
duration_seconds=duration_seconds,
)
session = await self.update(db, db_obj=session, obj_in=update_data)
# 生成报告
report = None
if request.generate_report:
report = await self._generate_report(
db, session_id=session_id, user_id=user_id
)
# 加载关联数据
await db.refresh(session, ["scene"])
if report:
await db.refresh(report, ["session"])
return EndTrainingResponse(session=session, report=report)
async def get_user_sessions(
self,
db: AsyncSession,
*,
user_id: int,
scene_id: Optional[int] = None,
status: Optional[TrainingSessionStatus] = None,
skip: int = 0,
limit: int = 20,
) -> List[TrainingSession]:
"""获取用户的陪练会话列表"""
query = select(self.model).where(self.model.user_id == user_id)
if scene_id:
query = query.where(self.model.scene_id == scene_id)
if status:
query = query.where(self.model.status == status)
query = query.order_by(self.model.created_at.desc())
return await self.get_multi(db, skip=skip, limit=limit, query=query)
async def _generate_report(
self, db: AsyncSession, *, session_id: int, user_id: int
) -> Optional[TrainingReport]:
"""生成陪练报告(内部方法)"""
# 获取会话消息
messages = await self.message_service.get_session_messages(
db, session_id=session_id
)
# TODO: 调用AI分析服务生成报告
# 这里先生成模拟报告
report_data = TrainingReportCreate(
session_id=session_id,
user_id=user_id,
overall_score=85.5,
dimension_scores={"表达能力": 88.0, "逻辑思维": 85.0, "专业知识": 82.0, "应变能力": 87.0},
strengths=["表达清晰,语言流畅", "能够快速理解问题并作出回应", "展现了良好的专业素养"],
weaknesses=["部分专业术语使用不够准确", "回答有时过于冗长,需要更加精炼"],
suggestions=["加强专业知识的学习,特别是术语的准确使用", "练习更加简洁有力的表达方式", "增加实际案例的积累,丰富回答内容"],
detailed_analysis="整体表现良好,展现了扎实的基础知识和良好的沟通能力...",
statistics={
"total_messages": len(messages),
"user_messages": len(
[m for m in messages if m.role == MessageRole.USER]
),
"avg_response_time": 2.5,
"total_words": 1500,
},
)
return await self.report_service.create(
db, obj_in=report_data, created_by=user_id
)
class TrainingMessageService(BaseService[TrainingMessage]):
"""陪练消息服务"""
def __init__(self):
super().__init__(TrainingMessage)
async def create_message(
self, db: AsyncSession, *, message_in: TrainingMessageCreate
) -> TrainingMessage:
"""创建消息"""
return await self.create(db, obj_in=message_in)
async def get_session_messages(
self, db: AsyncSession, *, session_id: int, skip: int = 0, limit: int = 100
) -> List[TrainingMessage]:
"""获取会话的所有消息"""
query = (
select(self.model)
.where(self.model.session_id == session_id)
.order_by(self.model.created_at)
)
return await self.get_multi(db, skip=skip, limit=limit, query=query)
async def save_voice_message(
self,
db: AsyncSession,
*,
session_id: int,
role: MessageRole,
content: str,
voice_url: str,
voice_duration: float,
metadata: Optional[Dict[str, Any]] = None,
) -> TrainingMessage:
"""保存语音消息"""
message_data = TrainingMessageCreate(
session_id=session_id,
role=role,
type=MessageType.VOICE,
content=content,
voice_url=voice_url,
voice_duration=voice_duration,
metadata=metadata,
)
return await self.create(db, obj_in=message_data)
class TrainingReportService(BaseService[TrainingReport]):
"""陪练报告服务"""
def __init__(self):
super().__init__(TrainingReport)
async def get_by_session(
self, db: AsyncSession, *, session_id: int
) -> Optional[TrainingReport]:
"""根据会话ID获取报告"""
result = await db.execute(
select(self.model).where(self.model.session_id == session_id)
)
return result.scalar_one_or_none()
async def get_user_reports(
self, db: AsyncSession, *, user_id: int, skip: int = 0, limit: int = 20
) -> List[TrainingReport]:
"""获取用户的所有报告"""
query = (
select(self.model)
.where(self.model.user_id == user_id)
.order_by(self.model.created_at.desc())
)
return await self.get_multi(db, skip=skip, limit=limit, query=query)