- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
373 lines
12 KiB
Python
373 lines
12 KiB
Python
"""陪练服务层"""
|
||
import logging
|
||
from typing import List, Optional, Dict, Any
|
||
from datetime import datetime
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, and_, or_, func
|
||
from fastapi import HTTPException, status
|
||
|
||
from app.models.training import (
|
||
TrainingScene,
|
||
TrainingSession,
|
||
TrainingMessage,
|
||
TrainingReport,
|
||
TrainingSceneStatus,
|
||
TrainingSessionStatus,
|
||
MessageRole,
|
||
MessageType,
|
||
)
|
||
from app.schemas.training import (
|
||
TrainingSceneCreate,
|
||
TrainingSceneUpdate,
|
||
TrainingSessionCreate,
|
||
TrainingSessionUpdate,
|
||
TrainingMessageCreate,
|
||
TrainingReportCreate,
|
||
StartTrainingRequest,
|
||
StartTrainingResponse,
|
||
EndTrainingRequest,
|
||
EndTrainingResponse,
|
||
)
|
||
from app.services.base_service import BaseService
|
||
|
||
# from app.services.ai.coze.client import CozeClient
|
||
from app.core.config import get_settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
settings = get_settings()
|
||
|
||
|
||
class TrainingSceneService(BaseService[TrainingScene]):
|
||
"""陪练场景服务"""
|
||
|
||
def __init__(self):
|
||
super().__init__(TrainingScene)
|
||
|
||
async def get_active_scenes(
|
||
self,
|
||
db: AsyncSession,
|
||
*,
|
||
category: Optional[str] = None,
|
||
is_public: Optional[bool] = None,
|
||
user_level: Optional[int] = None,
|
||
skip: int = 0,
|
||
limit: int = 20,
|
||
) -> List[TrainingScene]:
|
||
"""获取激活的陪练场景列表"""
|
||
query = select(self.model).where(
|
||
and_(
|
||
self.model.status == TrainingSceneStatus.ACTIVE,
|
||
self.model.is_deleted == False,
|
||
)
|
||
)
|
||
|
||
if category:
|
||
query = query.where(self.model.category == category)
|
||
|
||
if is_public is not None:
|
||
query = query.where(self.model.is_public == is_public)
|
||
|
||
if user_level is not None:
|
||
query = query.where(
|
||
or_(
|
||
self.model.required_level == None,
|
||
self.model.required_level <= user_level,
|
||
)
|
||
)
|
||
|
||
return await self.get_multi(db, skip=skip, limit=limit, query=query)
|
||
|
||
async def create_scene(
|
||
self, db: AsyncSession, *, scene_in: TrainingSceneCreate, created_by: int
|
||
) -> TrainingScene:
|
||
"""创建陪练场景"""
|
||
return await self.create(
|
||
db, obj_in=scene_in, created_by=created_by, updated_by=created_by
|
||
)
|
||
|
||
async def update_scene(
|
||
self,
|
||
db: AsyncSession,
|
||
*,
|
||
scene_id: int,
|
||
scene_in: TrainingSceneUpdate,
|
||
updated_by: int,
|
||
) -> Optional[TrainingScene]:
|
||
"""更新陪练场景"""
|
||
scene = await self.get(db, scene_id)
|
||
if not scene or scene.is_deleted:
|
||
return None
|
||
|
||
scene.updated_by = updated_by
|
||
return await self.update(db, db_obj=scene, obj_in=scene_in)
|
||
|
||
|
||
class TrainingSessionService(BaseService[TrainingSession]):
|
||
"""陪练会话服务"""
|
||
|
||
def __init__(self):
|
||
super().__init__(TrainingSession)
|
||
self.scene_service = TrainingSceneService()
|
||
self.message_service = TrainingMessageService()
|
||
self.report_service = TrainingReportService()
|
||
# TODO: 等Coze网关模块实现后替换
|
||
self._coze_client = None
|
||
|
||
@property
|
||
def coze_client(self):
|
||
"""延迟初始化Coze客户端"""
|
||
if self._coze_client is None:
|
||
try:
|
||
# from app.services.ai.coze.client import CozeClient
|
||
# self._coze_client = CozeClient()
|
||
logger.warning("Coze客户端暂未实现,使用模拟模式")
|
||
self._coze_client = None
|
||
except ImportError:
|
||
logger.warning("Coze客户端未实现,使用模拟模式")
|
||
return self._coze_client
|
||
|
||
async def start_training(
|
||
self, db: AsyncSession, *, request: StartTrainingRequest, user_id: int
|
||
) -> StartTrainingResponse:
|
||
"""开始陪练会话"""
|
||
# 验证场景
|
||
scene = await self.scene_service.get(db, request.scene_id)
|
||
if not scene or scene.is_deleted or scene.status != TrainingSceneStatus.ACTIVE:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在或未激活"
|
||
)
|
||
|
||
# 检查用户等级
|
||
# TODO: 从User服务获取用户等级
|
||
user_level = 1 # 临时模拟
|
||
if scene.required_level and user_level < scene.required_level:
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户等级不足")
|
||
|
||
# 创建会话
|
||
session_data = TrainingSessionCreate(
|
||
scene_id=request.scene_id, session_config=request.config
|
||
)
|
||
|
||
session = await self.create(
|
||
db, obj_in=session_data, user_id=user_id, created_by=user_id
|
||
)
|
||
|
||
# 初始化Coze会话
|
||
coze_conversation_id = None
|
||
if self.coze_client and scene.ai_config:
|
||
try:
|
||
bot_id = scene.ai_config.get("bot_id", settings.coze_training_bot_id)
|
||
if bot_id:
|
||
# 创建Coze会话
|
||
coze_result = await self.coze_client.create_conversation(
|
||
bot_id=bot_id,
|
||
user_id=str(user_id),
|
||
meta_data={
|
||
"scene_id": scene.id,
|
||
"scene_name": scene.name,
|
||
"session_id": session.id,
|
||
},
|
||
)
|
||
coze_conversation_id = coze_result.get("conversation_id")
|
||
|
||
# 更新会话的Coze ID
|
||
session.coze_conversation_id = coze_conversation_id
|
||
await db.commit()
|
||
except Exception as e:
|
||
logger.error(f"创建Coze会话失败: {e}")
|
||
|
||
# 加载场景信息
|
||
await db.refresh(session, ["scene"])
|
||
|
||
# 构造WebSocket URL(如果需要)
|
||
websocket_url = None
|
||
if coze_conversation_id:
|
||
websocket_url = f"ws://localhost:8000/ws/v1/training/{session.id}"
|
||
|
||
return StartTrainingResponse(
|
||
session_id=session.id,
|
||
coze_conversation_id=coze_conversation_id,
|
||
scene=scene,
|
||
websocket_url=websocket_url,
|
||
)
|
||
|
||
async def end_training(
|
||
self,
|
||
db: AsyncSession,
|
||
*,
|
||
session_id: int,
|
||
request: EndTrainingRequest,
|
||
user_id: int,
|
||
) -> EndTrainingResponse:
|
||
"""结束陪练会话"""
|
||
# 获取会话
|
||
session = await self.get(db, session_id)
|
||
if not session or session.user_id != user_id:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="会话不存在")
|
||
|
||
if session.status in [
|
||
TrainingSessionStatus.COMPLETED,
|
||
TrainingSessionStatus.CANCELLED,
|
||
]:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="会话已结束")
|
||
|
||
# 计算持续时间
|
||
end_time = datetime.now()
|
||
duration_seconds = int((end_time - session.start_time).total_seconds())
|
||
|
||
# 更新会话状态
|
||
update_data = TrainingSessionUpdate(
|
||
status=TrainingSessionStatus.COMPLETED,
|
||
end_time=end_time,
|
||
duration_seconds=duration_seconds,
|
||
)
|
||
session = await self.update(db, db_obj=session, obj_in=update_data)
|
||
|
||
# 生成报告
|
||
report = None
|
||
if request.generate_report:
|
||
report = await self._generate_report(
|
||
db, session_id=session_id, user_id=user_id
|
||
)
|
||
|
||
# 加载关联数据
|
||
await db.refresh(session, ["scene"])
|
||
if report:
|
||
await db.refresh(report, ["session"])
|
||
|
||
return EndTrainingResponse(session=session, report=report)
|
||
|
||
async def get_user_sessions(
|
||
self,
|
||
db: AsyncSession,
|
||
*,
|
||
user_id: int,
|
||
scene_id: Optional[int] = None,
|
||
status: Optional[TrainingSessionStatus] = None,
|
||
skip: int = 0,
|
||
limit: int = 20,
|
||
) -> List[TrainingSession]:
|
||
"""获取用户的陪练会话列表"""
|
||
query = select(self.model).where(self.model.user_id == user_id)
|
||
|
||
if scene_id:
|
||
query = query.where(self.model.scene_id == scene_id)
|
||
|
||
if status:
|
||
query = query.where(self.model.status == status)
|
||
|
||
query = query.order_by(self.model.created_at.desc())
|
||
|
||
return await self.get_multi(db, skip=skip, limit=limit, query=query)
|
||
|
||
async def _generate_report(
|
||
self, db: AsyncSession, *, session_id: int, user_id: int
|
||
) -> Optional[TrainingReport]:
|
||
"""生成陪练报告(内部方法)"""
|
||
# 获取会话消息
|
||
messages = await self.message_service.get_session_messages(
|
||
db, session_id=session_id
|
||
)
|
||
|
||
# TODO: 调用AI分析服务生成报告
|
||
# 这里先生成模拟报告
|
||
report_data = TrainingReportCreate(
|
||
session_id=session_id,
|
||
user_id=user_id,
|
||
overall_score=85.5,
|
||
dimension_scores={"表达能力": 88.0, "逻辑思维": 85.0, "专业知识": 82.0, "应变能力": 87.0},
|
||
strengths=["表达清晰,语言流畅", "能够快速理解问题并作出回应", "展现了良好的专业素养"],
|
||
weaknesses=["部分专业术语使用不够准确", "回答有时过于冗长,需要更加精炼"],
|
||
suggestions=["加强专业知识的学习,特别是术语的准确使用", "练习更加简洁有力的表达方式", "增加实际案例的积累,丰富回答内容"],
|
||
detailed_analysis="整体表现良好,展现了扎实的基础知识和良好的沟通能力...",
|
||
statistics={
|
||
"total_messages": len(messages),
|
||
"user_messages": len(
|
||
[m for m in messages if m.role == MessageRole.USER]
|
||
),
|
||
"avg_response_time": 2.5,
|
||
"total_words": 1500,
|
||
},
|
||
)
|
||
|
||
return await self.report_service.create(
|
||
db, obj_in=report_data, created_by=user_id
|
||
)
|
||
|
||
|
||
class TrainingMessageService(BaseService[TrainingMessage]):
|
||
"""陪练消息服务"""
|
||
|
||
def __init__(self):
|
||
super().__init__(TrainingMessage)
|
||
|
||
async def create_message(
|
||
self, db: AsyncSession, *, message_in: TrainingMessageCreate
|
||
) -> TrainingMessage:
|
||
"""创建消息"""
|
||
return await self.create(db, obj_in=message_in)
|
||
|
||
async def get_session_messages(
|
||
self, db: AsyncSession, *, session_id: int, skip: int = 0, limit: int = 100
|
||
) -> List[TrainingMessage]:
|
||
"""获取会话的所有消息"""
|
||
query = (
|
||
select(self.model)
|
||
.where(self.model.session_id == session_id)
|
||
.order_by(self.model.created_at)
|
||
)
|
||
|
||
return await self.get_multi(db, skip=skip, limit=limit, query=query)
|
||
|
||
async def save_voice_message(
|
||
self,
|
||
db: AsyncSession,
|
||
*,
|
||
session_id: int,
|
||
role: MessageRole,
|
||
content: str,
|
||
voice_url: str,
|
||
voice_duration: float,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> TrainingMessage:
|
||
"""保存语音消息"""
|
||
message_data = TrainingMessageCreate(
|
||
session_id=session_id,
|
||
role=role,
|
||
type=MessageType.VOICE,
|
||
content=content,
|
||
voice_url=voice_url,
|
||
voice_duration=voice_duration,
|
||
metadata=metadata,
|
||
)
|
||
|
||
return await self.create(db, obj_in=message_data)
|
||
|
||
|
||
class TrainingReportService(BaseService[TrainingReport]):
|
||
"""陪练报告服务"""
|
||
|
||
def __init__(self):
|
||
super().__init__(TrainingReport)
|
||
|
||
async def get_by_session(
|
||
self, db: AsyncSession, *, session_id: int
|
||
) -> Optional[TrainingReport]:
|
||
"""根据会话ID获取报告"""
|
||
result = await db.execute(
|
||
select(self.model).where(self.model.session_id == session_id)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_user_reports(
|
||
self, db: AsyncSession, *, user_id: int, skip: int = 0, limit: int = 20
|
||
) -> List[TrainingReport]:
|
||
"""获取用户的所有报告"""
|
||
query = (
|
||
select(self.model)
|
||
.where(self.model.user_id == user_id)
|
||
.order_by(self.model.created_at.desc())
|
||
)
|
||
|
||
return await self.get_multi(db, skip=skip, limit=limit, query=query)
|