feat: 初始化考培练系统项目
- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
This commit is contained in:
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""业务逻辑服务包"""
|
||||
272
backend/app/services/ability_assessment_service.py
Normal file
272
backend/app/services/ability_assessment_service.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
能力评估服务
|
||||
用于分析用户对话数据,生成能力评估报告和课程推荐
|
||||
|
||||
使用 Python 原生实现
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, List, Literal
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.ability import AbilityAssessment
|
||||
from app.services.ai import ability_analysis_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbilityAssessmentService:
|
||||
"""能力评估服务类"""
|
||||
|
||||
async def analyze_yanji_conversations(
|
||||
self,
|
||||
user_id: int,
|
||||
phone: str,
|
||||
db: AsyncSession,
|
||||
yanji_service,
|
||||
engine: Literal["v2"] = "v2"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
分析言迹对话并生成能力评估及课程推荐
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
phone: 用户手机号(用于获取言迹数据)
|
||||
db: 数据库会话
|
||||
yanji_service: 言迹服务实例
|
||||
engine: 引擎类型(v2=Python原生)
|
||||
|
||||
Returns:
|
||||
评估结果字典,包含:
|
||||
- assessment_id: 评估记录ID
|
||||
- total_score: 综合评分
|
||||
- dimensions: 能力维度列表
|
||||
- recommended_courses: 推荐课程列表
|
||||
- conversation_count: 分析的对话数量
|
||||
|
||||
Raises:
|
||||
ValueError: 未找到员工的录音记录
|
||||
Exception: API调用失败或其他错误
|
||||
"""
|
||||
logger.info(f"开始分析言迹对话: user_id={user_id}, phone={phone}, engine={engine}")
|
||||
|
||||
# 1. 获取员工对话数据(最多10条录音)
|
||||
conversations = await yanji_service.get_employee_conversations_for_analysis(
|
||||
phone=phone,
|
||||
limit=10
|
||||
)
|
||||
|
||||
if not conversations:
|
||||
logger.warning(f"未找到员工的录音记录: user_id={user_id}, phone={phone}")
|
||||
raise ValueError("未找到该员工的录音记录")
|
||||
|
||||
# 2. 合并所有对话历史
|
||||
all_dialogues = []
|
||||
for conv in conversations:
|
||||
all_dialogues.extend(conv['dialogue_history'])
|
||||
|
||||
logger.info(
|
||||
f"准备分析: user_id={user_id}, "
|
||||
f"对话数={len(conversations)}, "
|
||||
f"总轮次={len(all_dialogues)}"
|
||||
)
|
||||
|
||||
used_engine = "v2"
|
||||
|
||||
# Python 原生实现
|
||||
logger.info(f"调用原生能力分析服务")
|
||||
|
||||
# 将对话历史格式化为文本
|
||||
dialogue_text = self._format_dialogues_for_analysis(all_dialogues)
|
||||
|
||||
# 调用原生服务
|
||||
result = await ability_analysis_service.analyze(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
dialogue_history=dialogue_text
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
raise Exception(f"能力分析失败: {result.error}")
|
||||
|
||||
# 转换为兼容格式
|
||||
analysis_result = {
|
||||
"analysis": {
|
||||
"total_score": result.total_score,
|
||||
"ability_dimensions": [
|
||||
{"name": d.name, "score": d.score, "feedback": d.feedback}
|
||||
for d in result.ability_dimensions
|
||||
],
|
||||
"course_recommendations": [
|
||||
{
|
||||
"course_id": c.course_id,
|
||||
"course_name": c.course_name,
|
||||
"recommendation_reason": c.recommendation_reason,
|
||||
"priority": c.priority,
|
||||
"match_score": c.match_score,
|
||||
}
|
||||
for c in result.course_recommendations
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"能力分析完成 - total_score: {result.total_score}, "
|
||||
f"provider: {result.ai_provider}, latency: {result.ai_latency_ms}ms"
|
||||
)
|
||||
|
||||
# 4. 提取结果
|
||||
analysis = analysis_result.get('analysis', {})
|
||||
ability_dims = analysis.get('ability_dimensions', [])
|
||||
course_recs = analysis.get('course_recommendations', [])
|
||||
total_score = analysis.get('total_score')
|
||||
|
||||
logger.info(
|
||||
f"分析完成 (engine={used_engine}): total_score={total_score}, "
|
||||
f"dimensions={len(ability_dims)}, courses={len(course_recs)}"
|
||||
)
|
||||
|
||||
# 5. 保存能力评估记录到数据库
|
||||
assessment = AbilityAssessment(
|
||||
user_id=user_id,
|
||||
source_type='yanji_badge',
|
||||
source_id=','.join([str(c['audio_id']) for c in conversations]),
|
||||
total_score=total_score,
|
||||
ability_dimensions=ability_dims,
|
||||
recommended_courses=course_recs,
|
||||
conversation_count=len(conversations)
|
||||
)
|
||||
|
||||
db.add(assessment)
|
||||
await db.commit()
|
||||
await db.refresh(assessment)
|
||||
|
||||
logger.info(
|
||||
f"评估记录已保存: assessment_id={assessment.id}, "
|
||||
f"user_id={user_id}, total_score={total_score}"
|
||||
)
|
||||
|
||||
# 6. 返回评估结果
|
||||
return {
|
||||
"assessment_id": assessment.id,
|
||||
"total_score": total_score,
|
||||
"dimensions": ability_dims,
|
||||
"recommended_courses": course_recs,
|
||||
"conversation_count": len(conversations),
|
||||
"analyzed_at": assessment.analyzed_at,
|
||||
"engine": used_engine,
|
||||
}
|
||||
|
||||
def _format_dialogues_for_analysis(self, dialogues: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
将对话历史列表格式化为文本
|
||||
|
||||
Args:
|
||||
dialogues: 对话历史列表,每项包含 speaker, content 等字段
|
||||
|
||||
Returns:
|
||||
格式化后的对话文本
|
||||
"""
|
||||
lines = []
|
||||
for i, d in enumerate(dialogues, 1):
|
||||
speaker = d.get('speaker', 'unknown')
|
||||
content = d.get('content', '')
|
||||
|
||||
# 统一说话者标识
|
||||
if speaker in ['consultant', 'employee', 'user', '员工']:
|
||||
speaker_label = '员工'
|
||||
elif speaker in ['customer', 'client', '顾客', '客户']:
|
||||
speaker_label = '顾客'
|
||||
else:
|
||||
speaker_label = speaker
|
||||
|
||||
lines.append(f"[{i}] {speaker_label}: {content}")
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
async def get_user_assessment_history(
|
||||
self,
|
||||
user_id: int,
|
||||
db: AsyncSession,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户的能力评估历史记录
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db: 数据库会话
|
||||
limit: 返回记录数量限制
|
||||
|
||||
Returns:
|
||||
评估历史记录列表
|
||||
"""
|
||||
stmt = (
|
||||
select(AbilityAssessment)
|
||||
.where(AbilityAssessment.user_id == user_id)
|
||||
.order_by(AbilityAssessment.analyzed_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
assessments = result.scalars().all()
|
||||
|
||||
history = []
|
||||
for assessment in assessments:
|
||||
history.append({
|
||||
"id": assessment.id,
|
||||
"source_type": assessment.source_type,
|
||||
"total_score": assessment.total_score,
|
||||
"ability_dimensions": assessment.ability_dimensions,
|
||||
"recommended_courses": assessment.recommended_courses,
|
||||
"conversation_count": assessment.conversation_count,
|
||||
"analyzed_at": assessment.analyzed_at.isoformat() if assessment.analyzed_at else None,
|
||||
"created_at": assessment.created_at.isoformat() if assessment.created_at else None
|
||||
})
|
||||
|
||||
logger.info(f"获取评估历史: user_id={user_id}, count={len(history)}")
|
||||
return history
|
||||
|
||||
async def get_assessment_detail(
|
||||
self,
|
||||
assessment_id: int,
|
||||
db: AsyncSession
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取单个评估记录的详细信息
|
||||
|
||||
Args:
|
||||
assessment_id: 评估记录ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
评估详细信息
|
||||
|
||||
Raises:
|
||||
ValueError: 评估记录不存在
|
||||
"""
|
||||
stmt = select(AbilityAssessment).where(AbilityAssessment.id == assessment_id)
|
||||
result = await db.execute(stmt)
|
||||
assessment = result.scalar_one_or_none()
|
||||
|
||||
if not assessment:
|
||||
raise ValueError(f"评估记录不存在: assessment_id={assessment_id}")
|
||||
|
||||
return {
|
||||
"id": assessment.id,
|
||||
"user_id": assessment.user_id,
|
||||
"source_type": assessment.source_type,
|
||||
"source_id": assessment.source_id,
|
||||
"total_score": assessment.total_score,
|
||||
"ability_dimensions": assessment.ability_dimensions,
|
||||
"recommended_courses": assessment.recommended_courses,
|
||||
"conversation_count": assessment.conversation_count,
|
||||
"analyzed_at": assessment.analyzed_at.isoformat() if assessment.analyzed_at else None,
|
||||
"created_at": assessment.created_at.isoformat() if assessment.created_at else None
|
||||
}
|
||||
|
||||
|
||||
def get_ability_assessment_service() -> AbilityAssessmentService:
|
||||
"""获取能力评估服务实例(依赖注入)"""
|
||||
return AbilityAssessmentService()
|
||||
151
backend/app/services/ai/__init__.py
Normal file
151
backend/app/services/ai/__init__.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
AI 服务模块
|
||||
|
||||
包含:
|
||||
- AIService: 本地 AI 服务(支持 4sapi + OpenRouter 降级)
|
||||
- LLM JSON Parser: 大模型 JSON 输出解析器
|
||||
- KnowledgeAnalysisServiceV2: 知识点分析服务(Python 原生实现)
|
||||
- ExamGeneratorService: 试题生成服务(Python 原生实现)
|
||||
- CourseChatServiceV2: 课程对话服务(Python 原生实现)
|
||||
- PracticeSceneService: 陪练场景准备服务(Python 原生实现)
|
||||
- AbilityAnalysisService: 智能工牌能力分析服务(Python 原生实现)
|
||||
- AnswerJudgeService: 答案判断服务(Python 原生实现)
|
||||
- PracticeAnalysisService: 陪练分析报告服务(Python 原生实现)
|
||||
"""
|
||||
|
||||
from .ai_service import (
|
||||
AIService,
|
||||
AIResponse,
|
||||
AIConfig,
|
||||
AIServiceError,
|
||||
AIProvider,
|
||||
DEFAULT_MODEL,
|
||||
MODEL_ANALYSIS,
|
||||
MODEL_CREATIVE,
|
||||
MODEL_IMAGE_GEN,
|
||||
quick_chat,
|
||||
)
|
||||
|
||||
from .llm_json_parser import (
|
||||
parse_llm_json,
|
||||
parse_with_fallback,
|
||||
safe_json_loads,
|
||||
clean_llm_output,
|
||||
diagnose_json_error,
|
||||
validate_json_schema,
|
||||
ParseResult,
|
||||
JSONParseError,
|
||||
JSONUnrecoverableError,
|
||||
)
|
||||
|
||||
from .knowledge_analysis_v2 import (
|
||||
KnowledgeAnalysisServiceV2,
|
||||
knowledge_analysis_service_v2,
|
||||
)
|
||||
|
||||
from .exam_generator_service import (
|
||||
ExamGeneratorService,
|
||||
ExamGeneratorConfig,
|
||||
exam_generator_service,
|
||||
generate_exam,
|
||||
)
|
||||
|
||||
from .course_chat_service import (
|
||||
CourseChatServiceV2,
|
||||
course_chat_service_v2,
|
||||
)
|
||||
|
||||
from .practice_scene_service import (
|
||||
PracticeSceneService,
|
||||
PracticeScene,
|
||||
PracticeSceneResult,
|
||||
practice_scene_service,
|
||||
prepare_practice_knowledge,
|
||||
)
|
||||
|
||||
from .ability_analysis_service import (
|
||||
AbilityAnalysisService,
|
||||
AbilityAnalysisResult,
|
||||
AbilityDimension,
|
||||
CourseRecommendation,
|
||||
ability_analysis_service,
|
||||
)
|
||||
|
||||
from .answer_judge_service import (
|
||||
AnswerJudgeService,
|
||||
JudgeResult,
|
||||
answer_judge_service,
|
||||
judge_answer,
|
||||
)
|
||||
|
||||
from .practice_analysis_service import (
|
||||
PracticeAnalysisService,
|
||||
PracticeAnalysisResult,
|
||||
ScoreBreakdownItem,
|
||||
AbilityDimensionItem,
|
||||
DialogueAnnotation,
|
||||
Suggestion,
|
||||
practice_analysis_service,
|
||||
analyze_practice_session,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# AI Service
|
||||
"AIService",
|
||||
"AIResponse",
|
||||
"AIConfig",
|
||||
"AIServiceError",
|
||||
"AIProvider",
|
||||
"DEFAULT_MODEL",
|
||||
"MODEL_ANALYSIS",
|
||||
"MODEL_CREATIVE",
|
||||
"MODEL_IMAGE_GEN",
|
||||
"quick_chat",
|
||||
# JSON Parser
|
||||
"parse_llm_json",
|
||||
"parse_with_fallback",
|
||||
"safe_json_loads",
|
||||
"clean_llm_output",
|
||||
"diagnose_json_error",
|
||||
"validate_json_schema",
|
||||
"ParseResult",
|
||||
"JSONParseError",
|
||||
"JSONUnrecoverableError",
|
||||
# Knowledge Analysis V2
|
||||
"KnowledgeAnalysisServiceV2",
|
||||
"knowledge_analysis_service_v2",
|
||||
# Exam Generator V2
|
||||
"ExamGeneratorService",
|
||||
"ExamGeneratorConfig",
|
||||
"exam_generator_service",
|
||||
"generate_exam",
|
||||
# Course Chat V2
|
||||
"CourseChatServiceV2",
|
||||
"course_chat_service_v2",
|
||||
# Practice Scene V2
|
||||
"PracticeSceneService",
|
||||
"PracticeScene",
|
||||
"PracticeSceneResult",
|
||||
"practice_scene_service",
|
||||
"prepare_practice_knowledge",
|
||||
# Ability Analysis V2
|
||||
"AbilityAnalysisService",
|
||||
"AbilityAnalysisResult",
|
||||
"AbilityDimension",
|
||||
"CourseRecommendation",
|
||||
"ability_analysis_service",
|
||||
# Answer Judge V2
|
||||
"AnswerJudgeService",
|
||||
"JudgeResult",
|
||||
"answer_judge_service",
|
||||
"judge_answer",
|
||||
# Practice Analysis V2
|
||||
"PracticeAnalysisService",
|
||||
"PracticeAnalysisResult",
|
||||
"ScoreBreakdownItem",
|
||||
"AbilityDimensionItem",
|
||||
"DialogueAnnotation",
|
||||
"Suggestion",
|
||||
"practice_analysis_service",
|
||||
"analyze_practice_session",
|
||||
]
|
||||
479
backend/app/services/ai/ability_analysis_service.py
Normal file
479
backend/app/services/ai/ability_analysis_service.py
Normal file
@@ -0,0 +1,479 @@
|
||||
"""
|
||||
智能工牌能力分析与课程推荐服务 - Python 原生实现
|
||||
|
||||
功能:
|
||||
- 分析员工与顾客的对话记录
|
||||
- 评估多维度能力得分
|
||||
- 基于能力短板推荐课程
|
||||
|
||||
提供稳定可靠的能力分析和课程推荐能力。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import ExternalServiceError
|
||||
|
||||
from .ai_service import AIService, AIResponse
|
||||
from .llm_json_parser import parse_with_fallback, clean_llm_output
|
||||
from .prompts.ability_analysis_prompts import (
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
ABILITY_ANALYSIS_SCHEMA,
|
||||
ABILITY_DIMENSIONS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ==================== 数据结构 ====================
|
||||
|
||||
@dataclass
|
||||
class AbilityDimension:
|
||||
"""能力维度评分"""
|
||||
name: str
|
||||
score: float
|
||||
feedback: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CourseRecommendation:
|
||||
"""课程推荐"""
|
||||
course_id: int
|
||||
course_name: str
|
||||
recommendation_reason: str
|
||||
priority: str # high, medium, low
|
||||
match_score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbilityAnalysisResult:
|
||||
"""能力分析结果"""
|
||||
success: bool
|
||||
total_score: float = 0.0
|
||||
ability_dimensions: List[AbilityDimension] = field(default_factory=list)
|
||||
course_recommendations: List[CourseRecommendation] = field(default_factory=list)
|
||||
ai_provider: str = ""
|
||||
ai_model: str = ""
|
||||
ai_tokens: int = 0
|
||||
ai_latency_ms: int = 0
|
||||
error: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"success": self.success,
|
||||
"total_score": self.total_score,
|
||||
"ability_dimensions": [
|
||||
{"name": d.name, "score": d.score, "feedback": d.feedback}
|
||||
for d in self.ability_dimensions
|
||||
],
|
||||
"course_recommendations": [
|
||||
{
|
||||
"course_id": c.course_id,
|
||||
"course_name": c.course_name,
|
||||
"recommendation_reason": c.recommendation_reason,
|
||||
"priority": c.priority,
|
||||
"match_score": c.match_score,
|
||||
}
|
||||
for c in self.course_recommendations
|
||||
],
|
||||
"ai_provider": self.ai_provider,
|
||||
"ai_model": self.ai_model,
|
||||
"ai_tokens": self.ai_tokens,
|
||||
"ai_latency_ms": self.ai_latency_ms,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserPositionInfo:
|
||||
"""用户岗位信息"""
|
||||
position_id: int
|
||||
position_name: str
|
||||
code: str
|
||||
description: str
|
||||
skills: Optional[Dict[str, Any]]
|
||||
level: str
|
||||
status: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CourseInfo:
|
||||
"""课程信息"""
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
category: str
|
||||
tags: Optional[List[str]]
|
||||
difficulty_level: int
|
||||
duration_hours: float
|
||||
|
||||
|
||||
# ==================== 服务类 ====================
|
||||
|
||||
class AbilityAnalysisService:
|
||||
"""
|
||||
智能工牌能力分析服务
|
||||
|
||||
使用 Python 原生实现。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
service = AbilityAnalysisService()
|
||||
result = await service.analyze(
|
||||
db=db_session,
|
||||
user_id=1,
|
||||
dialogue_history="顾客:你好,我想了解一下你们的服务..."
|
||||
)
|
||||
print(result.total_score)
|
||||
print(result.course_recommendations)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
self.ai_service = AIService(module_code="ability_analysis")
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
dialogue_history: str
|
||||
) -> AbilityAnalysisResult:
|
||||
"""
|
||||
分析员工能力并推荐课程
|
||||
|
||||
Args:
|
||||
db: 数据库会话(支持多租户,每个租户传入各自的会话)
|
||||
user_id: 用户ID
|
||||
dialogue_history: 对话记录
|
||||
|
||||
Returns:
|
||||
AbilityAnalysisResult 分析结果
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始能力分析 - user_id: {user_id}")
|
||||
|
||||
# 1. 验证输入
|
||||
if not dialogue_history or not dialogue_history.strip():
|
||||
return AbilityAnalysisResult(
|
||||
success=False,
|
||||
error="对话记录不能为空"
|
||||
)
|
||||
|
||||
# 2. 查询用户岗位信息
|
||||
user_positions = await self._get_user_positions(db, user_id)
|
||||
user_info_str = self._format_user_info(user_positions)
|
||||
|
||||
logger.info(f"用户岗位信息: {len(user_positions)} 个岗位")
|
||||
|
||||
# 3. 查询所有可选课程
|
||||
courses = await self._get_published_courses(db)
|
||||
courses_str = self._format_courses(courses)
|
||||
|
||||
logger.info(f"可选课程: {len(courses)} 门")
|
||||
|
||||
# 4. 调用 AI 分析
|
||||
ai_response = await self._call_ai_analysis(
|
||||
dialogue_history=dialogue_history,
|
||||
user_info=user_info_str,
|
||||
courses=courses_str
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"AI 分析完成 - provider: {ai_response.provider}, "
|
||||
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
|
||||
)
|
||||
|
||||
# 5. 解析 JSON 结果
|
||||
analysis_data = self._parse_analysis_result(ai_response.content, courses)
|
||||
|
||||
# 6. 构建返回结果
|
||||
result = AbilityAnalysisResult(
|
||||
success=True,
|
||||
total_score=analysis_data.get("total_score", 0),
|
||||
ability_dimensions=[
|
||||
AbilityDimension(
|
||||
name=d.get("name", ""),
|
||||
score=d.get("score", 0),
|
||||
feedback=d.get("feedback", "")
|
||||
)
|
||||
for d in analysis_data.get("ability_dimensions", [])
|
||||
],
|
||||
course_recommendations=[
|
||||
CourseRecommendation(
|
||||
course_id=c.get("course_id", 0),
|
||||
course_name=c.get("course_name", ""),
|
||||
recommendation_reason=c.get("recommendation_reason", ""),
|
||||
priority=c.get("priority", "medium"),
|
||||
match_score=c.get("match_score", 0)
|
||||
)
|
||||
for c in analysis_data.get("course_recommendations", [])
|
||||
],
|
||||
ai_provider=ai_response.provider,
|
||||
ai_model=ai_response.model,
|
||||
ai_tokens=ai_response.total_tokens,
|
||||
ai_latency_ms=ai_response.latency_ms,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"能力分析完成 - user_id: {user_id}, total_score: {result.total_score}, "
|
||||
f"recommendations: {len(result.course_recommendations)}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"能力分析失败 - user_id: {user_id}, error: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return AbilityAnalysisResult(
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def _get_user_positions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: int
|
||||
) -> List[UserPositionInfo]:
|
||||
"""
|
||||
查询用户的岗位信息
|
||||
|
||||
获取用户基本信息
|
||||
"""
|
||||
query = text("""
|
||||
SELECT
|
||||
p.id as position_id,
|
||||
p.name as position_name,
|
||||
p.code,
|
||||
p.description,
|
||||
p.skills,
|
||||
p.level,
|
||||
p.status
|
||||
FROM positions p
|
||||
INNER JOIN position_members pm ON p.id = pm.position_id
|
||||
WHERE pm.user_id = :user_id
|
||||
AND pm.is_deleted = 0
|
||||
AND p.is_deleted = 0
|
||||
""")
|
||||
|
||||
result = await db.execute(query, {"user_id": user_id})
|
||||
rows = result.fetchall()
|
||||
|
||||
positions = []
|
||||
for row in rows:
|
||||
# 解析 skills JSON
|
||||
skills = None
|
||||
if row.skills:
|
||||
if isinstance(row.skills, str):
|
||||
try:
|
||||
skills = json.loads(row.skills)
|
||||
except json.JSONDecodeError:
|
||||
skills = None
|
||||
else:
|
||||
skills = row.skills
|
||||
|
||||
positions.append(UserPositionInfo(
|
||||
position_id=row.position_id,
|
||||
position_name=row.position_name,
|
||||
code=row.code or "",
|
||||
description=row.description or "",
|
||||
skills=skills,
|
||||
level=row.level or "",
|
||||
status=row.status or ""
|
||||
))
|
||||
|
||||
return positions
|
||||
|
||||
async def _get_published_courses(self, db: AsyncSession) -> List[CourseInfo]:
|
||||
"""
|
||||
查询所有已发布的课程
|
||||
|
||||
获取所有课程列表
|
||||
"""
|
||||
query = text("""
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
description,
|
||||
category,
|
||||
tags,
|
||||
difficulty_level,
|
||||
duration_hours
|
||||
FROM courses
|
||||
WHERE status = 'published'
|
||||
AND is_deleted = FALSE
|
||||
ORDER BY sort_order
|
||||
""")
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.fetchall()
|
||||
|
||||
courses = []
|
||||
for row in rows:
|
||||
# 解析 tags JSON
|
||||
tags = None
|
||||
if row.tags:
|
||||
if isinstance(row.tags, str):
|
||||
try:
|
||||
tags = json.loads(row.tags)
|
||||
except json.JSONDecodeError:
|
||||
tags = None
|
||||
else:
|
||||
tags = row.tags
|
||||
|
||||
courses.append(CourseInfo(
|
||||
id=row.id,
|
||||
name=row.name,
|
||||
description=row.description or "",
|
||||
category=row.category or "",
|
||||
tags=tags,
|
||||
difficulty_level=row.difficulty_level or 3,
|
||||
duration_hours=row.duration_hours or 0
|
||||
))
|
||||
|
||||
return courses
|
||||
|
||||
def _format_user_info(self, positions: List[UserPositionInfo]) -> str:
|
||||
"""格式化用户岗位信息为文本"""
|
||||
if not positions:
|
||||
return "暂无岗位信息"
|
||||
|
||||
lines = []
|
||||
for p in positions:
|
||||
info = f"- 岗位:{p.position_name}({p.code})"
|
||||
if p.level:
|
||||
info += f",级别:{p.level}"
|
||||
if p.description:
|
||||
info += f"\n 描述:{p.description}"
|
||||
if p.skills:
|
||||
skills_str = json.dumps(p.skills, ensure_ascii=False)
|
||||
info += f"\n 核心技能:{skills_str}"
|
||||
lines.append(info)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _format_courses(self, courses: List[CourseInfo]) -> str:
|
||||
"""格式化课程列表为文本"""
|
||||
if not courses:
|
||||
return "暂无可选课程"
|
||||
|
||||
lines = []
|
||||
for c in courses:
|
||||
info = f"- ID: {c.id}, 课程名称: {c.name}"
|
||||
if c.category:
|
||||
info += f", 分类: {c.category}"
|
||||
if c.difficulty_level:
|
||||
info += f", 难度: {c.difficulty_level}"
|
||||
if c.duration_hours:
|
||||
info += f", 时长: {c.duration_hours}小时"
|
||||
if c.description:
|
||||
# 截断过长的描述
|
||||
desc = c.description[:100] + "..." if len(c.description) > 100 else c.description
|
||||
info += f"\n 描述: {desc}"
|
||||
lines.append(info)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _call_ai_analysis(
|
||||
self,
|
||||
dialogue_history: str,
|
||||
user_info: str,
|
||||
courses: str
|
||||
) -> AIResponse:
|
||||
"""调用 AI 进行能力分析"""
|
||||
# 构建用户消息
|
||||
user_message = USER_PROMPT.format(
|
||||
dialogue_history=dialogue_history,
|
||||
user_info=user_info,
|
||||
courses=courses
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
# 调用 AI(自动支持 4sapi → OpenRouter 降级)
|
||||
response = await self.ai_service.chat(
|
||||
messages=messages,
|
||||
temperature=0.7, # 保持一定创意性
|
||||
prompt_name="ability_analysis"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _parse_analysis_result(
|
||||
self,
|
||||
ai_output: str,
|
||||
courses: List[CourseInfo]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
解析 AI 输出的分析结果 JSON
|
||||
|
||||
使用 LLM JSON Parser 进行多层兜底解析
|
||||
"""
|
||||
# 先清洗输出
|
||||
cleaned_output, rules = clean_llm_output(ai_output)
|
||||
if rules:
|
||||
logger.debug(f"AI 输出已清洗: {rules}")
|
||||
|
||||
# 使用带 Schema 校验的解析
|
||||
parsed = parse_with_fallback(
|
||||
cleaned_output,
|
||||
schema=ABILITY_ANALYSIS_SCHEMA,
|
||||
default={"analysis": {}},
|
||||
validate_schema=True,
|
||||
on_error="default"
|
||||
)
|
||||
|
||||
# 提取 analysis 部分
|
||||
analysis = parsed.get("analysis", {})
|
||||
|
||||
# 后处理:验证课程推荐的有效性
|
||||
valid_course_ids = {c.id for c in courses}
|
||||
valid_recommendations = []
|
||||
|
||||
for rec in analysis.get("course_recommendations", []):
|
||||
course_id = rec.get("course_id")
|
||||
if course_id in valid_course_ids:
|
||||
valid_recommendations.append(rec)
|
||||
else:
|
||||
logger.warning(f"推荐的课程ID不存在: {course_id}")
|
||||
|
||||
analysis["course_recommendations"] = valid_recommendations
|
||||
|
||||
# 确保能力维度完整
|
||||
existing_dims = {d.get("name") for d in analysis.get("ability_dimensions", [])}
|
||||
for dim_name in ABILITY_DIMENSIONS:
|
||||
if dim_name not in existing_dims:
|
||||
logger.warning(f"缺少能力维度: {dim_name},使用默认值")
|
||||
analysis.setdefault("ability_dimensions", []).append({
|
||||
"name": dim_name,
|
||||
"score": 70,
|
||||
"feedback": "暂无具体评价"
|
||||
})
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
# ==================== 全局实例 ====================
|
||||
|
||||
ability_analysis_service = AbilityAnalysisService()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
747
backend/app/services/ai/ai_service.py
Normal file
747
backend/app/services/ai/ai_service.py
Normal file
@@ -0,0 +1,747 @@
|
||||
"""
|
||||
本地 AI 服务 - 遵循瑞小美 AI 接入规范
|
||||
|
||||
功能:
|
||||
- 支持 4sapi.com(首选)和 OpenRouter(备选)自动降级
|
||||
- 统一的请求/响应格式
|
||||
- 调用日志记录
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIProvider(Enum):
|
||||
"""AI 服务商"""
|
||||
PRIMARY = "4sapi" # 首选:4sapi.com
|
||||
FALLBACK = "openrouter" # 备选:OpenRouter
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIResponse:
|
||||
"""AI 响应结果"""
|
||||
content: str # AI 回复内容
|
||||
model: str = "" # 使用的模型
|
||||
provider: str = "" # 实际使用的服务商
|
||||
input_tokens: int = 0 # 输入 token 数
|
||||
output_tokens: int = 0 # 输出 token 数
|
||||
total_tokens: int = 0 # 总 token 数
|
||||
cost: float = 0.0 # 费用(美元)
|
||||
latency_ms: int = 0 # 响应延迟(毫秒)
|
||||
raw_response: Dict[str, Any] = field(default_factory=dict) # 原始响应
|
||||
images: List[str] = field(default_factory=list) # 图像生成结果
|
||||
annotations: Dict[str, Any] = field(default_factory=dict) # PDF 解析注释
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
"""AI 服务配置"""
|
||||
primary_api_key: str # 通用 Key(Gemini/DeepSeek 等)
|
||||
anthropic_api_key: str = "" # Claude 专属 Key
|
||||
primary_base_url: str = "https://4sapi.com/v1"
|
||||
fallback_api_key: str = ""
|
||||
fallback_base_url: str = "https://openrouter.ai/api/v1"
|
||||
default_model: str = "claude-opus-4-5-20251101-thinking" # 默认使用最强模型
|
||||
timeout: float = 120.0
|
||||
max_retries: int = 2
|
||||
|
||||
|
||||
# Claude 模型列表(需要使用 anthropic_api_key)
|
||||
CLAUDE_MODELS = [
|
||||
"claude-opus-4-5-20251101-thinking",
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-3-opus",
|
||||
"claude-3-sonnet",
|
||||
"claude-3-haiku",
|
||||
]
|
||||
|
||||
|
||||
def is_claude_model(model: str) -> bool:
|
||||
"""判断是否为 Claude 模型"""
|
||||
model_lower = model.lower()
|
||||
return any(claude in model_lower for claude in ["claude", "anthropic"])
|
||||
|
||||
|
||||
# 模型名称映射:4sapi -> OpenRouter
|
||||
MODEL_MAPPING = {
|
||||
# 4sapi 使用简短名称,OpenRouter 使用完整路径
|
||||
"gemini-3-flash-preview": "google/gemini-3-flash-preview",
|
||||
"gemini-3-pro-preview": "google/gemini-3-pro-preview",
|
||||
"claude-opus-4-5-20251101-thinking": "anthropic/claude-opus-4.5",
|
||||
"gemini-2.5-flash-image-preview": "google/gemini-2.0-flash-exp:free",
|
||||
}
|
||||
|
||||
# 反向映射:OpenRouter -> 4sapi
|
||||
MODEL_MAPPING_REVERSE = {v: k for k, v in MODEL_MAPPING.items()}
|
||||
|
||||
|
||||
class AIServiceError(Exception):
|
||||
"""AI 服务错误"""
|
||||
def __init__(self, message: str, provider: str = "", status_code: int = 0):
|
||||
super().__init__(message)
|
||||
self.provider = provider
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class AIService:
|
||||
"""
|
||||
本地 AI 服务
|
||||
|
||||
遵循瑞小美 AI 接入规范:
|
||||
- 首选 4sapi.com,失败自动降级到 OpenRouter
|
||||
- 统一的响应格式
|
||||
- 自动模型名称转换
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
ai = AIService(module_code="knowledge_analysis")
|
||||
response = await ai.chat(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是助手"},
|
||||
{"role": "user", "content": "你好"}
|
||||
],
|
||||
prompt_name="greeting"
|
||||
)
|
||||
print(response.content)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_code: str = "default",
|
||||
config: Optional[AIConfig] = None,
|
||||
db_session: Any = None
|
||||
):
|
||||
"""
|
||||
初始化 AI 服务
|
||||
|
||||
配置加载优先级(遵循瑞小美 AI 接入规范):
|
||||
1. 显式传入的 config 参数
|
||||
2. 数据库 ai_config 表(推荐)
|
||||
3. 环境变量(fallback)
|
||||
|
||||
Args:
|
||||
module_code: 模块标识,用于统计
|
||||
config: AI 配置,None 则从数据库/环境变量读取
|
||||
db_session: 数据库会话,用于记录调用日志和读取配置
|
||||
"""
|
||||
self.module_code = module_code
|
||||
self.db_session = db_session
|
||||
self.config = config or self._load_config(db_session)
|
||||
|
||||
logger.info(f"AIService 初始化: module={module_code}, primary={self.config.primary_base_url}")
|
||||
|
||||
def _load_config(self, db_session: Any) -> AIConfig:
|
||||
"""
|
||||
加载配置
|
||||
|
||||
配置加载优先级(遵循瑞小美 AI 接入规范):
|
||||
1. 管理库 tenant_configs 表(推荐,通过 DynamicConfig)
|
||||
2. 环境变量(fallback)
|
||||
|
||||
Args:
|
||||
db_session: 数据库会话(可选,用于日志记录)
|
||||
|
||||
Returns:
|
||||
AIConfig 配置对象
|
||||
"""
|
||||
# 优先从管理库加载(同步方式)
|
||||
try:
|
||||
config = self._load_config_from_admin_db()
|
||||
if config:
|
||||
logger.info("✅ AI 配置已从管理库(tenant_configs)加载")
|
||||
return config
|
||||
except Exception as e:
|
||||
logger.debug(f"从管理库加载 AI 配置失败: {e}")
|
||||
|
||||
# Fallback 到环境变量
|
||||
logger.info("AI 配置从环境变量加载")
|
||||
return self._load_config_from_env()
|
||||
|
||||
def _load_config_from_admin_db(self) -> Optional[AIConfig]:
|
||||
"""
|
||||
从管理库 tenant_configs 表加载配置
|
||||
|
||||
使用同步方式直接查询 kaopeilian_admin.tenant_configs 表
|
||||
|
||||
Returns:
|
||||
AIConfig 配置对象,如果无数据则返回 None
|
||||
"""
|
||||
import os
|
||||
|
||||
# 获取当前租户编码
|
||||
tenant_code = os.getenv("TENANT_CODE", "demo")
|
||||
|
||||
# 获取管理库连接信息
|
||||
admin_db_host = os.getenv("ADMIN_DB_HOST", "prod-mysql")
|
||||
admin_db_port = int(os.getenv("ADMIN_DB_PORT", "3306"))
|
||||
admin_db_user = os.getenv("ADMIN_DB_USER", "root")
|
||||
admin_db_password = os.getenv("ADMIN_DB_PASSWORD", "")
|
||||
admin_db_name = os.getenv("ADMIN_DB_NAME", "kaopeilian_admin")
|
||||
|
||||
if not admin_db_password:
|
||||
logger.debug("ADMIN_DB_PASSWORD 未配置,跳过管理库配置加载")
|
||||
return None
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine, text
|
||||
import urllib.parse
|
||||
|
||||
# 构建连接 URL
|
||||
encoded_password = urllib.parse.quote_plus(admin_db_password)
|
||||
admin_db_url = f"mysql+pymysql://{admin_db_user}:{encoded_password}@{admin_db_host}:{admin_db_port}/{admin_db_name}?charset=utf8mb4"
|
||||
|
||||
engine = create_engine(admin_db_url, pool_pre_ping=True)
|
||||
|
||||
with engine.connect() as conn:
|
||||
# 1. 获取租户 ID
|
||||
result = conn.execute(
|
||||
text("SELECT id FROM tenants WHERE code = :code AND status = 'active'"),
|
||||
{"code": tenant_code}
|
||||
)
|
||||
row = result.fetchone()
|
||||
if not row:
|
||||
logger.debug(f"租户 {tenant_code} 不存在或未激活")
|
||||
engine.dispose()
|
||||
return None
|
||||
|
||||
tenant_id = row[0]
|
||||
|
||||
# 2. 获取 AI 配置
|
||||
result = conn.execute(
|
||||
text("""
|
||||
SELECT config_key, config_value
|
||||
FROM tenant_configs
|
||||
WHERE tenant_id = :tenant_id AND config_group = 'ai'
|
||||
"""),
|
||||
{"tenant_id": tenant_id}
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
engine.dispose()
|
||||
|
||||
if not rows:
|
||||
logger.debug(f"租户 {tenant_code} 无 AI 配置")
|
||||
return None
|
||||
|
||||
# 转换为字典
|
||||
config_dict = {row[0]: row[1] for row in rows}
|
||||
|
||||
# 检查必要的配置是否存在
|
||||
primary_key = config_dict.get("AI_PRIMARY_API_KEY", "")
|
||||
if not primary_key:
|
||||
logger.warning(f"租户 {tenant_code} 的 AI_PRIMARY_API_KEY 为空")
|
||||
return None
|
||||
|
||||
logger.info(f"✅ 从管理库加载租户 {tenant_code} 的 AI 配置成功")
|
||||
|
||||
return AIConfig(
|
||||
primary_api_key=primary_key,
|
||||
anthropic_api_key=config_dict.get("AI_ANTHROPIC_API_KEY", ""),
|
||||
primary_base_url=config_dict.get("AI_PRIMARY_BASE_URL", "https://4sapi.com/v1"),
|
||||
fallback_api_key=config_dict.get("AI_FALLBACK_API_KEY", ""),
|
||||
fallback_base_url=config_dict.get("AI_FALLBACK_BASE_URL", "https://openrouter.ai/api/v1"),
|
||||
default_model=config_dict.get("AI_DEFAULT_MODEL", "claude-opus-4-5-20251101-thinking"),
|
||||
timeout=float(config_dict.get("AI_TIMEOUT", "120")),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"从管理库读取 AI 配置异常: {e}")
|
||||
return None
|
||||
|
||||
def _load_config_from_env(self) -> AIConfig:
|
||||
"""
|
||||
从环境变量加载配置
|
||||
|
||||
⚠️ 强制要求(遵循瑞小美 AI 接入规范):
|
||||
- 禁止在代码中硬编码 API Key
|
||||
- 必须通过环境变量配置 Key
|
||||
|
||||
必须配置的环境变量:
|
||||
- AI_PRIMARY_API_KEY: 通用 Key(用于 Gemini/DeepSeek 等)
|
||||
- AI_ANTHROPIC_API_KEY: Claude 专属 Key
|
||||
"""
|
||||
import os
|
||||
|
||||
primary_api_key = os.getenv("AI_PRIMARY_API_KEY", "")
|
||||
anthropic_api_key = os.getenv("AI_ANTHROPIC_API_KEY", "")
|
||||
|
||||
# 检查必要的 Key 是否已配置
|
||||
if not primary_api_key:
|
||||
logger.warning("⚠️ AI_PRIMARY_API_KEY 未配置,AI 服务可能无法正常工作")
|
||||
if not anthropic_api_key:
|
||||
logger.warning("⚠️ AI_ANTHROPIC_API_KEY 未配置,Claude 模型调用将失败")
|
||||
|
||||
return AIConfig(
|
||||
# 通用 Key(Gemini/DeepSeek 等非 Anthropic 模型)
|
||||
primary_api_key=primary_api_key,
|
||||
# Claude 专属 Key
|
||||
anthropic_api_key=anthropic_api_key,
|
||||
primary_base_url=os.getenv("AI_PRIMARY_BASE_URL", "https://4sapi.com/v1"),
|
||||
fallback_api_key=os.getenv("AI_FALLBACK_API_KEY", ""),
|
||||
fallback_base_url=os.getenv("AI_FALLBACK_BASE_URL", "https://openrouter.ai/api/v1"),
|
||||
# 默认模型:遵循"优先最强"原则,使用 Claude Opus 4.5
|
||||
default_model=os.getenv("AI_DEFAULT_MODEL", "claude-opus-4-5-20251101-thinking"),
|
||||
timeout=float(os.getenv("AI_TIMEOUT", "120")),
|
||||
)
|
||||
|
||||
def _convert_model_name(self, model: str, provider: AIProvider) -> str:
|
||||
"""
|
||||
转换模型名称以匹配服务商格式
|
||||
|
||||
Args:
|
||||
model: 原始模型名称
|
||||
provider: 目标服务商
|
||||
|
||||
Returns:
|
||||
转换后的模型名称
|
||||
"""
|
||||
if provider == AIProvider.FALLBACK:
|
||||
# 4sapi -> OpenRouter
|
||||
return MODEL_MAPPING.get(model, f"google/{model}" if "/" not in model else model)
|
||||
else:
|
||||
# OpenRouter -> 4sapi
|
||||
return MODEL_MAPPING_REVERSE.get(model, model.split("/")[-1] if "/" in model else model)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
prompt_name: str = "default",
|
||||
**kwargs
|
||||
) -> AIResponse:
|
||||
"""
|
||||
文本聊天
|
||||
|
||||
Args:
|
||||
messages: 消息列表 [{"role": "system/user/assistant", "content": "..."}]
|
||||
model: 模型名称,None 使用默认模型
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大输出 token 数
|
||||
prompt_name: 提示词名称,用于统计
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
AIResponse 响应对象
|
||||
"""
|
||||
model = model or self.config.default_model
|
||||
|
||||
# 构建请求体
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if max_tokens:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
# 首选服务商
|
||||
try:
|
||||
return await self._call_provider(
|
||||
provider=AIProvider.PRIMARY,
|
||||
endpoint="/chat/completions",
|
||||
payload=payload,
|
||||
prompt_name=prompt_name
|
||||
)
|
||||
except AIServiceError as e:
|
||||
logger.warning(f"首选服务商调用失败: {e}, 尝试降级到备选服务商")
|
||||
|
||||
# 如果没有备选 API Key,直接抛出异常
|
||||
if not self.config.fallback_api_key:
|
||||
raise
|
||||
|
||||
# 降级到备选服务商
|
||||
# 转换模型名称
|
||||
fallback_model = self._convert_model_name(model, AIProvider.FALLBACK)
|
||||
payload["model"] = fallback_model
|
||||
|
||||
return await self._call_provider(
|
||||
provider=AIProvider.FALLBACK,
|
||||
endpoint="/chat/completions",
|
||||
payload=payload,
|
||||
prompt_name=prompt_name
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
prompt_name: str = "default",
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
流式文本聊天
|
||||
|
||||
Args:
|
||||
messages: 消息列表 [{"role": "system/user/assistant", "content": "..."}]
|
||||
model: 模型名称,None 使用默认模型
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大输出 token 数
|
||||
prompt_name: 提示词名称,用于统计
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
str: 文本块(逐字返回)
|
||||
"""
|
||||
model = model or self.config.default_model
|
||||
|
||||
# 构建请求体
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
if max_tokens:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
# 首选服务商
|
||||
try:
|
||||
async for chunk in self._call_provider_stream(
|
||||
provider=AIProvider.PRIMARY,
|
||||
endpoint="/chat/completions",
|
||||
payload=payload,
|
||||
prompt_name=prompt_name
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
except AIServiceError as e:
|
||||
logger.warning(f"首选服务商流式调用失败: {e}, 尝试降级到备选服务商")
|
||||
|
||||
# 如果没有备选 API Key,直接抛出异常
|
||||
if not self.config.fallback_api_key:
|
||||
raise
|
||||
|
||||
# 降级到备选服务商
|
||||
# 转换模型名称
|
||||
fallback_model = self._convert_model_name(model, AIProvider.FALLBACK)
|
||||
payload["model"] = fallback_model
|
||||
|
||||
async for chunk in self._call_provider_stream(
|
||||
provider=AIProvider.FALLBACK,
|
||||
endpoint="/chat/completions",
|
||||
payload=payload,
|
||||
prompt_name=prompt_name
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _call_provider_stream(
|
||||
self,
|
||||
provider: AIProvider,
|
||||
endpoint: str,
|
||||
payload: Dict[str, Any],
|
||||
prompt_name: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
流式调用指定服务商
|
||||
|
||||
Args:
|
||||
provider: 服务商
|
||||
endpoint: API 端点
|
||||
payload: 请求体
|
||||
prompt_name: 提示词名称
|
||||
|
||||
Yields:
|
||||
str: 文本块
|
||||
"""
|
||||
# 获取配置
|
||||
if provider == AIProvider.PRIMARY:
|
||||
base_url = self.config.primary_base_url
|
||||
# 根据模型选择 API Key:Claude 用专属 Key,其他用通用 Key
|
||||
model = payload.get("model", "")
|
||||
if is_claude_model(model) and self.config.anthropic_api_key:
|
||||
api_key = self.config.anthropic_api_key
|
||||
logger.debug(f"[Stream] 使用 Claude 专属 Key 调用模型: {model}")
|
||||
else:
|
||||
api_key = self.config.primary_api_key
|
||||
else:
|
||||
api_key = self.config.fallback_api_key
|
||||
base_url = self.config.fallback_base_url
|
||||
|
||||
url = f"{base_url.rstrip('/')}{endpoint}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# OpenRouter 需要额外的 header
|
||||
if provider == AIProvider.FALLBACK:
|
||||
headers["HTTP-Referer"] = "https://kaopeilian.ireborn.com.cn"
|
||||
headers["X-Title"] = "KaoPeiLian"
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
timeout = httpx.Timeout(self.config.timeout, connect=10.0)
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
logger.info(f"流式调用 AI 服务: provider={provider.value}, model={payload.get('model')}")
|
||||
|
||||
async with client.stream("POST", url, json=payload, headers=headers) as response:
|
||||
# 检查响应状态
|
||||
if response.status_code != 200:
|
||||
error_text = await response.aread()
|
||||
logger.error(f"AI 服务流式返回错误: status={response.status_code}, body={error_text[:500]}")
|
||||
raise AIServiceError(
|
||||
f"API 流式请求失败: HTTP {response.status_code}",
|
||||
provider=provider.value,
|
||||
status_code=response.status_code
|
||||
)
|
||||
|
||||
# 处理 SSE 流
|
||||
async for line in response.aiter_lines():
|
||||
if not line or not line.strip():
|
||||
continue
|
||||
|
||||
# 解析 SSE 数据行
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # 移除 "data: " 前缀
|
||||
|
||||
# 检查是否是结束标记
|
||||
if data_str.strip() == "[DONE]":
|
||||
logger.info(f"流式响应完成: provider={provider.value}")
|
||||
return
|
||||
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
|
||||
# 提取 delta 内容
|
||||
choices = event_data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.debug(f"解析流式数据失败: {e} - 数据: {data_str[:100]}")
|
||||
continue
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
logger.info(f"流式调用完成: provider={provider.value}, latency={latency_ms}ms")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(f"AI 服务流式超时: provider={provider.value}, latency={latency_ms}ms")
|
||||
raise AIServiceError(f"流式请求超时({self.config.timeout}秒)", provider=provider.value)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"AI 服务流式网络错误: provider={provider.value}, error={e}")
|
||||
raise AIServiceError(f"流式网络错误: {e}", provider=provider.value)
|
||||
|
||||
async def _call_provider(
|
||||
self,
|
||||
provider: AIProvider,
|
||||
endpoint: str,
|
||||
payload: Dict[str, Any],
|
||||
prompt_name: str
|
||||
) -> AIResponse:
|
||||
"""
|
||||
调用指定服务商
|
||||
|
||||
Args:
|
||||
provider: 服务商
|
||||
endpoint: API 端点
|
||||
payload: 请求体
|
||||
prompt_name: 提示词名称
|
||||
|
||||
Returns:
|
||||
AIResponse 响应对象
|
||||
"""
|
||||
# 获取配置
|
||||
if provider == AIProvider.PRIMARY:
|
||||
base_url = self.config.primary_base_url
|
||||
# 根据模型选择 API Key:Claude 用专属 Key,其他用通用 Key
|
||||
model = payload.get("model", "")
|
||||
if is_claude_model(model) and self.config.anthropic_api_key:
|
||||
api_key = self.config.anthropic_api_key
|
||||
logger.debug(f"使用 Claude 专属 Key 调用模型: {model}")
|
||||
else:
|
||||
api_key = self.config.primary_api_key
|
||||
else:
|
||||
api_key = self.config.fallback_api_key
|
||||
base_url = self.config.fallback_base_url
|
||||
|
||||
url = f"{base_url.rstrip('/')}{endpoint}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# OpenRouter 需要额外的 header
|
||||
if provider == AIProvider.FALLBACK:
|
||||
headers["HTTP-Referer"] = "https://kaopeilian.ireborn.com.cn"
|
||||
headers["X-Title"] = "KaoPeiLian"
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
|
||||
logger.info(f"调用 AI 服务: provider={provider.value}, model={payload.get('model')}")
|
||||
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
logger.error(f"AI 服务返回错误: status={response.status_code}, body={error_text[:500]}")
|
||||
raise AIServiceError(
|
||||
f"API 请求失败: HTTP {response.status_code}",
|
||||
provider=provider.value,
|
||||
status_code=response.status_code
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 解析响应
|
||||
ai_response = self._parse_response(data, provider, latency_ms)
|
||||
|
||||
# 记录日志
|
||||
logger.info(
|
||||
f"AI 调用成功: provider={provider.value}, model={ai_response.model}, "
|
||||
f"tokens={ai_response.total_tokens}, latency={latency_ms}ms"
|
||||
)
|
||||
|
||||
# 保存到数据库(如果有 session)
|
||||
await self._log_call(prompt_name, ai_response)
|
||||
|
||||
return ai_response
|
||||
|
||||
except httpx.TimeoutException:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(f"AI 服务超时: provider={provider.value}, latency={latency_ms}ms")
|
||||
raise AIServiceError(f"请求超时({self.config.timeout}秒)", provider=provider.value)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"AI 服务网络错误: provider={provider.value}, error={e}")
|
||||
raise AIServiceError(f"网络错误: {e}", provider=provider.value)
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
provider: AIProvider,
|
||||
latency_ms: int
|
||||
) -> AIResponse:
|
||||
"""解析 API 响应"""
|
||||
# 提取内容
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
raise AIServiceError("响应中没有 choices")
|
||||
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content", "")
|
||||
|
||||
# 提取 usage
|
||||
usage = data.get("usage", {})
|
||||
input_tokens = usage.get("prompt_tokens", 0)
|
||||
output_tokens = usage.get("completion_tokens", 0)
|
||||
total_tokens = usage.get("total_tokens", input_tokens + output_tokens)
|
||||
|
||||
# 提取费用(如果有)
|
||||
cost = usage.get("total_cost", 0.0)
|
||||
|
||||
return AIResponse(
|
||||
content=content,
|
||||
model=data.get("model", ""),
|
||||
provider=provider.value,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=total_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
raw_response=data
|
||||
)
|
||||
|
||||
async def _log_call(self, prompt_name: str, response: AIResponse) -> None:
|
||||
"""记录调用日志到数据库"""
|
||||
if not self.db_session:
|
||||
return
|
||||
|
||||
try:
|
||||
# TODO: 实现调用日志记录
|
||||
# 可以参考 ai_call_logs 表结构
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"记录 AI 调用日志失败: {e}")
|
||||
|
||||
async def analyze_document(
|
||||
self,
|
||||
content: str,
|
||||
prompt: str,
|
||||
model: Optional[str] = None,
|
||||
prompt_name: str = "document_analysis"
|
||||
) -> AIResponse:
|
||||
"""
|
||||
分析文档内容
|
||||
|
||||
Args:
|
||||
content: 文档内容
|
||||
prompt: 分析提示词
|
||||
model: 模型名称
|
||||
prompt_name: 提示词名称
|
||||
|
||||
Returns:
|
||||
AIResponse 响应对象
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": f"{prompt}\n\n文档内容:\n{content}"}
|
||||
]
|
||||
|
||||
return await self.chat(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=0.1, # 文档分析使用低温度
|
||||
prompt_name=prompt_name
|
||||
)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def quick_chat(
|
||||
messages: List[Dict[str, str]],
|
||||
model: Optional[str] = None,
|
||||
module_code: str = "quick"
|
||||
) -> str:
|
||||
"""
|
||||
快速聊天,返回纯文本
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
module_code: 模块标识
|
||||
|
||||
Returns:
|
||||
AI 回复的文本内容
|
||||
"""
|
||||
ai = AIService(module_code=module_code)
|
||||
response = await ai.chat(messages, model=model)
|
||||
return response.content
|
||||
|
||||
|
||||
# 模型常量(遵循瑞小美 AI 接入规范)
|
||||
# 按优先级排序:首选 > 标准 > 快速
|
||||
MODEL_PRIMARY = "claude-opus-4-5-20251101-thinking" # 🥇 首选:所有任务首先尝试
|
||||
MODEL_STANDARD = "gemini-3-pro-preview" # 🥈 标准:Claude 失败后降级
|
||||
MODEL_FAST = "gemini-3-flash-preview" # 🥉 快速:最终保底
|
||||
MODEL_IMAGE = "gemini-2.5-flash-image-preview" # 🖼️ 图像生成专用
|
||||
MODEL_VIDEO = "veo3.1-pro" # 🎬 视频生成专用
|
||||
|
||||
# 兼容旧代码的别名
|
||||
DEFAULT_MODEL = MODEL_PRIMARY # 默认使用最强模型
|
||||
MODEL_ANALYSIS = MODEL_PRIMARY
|
||||
MODEL_CREATIVE = MODEL_STANDARD
|
||||
MODEL_IMAGE_GEN = MODEL_IMAGE
|
||||
|
||||
197
backend/app/services/ai/answer_judge_service.py
Normal file
197
backend/app/services/ai/answer_judge_service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
答案判断服务 - Python 原生实现
|
||||
|
||||
功能:
|
||||
- 判断填空题与问答题的答案是否正确
|
||||
- 通过 AI 语义理解比对用户答案与标准答案
|
||||
|
||||
提供稳定可靠的答案判断能力。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from .ai_service import AIService, AIResponse
|
||||
from .prompts.answer_judge_prompts import (
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
CORRECT_KEYWORDS,
|
||||
INCORRECT_KEYWORDS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JudgeResult:
|
||||
"""判断结果"""
|
||||
is_correct: bool
|
||||
raw_response: str
|
||||
ai_provider: str = ""
|
||||
ai_model: str = ""
|
||||
ai_tokens: int = 0
|
||||
ai_latency_ms: int = 0
|
||||
|
||||
|
||||
class AnswerJudgeService:
|
||||
"""
|
||||
答案判断服务
|
||||
|
||||
使用 Python 原生实现。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
service = AnswerJudgeService()
|
||||
result = await service.judge(
|
||||
db=db_session, # 传入 db_session 用于记录调用日志
|
||||
question="玻尿酸的主要作用是什么?",
|
||||
correct_answer="补水保湿、填充塑形",
|
||||
user_answer="保湿和塑形",
|
||||
analysis="玻尿酸具有补水保湿和填充塑形两大功能"
|
||||
)
|
||||
print(result.is_correct) # True
|
||||
```
|
||||
"""
|
||||
|
||||
MODULE_CODE = "answer_judge"
|
||||
|
||||
async def judge(
|
||||
self,
|
||||
question: str,
|
||||
correct_answer: str,
|
||||
user_answer: str,
|
||||
analysis: str = "",
|
||||
db: Any = None # 数据库会话,用于记录 AI 调用日志
|
||||
) -> JudgeResult:
|
||||
"""
|
||||
判断答案是否正确
|
||||
|
||||
Args:
|
||||
question: 题目内容
|
||||
correct_answer: 标准答案
|
||||
user_answer: 用户答案
|
||||
analysis: 答案解析(可选)
|
||||
db: 数据库会话,用于记录调用日志(符合 AI 接入规范)
|
||||
|
||||
Returns:
|
||||
JudgeResult 判断结果
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"开始判断答案 - question: {question[:50]}..., "
|
||||
f"user_answer: {user_answer[:50]}..."
|
||||
)
|
||||
|
||||
# 创建 AIService 实例(传入 db_session 用于记录调用日志)
|
||||
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
|
||||
|
||||
# 构建提示词
|
||||
user_prompt = USER_PROMPT.format(
|
||||
question=question,
|
||||
correct_answer=correct_answer,
|
||||
user_answer=user_answer,
|
||||
analysis=analysis or "无"
|
||||
)
|
||||
|
||||
# 调用 AI
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
ai_response = await ai_service.chat(
|
||||
messages=messages,
|
||||
temperature=0.1, # 低温度,确保输出稳定
|
||||
prompt_name="answer_judge"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"AI 判断完成 - provider: {ai_response.provider}, "
|
||||
f"response: {ai_response.content}, "
|
||||
f"latency: {ai_response.latency_ms}ms"
|
||||
)
|
||||
|
||||
# 解析 AI 输出
|
||||
is_correct = self._parse_judge_result(ai_response.content)
|
||||
|
||||
logger.info(f"答案判断结果: {is_correct}")
|
||||
|
||||
return JudgeResult(
|
||||
is_correct=is_correct,
|
||||
raw_response=ai_response.content,
|
||||
ai_provider=ai_response.provider,
|
||||
ai_model=ai_response.model,
|
||||
ai_tokens=ai_response.total_tokens,
|
||||
ai_latency_ms=ai_response.latency_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"答案判断失败: {e}", exc_info=True)
|
||||
# 出错时默认返回错误,保守处理
|
||||
return JudgeResult(
|
||||
is_correct=False,
|
||||
raw_response=f"判断失败: {e}",
|
||||
)
|
||||
|
||||
def _parse_judge_result(self, ai_output: str) -> bool:
|
||||
"""
|
||||
解析 AI 输出的判断结果
|
||||
|
||||
Args:
|
||||
ai_output: AI 返回的文本
|
||||
|
||||
Returns:
|
||||
bool: True 表示正确,False 表示错误
|
||||
"""
|
||||
# 清洗输出
|
||||
output = ai_output.strip().lower()
|
||||
|
||||
# 检查是否包含正确关键词
|
||||
for keyword in CORRECT_KEYWORDS:
|
||||
if keyword.lower() in output:
|
||||
return True
|
||||
|
||||
# 检查是否包含错误关键词
|
||||
for keyword in INCORRECT_KEYWORDS:
|
||||
if keyword.lower() in output:
|
||||
return False
|
||||
|
||||
# 无法识别时,默认返回错误(保守处理)
|
||||
logger.warning(f"无法解析判断结果,默认返回错误: {ai_output}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== 全局实例 ====================
|
||||
|
||||
answer_judge_service = AnswerJudgeService()
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
async def judge_answer(
|
||||
question: str,
|
||||
correct_answer: str,
|
||||
user_answer: str,
|
||||
analysis: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
便捷函数:判断答案是否正确
|
||||
|
||||
Args:
|
||||
question: 题目内容
|
||||
correct_answer: 标准答案
|
||||
user_answer: 用户答案
|
||||
analysis: 答案解析
|
||||
|
||||
Returns:
|
||||
bool: True 表示正确,False 表示错误
|
||||
"""
|
||||
result = await answer_judge_service.judge(
|
||||
question=question,
|
||||
correct_answer=correct_answer,
|
||||
user_answer=user_answer,
|
||||
analysis=analysis
|
||||
)
|
||||
return result.is_correct
|
||||
|
||||
757
backend/app/services/ai/course_chat_service.py
Normal file
757
backend/app/services/ai/course_chat_service.py
Normal file
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
课程对话服务 V2 - Python 原生实现
|
||||
|
||||
功能:
|
||||
- 查询课程知识点作为知识库
|
||||
- 调用 AI 进行对话
|
||||
- 支持流式输出
|
||||
- 多轮对话历史管理(Redis 缓存)
|
||||
|
||||
提供稳定可靠的课程对话能力。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import ExternalServiceError
|
||||
|
||||
from .ai_service import AIService
|
||||
from .prompts.course_chat_prompts import (
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
KNOWLEDGE_ITEM_TEMPLATE,
|
||||
CONVERSATION_WINDOW_SIZE,
|
||||
CONVERSATION_TTL,
|
||||
MAX_KNOWLEDGE_POINTS,
|
||||
MAX_KNOWLEDGE_BASE_LENGTH,
|
||||
DEFAULT_CHAT_MODEL,
|
||||
DEFAULT_TEMPERATURE,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 会话索引 Redis key 前缀/后缀
|
||||
CONVERSATION_INDEX_PREFIX = "course_chat:user:"
|
||||
CONVERSATION_INDEX_SUFFIX = ":conversations"
|
||||
# 会话元数据 key 前缀
|
||||
CONVERSATION_META_PREFIX = "course_chat:meta:"
|
||||
# 会话索引过期时间(与会话数据一致)
|
||||
CONVERSATION_INDEX_TTL = CONVERSATION_TTL
|
||||
|
||||
|
||||
class CourseChatServiceV2:
|
||||
"""
|
||||
课程对话服务 V2
|
||||
|
||||
使用 Python 原生实现。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
service = CourseChatServiceV2()
|
||||
|
||||
# 非流式对话
|
||||
response = await service.chat(
|
||||
db=db_session,
|
||||
course_id=1,
|
||||
query="什么是玻尿酸?",
|
||||
user_id=1,
|
||||
conversation_id=None
|
||||
)
|
||||
|
||||
# 流式对话
|
||||
async for chunk in service.chat_stream(
|
||||
db=db_session,
|
||||
course_id=1,
|
||||
query="什么是玻尿酸?",
|
||||
user_id=1,
|
||||
conversation_id=None
|
||||
):
|
||||
print(chunk, end="", flush=True)
|
||||
```
|
||||
"""
|
||||
|
||||
# Redis key 前缀
|
||||
CONVERSATION_KEY_PREFIX = "course_chat:conversation:"
|
||||
# 模块标识
|
||||
MODULE_CODE = "course_chat"
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务(AIService 在方法中动态创建,以传入 db_session)"""
|
||||
pass
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
query: str,
|
||||
user_id: int,
|
||||
conversation_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
与课程对话(非流式)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
query: 用户问题
|
||||
user_id: 用户ID
|
||||
conversation_id: 会话ID(续接对话时传入)
|
||||
|
||||
Returns:
|
||||
包含 answer、conversation_id 等字段的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"开始课程对话 V2 - course_id: {course_id}, user_id: {user_id}, "
|
||||
f"conversation_id: {conversation_id}"
|
||||
)
|
||||
|
||||
# 1. 获取课程知识点
|
||||
knowledge_base = await self._get_course_knowledge(db, course_id)
|
||||
|
||||
if not knowledge_base:
|
||||
logger.warning(f"课程 {course_id} 没有知识点,使用空知识库")
|
||||
knowledge_base = "(该课程暂无知识点内容)"
|
||||
|
||||
# 2. 获取或创建会话ID
|
||||
is_new_conversation = False
|
||||
if not conversation_id:
|
||||
conversation_id = self._generate_conversation_id(user_id, course_id)
|
||||
is_new_conversation = True
|
||||
logger.info(f"创建新会话: {conversation_id}")
|
||||
|
||||
# 3. 构建消息列表
|
||||
messages = await self._build_messages(
|
||||
knowledge_base=knowledge_base,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
# 4. 创建 AIService 并调用(传入 db_session 以记录调用日志)
|
||||
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
|
||||
response = await ai_service.chat(
|
||||
messages=messages,
|
||||
model=DEFAULT_CHAT_MODEL,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
prompt_name="course_chat"
|
||||
)
|
||||
|
||||
answer = response.content
|
||||
|
||||
# 5. 保存对话历史
|
||||
await self._save_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
user_message=query,
|
||||
assistant_message=answer
|
||||
)
|
||||
|
||||
# 6. 更新会话索引
|
||||
if is_new_conversation:
|
||||
await self._add_to_conversation_index(user_id, conversation_id, course_id)
|
||||
else:
|
||||
await self._update_conversation_index(user_id, conversation_id)
|
||||
|
||||
logger.info(
|
||||
f"课程对话完成 - course_id: {course_id}, conversation_id: {conversation_id}, "
|
||||
f"provider: {response.provider}, tokens: {response.total_tokens}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"answer": answer,
|
||||
"conversation_id": conversation_id,
|
||||
"ai_provider": response.provider,
|
||||
"ai_model": response.model,
|
||||
"ai_tokens": response.total_tokens,
|
||||
"ai_latency_ms": response.latency_ms,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"课程对话失败 - course_id: {course_id}, user_id: {user_id}, error: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise ExternalServiceError(f"课程对话失败: {e}")
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
query: str,
|
||||
user_id: int,
|
||||
conversation_id: Optional[str] = None
|
||||
) -> AsyncGenerator[Tuple[str, Optional[str]], None]:
|
||||
"""
|
||||
与课程对话(流式输出)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
query: 用户问题
|
||||
user_id: 用户ID
|
||||
conversation_id: 会话ID(续接对话时传入)
|
||||
|
||||
Yields:
|
||||
Tuple[str, Optional[str]]: (事件类型, 数据)
|
||||
- ("conversation_started", conversation_id): 会话开始
|
||||
- ("chunk", text): 文本块
|
||||
- ("end", None): 结束
|
||||
- ("error", message): 错误
|
||||
"""
|
||||
full_answer = ""
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"开始流式课程对话 V2 - course_id: {course_id}, user_id: {user_id}, "
|
||||
f"conversation_id: {conversation_id}"
|
||||
)
|
||||
|
||||
# 1. 获取课程知识点
|
||||
knowledge_base = await self._get_course_knowledge(db, course_id)
|
||||
|
||||
if not knowledge_base:
|
||||
logger.warning(f"课程 {course_id} 没有知识点,使用空知识库")
|
||||
knowledge_base = "(该课程暂无知识点内容)"
|
||||
|
||||
# 2. 获取或创建会话ID
|
||||
is_new_conversation = False
|
||||
if not conversation_id:
|
||||
conversation_id = self._generate_conversation_id(user_id, course_id)
|
||||
is_new_conversation = True
|
||||
logger.info(f"创建新会话: {conversation_id}")
|
||||
|
||||
# 3. 发送会话开始事件(如果是新会话)
|
||||
if is_new_conversation:
|
||||
yield ("conversation_started", conversation_id)
|
||||
|
||||
# 4. 构建消息列表
|
||||
messages = await self._build_messages(
|
||||
knowledge_base=knowledge_base,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
# 5. 创建 AIService 并流式调用(传入 db_session 以记录调用日志)
|
||||
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
|
||||
async for chunk in ai_service.chat_stream(
|
||||
messages=messages,
|
||||
model=DEFAULT_CHAT_MODEL,
|
||||
temperature=DEFAULT_TEMPERATURE,
|
||||
prompt_name="course_chat"
|
||||
):
|
||||
full_answer += chunk
|
||||
yield ("chunk", chunk)
|
||||
|
||||
# 6. 发送结束事件
|
||||
yield ("end", None)
|
||||
|
||||
# 7. 保存对话历史
|
||||
await self._save_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
user_message=query,
|
||||
assistant_message=full_answer
|
||||
)
|
||||
|
||||
# 8. 更新会话索引
|
||||
if is_new_conversation:
|
||||
await self._add_to_conversation_index(user_id, conversation_id, course_id)
|
||||
else:
|
||||
await self._update_conversation_index(user_id, conversation_id)
|
||||
|
||||
logger.info(
|
||||
f"流式课程对话完成 - course_id: {course_id}, conversation_id: {conversation_id}, "
|
||||
f"answer_length: {len(full_answer)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"流式课程对话失败 - course_id: {course_id}, user_id: {user_id}, error: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
yield ("error", str(e))
|
||||
|
||||
async def _get_course_knowledge(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int
|
||||
) -> str:
|
||||
"""
|
||||
获取课程知识点,构建知识库文本
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
|
||||
Returns:
|
||||
知识库文本
|
||||
"""
|
||||
try:
|
||||
# 查询知识点(课程知识点查询)
|
||||
query = text("""
|
||||
SELECT kp.name, kp.description
|
||||
FROM knowledge_points kp
|
||||
INNER JOIN course_materials cm ON kp.material_id = cm.id
|
||||
WHERE kp.course_id = :course_id
|
||||
AND kp.is_deleted = 0
|
||||
AND cm.is_deleted = 0
|
||||
ORDER BY kp.id
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
result = await db.execute(
|
||||
query,
|
||||
{"course_id": course_id, "limit": MAX_KNOWLEDGE_POINTS}
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
if not rows:
|
||||
logger.warning(f"课程 {course_id} 没有关联的知识点")
|
||||
return ""
|
||||
|
||||
# 构建知识库文本
|
||||
knowledge_items = []
|
||||
total_length = 0
|
||||
|
||||
for row in rows:
|
||||
name = row[0] or ""
|
||||
description = row[1] or ""
|
||||
|
||||
item = KNOWLEDGE_ITEM_TEMPLATE.format(
|
||||
name=name,
|
||||
description=description
|
||||
)
|
||||
|
||||
# 检查是否超过长度限制
|
||||
if total_length + len(item) > MAX_KNOWLEDGE_BASE_LENGTH:
|
||||
logger.warning(
|
||||
f"知识库文本已达到最大长度限制 {MAX_KNOWLEDGE_BASE_LENGTH},"
|
||||
f"停止添加更多知识点"
|
||||
)
|
||||
break
|
||||
|
||||
knowledge_items.append(item)
|
||||
total_length += len(item)
|
||||
|
||||
knowledge_base = "\n".join(knowledge_items)
|
||||
|
||||
logger.info(
|
||||
f"获取课程知识点成功 - course_id: {course_id}, "
|
||||
f"count: {len(knowledge_items)}, length: {len(knowledge_base)}"
|
||||
)
|
||||
|
||||
return knowledge_base
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取课程知识点失败: {e}")
|
||||
raise
|
||||
|
||||
async def _build_messages(
|
||||
self,
|
||||
knowledge_base: str,
|
||||
query: str,
|
||||
user_id: int,
|
||||
conversation_id: str
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
构建消息列表(包含历史对话)
|
||||
|
||||
Args:
|
||||
knowledge_base: 知识库文本
|
||||
query: 当前用户问题
|
||||
user_id: 用户ID
|
||||
conversation_id: 会话ID
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# 1. 系统提示词
|
||||
system_content = SYSTEM_PROMPT.format(knowledge_base=knowledge_base)
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
|
||||
# 2. 获取历史对话
|
||||
history = await self._get_conversation_history(conversation_id)
|
||||
|
||||
# 限制历史窗口大小
|
||||
if len(history) > CONVERSATION_WINDOW_SIZE * 2:
|
||||
history = history[-(CONVERSATION_WINDOW_SIZE * 2):]
|
||||
|
||||
# 添加历史消息
|
||||
messages.extend(history)
|
||||
|
||||
# 3. 当前用户问题
|
||||
user_content = USER_PROMPT.format(query=query)
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
logger.debug(
|
||||
f"构建消息列表 - total: {len(messages)}, history: {len(history)}"
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def _generate_conversation_id(self, user_id: int, course_id: int) -> str:
|
||||
"""生成会话ID"""
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
return f"conv_{user_id}_{course_id}_{unique_id}"
|
||||
|
||||
async def _get_conversation_history(
|
||||
self,
|
||||
conversation_id: str
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
从 Redis 获取会话历史
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
|
||||
Returns:
|
||||
消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||
"""
|
||||
try:
|
||||
from app.core.redis import get_redis_client
|
||||
|
||||
redis = get_redis_client()
|
||||
key = f"{self.CONVERSATION_KEY_PREFIX}{conversation_id}"
|
||||
|
||||
data = await redis.get(key)
|
||||
if not data:
|
||||
return []
|
||||
|
||||
history = json.loads(data)
|
||||
return history
|
||||
|
||||
except RuntimeError:
|
||||
# Redis 未初始化,返回空历史
|
||||
logger.warning("Redis 未初始化,无法获取会话历史")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"获取会话历史失败: {e}")
|
||||
return []
|
||||
|
||||
async def _save_conversation_history(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
assistant_message: str
|
||||
) -> None:
|
||||
"""
|
||||
保存对话历史到 Redis
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
user_message: 用户消息
|
||||
assistant_message: AI 回复
|
||||
"""
|
||||
try:
|
||||
from app.core.redis import get_redis_client
|
||||
|
||||
redis = get_redis_client()
|
||||
key = f"{self.CONVERSATION_KEY_PREFIX}{conversation_id}"
|
||||
|
||||
# 获取现有历史
|
||||
history = await self._get_conversation_history(conversation_id)
|
||||
|
||||
# 添加新消息
|
||||
history.append({"role": "user", "content": user_message})
|
||||
history.append({"role": "assistant", "content": assistant_message})
|
||||
|
||||
# 限制历史长度
|
||||
max_messages = CONVERSATION_WINDOW_SIZE * 2
|
||||
if len(history) > max_messages:
|
||||
history = history[-max_messages:]
|
||||
|
||||
# 保存到 Redis
|
||||
await redis.setex(
|
||||
key,
|
||||
CONVERSATION_TTL,
|
||||
json.dumps(history, ensure_ascii=False)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"保存会话历史成功 - conversation_id: {conversation_id}, "
|
||||
f"messages: {len(history)}"
|
||||
)
|
||||
|
||||
except RuntimeError:
|
||||
# Redis 未初始化,跳过保存
|
||||
logger.warning("Redis 未初始化,无法保存会话历史")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存会话历史失败: {e}")
|
||||
|
||||
async def get_conversation_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取会话的历史消息
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
user_id: 用户ID(用于权限验证)
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
# 验证会话ID是否属于该用户
|
||||
if not conversation_id.startswith(f"conv_{user_id}_"):
|
||||
logger.warning(
|
||||
f"用户 {user_id} 尝试访问不属于自己的会话: {conversation_id}"
|
||||
)
|
||||
return []
|
||||
|
||||
history = await self._get_conversation_history(conversation_id)
|
||||
|
||||
# 格式化返回数据
|
||||
messages = []
|
||||
for i, msg in enumerate(history):
|
||||
messages.append({
|
||||
"id": i,
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
async def _add_to_conversation_index(
|
||||
self,
|
||||
user_id: int,
|
||||
conversation_id: str,
|
||||
course_id: int
|
||||
) -> None:
|
||||
"""
|
||||
将会话添加到用户索引
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
conversation_id: 会话ID
|
||||
course_id: 课程ID
|
||||
"""
|
||||
try:
|
||||
from app.core.redis import get_redis_client
|
||||
|
||||
redis = get_redis_client()
|
||||
|
||||
# 1. 添加到用户的会话索引(Sorted Set,score 为时间戳)
|
||||
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
|
||||
timestamp = time.time()
|
||||
await redis.zadd(index_key, {conversation_id: timestamp})
|
||||
await redis.expire(index_key, CONVERSATION_INDEX_TTL)
|
||||
|
||||
# 2. 保存会话元数据
|
||||
meta_key = f"{CONVERSATION_META_PREFIX}{conversation_id}"
|
||||
meta_data = {
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": user_id,
|
||||
"course_id": course_id,
|
||||
"created_at": timestamp,
|
||||
"updated_at": timestamp,
|
||||
}
|
||||
await redis.setex(
|
||||
meta_key,
|
||||
CONVERSATION_INDEX_TTL,
|
||||
json.dumps(meta_data, ensure_ascii=False)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"会话已添加到索引 - user_id: {user_id}, conversation_id: {conversation_id}"
|
||||
)
|
||||
|
||||
except RuntimeError:
|
||||
logger.warning("Redis 未初始化,无法添加会话索引")
|
||||
except Exception as e:
|
||||
logger.warning(f"添加会话索引失败: {e}")
|
||||
|
||||
async def _update_conversation_index(
|
||||
self,
|
||||
user_id: int,
|
||||
conversation_id: str
|
||||
) -> None:
|
||||
"""
|
||||
更新会话的最后活跃时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
conversation_id: 会话ID
|
||||
"""
|
||||
try:
|
||||
from app.core.redis import get_redis_client
|
||||
|
||||
redis = get_redis_client()
|
||||
|
||||
# 更新索引中的时间戳
|
||||
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
|
||||
timestamp = time.time()
|
||||
await redis.zadd(index_key, {conversation_id: timestamp})
|
||||
await redis.expire(index_key, CONVERSATION_INDEX_TTL)
|
||||
|
||||
# 更新元数据中的 updated_at
|
||||
meta_key = f"{CONVERSATION_META_PREFIX}{conversation_id}"
|
||||
meta_data = await redis.get(meta_key)
|
||||
if meta_data:
|
||||
meta = json.loads(meta_data)
|
||||
meta["updated_at"] = timestamp
|
||||
await redis.setex(
|
||||
meta_key,
|
||||
CONVERSATION_INDEX_TTL,
|
||||
json.dumps(meta, ensure_ascii=False)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"会话索引已更新 - user_id: {user_id}, conversation_id: {conversation_id}"
|
||||
)
|
||||
|
||||
except RuntimeError:
|
||||
logger.warning("Redis 未初始化,无法更新会话索引")
|
||||
except Exception as e:
|
||||
logger.warning(f"更新会话索引失败: {e}")
|
||||
|
||||
async def list_user_conversations(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 20
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户的会话列表
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
会话列表,按更新时间倒序
|
||||
"""
|
||||
try:
|
||||
from app.core.redis import get_redis_client
|
||||
|
||||
redis = get_redis_client()
|
||||
|
||||
# 1. 从索引获取最近的会话ID列表(倒序)
|
||||
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
|
||||
conversation_ids = await redis.zrevrange(index_key, 0, limit - 1)
|
||||
|
||||
if not conversation_ids:
|
||||
logger.debug(f"用户 {user_id} 没有会话记录")
|
||||
return []
|
||||
|
||||
# 2. 获取每个会话的元数据和最后消息
|
||||
conversations = []
|
||||
for conv_id in conversation_ids:
|
||||
# 确保是字符串
|
||||
if isinstance(conv_id, bytes):
|
||||
conv_id = conv_id.decode('utf-8')
|
||||
|
||||
# 获取元数据
|
||||
meta_key = f"{CONVERSATION_META_PREFIX}{conv_id}"
|
||||
meta_data = await redis.get(meta_key)
|
||||
|
||||
if meta_data:
|
||||
if isinstance(meta_data, bytes):
|
||||
meta_data = meta_data.decode('utf-8')
|
||||
meta = json.loads(meta_data)
|
||||
else:
|
||||
# 从 conversation_id 解析 course_id
|
||||
# 格式: conv_{user_id}_{course_id}_{uuid}
|
||||
parts = conv_id.split('_')
|
||||
course_id = int(parts[2]) if len(parts) >= 3 else 0
|
||||
meta = {
|
||||
"conversation_id": conv_id,
|
||||
"user_id": user_id,
|
||||
"course_id": course_id,
|
||||
"created_at": time.time(),
|
||||
"updated_at": time.time(),
|
||||
}
|
||||
|
||||
# 获取最后一条消息作为预览
|
||||
history = await self._get_conversation_history(conv_id)
|
||||
last_message = ""
|
||||
if history:
|
||||
# 获取最后一条 assistant 消息
|
||||
for msg in reversed(history):
|
||||
if msg["role"] == "assistant":
|
||||
last_message = msg["content"][:100] # 截取前100字符
|
||||
if len(msg["content"]) > 100:
|
||||
last_message += "..."
|
||||
break
|
||||
|
||||
conversations.append({
|
||||
"id": conv_id,
|
||||
"course_id": meta.get("course_id"),
|
||||
"created_at": meta.get("created_at"),
|
||||
"updated_at": meta.get("updated_at"),
|
||||
"last_message": last_message,
|
||||
"message_count": len(history),
|
||||
})
|
||||
|
||||
logger.info(f"获取用户会话列表 - user_id: {user_id}, count: {len(conversations)}")
|
||||
return conversations
|
||||
|
||||
except RuntimeError:
|
||||
logger.warning("Redis 未初始化,无法获取会话列表")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"获取会话列表失败: {e}")
|
||||
return []
|
||||
|
||||
# 别名方法,供 API 层调用
|
||||
async def get_conversations(
|
||||
self,
|
||||
user_id: int,
|
||||
course_id: Optional[int] = None,
|
||||
limit: int = 20
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户的会话列表(别名方法)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID(可选,用于过滤)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
会话列表
|
||||
"""
|
||||
conversations = await self.list_user_conversations(user_id, limit)
|
||||
|
||||
# 如果指定了 course_id,进行过滤
|
||||
if course_id is not None:
|
||||
conversations = [
|
||||
c for c in conversations
|
||||
if c.get("course_id") == course_id
|
||||
]
|
||||
|
||||
return conversations
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_id: int,
|
||||
limit: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取会话历史消息(别名方法)
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
user_id: 用户ID(用于权限验证)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
messages = await self.get_conversation_messages(conversation_id, limit)
|
||||
return messages
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
course_chat_service_v2 = CourseChatServiceV2()
|
||||
|
||||
61
backend/app/services/ai/coze/__init__.py
Normal file
61
backend/app/services/ai/coze/__init__.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Coze AI 服务模块
|
||||
"""
|
||||
|
||||
from .client import get_coze_client, get_auth_manager, get_bot_config, get_workspace_id
|
||||
from .service import get_coze_service, CozeService
|
||||
from .models import (
|
||||
SessionType,
|
||||
MessageRole,
|
||||
ContentType,
|
||||
StreamEventType,
|
||||
CozeSession,
|
||||
CozeMessage,
|
||||
StreamEvent,
|
||||
CreateSessionRequest,
|
||||
CreateSessionResponse,
|
||||
SendMessageRequest,
|
||||
EndSessionRequest,
|
||||
EndSessionResponse,
|
||||
)
|
||||
from .exceptions import (
|
||||
CozeException,
|
||||
CozeAuthError,
|
||||
CozeAPIError,
|
||||
CozeRateLimitError,
|
||||
CozeTimeoutError,
|
||||
CozeStreamError,
|
||||
map_coze_error_to_exception,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Client
|
||||
"get_coze_client",
|
||||
"get_auth_manager",
|
||||
"get_bot_config",
|
||||
"get_workspace_id",
|
||||
# Service
|
||||
"get_coze_service",
|
||||
"CozeService",
|
||||
# Models
|
||||
"SessionType",
|
||||
"MessageRole",
|
||||
"ContentType",
|
||||
"StreamEventType",
|
||||
"CozeSession",
|
||||
"CozeMessage",
|
||||
"StreamEvent",
|
||||
"CreateSessionRequest",
|
||||
"CreateSessionResponse",
|
||||
"SendMessageRequest",
|
||||
"EndSessionRequest",
|
||||
"EndSessionResponse",
|
||||
# Exceptions
|
||||
"CozeException",
|
||||
"CozeAuthError",
|
||||
"CozeAPIError",
|
||||
"CozeRateLimitError",
|
||||
"CozeTimeoutError",
|
||||
"CozeStreamError",
|
||||
"map_coze_error_to_exception",
|
||||
]
|
||||
203
backend/app/services/ai/coze/client.py
Normal file
203
backend/app/services/ai/coze/client.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Coze AI 客户端管理
|
||||
负责管理 Coze API 的认证和客户端实例
|
||||
"""
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from cozepy import Coze, TokenAuth, JWTAuth, COZE_CN_BASE_URL
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CozeAuthManager:
|
||||
"""Coze 认证管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = get_settings()
|
||||
self._client: Optional[Coze] = None
|
||||
|
||||
def _create_pat_auth(self) -> TokenAuth:
|
||||
"""创建个人访问令牌认证"""
|
||||
if not self.settings.COZE_API_TOKEN:
|
||||
raise ValueError("COZE_API_TOKEN 未配置")
|
||||
|
||||
return TokenAuth(token=self.settings.COZE_API_TOKEN)
|
||||
|
||||
def _create_oauth_auth(self) -> JWTAuth:
|
||||
"""创建 OAuth 认证"""
|
||||
if not all(
|
||||
[
|
||||
self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
|
||||
]
|
||||
):
|
||||
raise ValueError("OAuth 配置不完整")
|
||||
|
||||
# 读取私钥
|
||||
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
|
||||
if not private_key_path.exists():
|
||||
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
|
||||
|
||||
with open(private_key_path, "r") as f:
|
||||
private_key = f.read()
|
||||
|
||||
try:
|
||||
return JWTAuth(
|
||||
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
private_key=private_key,
|
||||
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建 OAuth 认证失败: {e}")
|
||||
raise
|
||||
|
||||
def get_client(self, force_new: bool = False) -> Coze:
|
||||
"""
|
||||
获取 Coze 客户端实例
|
||||
|
||||
Args:
|
||||
force_new: 是否强制创建新客户端(用于长时间运行的请求,避免token过期)
|
||||
|
||||
认证优先级:
|
||||
1. OAuth(推荐):配置完整时使用,自动刷新token
|
||||
2. PAT:仅当OAuth未配置时使用(注意:PAT会过期)
|
||||
"""
|
||||
if self._client is not None and not force_new:
|
||||
return self._client
|
||||
|
||||
auth = None
|
||||
auth_type = None
|
||||
|
||||
# 检查 OAuth 配置是否完整
|
||||
oauth_configured = all([
|
||||
self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
|
||||
])
|
||||
|
||||
if oauth_configured:
|
||||
# OAuth 配置完整,必须使用 OAuth(不fallback到PAT)
|
||||
try:
|
||||
auth = self._create_oauth_auth()
|
||||
auth_type = "OAuth"
|
||||
logger.info("使用 OAuth 认证")
|
||||
except Exception as e:
|
||||
# OAuth 配置完整但创建失败,直接抛出异常(不fallback到可能过期的PAT)
|
||||
logger.error(f"OAuth 认证创建失败: {e}")
|
||||
raise ValueError(f"OAuth 认证失败,请检查私钥文件和配置: {e}")
|
||||
else:
|
||||
# OAuth 未配置,使用 PAT
|
||||
if self.settings.COZE_API_TOKEN:
|
||||
auth = self._create_pat_auth()
|
||||
auth_type = "PAT"
|
||||
logger.warning("使用 PAT 认证(注意:PAT会过期,建议配置OAuth)")
|
||||
else:
|
||||
raise ValueError("Coze 认证未配置:需要配置 OAuth 或 PAT Token")
|
||||
|
||||
# 创建客户端
|
||||
client = Coze(
|
||||
auth=auth, base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL
|
||||
)
|
||||
|
||||
logger.debug(f"Coze客户端创建成功,认证方式: {auth_type}, force_new: {force_new}")
|
||||
|
||||
# 只有非强制创建时才缓存
|
||||
if not force_new:
|
||||
self._client = client
|
||||
|
||||
return client
|
||||
|
||||
def reset(self):
|
||||
"""重置客户端实例"""
|
||||
self._client = None
|
||||
|
||||
def get_oauth_token(self) -> str:
|
||||
"""
|
||||
获取OAuth JWT Token用于前端直连
|
||||
|
||||
Returns:
|
||||
JWT token字符串
|
||||
"""
|
||||
if not all([
|
||||
self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
|
||||
]):
|
||||
raise ValueError("OAuth 配置不完整")
|
||||
|
||||
# 读取私钥
|
||||
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
|
||||
if not private_key_path.exists():
|
||||
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
|
||||
|
||||
with open(private_key_path, "r") as f:
|
||||
private_key = f.read()
|
||||
|
||||
# 创建JWTAuth实例(必须指定中国区base_url)
|
||||
jwt_auth = JWTAuth(
|
||||
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
|
||||
private_key=private_key,
|
||||
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
|
||||
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
|
||||
)
|
||||
|
||||
# 获取token(JWTAuth内部会自动生成)
|
||||
# JWTAuth.token属性返回已签名的JWT
|
||||
return jwt_auth.token
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_auth_manager() -> CozeAuthManager:
|
||||
"""获取认证管理器单例"""
|
||||
return CozeAuthManager()
|
||||
|
||||
|
||||
def get_coze_client(force_new: bool = False) -> Coze:
|
||||
"""
|
||||
获取 Coze 客户端
|
||||
|
||||
Args:
|
||||
force_new: 是否强制创建新客户端(用于工作流等长时间运行的请求)
|
||||
"""
|
||||
return get_auth_manager().get_client(force_new=force_new)
|
||||
|
||||
|
||||
def get_workspace_id() -> str:
|
||||
"""获取工作空间 ID"""
|
||||
settings = get_settings()
|
||||
if not settings.COZE_WORKSPACE_ID:
|
||||
raise ValueError("COZE_WORKSPACE_ID 未配置")
|
||||
return settings.COZE_WORKSPACE_ID
|
||||
|
||||
|
||||
def get_bot_config(session_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
根据会话类型获取 Bot 配置
|
||||
|
||||
Args:
|
||||
session_type: 会话类型 (course_chat 或 training)
|
||||
|
||||
Returns:
|
||||
包含 bot_id 等配置的字典
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
if session_type == "course_chat":
|
||||
bot_id = settings.COZE_CHAT_BOT_ID
|
||||
if not bot_id:
|
||||
raise ValueError("COZE_CHAT_BOT_ID 未配置")
|
||||
elif session_type == "training":
|
||||
bot_id = settings.COZE_TRAINING_BOT_ID
|
||||
if not bot_id:
|
||||
raise ValueError("COZE_TRAINING_BOT_ID 未配置")
|
||||
else:
|
||||
raise ValueError(f"不支持的会话类型: {session_type}")
|
||||
|
||||
return {"bot_id": bot_id, "workspace_id": settings.COZE_WORKSPACE_ID}
|
||||
44
backend/app/services/ai/coze/client_backup.py
Normal file
44
backend/app/services/ai/coze/client_backup.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Coze客户端(临时模拟,等Agent-Coze实现后替换)"""
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CozeClient:
|
||||
"""
|
||||
Coze客户端模拟类
|
||||
TODO: 等Agent-Coze模块实现后,这个类将被真实的Coze网关客户端替换
|
||||
"""
|
||||
|
||||
async def create_conversation(
|
||||
self, bot_id: str, user_id: str, meta_data: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""创建会话(模拟)"""
|
||||
logger.info(f"模拟创建Coze会话: bot_id={bot_id}, user_id={user_id}")
|
||||
|
||||
# 返回模拟的会话信息
|
||||
return {
|
||||
"conversation_id": f"mock_conversation_{user_id}_{bot_id[:8]}",
|
||||
"bot_id": bot_id,
|
||||
"status": "active",
|
||||
}
|
||||
|
||||
async def send_message(
|
||||
self, conversation_id: str, content: str, message_type: str = "text"
|
||||
) -> Dict[str, Any]:
|
||||
"""发送消息(模拟)"""
|
||||
logger.info(f"模拟发送消息到会话 {conversation_id}: {content[:50]}...")
|
||||
|
||||
# 返回模拟的消息响应
|
||||
return {
|
||||
"message_id": f"mock_msg_{conversation_id[:8]}",
|
||||
"content": f"这是对'{content[:30]}...'的模拟回复",
|
||||
"role": "assistant",
|
||||
}
|
||||
|
||||
async def end_conversation(self, conversation_id: str) -> Dict[str, Any]:
|
||||
"""结束会话(模拟)"""
|
||||
logger.info(f"模拟结束会话: {conversation_id}")
|
||||
|
||||
return {"status": "completed", "conversation_id": conversation_id}
|
||||
101
backend/app/services/ai/coze/exceptions.py
Normal file
101
backend/app/services/ai/coze/exceptions.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Coze 服务异常定义
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class CozeException(Exception):
|
||||
"""Coze 服务基础异常"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: Optional[str] = None,
|
||||
status_code: Optional[int] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class CozeAuthError(CozeException):
|
||||
"""认证异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CozeAPIError(CozeException):
|
||||
"""API 调用异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CozeRateLimitError(CozeException):
|
||||
"""速率限制异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CozeTimeoutError(CozeException):
|
||||
"""超时异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CozeStreamError(CozeException):
|
||||
"""流式响应异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def map_coze_error_to_exception(error: Exception) -> CozeException:
|
||||
"""
|
||||
将 Coze SDK 错误映射为统一异常
|
||||
|
||||
Args:
|
||||
error: 原始异常
|
||||
|
||||
Returns:
|
||||
CozeException: 映射后的异常
|
||||
"""
|
||||
error_message = str(error)
|
||||
|
||||
# 根据错误消息判断错误类型
|
||||
if (
|
||||
"authentication" in error_message.lower()
|
||||
or "unauthorized" in error_message.lower()
|
||||
):
|
||||
return CozeAuthError(
|
||||
message="Coze 认证失败",
|
||||
code="COZE_AUTH_ERROR",
|
||||
status_code=401,
|
||||
details={"original_error": error_message},
|
||||
)
|
||||
|
||||
if "rate limit" in error_message.lower():
|
||||
return CozeRateLimitError(
|
||||
message="Coze API 速率限制",
|
||||
code="COZE_RATE_LIMIT",
|
||||
status_code=429,
|
||||
details={"original_error": error_message},
|
||||
)
|
||||
|
||||
if "timeout" in error_message.lower():
|
||||
return CozeTimeoutError(
|
||||
message="Coze API 调用超时",
|
||||
code="COZE_TIMEOUT",
|
||||
status_code=504,
|
||||
details={"original_error": error_message},
|
||||
)
|
||||
|
||||
# 默认映射为 API 错误
|
||||
return CozeAPIError(
|
||||
message="Coze API 调用失败",
|
||||
code="COZE_API_ERROR",
|
||||
status_code=500,
|
||||
details={"original_error": error_message},
|
||||
)
|
||||
136
backend/app/services/ai/coze/models.py
Normal file
136
backend/app/services/ai/coze/models.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Coze 服务数据模型
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
"""会话类型"""
|
||||
|
||||
COURSE_CHAT = "course_chat" # 课程对话
|
||||
TRAINING = "training" # 陪练会话
|
||||
EXAM = "exam" # 考试会话
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""消息角色"""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class ContentType(str, Enum):
|
||||
"""内容类型"""
|
||||
|
||||
TEXT = "text"
|
||||
CARD = "card"
|
||||
IMAGE = "image"
|
||||
FILE = "file"
|
||||
|
||||
|
||||
class StreamEventType(str, Enum):
|
||||
"""流式事件类型"""
|
||||
|
||||
MESSAGE_START = "conversation.message.start"
|
||||
MESSAGE_DELTA = "conversation.message.delta"
|
||||
MESSAGE_COMPLETED = "conversation.message.completed"
|
||||
ERROR = "error"
|
||||
DONE = "done"
|
||||
|
||||
|
||||
class CozeSession(BaseModel):
|
||||
"""Coze 会话模型"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
conversation_id: str = Field(..., description="Coze对话ID")
|
||||
session_type: SessionType = Field(..., description="会话类型")
|
||||
user_id: str = Field(..., description="用户ID")
|
||||
bot_id: str = Field(..., description="Bot ID")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
|
||||
ended_at: Optional[datetime] = Field(None, description="结束时间")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class CozeMessage(BaseModel):
|
||||
"""Coze 消息模型"""
|
||||
|
||||
message_id: str = Field(..., description="消息ID")
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
role: MessageRole = Field(..., description="消息角色")
|
||||
content: str = Field(..., description="消息内容")
|
||||
content_type: ContentType = Field(ContentType.TEXT, description="内容类型")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class StreamEvent(BaseModel):
|
||||
"""流式事件模型"""
|
||||
|
||||
event: StreamEventType = Field(..., description="事件类型")
|
||||
data: Dict[str, Any] = Field(..., description="事件数据")
|
||||
message_id: Optional[str] = Field(None, description="消息ID")
|
||||
content: Optional[str] = Field(None, description="内容")
|
||||
content_type: Optional[ContentType] = Field(None, description="内容类型")
|
||||
role: Optional[MessageRole] = Field(None, description="角色")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""创建会话请求"""
|
||||
|
||||
session_type: SessionType = Field(..., description="会话类型")
|
||||
user_id: str = Field(..., description="用户ID")
|
||||
course_id: Optional[str] = Field(None, description="课程ID (课程对话时必需)")
|
||||
training_topic: Optional[str] = Field(None, description="陪练主题 (陪练时可选)")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""创建会话响应"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
conversation_id: str = Field(..., description="Coze对话ID")
|
||||
bot_id: str = Field(..., description="Bot ID")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
"""发送消息请求"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
content: str = Field(..., description="消息内容")
|
||||
file_ids: List[str] = Field(default_factory=list, description="附件ID列表")
|
||||
stream: bool = Field(True, description="是否流式响应")
|
||||
|
||||
|
||||
class EndSessionRequest(BaseModel):
|
||||
"""结束会话请求"""
|
||||
|
||||
reason: Optional[str] = Field(None, description="结束原因")
|
||||
feedback: Optional[Dict[str, Any]] = Field(None, description="用户反馈")
|
||||
|
||||
|
||||
class EndSessionResponse(BaseModel):
|
||||
"""结束会话响应"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
ended_at: datetime = Field(..., description="结束时间")
|
||||
duration_seconds: int = Field(..., description="会话时长(秒)")
|
||||
message_count: int = Field(..., description="消息数量")
|
||||
|
||||
class Config:
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
335
backend/app/services/ai/coze/service.py
Normal file
335
backend/app/services/ai/coze/service.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Coze 服务层实现
|
||||
处理会话管理、消息发送、流式响应等核心功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import AsyncIterator, Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from cozepy import ChatEventType, Message, MessageContentType
|
||||
|
||||
from .client import get_coze_client, get_bot_config, get_workspace_id
|
||||
from .models import (
|
||||
CozeSession,
|
||||
CozeMessage,
|
||||
StreamEvent,
|
||||
SessionType,
|
||||
MessageRole,
|
||||
ContentType,
|
||||
StreamEventType,
|
||||
CreateSessionRequest,
|
||||
CreateSessionResponse,
|
||||
SendMessageRequest,
|
||||
EndSessionRequest,
|
||||
EndSessionResponse,
|
||||
)
|
||||
from .exceptions import (
|
||||
CozeAPIError,
|
||||
CozeStreamError,
|
||||
CozeTimeoutError,
|
||||
map_coze_error_to_exception,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CozeService:
|
||||
"""Coze 服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = get_coze_client()
|
||||
self.bot_config = get_bot_config()
|
||||
self.workspace_id = get_workspace_id()
|
||||
|
||||
# 内存中的会话存储(生产环境应使用 Redis)
|
||||
self._sessions: Dict[str, CozeSession] = {}
|
||||
self._messages: Dict[str, List[CozeMessage]] = {}
|
||||
|
||||
async def create_session(
|
||||
self, request: CreateSessionRequest
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
创建新会话
|
||||
|
||||
Args:
|
||||
request: 创建会话请求
|
||||
|
||||
Returns:
|
||||
CreateSessionResponse: 会话信息
|
||||
"""
|
||||
try:
|
||||
# 根据会话类型选择 Bot
|
||||
bot_id = self._get_bot_id_by_type(request.session_type)
|
||||
|
||||
# 创建 Coze 对话
|
||||
conversation = await asyncio.to_thread(
|
||||
self.client.conversations.create, bot_id=bot_id
|
||||
)
|
||||
|
||||
# 创建本地会话记录
|
||||
session = CozeSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
conversation_id=conversation.id,
|
||||
session_type=request.session_type,
|
||||
user_id=request.user_id,
|
||||
bot_id=bot_id,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
# 保存会话
|
||||
self._sessions[session.session_id] = session
|
||||
self._messages[session.session_id] = []
|
||||
|
||||
logger.info(
|
||||
f"创建会话成功",
|
||||
extra={
|
||||
"session_id": session.session_id,
|
||||
"conversation_id": conversation.id,
|
||||
"session_type": request.session_type.value,
|
||||
"user_id": request.user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return CreateSessionResponse(
|
||||
session_id=session.session_id,
|
||||
conversation_id=session.conversation_id,
|
||||
bot_id=session.bot_id,
|
||||
created_at=session.created_at,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建会话失败: {e}", exc_info=True)
|
||||
raise map_coze_error_to_exception(e)
|
||||
|
||||
async def send_message(
|
||||
self, request: SendMessageRequest
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""
|
||||
发送消息并处理流式响应
|
||||
|
||||
Args:
|
||||
request: 发送消息请求
|
||||
|
||||
Yields:
|
||||
StreamEvent: 流式事件
|
||||
"""
|
||||
session = self._get_session(request.session_id)
|
||||
if not session:
|
||||
raise CozeAPIError(f"会话不存在: {request.session_id}")
|
||||
|
||||
# 记录用户消息
|
||||
user_message = CozeMessage(
|
||||
message_id=str(uuid.uuid4()),
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content=request.content,
|
||||
)
|
||||
self._messages[session.session_id].append(user_message)
|
||||
|
||||
try:
|
||||
# 构建消息历史
|
||||
messages = self._build_message_history(session.session_id)
|
||||
|
||||
# 调用 Coze API
|
||||
stream = await asyncio.to_thread(
|
||||
self.client.chat.stream,
|
||||
bot_id=session.bot_id,
|
||||
conversation_id=session.conversation_id,
|
||||
additional_messages=messages,
|
||||
auto_save_history=True,
|
||||
)
|
||||
|
||||
# 处理流式响应
|
||||
async for event in self._process_stream(stream, session.session_id):
|
||||
yield event
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"消息发送超时: session_id={request.session_id}")
|
||||
raise CozeTimeoutError("消息处理超时")
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}", exc_info=True)
|
||||
raise map_coze_error_to_exception(e)
|
||||
|
||||
async def end_session(
|
||||
self, session_id: str, request: EndSessionRequest
|
||||
) -> EndSessionResponse:
|
||||
"""
|
||||
结束会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
request: 结束会话请求
|
||||
|
||||
Returns:
|
||||
EndSessionResponse: 结束会话响应
|
||||
"""
|
||||
session = self._get_session(session_id)
|
||||
if not session:
|
||||
raise CozeAPIError(f"会话不存在: {session_id}")
|
||||
|
||||
# 更新会话状态
|
||||
session.ended_at = datetime.now()
|
||||
|
||||
# 计算会话统计
|
||||
duration_seconds = int((session.ended_at - session.created_at).total_seconds())
|
||||
message_count = len(self._messages.get(session_id, []))
|
||||
|
||||
# 记录结束原因和反馈
|
||||
if request.reason:
|
||||
session.metadata["end_reason"] = request.reason
|
||||
if request.feedback:
|
||||
session.metadata["feedback"] = request.feedback
|
||||
|
||||
logger.info(
|
||||
f"会话结束",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"duration_seconds": duration_seconds,
|
||||
"message_count": message_count,
|
||||
"reason": request.reason,
|
||||
},
|
||||
)
|
||||
|
||||
return EndSessionResponse(
|
||||
session_id=session_id,
|
||||
ended_at=session.ended_at,
|
||||
duration_seconds=duration_seconds,
|
||||
message_count=message_count,
|
||||
)
|
||||
|
||||
async def get_session_messages(
|
||||
self, session_id: str, limit: int = 50, offset: int = 0
|
||||
) -> List[CozeMessage]:
|
||||
"""获取会话消息历史"""
|
||||
messages = self._messages.get(session_id, [])
|
||||
return messages[offset : offset + limit]
|
||||
|
||||
def _get_bot_id_by_type(self, session_type: SessionType) -> str:
|
||||
"""根据会话类型获取 Bot ID"""
|
||||
mapping = {
|
||||
SessionType.COURSE_CHAT: self.bot_config["course_chat"],
|
||||
SessionType.TRAINING: self.bot_config["training"],
|
||||
SessionType.EXAM: self.bot_config["exam"],
|
||||
}
|
||||
return mapping.get(session_type, self.bot_config["training"])
|
||||
|
||||
def _get_session(self, session_id: str) -> Optional[CozeSession]:
|
||||
"""获取会话"""
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def _build_message_history(self, session_id: str) -> List[Message]:
|
||||
"""构建消息历史"""
|
||||
messages = self._messages.get(session_id, [])
|
||||
history = []
|
||||
|
||||
for msg in messages[-10:]: # 只发送最近10条消息作为上下文
|
||||
history.append(
|
||||
Message(
|
||||
role=msg.role.value,
|
||||
content=msg.content,
|
||||
content_type=MessageContentType.TEXT,
|
||||
)
|
||||
)
|
||||
|
||||
return history
|
||||
|
||||
async def _process_stream(
|
||||
self, stream, session_id: str
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""处理流式响应"""
|
||||
assistant_message_id = str(uuid.uuid4())
|
||||
accumulated_content = []
|
||||
content_type = ContentType.TEXT
|
||||
|
||||
try:
|
||||
for event in stream:
|
||||
if event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA:
|
||||
# 消息片段
|
||||
content = event.message.content
|
||||
accumulated_content.append(content)
|
||||
|
||||
# 检测卡片类型
|
||||
if (
|
||||
hasattr(event.message, "content_type")
|
||||
and event.message.content_type == "card"
|
||||
):
|
||||
content_type = ContentType.CARD
|
||||
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.MESSAGE_DELTA,
|
||||
data={
|
||||
"conversation_id": event.conversation_id,
|
||||
"message_id": assistant_message_id,
|
||||
"content": content,
|
||||
"content_type": content_type.value,
|
||||
},
|
||||
message_id=assistant_message_id,
|
||||
content=content,
|
||||
content_type=content_type,
|
||||
role=MessageRole.ASSISTANT,
|
||||
)
|
||||
|
||||
elif event.event == ChatEventType.CONVERSATION_MESSAGE_COMPLETED:
|
||||
# 消息完成
|
||||
full_content = "".join(accumulated_content)
|
||||
|
||||
# 保存助手消息
|
||||
assistant_message = CozeMessage(
|
||||
message_id=assistant_message_id,
|
||||
session_id=session_id,
|
||||
role=MessageRole.ASSISTANT,
|
||||
content=full_content,
|
||||
content_type=content_type,
|
||||
)
|
||||
self._messages[session_id].append(assistant_message)
|
||||
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.MESSAGE_COMPLETED,
|
||||
data={
|
||||
"conversation_id": event.conversation_id,
|
||||
"message_id": assistant_message_id,
|
||||
"content": full_content,
|
||||
"content_type": content_type.value,
|
||||
"usage": getattr(event, "usage", {}),
|
||||
},
|
||||
message_id=assistant_message_id,
|
||||
content=full_content,
|
||||
content_type=content_type,
|
||||
role=MessageRole.ASSISTANT,
|
||||
)
|
||||
|
||||
elif event.event == ChatEventType.ERROR:
|
||||
# 错误事件
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.ERROR,
|
||||
data={"error": str(event)},
|
||||
error=str(event),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式处理错误: {e}", exc_info=True)
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.ERROR, data={"error": str(e)}, error=str(e)
|
||||
)
|
||||
finally:
|
||||
# 发送结束事件
|
||||
yield StreamEvent(
|
||||
event=StreamEventType.DONE, data={"session_id": session_id}
|
||||
)
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
_service: Optional[CozeService] = None
|
||||
|
||||
|
||||
def get_coze_service() -> CozeService:
|
||||
"""获取 Coze 服务单例"""
|
||||
global _service
|
||||
if _service is None:
|
||||
_service = CozeService()
|
||||
return _service
|
||||
512
backend/app/services/ai/exam_generator_service.py
Normal file
512
backend/app/services/ai/exam_generator_service.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""
|
||||
试题生成服务 V2 - Python 原生实现
|
||||
|
||||
功能:
|
||||
- 根据岗位和知识点动态生成考试题目
|
||||
- 支持错题重出模式
|
||||
- 调用 AI 生成并解析 JSON 结果
|
||||
|
||||
提供稳定可靠的试题生成能力。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import ExternalServiceError
|
||||
|
||||
from .ai_service import AIService, AIResponse
|
||||
from .llm_json_parser import parse_with_fallback, clean_llm_output
|
||||
from .prompts.exam_generator_prompts import (
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
MISTAKE_REGEN_SYSTEM_PROMPT,
|
||||
MISTAKE_REGEN_USER_PROMPT,
|
||||
QUESTION_SCHEMA,
|
||||
DEFAULT_QUESTION_COUNTS,
|
||||
DEFAULT_DIFFICULTY_LEVEL,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExamGeneratorConfig:
|
||||
"""考试生成配置"""
|
||||
course_id: int
|
||||
position_id: int
|
||||
single_choice_count: int = DEFAULT_QUESTION_COUNTS["single_choice_count"]
|
||||
multiple_choice_count: int = DEFAULT_QUESTION_COUNTS["multiple_choice_count"]
|
||||
true_false_count: int = DEFAULT_QUESTION_COUNTS["true_false_count"]
|
||||
fill_blank_count: int = DEFAULT_QUESTION_COUNTS["fill_blank_count"]
|
||||
essay_count: int = DEFAULT_QUESTION_COUNTS["essay_count"]
|
||||
difficulty_level: int = DEFAULT_DIFFICULTY_LEVEL
|
||||
mistake_records: str = ""
|
||||
|
||||
@property
|
||||
def total_count(self) -> int:
|
||||
"""计算总题量"""
|
||||
return (
|
||||
self.single_choice_count +
|
||||
self.multiple_choice_count +
|
||||
self.true_false_count +
|
||||
self.fill_blank_count +
|
||||
self.essay_count
|
||||
)
|
||||
|
||||
@property
|
||||
def has_mistakes(self) -> bool:
|
||||
"""是否有错题记录"""
|
||||
return bool(self.mistake_records and self.mistake_records.strip())
|
||||
|
||||
|
||||
class ExamGeneratorService:
|
||||
"""
|
||||
试题生成服务 V2
|
||||
|
||||
使用 Python 原生实现。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
service = ExamGeneratorService()
|
||||
result = await service.generate_exam(
|
||||
db=db_session,
|
||||
config=ExamGeneratorConfig(
|
||||
course_id=1,
|
||||
position_id=1,
|
||||
single_choice_count=5,
|
||||
multiple_choice_count=3,
|
||||
difficulty_level=3
|
||||
)
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
self.ai_service = AIService(module_code="exam_generator")
|
||||
|
||||
async def generate_exam(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
config: ExamGeneratorConfig
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成考试题目(主入口)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config: 考试生成配置
|
||||
|
||||
Returns:
|
||||
生成结果,包含 success、questions、total_count 等字段
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"开始生成试题 - course_id: {config.course_id}, position_id: {config.position_id}, "
|
||||
f"total_count: {config.total_count}, has_mistakes: {config.has_mistakes}"
|
||||
)
|
||||
|
||||
# 根据是否有错题记录,走不同分支
|
||||
if config.has_mistakes:
|
||||
return await self._regenerate_from_mistakes(db, config)
|
||||
else:
|
||||
return await self._generate_from_knowledge(db, config)
|
||||
|
||||
except ExternalServiceError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"试题生成失败 - course_id: {config.course_id}, error: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise ExternalServiceError(f"试题生成失败: {e}")
|
||||
|
||||
async def _generate_from_knowledge(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
config: ExamGeneratorConfig
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
基于知识点生成题目(无错题模式)
|
||||
|
||||
流程:
|
||||
1. 查询岗位信息
|
||||
2. 随机查询知识点
|
||||
3. 调用 AI 生成题目
|
||||
4. 解析并返回结果
|
||||
"""
|
||||
# 1. 查询岗位信息
|
||||
position_info = await self._query_position(db, config.position_id)
|
||||
if not position_info:
|
||||
raise ExternalServiceError(f"岗位不存在: position_id={config.position_id}")
|
||||
|
||||
logger.info(f"岗位信息: {position_info.get('name', 'unknown')}")
|
||||
|
||||
# 2. 随机查询知识点
|
||||
knowledge_points = await self._query_knowledge_points(
|
||||
db,
|
||||
config.course_id,
|
||||
config.total_count
|
||||
)
|
||||
if not knowledge_points:
|
||||
raise ExternalServiceError(
|
||||
f"课程没有可用的知识点: course_id={config.course_id}"
|
||||
)
|
||||
|
||||
logger.info(f"查询到 {len(knowledge_points)} 个知识点")
|
||||
|
||||
# 3. 构建提示词
|
||||
system_prompt = SYSTEM_PROMPT.format(
|
||||
total_count=config.total_count,
|
||||
single_choice_count=config.single_choice_count,
|
||||
multiple_choice_count=config.multiple_choice_count,
|
||||
true_false_count=config.true_false_count,
|
||||
fill_blank_count=config.fill_blank_count,
|
||||
essay_count=config.essay_count,
|
||||
difficulty_level=config.difficulty_level,
|
||||
)
|
||||
|
||||
user_prompt = USER_PROMPT.format(
|
||||
position_info=self._format_position_info(position_info),
|
||||
knowledge_points=self._format_knowledge_points(knowledge_points),
|
||||
)
|
||||
|
||||
# 4. 调用 AI 生成
|
||||
ai_response = await self._call_ai_generate(system_prompt, user_prompt)
|
||||
|
||||
logger.info(
|
||||
f"AI 生成完成 - provider: {ai_response.provider}, "
|
||||
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
|
||||
)
|
||||
|
||||
# 5. 解析题目
|
||||
questions = self._parse_questions(ai_response.content)
|
||||
|
||||
logger.info(f"试题解析成功,数量: {len(questions)}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"questions": questions,
|
||||
"total_count": len(questions),
|
||||
"mode": "knowledge_based",
|
||||
"ai_provider": ai_response.provider,
|
||||
"ai_model": ai_response.model,
|
||||
"ai_tokens": ai_response.total_tokens,
|
||||
"ai_latency_ms": ai_response.latency_ms,
|
||||
}
|
||||
|
||||
async def _regenerate_from_mistakes(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
config: ExamGeneratorConfig
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
错题重出模式
|
||||
|
||||
流程:
|
||||
1. 构建错题重出提示词
|
||||
2. 调用 AI 生成新题
|
||||
3. 解析并返回结果
|
||||
"""
|
||||
logger.info("进入错题重出模式")
|
||||
|
||||
# 1. 构建提示词
|
||||
system_prompt = MISTAKE_REGEN_SYSTEM_PROMPT.format(
|
||||
difficulty_level=config.difficulty_level,
|
||||
)
|
||||
|
||||
user_prompt = MISTAKE_REGEN_USER_PROMPT.format(
|
||||
mistake_records=config.mistake_records,
|
||||
)
|
||||
|
||||
# 2. 调用 AI 生成
|
||||
ai_response = await self._call_ai_generate(system_prompt, user_prompt)
|
||||
|
||||
logger.info(
|
||||
f"错题重出完成 - provider: {ai_response.provider}, "
|
||||
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
|
||||
)
|
||||
|
||||
# 3. 解析题目
|
||||
questions = self._parse_questions(ai_response.content)
|
||||
|
||||
logger.info(f"错题重出解析成功,数量: {len(questions)}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"questions": questions,
|
||||
"total_count": len(questions),
|
||||
"mode": "mistake_regen",
|
||||
"ai_provider": ai_response.provider,
|
||||
"ai_model": ai_response.model,
|
||||
"ai_tokens": ai_response.total_tokens,
|
||||
"ai_latency_ms": ai_response.latency_ms,
|
||||
}
|
||||
|
||||
async def _query_position(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
position_id: int
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
查询岗位信息
|
||||
|
||||
SQL:SELECT id, name, description, skills, level FROM positions
|
||||
WHERE id = :id AND is_deleted = FALSE
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT id, name, description, skills, level
|
||||
FROM positions
|
||||
WHERE id = :position_id AND is_deleted = FALSE
|
||||
"""),
|
||||
{"position_id": position_id}
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
# 将 Row 转换为字典
|
||||
return {
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"description": row[2],
|
||||
"skills": row[3], # JSON 字段
|
||||
"level": row[4],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询岗位信息失败: {e}")
|
||||
raise ExternalServiceError(f"查询岗位信息失败: {e}")
|
||||
|
||||
async def _query_knowledge_points(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
limit: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机查询知识点
|
||||
|
||||
SQL:SELECT kp.id, kp.name, kp.description, kp.topic_relation
|
||||
FROM knowledge_points kp
|
||||
INNER JOIN course_materials cm ON kp.material_id = cm.id
|
||||
WHERE kp.course_id = :course_id
|
||||
AND kp.is_deleted = FALSE
|
||||
AND cm.is_deleted = FALSE
|
||||
ORDER BY RAND()
|
||||
LIMIT :limit
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
text("""
|
||||
SELECT kp.id, kp.name, kp.description, kp.topic_relation
|
||||
FROM knowledge_points kp
|
||||
INNER JOIN course_materials cm ON kp.material_id = cm.id
|
||||
WHERE kp.course_id = :course_id
|
||||
AND kp.is_deleted = FALSE
|
||||
AND cm.is_deleted = FALSE
|
||||
ORDER BY RAND()
|
||||
LIMIT :limit
|
||||
"""),
|
||||
{"course_id": course_id, "limit": limit}
|
||||
)
|
||||
rows = result.fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"description": row[2],
|
||||
"topic_relation": row[3],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询知识点失败: {e}")
|
||||
raise ExternalServiceError(f"查询知识点失败: {e}")
|
||||
|
||||
async def _call_ai_generate(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str
|
||||
) -> AIResponse:
|
||||
"""调用 AI 生成题目"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
response = await self.ai_service.chat(
|
||||
messages=messages,
|
||||
temperature=0.7, # 适当的创造性
|
||||
prompt_name="exam_generator"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _parse_questions(self, ai_output: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析 AI 输出的题目 JSON
|
||||
|
||||
使用 LLM JSON Parser 进行多层兜底解析
|
||||
"""
|
||||
# 先清洗输出
|
||||
cleaned_output, rules = clean_llm_output(ai_output)
|
||||
if rules:
|
||||
logger.debug(f"AI 输出已清洗: {rules}")
|
||||
|
||||
# 使用带 Schema 校验的解析
|
||||
questions = parse_with_fallback(
|
||||
cleaned_output,
|
||||
schema=QUESTION_SCHEMA,
|
||||
default=[],
|
||||
validate_schema=True,
|
||||
on_error="default"
|
||||
)
|
||||
|
||||
# 后处理:确保每个题目有必要字段
|
||||
processed_questions = []
|
||||
for i, q in enumerate(questions):
|
||||
if isinstance(q, dict):
|
||||
# 确保有 num 字段
|
||||
if "num" not in q:
|
||||
q["num"] = i + 1
|
||||
|
||||
# 确保 num 是整数
|
||||
try:
|
||||
q["num"] = int(q["num"])
|
||||
except (ValueError, TypeError):
|
||||
q["num"] = i + 1
|
||||
|
||||
# 确保有 type 字段
|
||||
if "type" not in q:
|
||||
# 根据是否有 options 推断类型
|
||||
if q.get("topic", {}).get("options"):
|
||||
q["type"] = "single_choice"
|
||||
else:
|
||||
q["type"] = "essay"
|
||||
|
||||
# 确保 knowledge_point_id 是整数或 None
|
||||
kp_id = q.get("knowledge_point_id")
|
||||
if kp_id is not None:
|
||||
try:
|
||||
q["knowledge_point_id"] = int(kp_id)
|
||||
except (ValueError, TypeError):
|
||||
q["knowledge_point_id"] = None
|
||||
|
||||
# 验证必要字段
|
||||
if q.get("topic") and q.get("correct"):
|
||||
processed_questions.append(q)
|
||||
else:
|
||||
logger.warning(f"题目缺少必要字段,已跳过: {q}")
|
||||
|
||||
if not processed_questions:
|
||||
logger.warning("未能解析出有效的题目")
|
||||
|
||||
return processed_questions
|
||||
|
||||
def _format_position_info(self, position: Dict[str, Any]) -> str:
|
||||
"""格式化岗位信息为文本"""
|
||||
lines = [
|
||||
f"岗位名称: {position.get('name', '未知')}",
|
||||
f"岗位等级: {position.get('level', '未设置')}",
|
||||
]
|
||||
|
||||
if position.get('description'):
|
||||
lines.append(f"岗位描述: {position['description']}")
|
||||
|
||||
skills = position.get('skills')
|
||||
if skills:
|
||||
# skills 可能是 JSON 字符串或列表
|
||||
if isinstance(skills, str):
|
||||
try:
|
||||
skills = json.loads(skills)
|
||||
except json.JSONDecodeError:
|
||||
skills = [skills]
|
||||
|
||||
if isinstance(skills, list) and skills:
|
||||
lines.append(f"核心技能: {', '.join(str(s) for s in skills)}")
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
def _format_knowledge_points(self, knowledge_points: List[Dict[str, Any]]) -> str:
|
||||
"""格式化知识点列表为文本"""
|
||||
lines = []
|
||||
for kp in knowledge_points:
|
||||
kp_text = f"【知识点 ID: {kp['id']}】{kp['name']}"
|
||||
if kp.get('description'):
|
||||
kp_text += f"\n{kp['description']}"
|
||||
if kp.get('topic_relation'):
|
||||
kp_text += f"\n关系描述: {kp['topic_relation']}"
|
||||
lines.append(kp_text)
|
||||
|
||||
return '\n\n'.join(lines)
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
exam_generator_service = ExamGeneratorService()
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
async def generate_exam(
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
position_id: int,
|
||||
single_choice_count: int = 4,
|
||||
multiple_choice_count: int = 2,
|
||||
true_false_count: int = 1,
|
||||
fill_blank_count: int = 2,
|
||||
essay_count: int = 1,
|
||||
difficulty_level: int = 3,
|
||||
mistake_records: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
便捷函数:生成考试题目
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
position_id: 岗位ID
|
||||
single_choice_count: 单选题数量
|
||||
multiple_choice_count: 多选题数量
|
||||
true_false_count: 判断题数量
|
||||
fill_blank_count: 填空题数量
|
||||
essay_count: 问答题数量
|
||||
difficulty_level: 难度等级(1-5)
|
||||
mistake_records: 错题记录JSON字符串
|
||||
|
||||
Returns:
|
||||
生成结果
|
||||
"""
|
||||
config = ExamGeneratorConfig(
|
||||
course_id=course_id,
|
||||
position_id=position_id,
|
||||
single_choice_count=single_choice_count,
|
||||
multiple_choice_count=multiple_choice_count,
|
||||
true_false_count=true_false_count,
|
||||
fill_blank_count=fill_blank_count,
|
||||
essay_count=essay_count,
|
||||
difficulty_level=difficulty_level,
|
||||
mistake_records=mistake_records,
|
||||
)
|
||||
|
||||
return await exam_generator_service.generate_exam(db, config)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
548
backend/app/services/ai/knowledge_analysis_v2.py
Normal file
548
backend/app/services/ai/knowledge_analysis_v2.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""
|
||||
知识点分析服务 V2 - Python 原生实现
|
||||
|
||||
功能:
|
||||
- 读取文档内容(PDF/Word/TXT)
|
||||
- 调用 AI 分析提取知识点
|
||||
- 解析 JSON 结果
|
||||
- 写入数据库
|
||||
|
||||
提供稳定可靠的知识点分析能力。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import ExternalServiceError
|
||||
from app.schemas.course import KnowledgePointCreate
|
||||
|
||||
from .ai_service import AIService, AIResponse
|
||||
from .llm_json_parser import parse_with_fallback, clean_llm_output
|
||||
from .prompts.knowledge_analysis_prompts import (
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
KNOWLEDGE_POINT_SCHEMA,
|
||||
DEFAULT_KNOWLEDGE_TYPE,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 配置常量
|
||||
STATIC_UPLOADS_PREFIX = '/static/uploads/'
|
||||
MAX_CONTENT_LENGTH = 100000 # 最大文档内容长度(字符)
|
||||
MAX_KNOWLEDGE_POINTS = 20 # 最大知识点数量
|
||||
|
||||
|
||||
class KnowledgeAnalysisServiceV2:
|
||||
"""
|
||||
知识点分析服务 V2
|
||||
|
||||
使用 Python 原生实现。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
service = KnowledgeAnalysisServiceV2()
|
||||
result = await service.analyze_course_material(
|
||||
db=db_session,
|
||||
course_id=1,
|
||||
material_id=10,
|
||||
file_url="/static/uploads/courses/1/doc.pdf",
|
||||
course_title="医美产品知识",
|
||||
user_id=1
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
self.ai_service = AIService(module_code="knowledge_analysis")
|
||||
self.upload_path = getattr(settings, 'UPLOAD_PATH', 'uploads')
|
||||
|
||||
async def analyze_course_material(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
material_id: int,
|
||||
file_url: str,
|
||||
course_title: str,
|
||||
user_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
分析课程资料并提取知识点
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
material_id: 资料ID
|
||||
file_url: 文件URL(相对路径)
|
||||
course_title: 课程标题
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
分析结果,包含 success、knowledge_points_count 等字段
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"开始知识点分析 V2 - course_id: {course_id}, material_id: {material_id}, "
|
||||
f"file_url: {file_url}"
|
||||
)
|
||||
|
||||
# 1. 解析文件路径
|
||||
file_path = self._resolve_file_path(file_url)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
logger.info(f"文件路径解析成功: {file_path}")
|
||||
|
||||
# 2. 提取文档内容
|
||||
content = await self._extract_document_content(file_path)
|
||||
if not content or not content.strip():
|
||||
raise ValueError("文档内容为空")
|
||||
|
||||
logger.info(f"文档内容提取成功,长度: {len(content)} 字符")
|
||||
|
||||
# 3. 调用 AI 分析
|
||||
ai_response = await self._call_ai_analysis(content, course_title)
|
||||
|
||||
logger.info(
|
||||
f"AI 分析完成 - provider: {ai_response.provider}, "
|
||||
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
|
||||
)
|
||||
|
||||
# 4. 解析 JSON 结果
|
||||
knowledge_points = self._parse_knowledge_points(ai_response.content)
|
||||
|
||||
logger.info(f"知识点解析成功,数量: {len(knowledge_points)}")
|
||||
|
||||
# 5. 删除旧的知识点
|
||||
await self._delete_old_knowledge_points(db, material_id)
|
||||
|
||||
# 6. 保存到数据库
|
||||
saved_count = await self._save_knowledge_points(
|
||||
db=db,
|
||||
course_id=course_id,
|
||||
material_id=material_id,
|
||||
knowledge_points=knowledge_points,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"知识点分析完成 - course_id: {course_id}, material_id: {material_id}, "
|
||||
f"saved_count: {saved_count}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"status": "completed",
|
||||
"knowledge_points_count": saved_count,
|
||||
"ai_provider": ai_response.provider,
|
||||
"ai_model": ai_response.model,
|
||||
"ai_tokens": ai_response.total_tokens,
|
||||
"ai_latency_ms": ai_response.latency_ms,
|
||||
}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"文件不存在: {e}")
|
||||
raise ExternalServiceError(f"分析文件不存在: {e}")
|
||||
except ValueError as e:
|
||||
logger.error(f"参数错误: {e}")
|
||||
raise ExternalServiceError(f"分析参数错误: {e}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"知识点分析失败 - course_id: {course_id}, material_id: {material_id}, "
|
||||
f"error: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise ExternalServiceError(f"知识点分析失败: {e}")
|
||||
|
||||
def _resolve_file_path(self, file_url: str) -> Path:
|
||||
"""解析文件 URL 为本地路径"""
|
||||
if file_url.startswith(STATIC_UPLOADS_PREFIX):
|
||||
relative_path = file_url.replace(STATIC_UPLOADS_PREFIX, '')
|
||||
return Path(self.upload_path) / relative_path
|
||||
elif file_url.startswith('/'):
|
||||
# 绝对路径
|
||||
return Path(file_url)
|
||||
else:
|
||||
# 相对路径
|
||||
return Path(self.upload_path) / file_url
|
||||
|
||||
async def _extract_document_content(self, file_path: Path) -> str:
|
||||
"""
|
||||
提取文档内容
|
||||
|
||||
支持:PDF、Word(docx)、文本文件
|
||||
"""
|
||||
suffix = file_path.suffix.lower()
|
||||
|
||||
try:
|
||||
if suffix == '.pdf':
|
||||
return await self._extract_pdf_content(file_path)
|
||||
elif suffix in ['.docx', '.doc']:
|
||||
return await self._extract_docx_content(file_path)
|
||||
elif suffix in ['.txt', '.md', '.text']:
|
||||
return await self._extract_text_content(file_path)
|
||||
else:
|
||||
# 尝试作为文本读取
|
||||
return await self._extract_text_content(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"文档内容提取失败: {file_path}, error: {e}")
|
||||
raise ValueError(f"无法读取文档内容: {e}")
|
||||
|
||||
async def _extract_pdf_content(self, file_path: Path) -> str:
|
||||
"""提取 PDF 内容"""
|
||||
try:
|
||||
from PyPDF2 import PdfReader
|
||||
|
||||
reader = PdfReader(str(file_path))
|
||||
text_parts = []
|
||||
|
||||
for page in reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
|
||||
content = '\n'.join(text_parts)
|
||||
|
||||
# 清理和截断
|
||||
content = self._clean_content(content)
|
||||
|
||||
return content
|
||||
|
||||
except ImportError:
|
||||
logger.error("PyPDF2 未安装,无法读取 PDF")
|
||||
raise ValueError("服务器未安装 PDF 读取组件")
|
||||
except Exception as e:
|
||||
logger.error(f"PDF 读取失败: {e}")
|
||||
raise ValueError(f"PDF 读取失败: {e}")
|
||||
|
||||
async def _extract_docx_content(self, file_path: Path) -> str:
|
||||
"""提取 Word 文档内容"""
|
||||
try:
|
||||
from docx import Document
|
||||
|
||||
doc = Document(str(file_path))
|
||||
text_parts = []
|
||||
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
text_parts.append(para.text)
|
||||
|
||||
# 也提取表格内容
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
for cell in row.cells:
|
||||
if cell.text.strip():
|
||||
text_parts.append(cell.text)
|
||||
|
||||
content = '\n'.join(text_parts)
|
||||
content = self._clean_content(content)
|
||||
|
||||
return content
|
||||
|
||||
except ImportError:
|
||||
logger.error("python-docx 未安装,无法读取 Word 文档")
|
||||
raise ValueError("服务器未安装 Word 读取组件")
|
||||
except Exception as e:
|
||||
logger.error(f"Word 文档读取失败: {e}")
|
||||
raise ValueError(f"Word 文档读取失败: {e}")
|
||||
|
||||
async def _extract_text_content(self, file_path: Path) -> str:
|
||||
"""提取文本文件内容"""
|
||||
try:
|
||||
# 尝试多种编码
|
||||
encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1']
|
||||
|
||||
for encoding in encodings:
|
||||
try:
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
content = f.read()
|
||||
return self._clean_content(content)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
raise ValueError("无法识别文件编码")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文本文件读取失败: {e}")
|
||||
raise ValueError(f"文本文件读取失败: {e}")
|
||||
|
||||
def _clean_content(self, content: str) -> str:
|
||||
"""清理和截断内容"""
|
||||
# 移除多余空白
|
||||
import re
|
||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||
content = re.sub(r' {2,}', ' ', content)
|
||||
|
||||
# 截断过长内容
|
||||
if len(content) > MAX_CONTENT_LENGTH:
|
||||
logger.warning(f"文档内容过长,截断至 {MAX_CONTENT_LENGTH} 字符")
|
||||
content = content[:MAX_CONTENT_LENGTH] + "\n\n[内容已截断...]"
|
||||
|
||||
return content.strip()
|
||||
|
||||
async def _call_ai_analysis(
|
||||
self,
|
||||
content: str,
|
||||
course_title: str
|
||||
) -> AIResponse:
|
||||
"""调用 AI 进行知识点分析"""
|
||||
# 构建消息
|
||||
user_message = USER_PROMPT.format(
|
||||
course_name=course_title,
|
||||
content=content
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
# 调用 AI
|
||||
response = await self.ai_service.chat(
|
||||
messages=messages,
|
||||
temperature=0.1, # 低温度,保持输出稳定
|
||||
prompt_name="knowledge_analysis"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _parse_knowledge_points(self, ai_output: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析 AI 输出的知识点 JSON
|
||||
|
||||
使用 LLM JSON Parser 进行多层兜底解析
|
||||
"""
|
||||
# 先清洗输出
|
||||
cleaned_output, rules = clean_llm_output(ai_output)
|
||||
if rules:
|
||||
logger.debug(f"AI 输出已清洗: {rules}")
|
||||
|
||||
# 使用带 Schema 校验的解析
|
||||
knowledge_points = parse_with_fallback(
|
||||
cleaned_output,
|
||||
schema=KNOWLEDGE_POINT_SCHEMA,
|
||||
default=[],
|
||||
validate_schema=True,
|
||||
on_error="default"
|
||||
)
|
||||
|
||||
# 后处理:确保每个知识点有必要字段
|
||||
processed_points = []
|
||||
for i, kp in enumerate(knowledge_points):
|
||||
if i >= MAX_KNOWLEDGE_POINTS:
|
||||
logger.warning(f"知识点数量超过限制 {MAX_KNOWLEDGE_POINTS},截断")
|
||||
break
|
||||
|
||||
if isinstance(kp, dict):
|
||||
# 提取字段(兼容多种字段名)
|
||||
title = (
|
||||
kp.get('title') or
|
||||
kp.get('name') or
|
||||
kp.get('知识点名称') or
|
||||
f"知识点 {i + 1}"
|
||||
)
|
||||
content = (
|
||||
kp.get('content') or
|
||||
kp.get('description') or
|
||||
kp.get('知识点描述') or
|
||||
''
|
||||
)
|
||||
kp_type = (
|
||||
kp.get('type') or
|
||||
kp.get('知识点类型') or
|
||||
DEFAULT_KNOWLEDGE_TYPE
|
||||
)
|
||||
topic_relation = (
|
||||
kp.get('topic_relation') or
|
||||
kp.get('关系描述') or
|
||||
''
|
||||
)
|
||||
|
||||
if title and (content or topic_relation):
|
||||
processed_points.append({
|
||||
'title': title[:200], # 限制长度
|
||||
'content': content,
|
||||
'type': kp_type,
|
||||
'topic_relation': topic_relation,
|
||||
})
|
||||
|
||||
if not processed_points:
|
||||
logger.warning("未能解析出有效的知识点")
|
||||
|
||||
return processed_points
|
||||
|
||||
async def _delete_old_knowledge_points(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
material_id: int
|
||||
) -> int:
|
||||
"""删除资料关联的旧知识点"""
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
|
||||
result = await db.execute(
|
||||
text("DELETE FROM knowledge_points WHERE material_id = :material_id"),
|
||||
{"material_id": material_id}
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
if deleted_count > 0:
|
||||
logger.info(f"已删除旧知识点: material_id={material_id}, count={deleted_count}")
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除旧知识点失败: {e}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
async def _save_knowledge_points(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
material_id: int,
|
||||
knowledge_points: List[Dict[str, Any]],
|
||||
user_id: int
|
||||
) -> int:
|
||||
"""保存知识点到数据库"""
|
||||
from app.services.course_service import knowledge_point_service
|
||||
|
||||
saved_count = 0
|
||||
|
||||
for kp_data in knowledge_points:
|
||||
try:
|
||||
kp_create = KnowledgePointCreate(
|
||||
name=kp_data['title'],
|
||||
description=kp_data.get('content', ''),
|
||||
type=kp_data.get('type', DEFAULT_KNOWLEDGE_TYPE),
|
||||
source=1, # AI 分析来源
|
||||
topic_relation=kp_data.get('topic_relation'),
|
||||
material_id=material_id
|
||||
)
|
||||
|
||||
await knowledge_point_service.create_knowledge_point(
|
||||
db=db,
|
||||
course_id=course_id,
|
||||
point_in=kp_create,
|
||||
created_by=user_id
|
||||
)
|
||||
saved_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"保存单个知识点失败: title={kp_data.get('title')}, error={e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return saved_count
|
||||
|
||||
async def reanalyze_course_materials(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
course_title: str,
|
||||
user_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
重新分析课程的所有资料
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
course_title: 课程标题
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
分析结果汇总
|
||||
"""
|
||||
try:
|
||||
from app.services.course_service import course_service
|
||||
|
||||
# 获取课程的所有资料
|
||||
materials = await course_service.get_course_materials(db, course_id=course_id)
|
||||
|
||||
if not materials:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "该课程暂无资料需要分析",
|
||||
"materials_count": 0,
|
||||
"knowledge_points_count": 0
|
||||
}
|
||||
|
||||
total_knowledge_points = 0
|
||||
analysis_results = []
|
||||
|
||||
for material in materials:
|
||||
try:
|
||||
result = await self.analyze_course_material(
|
||||
db=db,
|
||||
course_id=course_id,
|
||||
material_id=material.id,
|
||||
file_url=material.file_url,
|
||||
course_title=course_title,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
kp_count = result.get('knowledge_points_count', 0)
|
||||
total_knowledge_points += kp_count
|
||||
|
||||
analysis_results.append({
|
||||
"material_id": material.id,
|
||||
"material_name": material.name,
|
||||
"success": True,
|
||||
"knowledge_points_count": kp_count
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"资料分析失败: material_id={material.id}, error={e}"
|
||||
)
|
||||
analysis_results.append({
|
||||
"material_id": material.id,
|
||||
"material_name": material.name,
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
success_count = sum(1 for r in analysis_results if r['success'])
|
||||
|
||||
logger.info(
|
||||
f"课程资料重新分析完成 - course_id: {course_id}, "
|
||||
f"materials: {len(materials)}, success: {success_count}, "
|
||||
f"total_knowledge_points: {total_knowledge_points}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"materials_count": len(materials),
|
||||
"success_count": success_count,
|
||||
"knowledge_points_count": total_knowledge_points,
|
||||
"analysis_results": analysis_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"课程资料重新分析失败 - course_id: {course_id}, error: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise ExternalServiceError(f"重新分析失败: {e}")
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
knowledge_analysis_service_v2 = KnowledgeAnalysisServiceV2()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
707
backend/app/services/ai/llm_json_parser.py
Normal file
707
backend/app/services/ai/llm_json_parser.py
Normal file
@@ -0,0 +1,707 @@
|
||||
"""
|
||||
LLM JSON Parser - 大模型 JSON 输出解析器
|
||||
|
||||
功能:
|
||||
- 使用 json-repair 库修复 AI 输出的 JSON
|
||||
- 处理中文标点、尾部逗号、Python 风格等问题
|
||||
- Schema 校验确保数据完整性
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
from app.services.ai.llm_json_parser import parse_llm_json, parse_with_fallback
|
||||
|
||||
# 简单解析
|
||||
result = parse_llm_json(ai_response)
|
||||
|
||||
# 带 Schema 校验和默认值
|
||||
result = parse_with_fallback(
|
||||
ai_response,
|
||||
schema=MY_SCHEMA,
|
||||
default=[]
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 尝试导入 json-repair
|
||||
try:
|
||||
from json_repair import loads as json_repair_loads
|
||||
from json_repair import repair_json
|
||||
HAS_JSON_REPAIR = True
|
||||
except ImportError:
|
||||
HAS_JSON_REPAIR = False
|
||||
logger.warning("json-repair 未安装,将使用内置修复逻辑")
|
||||
|
||||
# 尝试导入 jsonschema
|
||||
try:
|
||||
from jsonschema import validate, ValidationError, Draft7Validator
|
||||
HAS_JSONSCHEMA = True
|
||||
except ImportError:
|
||||
HAS_JSONSCHEMA = False
|
||||
logger.warning("jsonschema 未安装,将跳过 Schema 校验")
|
||||
|
||||
|
||||
# ==================== 异常类 ====================
|
||||
|
||||
class JSONParseError(Exception):
|
||||
"""JSON 解析错误基类"""
|
||||
def __init__(self, message: str, raw_text: str = "", issues: List[dict] = None):
|
||||
super().__init__(message)
|
||||
self.raw_text = raw_text
|
||||
self.issues = issues or []
|
||||
|
||||
|
||||
class JSONUnrecoverableError(JSONParseError):
|
||||
"""不可恢复的 JSON 错误"""
|
||||
pass
|
||||
|
||||
|
||||
# ==================== 解析结果 ====================
|
||||
|
||||
@dataclass
|
||||
class ParseResult:
|
||||
"""解析结果"""
|
||||
success: bool
|
||||
data: Any = None
|
||||
method: str = "" # direct / json_repair / preprocessed / fixed / completed / default
|
||||
issues: List[dict] = field(default_factory=list)
|
||||
raw_text: str = ""
|
||||
error: str = ""
|
||||
|
||||
|
||||
# ==================== 核心解析函数 ====================
|
||||
|
||||
def parse_llm_json(
|
||||
text: str,
|
||||
*,
|
||||
strict: bool = False,
|
||||
return_result: bool = False
|
||||
) -> Union[Any, ParseResult]:
|
||||
"""
|
||||
智能解析 LLM 输出的 JSON
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
strict: 严格模式,不进行自动修复
|
||||
return_result: 返回 ParseResult 对象而非直接数据
|
||||
|
||||
Returns:
|
||||
解析后的 JSON 对象,或 ParseResult(如果 return_result=True)
|
||||
|
||||
Raises:
|
||||
JSONUnrecoverableError: 所有修复尝试都失败
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
if return_result:
|
||||
return ParseResult(success=False, error="Empty input")
|
||||
raise JSONUnrecoverableError("Empty input", text)
|
||||
|
||||
text = text.strip()
|
||||
issues = []
|
||||
|
||||
# 第一层:直接解析
|
||||
try:
|
||||
data = json.loads(text)
|
||||
result = ParseResult(success=True, data=data, method="direct", raw_text=text)
|
||||
return result if return_result else data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if strict:
|
||||
if return_result:
|
||||
return ParseResult(success=False, error="Strict mode: direct parse failed", raw_text=text)
|
||||
raise JSONUnrecoverableError("Strict mode: direct parse failed", text)
|
||||
|
||||
# 第二层:使用 json-repair(推荐)
|
||||
if HAS_JSON_REPAIR:
|
||||
try:
|
||||
data = json_repair_loads(text)
|
||||
issues.append({"type": "json_repair", "action": "Auto-repaired by json-repair library"})
|
||||
result = ParseResult(success=True, data=data, method="json_repair", issues=issues, raw_text=text)
|
||||
return result if return_result else data
|
||||
except Exception as e:
|
||||
logger.debug(f"json-repair 修复失败: {e}")
|
||||
|
||||
# 第三层:预处理(提取代码块、清理文字)
|
||||
preprocessed = _preprocess_text(text)
|
||||
if preprocessed != text:
|
||||
try:
|
||||
data = json.loads(preprocessed)
|
||||
issues.append({"type": "preprocessed", "action": "Extracted JSON from text"})
|
||||
result = ParseResult(success=True, data=data, method="preprocessed", issues=issues, raw_text=text)
|
||||
return result if return_result else data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 再次尝试 json-repair
|
||||
if HAS_JSON_REPAIR:
|
||||
try:
|
||||
data = json_repair_loads(preprocessed)
|
||||
issues.append({"type": "json_repair_preprocessed", "action": "Repaired after preprocessing"})
|
||||
result = ParseResult(success=True, data=data, method="json_repair", issues=issues, raw_text=text)
|
||||
return result if return_result else data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 第四层:自动修复
|
||||
fixed, fix_issues = _fix_json_format(preprocessed)
|
||||
issues.extend(fix_issues)
|
||||
|
||||
if fixed != preprocessed:
|
||||
try:
|
||||
data = json.loads(fixed)
|
||||
result = ParseResult(success=True, data=data, method="fixed", issues=issues, raw_text=text)
|
||||
return result if return_result else data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 第五层:尝试补全截断的 JSON
|
||||
completed = _try_complete_json(fixed)
|
||||
if completed:
|
||||
try:
|
||||
data = json.loads(completed)
|
||||
issues.append({"type": "completed", "action": "Auto-completed truncated JSON"})
|
||||
result = ParseResult(success=True, data=data, method="completed", issues=issues, raw_text=text)
|
||||
return result if return_result else data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 所有尝试都失败
|
||||
diagnosis = diagnose_json_error(fixed)
|
||||
if return_result:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
method="failed",
|
||||
issues=issues + diagnosis.get("issues", []),
|
||||
raw_text=text,
|
||||
error=f"All parse attempts failed. Issues: {diagnosis}"
|
||||
)
|
||||
raise JSONUnrecoverableError(f"All parse attempts failed: {diagnosis}", text, issues)
|
||||
|
||||
|
||||
def parse_with_fallback(
|
||||
raw_text: str,
|
||||
schema: dict = None,
|
||||
default: Any = None,
|
||||
*,
|
||||
validate_schema: bool = True,
|
||||
on_error: str = "default" # "default" / "raise" / "none"
|
||||
) -> Any:
|
||||
"""
|
||||
带兜底的 JSON 解析
|
||||
|
||||
Args:
|
||||
raw_text: 原始文本
|
||||
schema: JSON Schema(可选)
|
||||
default: 默认值
|
||||
validate_schema: 是否进行 Schema 校验
|
||||
on_error: 错误处理方式
|
||||
|
||||
Returns:
|
||||
解析后的数据或默认值
|
||||
"""
|
||||
try:
|
||||
result = parse_llm_json(raw_text, return_result=True)
|
||||
|
||||
if not result.success:
|
||||
logger.warning(f"JSON 解析失败: {result.error}")
|
||||
if on_error == "raise":
|
||||
raise JSONUnrecoverableError(result.error, raw_text, result.issues)
|
||||
elif on_error == "none":
|
||||
return None
|
||||
return default
|
||||
|
||||
data = result.data
|
||||
|
||||
# Schema 校验
|
||||
if validate_schema and schema and HAS_JSONSCHEMA:
|
||||
is_valid, errors = validate_json_schema(data, schema)
|
||||
if not is_valid:
|
||||
logger.warning(f"Schema 校验失败: {errors}")
|
||||
if on_error == "raise":
|
||||
raise JSONUnrecoverableError(f"Schema validation failed: {errors}", raw_text)
|
||||
elif on_error == "none":
|
||||
return None
|
||||
return default
|
||||
|
||||
# 记录解析方法
|
||||
if result.method != "direct":
|
||||
logger.info(f"JSON 解析成功: method={result.method}, issues={result.issues}")
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"JSON 解析异常: {e}")
|
||||
if on_error == "raise":
|
||||
raise
|
||||
elif on_error == "none":
|
||||
return None
|
||||
return default
|
||||
|
||||
|
||||
# ==================== 预处理函数 ====================
|
||||
|
||||
def _preprocess_text(text: str) -> str:
|
||||
"""预处理文本:提取代码块、清理前后文字"""
|
||||
# 移除 BOM
|
||||
text = text.lstrip('\ufeff')
|
||||
|
||||
# 移除零宽字符
|
||||
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
|
||||
|
||||
# 提取 Markdown 代码块
|
||||
patterns = [
|
||||
r'```json\s*([\s\S]*?)\s*```',
|
||||
r'```\s*([\s\S]*?)\s*```',
|
||||
r'`([^`]+)`',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
extracted = match.group(1).strip()
|
||||
if extracted.startswith(('{', '[')):
|
||||
text = extracted
|
||||
break
|
||||
|
||||
# 找到 JSON 边界
|
||||
text = _find_json_boundaries(text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _find_json_boundaries(text: str) -> str:
|
||||
"""找到 JSON 的起止位置"""
|
||||
# 找第一个 { 或 [
|
||||
start = -1
|
||||
for i, c in enumerate(text):
|
||||
if c in '{[':
|
||||
start = i
|
||||
break
|
||||
|
||||
if start == -1:
|
||||
return text
|
||||
|
||||
# 找最后一个匹配的 } 或 ]
|
||||
depth = 0
|
||||
end = -1
|
||||
in_string = False
|
||||
escape = False
|
||||
|
||||
for i in range(start, len(text)):
|
||||
c = text[i]
|
||||
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
|
||||
if c == '\\':
|
||||
escape = True
|
||||
continue
|
||||
|
||||
if c == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
continue
|
||||
|
||||
if c in '{[':
|
||||
depth += 1
|
||||
elif c in '}]':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
end = i + 1
|
||||
break
|
||||
|
||||
if end == -1:
|
||||
# 找最后一个 } 或 ]
|
||||
for i in range(len(text) - 1, start, -1):
|
||||
if text[i] in '}]':
|
||||
end = i + 1
|
||||
break
|
||||
|
||||
if end > start:
|
||||
return text[start:end]
|
||||
|
||||
return text[start:]
|
||||
|
||||
|
||||
# ==================== 修复函数 ====================
|
||||
|
||||
def _fix_json_format(text: str) -> Tuple[str, List[dict]]:
|
||||
"""修复常见 JSON 格式问题"""
|
||||
issues = []
|
||||
|
||||
# 1. 中文标点转英文
|
||||
cn_punctuation = {
|
||||
',': ',', '。': '.', ':': ':', ';': ';',
|
||||
'"': '"', '"': '"', ''': "'", ''': "'",
|
||||
'【': '[', '】': ']', '(': '(', ')': ')',
|
||||
'{': '{', '}': '}',
|
||||
}
|
||||
for cn, en in cn_punctuation.items():
|
||||
if cn in text:
|
||||
text = text.replace(cn, en)
|
||||
issues.append({"type": "chinese_punctuation", "from": cn, "to": en})
|
||||
|
||||
# 2. 移除注释
|
||||
if '//' in text:
|
||||
text = re.sub(r'//[^\n]*', '', text)
|
||||
issues.append({"type": "removed_comments", "style": "single-line"})
|
||||
|
||||
if '/*' in text:
|
||||
text = re.sub(r'/\*[\s\S]*?\*/', '', text)
|
||||
issues.append({"type": "removed_comments", "style": "multi-line"})
|
||||
|
||||
# 3. Python 风格转 JSON
|
||||
python_replacements = [
|
||||
(r'\bTrue\b', 'true'),
|
||||
(r'\bFalse\b', 'false'),
|
||||
(r'\bNone\b', 'null'),
|
||||
]
|
||||
for pattern, replacement in python_replacements:
|
||||
if re.search(pattern, text):
|
||||
text = re.sub(pattern, replacement, text)
|
||||
issues.append({"type": "python_style", "from": pattern, "to": replacement})
|
||||
|
||||
# 4. 移除尾部逗号
|
||||
trailing_comma_patterns = [
|
||||
(r',(\s*})', r'\1'),
|
||||
(r',(\s*\])', r'\1'),
|
||||
]
|
||||
for pattern, replacement in trailing_comma_patterns:
|
||||
if re.search(pattern, text):
|
||||
text = re.sub(pattern, replacement, text)
|
||||
issues.append({"type": "trailing_comma", "action": "removed"})
|
||||
|
||||
# 5. 修复单引号(谨慎处理)
|
||||
if text.count("'") > text.count('"') and re.match(r"^\s*\{?\s*'", text):
|
||||
text = re.sub(r"'([^']*)'(\s*:)", r'"\1"\2', text)
|
||||
text = re.sub(r":\s*'([^']*)'", r': "\1"', text)
|
||||
issues.append({"type": "single_quotes", "action": "replaced"})
|
||||
|
||||
return text, issues
|
||||
|
||||
|
||||
def _try_complete_json(text: str) -> Optional[str]:
|
||||
"""尝试补全截断的 JSON"""
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# 统计括号
|
||||
stack = []
|
||||
in_string = False
|
||||
escape = False
|
||||
|
||||
for c in text:
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
|
||||
if c == '\\':
|
||||
escape = True
|
||||
continue
|
||||
|
||||
if c == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
continue
|
||||
|
||||
if c in '{[':
|
||||
stack.append(c)
|
||||
elif c == '}':
|
||||
if stack and stack[-1] == '{':
|
||||
stack.pop()
|
||||
elif c == ']':
|
||||
if stack and stack[-1] == '[':
|
||||
stack.pop()
|
||||
|
||||
if not stack:
|
||||
return None # 已经平衡了
|
||||
|
||||
# 如果在字符串中,先闭合字符串
|
||||
if in_string:
|
||||
text += '"'
|
||||
|
||||
# 补全括号
|
||||
completion = ""
|
||||
for bracket in reversed(stack):
|
||||
if bracket == '{':
|
||||
completion += '}'
|
||||
elif bracket == '[':
|
||||
completion += ']'
|
||||
|
||||
return text + completion
|
||||
|
||||
|
||||
# ==================== Schema 校验 ====================
|
||||
|
||||
def validate_json_schema(data: Any, schema: dict) -> Tuple[bool, List[dict]]:
|
||||
"""
|
||||
校验 JSON 是否符合 Schema
|
||||
|
||||
Returns:
|
||||
(is_valid, errors)
|
||||
"""
|
||||
if not HAS_JSONSCHEMA:
|
||||
logger.warning("jsonschema 未安装,跳过校验")
|
||||
return True, []
|
||||
|
||||
try:
|
||||
validator = Draft7Validator(schema)
|
||||
errors = list(validator.iter_errors(data))
|
||||
|
||||
if errors:
|
||||
error_messages = [
|
||||
{
|
||||
"path": list(e.absolute_path),
|
||||
"message": e.message,
|
||||
"validator": e.validator
|
||||
}
|
||||
for e in errors
|
||||
]
|
||||
return False, error_messages
|
||||
|
||||
return True, []
|
||||
|
||||
except Exception as e:
|
||||
return False, [{"message": str(e)}]
|
||||
|
||||
|
||||
# ==================== 诊断函数 ====================
|
||||
|
||||
def diagnose_json_error(text: str) -> dict:
|
||||
"""诊断 JSON 错误"""
|
||||
issues = []
|
||||
|
||||
# 检查是否为空
|
||||
if not text or not text.strip():
|
||||
issues.append({
|
||||
"type": "empty_input",
|
||||
"severity": "critical",
|
||||
"suggestion": "输入为空"
|
||||
})
|
||||
return {"issues": issues, "fixable": False}
|
||||
|
||||
# 检查中文标点
|
||||
cn_punctuation = [',', '。', ':', ';', '"', '"', ''', ''']
|
||||
for p in cn_punctuation:
|
||||
if p in text:
|
||||
issues.append({
|
||||
"type": "chinese_punctuation",
|
||||
"char": p,
|
||||
"severity": "low",
|
||||
"suggestion": f"将 {p} 替换为对应英文标点"
|
||||
})
|
||||
|
||||
# 检查代码块包裹
|
||||
if '```' in text:
|
||||
issues.append({
|
||||
"type": "markdown_wrapped",
|
||||
"severity": "low",
|
||||
"suggestion": "需要提取代码块内容"
|
||||
})
|
||||
|
||||
# 检查注释
|
||||
if '//' in text or '/*' in text:
|
||||
issues.append({
|
||||
"type": "has_comments",
|
||||
"severity": "low",
|
||||
"suggestion": "需要移除注释"
|
||||
})
|
||||
|
||||
# 检查 Python 风格
|
||||
if re.search(r'\b(True|False|None)\b', text):
|
||||
issues.append({
|
||||
"type": "python_style",
|
||||
"severity": "low",
|
||||
"suggestion": "将 True/False/None 转为 true/false/null"
|
||||
})
|
||||
|
||||
# 检查尾部逗号
|
||||
if re.search(r',\s*[}\]]', text):
|
||||
issues.append({
|
||||
"type": "trailing_comma",
|
||||
"severity": "low",
|
||||
"suggestion": "移除 } 或 ] 前的逗号"
|
||||
})
|
||||
|
||||
# 检查括号平衡
|
||||
open_braces = text.count('{') - text.count('}')
|
||||
open_brackets = text.count('[') - text.count(']')
|
||||
|
||||
if open_braces > 0:
|
||||
issues.append({
|
||||
"type": "unclosed_brace",
|
||||
"count": open_braces,
|
||||
"severity": "medium",
|
||||
"suggestion": f"缺少 {open_braces} 个 }}"
|
||||
})
|
||||
elif open_braces < 0:
|
||||
issues.append({
|
||||
"type": "extra_brace",
|
||||
"count": -open_braces,
|
||||
"severity": "medium",
|
||||
"suggestion": f"多余 {-open_braces} 个 }}"
|
||||
})
|
||||
|
||||
if open_brackets > 0:
|
||||
issues.append({
|
||||
"type": "unclosed_bracket",
|
||||
"count": open_brackets,
|
||||
"severity": "medium",
|
||||
"suggestion": f"缺少 {open_brackets} 个 ]"
|
||||
})
|
||||
elif open_brackets < 0:
|
||||
issues.append({
|
||||
"type": "extra_bracket",
|
||||
"count": -open_brackets,
|
||||
"severity": "medium",
|
||||
"suggestion": f"多余 {-open_brackets} 个 ]"
|
||||
})
|
||||
|
||||
# 检查引号平衡
|
||||
quote_count = text.count('"')
|
||||
if quote_count % 2 != 0:
|
||||
issues.append({
|
||||
"type": "unbalanced_quotes",
|
||||
"severity": "high",
|
||||
"suggestion": "引号数量不平衡,可能有未闭合的字符串"
|
||||
})
|
||||
|
||||
# 判断是否可修复
|
||||
fixable_types = {
|
||||
"chinese_punctuation", "markdown_wrapped", "has_comments",
|
||||
"python_style", "trailing_comma", "unclosed_brace", "unclosed_bracket"
|
||||
}
|
||||
fixable = all(i["type"] in fixable_types for i in issues)
|
||||
|
||||
return {
|
||||
"issues": issues,
|
||||
"issue_count": len(issues),
|
||||
"fixable": fixable,
|
||||
"severity": max(
|
||||
(i.get("severity", "low") for i in issues),
|
||||
key=lambda x: {"low": 1, "medium": 2, "high": 3, "critical": 4}.get(x, 0),
|
||||
default="low"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
def safe_json_loads(text: str, default: Any = None) -> Any:
|
||||
"""安全的 json.loads,失败返回默认值"""
|
||||
try:
|
||||
return parse_llm_json(text)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def extract_json_from_text(text: str) -> Optional[str]:
|
||||
"""从文本中提取 JSON 字符串"""
|
||||
preprocessed = _preprocess_text(text)
|
||||
fixed, _ = _fix_json_format(preprocessed)
|
||||
|
||||
try:
|
||||
json.loads(fixed)
|
||||
return fixed
|
||||
except Exception:
|
||||
completed = _try_complete_json(fixed)
|
||||
if completed:
|
||||
try:
|
||||
json.loads(completed)
|
||||
return completed
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def clean_llm_output(text: str) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
清洗大模型输出,返回清洗后的文本和应用的清洗规则
|
||||
|
||||
Args:
|
||||
text: 原始输出文本
|
||||
|
||||
Returns:
|
||||
(cleaned_text, applied_rules)
|
||||
"""
|
||||
if not text:
|
||||
return "", ["empty_input"]
|
||||
|
||||
applied_rules = []
|
||||
original = text
|
||||
|
||||
# 1. 去除 BOM 头
|
||||
if text.startswith('\ufeff'):
|
||||
text = text.lstrip('\ufeff')
|
||||
applied_rules.append("removed_bom")
|
||||
|
||||
# 2. 去除 ANSI 转义序列
|
||||
ansi_pattern = re.compile(r'\x1b\[[0-9;]*m')
|
||||
if ansi_pattern.search(text):
|
||||
text = ansi_pattern.sub('', text)
|
||||
applied_rules.append("removed_ansi")
|
||||
|
||||
# 3. 去除首尾空白
|
||||
text = text.strip()
|
||||
|
||||
# 4. 去除开头的客套话
|
||||
polite_patterns = [
|
||||
r'^好的[,,。.]?\s*',
|
||||
r'^当然[,,。.]?\s*',
|
||||
r'^没问题[,,。.]?\s*',
|
||||
r'^根据您的要求[,,。.]?\s*',
|
||||
r'^以下是.*?[::]\s*',
|
||||
r'^分析结果如下[::]\s*',
|
||||
r'^我来为您.*?[::]\s*',
|
||||
r'^这是.*?结果[::]\s*',
|
||||
]
|
||||
for pattern in polite_patterns:
|
||||
if re.match(pattern, text, re.IGNORECASE):
|
||||
text = re.sub(pattern, '', text, flags=re.IGNORECASE)
|
||||
applied_rules.append("removed_polite_prefix")
|
||||
break
|
||||
|
||||
# 5. 提取 Markdown JSON 代码块
|
||||
json_block_patterns = [
|
||||
r'```json\s*([\s\S]*?)\s*```',
|
||||
r'```\s*([\s\S]*?)\s*```',
|
||||
]
|
||||
for pattern in json_block_patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE)
|
||||
if match:
|
||||
extracted = match.group(1).strip()
|
||||
if extracted.startswith(('{', '[')):
|
||||
text = extracted
|
||||
applied_rules.append("extracted_code_block")
|
||||
break
|
||||
|
||||
# 6. 处理零宽字符
|
||||
zero_width = re.compile(r'[\u200b\u200c\u200d\ufeff]')
|
||||
if zero_width.search(text):
|
||||
text = zero_width.sub('', text)
|
||||
applied_rules.append("removed_zero_width")
|
||||
|
||||
return text.strip(), applied_rules
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
377
backend/app/services/ai/practice_analysis_service.py
Normal file
377
backend/app/services/ai/practice_analysis_service.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
陪练分析报告服务 - Python 原生实现
|
||||
|
||||
功能:
|
||||
- 分析陪练对话历史
|
||||
- 生成综合评分、能力维度评估
|
||||
- 提供对话标注和改进建议
|
||||
|
||||
提供稳定可靠的陪练分析报告生成能力。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .ai_service import AIService, AIResponse
|
||||
from .llm_json_parser import parse_with_fallback, clean_llm_output
|
||||
from .prompts.practice_analysis_prompts import (
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
PRACTICE_ANALYSIS_SCHEMA,
|
||||
SCORE_BREAKDOWN_ITEMS,
|
||||
ABILITY_DIMENSIONS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ==================== 数据结构 ====================
|
||||
|
||||
@dataclass
|
||||
class ScoreBreakdownItem:
|
||||
"""分数细分项"""
|
||||
name: str
|
||||
score: float
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbilityDimensionItem:
|
||||
"""能力维度项"""
|
||||
name: str
|
||||
score: float
|
||||
feedback: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialogueAnnotation:
|
||||
"""对话标注"""
|
||||
sequence: int
|
||||
tags: List[str]
|
||||
comment: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Suggestion:
|
||||
"""改进建议"""
|
||||
title: str
|
||||
content: str
|
||||
example: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PracticeAnalysisResult:
|
||||
"""陪练分析结果"""
|
||||
success: bool
|
||||
total_score: float = 0.0
|
||||
score_breakdown: List[ScoreBreakdownItem] = field(default_factory=list)
|
||||
ability_dimensions: List[AbilityDimensionItem] = field(default_factory=list)
|
||||
dialogue_annotations: List[DialogueAnnotation] = field(default_factory=list)
|
||||
suggestions: List[Suggestion] = field(default_factory=list)
|
||||
ai_provider: str = ""
|
||||
ai_model: str = ""
|
||||
ai_tokens: int = 0
|
||||
ai_latency_ms: int = 0
|
||||
error: str = ""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典(兼容原有数据格式)"""
|
||||
return {
|
||||
"analysis": {
|
||||
"total_score": self.total_score,
|
||||
"score_breakdown": [
|
||||
{"name": s.name, "score": s.score, "description": s.description}
|
||||
for s in self.score_breakdown
|
||||
],
|
||||
"ability_dimensions": [
|
||||
{"name": d.name, "score": d.score, "feedback": d.feedback}
|
||||
for d in self.ability_dimensions
|
||||
],
|
||||
"dialogue_annotations": [
|
||||
{"sequence": a.sequence, "tags": a.tags, "comment": a.comment}
|
||||
for a in self.dialogue_annotations
|
||||
],
|
||||
"suggestions": [
|
||||
{"title": s.title, "content": s.content, "example": s.example}
|
||||
for s in self.suggestions
|
||||
],
|
||||
},
|
||||
"ai_provider": self.ai_provider,
|
||||
"ai_model": self.ai_model,
|
||||
"ai_tokens": self.ai_tokens,
|
||||
"ai_latency_ms": self.ai_latency_ms,
|
||||
}
|
||||
|
||||
def to_db_format(self) -> Dict[str, Any]:
|
||||
"""转换为数据库存储格式(兼容 PracticeReport 模型)"""
|
||||
return {
|
||||
"total_score": int(self.total_score),
|
||||
"score_breakdown": [
|
||||
{"name": s.name, "score": s.score, "description": s.description}
|
||||
for s in self.score_breakdown
|
||||
],
|
||||
"ability_dimensions": [
|
||||
{"name": d.name, "score": d.score, "feedback": d.feedback}
|
||||
for d in self.ability_dimensions
|
||||
],
|
||||
"dialogue_review": [
|
||||
{"sequence": a.sequence, "tags": a.tags, "comment": a.comment}
|
||||
for a in self.dialogue_annotations
|
||||
],
|
||||
"suggestions": [
|
||||
{"title": s.title, "content": s.content, "example": s.example}
|
||||
for s in self.suggestions
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ==================== 服务类 ====================
|
||||
|
||||
class PracticeAnalysisService:
|
||||
"""
|
||||
陪练分析报告服务
|
||||
|
||||
使用 Python 原生实现。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
service = PracticeAnalysisService()
|
||||
result = await service.analyze(
|
||||
db=db_session, # 传入 db_session 用于记录调用日志
|
||||
dialogue_history=[
|
||||
{"speaker": "user", "content": "您好,我想咨询一下..."},
|
||||
{"speaker": "ai", "content": "您好!很高兴为您服务..."}
|
||||
]
|
||||
)
|
||||
print(result.total_score)
|
||||
print(result.suggestions)
|
||||
```
|
||||
"""
|
||||
|
||||
MODULE_CODE = "practice_analysis"
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
dialogue_history: List[Dict[str, Any]],
|
||||
db: Any = None # 数据库会话,用于记录 AI 调用日志
|
||||
) -> PracticeAnalysisResult:
|
||||
"""
|
||||
分析陪练对话
|
||||
|
||||
Args:
|
||||
dialogue_history: 对话历史列表,每项包含 speaker, content, timestamp 等字段
|
||||
db: 数据库会话,用于记录调用日志(符合 AI 接入规范)
|
||||
|
||||
Returns:
|
||||
PracticeAnalysisResult 分析结果
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始分析陪练对话 - 对话轮次: {len(dialogue_history)}")
|
||||
|
||||
# 1. 验证输入
|
||||
if not dialogue_history or len(dialogue_history) < 2:
|
||||
return PracticeAnalysisResult(
|
||||
success=False,
|
||||
error="对话记录太少,无法生成分析报告(至少需要2轮对话)"
|
||||
)
|
||||
|
||||
# 2. 格式化对话历史
|
||||
dialogue_text = self._format_dialogue_history(dialogue_history)
|
||||
|
||||
# 3. 创建 AIService 实例(传入 db_session 用于记录调用日志)
|
||||
self._ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
|
||||
|
||||
# 4. 调用 AI 分析
|
||||
ai_response = await self._call_ai_analysis(dialogue_text)
|
||||
|
||||
logger.info(
|
||||
f"AI 分析完成 - provider: {ai_response.provider}, "
|
||||
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
|
||||
)
|
||||
|
||||
# 4. 解析 JSON 结果
|
||||
analysis_data = self._parse_analysis_result(ai_response.content)
|
||||
|
||||
# 5. 构建返回结果
|
||||
result = PracticeAnalysisResult(
|
||||
success=True,
|
||||
total_score=analysis_data.get("total_score", 0),
|
||||
score_breakdown=[
|
||||
ScoreBreakdownItem(
|
||||
name=s.get("name", ""),
|
||||
score=s.get("score", 0),
|
||||
description=s.get("description", "")
|
||||
)
|
||||
for s in analysis_data.get("score_breakdown", [])
|
||||
],
|
||||
ability_dimensions=[
|
||||
AbilityDimensionItem(
|
||||
name=d.get("name", ""),
|
||||
score=d.get("score", 0),
|
||||
feedback=d.get("feedback", "")
|
||||
)
|
||||
for d in analysis_data.get("ability_dimensions", [])
|
||||
],
|
||||
dialogue_annotations=[
|
||||
DialogueAnnotation(
|
||||
sequence=a.get("sequence", 0),
|
||||
tags=a.get("tags", []),
|
||||
comment=a.get("comment", "")
|
||||
)
|
||||
for a in analysis_data.get("dialogue_annotations", [])
|
||||
],
|
||||
suggestions=[
|
||||
Suggestion(
|
||||
title=s.get("title", ""),
|
||||
content=s.get("content", ""),
|
||||
example=s.get("example", "")
|
||||
)
|
||||
for s in analysis_data.get("suggestions", [])
|
||||
],
|
||||
ai_provider=ai_response.provider,
|
||||
ai_model=ai_response.model,
|
||||
ai_tokens=ai_response.total_tokens,
|
||||
ai_latency_ms=ai_response.latency_ms,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"陪练分析完成 - total_score: {result.total_score}, "
|
||||
f"annotations: {len(result.dialogue_annotations)}, "
|
||||
f"suggestions: {len(result.suggestions)}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"陪练分析失败: {e}", exc_info=True)
|
||||
return PracticeAnalysisResult(
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def _format_dialogue_history(self, dialogue_history: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
格式化对话历史为文本
|
||||
|
||||
Args:
|
||||
dialogue_history: 对话历史列表
|
||||
|
||||
Returns:
|
||||
格式化后的对话文本
|
||||
"""
|
||||
lines = []
|
||||
for i, d in enumerate(dialogue_history, 1):
|
||||
speaker = d.get('speaker', 'unknown')
|
||||
content = d.get('content', '')
|
||||
|
||||
# 统一说话者标识
|
||||
if speaker in ['user', 'employee', 'consultant', '员工', '用户']:
|
||||
speaker_label = '员工'
|
||||
elif speaker in ['ai', 'customer', 'client', '顾客', '客户', 'AI']:
|
||||
speaker_label = '顾客'
|
||||
else:
|
||||
speaker_label = speaker
|
||||
|
||||
lines.append(f"[{i}] {speaker_label}: {content}")
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
async def _call_ai_analysis(self, dialogue_text: str) -> AIResponse:
|
||||
"""调用 AI 进行分析"""
|
||||
user_message = USER_PROMPT.format(dialogue_history=dialogue_text)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
response = await self._ai_service.chat(
|
||||
messages=messages,
|
||||
temperature=0.7,
|
||||
prompt_name="practice_analysis"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _parse_analysis_result(self, ai_output: str) -> Dict[str, Any]:
|
||||
"""
|
||||
解析 AI 输出的分析结果 JSON
|
||||
|
||||
使用 LLM JSON Parser 进行多层兜底解析
|
||||
"""
|
||||
# 先清洗输出
|
||||
cleaned_output, rules = clean_llm_output(ai_output)
|
||||
if rules:
|
||||
logger.debug(f"AI 输出已清洗: {rules}")
|
||||
|
||||
# 使用带 Schema 校验的解析
|
||||
parsed = parse_with_fallback(
|
||||
cleaned_output,
|
||||
schema=PRACTICE_ANALYSIS_SCHEMA,
|
||||
default={"analysis": {}},
|
||||
validate_schema=True,
|
||||
on_error="default"
|
||||
)
|
||||
|
||||
# 提取 analysis 部分
|
||||
analysis = parsed.get("analysis", {})
|
||||
|
||||
# 确保 score_breakdown 完整
|
||||
existing_breakdown = {s.get("name") for s in analysis.get("score_breakdown", [])}
|
||||
for item_name in SCORE_BREAKDOWN_ITEMS:
|
||||
if item_name not in existing_breakdown:
|
||||
logger.warning(f"缺少分数维度: {item_name},使用默认值")
|
||||
analysis.setdefault("score_breakdown", []).append({
|
||||
"name": item_name,
|
||||
"score": 75,
|
||||
"description": "暂无详细评价"
|
||||
})
|
||||
|
||||
# 确保 ability_dimensions 完整
|
||||
existing_dims = {d.get("name") for d in analysis.get("ability_dimensions", [])}
|
||||
for dim_name in ABILITY_DIMENSIONS:
|
||||
if dim_name not in existing_dims:
|
||||
logger.warning(f"缺少能力维度: {dim_name},使用默认值")
|
||||
analysis.setdefault("ability_dimensions", []).append({
|
||||
"name": dim_name,
|
||||
"score": 75,
|
||||
"feedback": "暂无详细评价"
|
||||
})
|
||||
|
||||
# 确保有建议
|
||||
if not analysis.get("suggestions"):
|
||||
analysis["suggestions"] = [
|
||||
{
|
||||
"title": "持续练习",
|
||||
"content": "建议继续进行陪练练习,提升整体表现",
|
||||
"example": "每周进行2-3次陪练,针对薄弱环节重点练习"
|
||||
}
|
||||
]
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
# ==================== 全局实例 ====================
|
||||
|
||||
practice_analysis_service = PracticeAnalysisService()
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
async def analyze_practice_session(
|
||||
dialogue_history: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
便捷函数:分析陪练会话
|
||||
|
||||
Args:
|
||||
dialogue_history: 对话历史列表
|
||||
|
||||
Returns:
|
||||
分析结果字典(兼容原有格式)
|
||||
"""
|
||||
result = await practice_analysis_service.analyze(dialogue_history)
|
||||
return result.to_dict()
|
||||
|
||||
379
backend/app/services/ai/practice_scene_service.py
Normal file
379
backend/app/services/ai/practice_scene_service.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
陪练场景准备服务 - Python 原生实现
|
||||
|
||||
功能:
|
||||
- 根据课程ID获取知识点
|
||||
- 调用 AI 生成陪练场景配置
|
||||
- 解析并返回结构化场景数据
|
||||
|
||||
提供稳定可靠的陪练场景提取能力。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import ExternalServiceError
|
||||
|
||||
from .ai_service import AIService, AIResponse
|
||||
from .llm_json_parser import parse_with_fallback, clean_llm_output
|
||||
from .prompts.practice_scene_prompts import (
|
||||
SYSTEM_PROMPT,
|
||||
USER_PROMPT,
|
||||
PRACTICE_SCENE_SCHEMA,
|
||||
DEFAULT_SCENE_TYPE,
|
||||
DEFAULT_DIFFICULTY,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ==================== 数据结构 ====================
|
||||
|
||||
@dataclass
|
||||
class PracticeScene:
|
||||
"""陪练场景数据结构"""
|
||||
name: str
|
||||
description: str
|
||||
background: str
|
||||
ai_role: str
|
||||
objectives: List[str]
|
||||
keywords: List[str]
|
||||
type: str = DEFAULT_SCENE_TYPE
|
||||
difficulty: str = DEFAULT_DIFFICULTY
|
||||
|
||||
|
||||
@dataclass
|
||||
class PracticeSceneResult:
|
||||
"""陪练场景生成结果"""
|
||||
success: bool
|
||||
scene: Optional[PracticeScene] = None
|
||||
raw_response: Dict[str, Any] = field(default_factory=dict)
|
||||
ai_provider: str = ""
|
||||
ai_model: str = ""
|
||||
ai_tokens: int = 0
|
||||
ai_latency_ms: int = 0
|
||||
knowledge_points_count: int = 0
|
||||
error: str = ""
|
||||
|
||||
|
||||
# ==================== 服务类 ====================
|
||||
|
||||
class PracticeSceneService:
|
||||
"""
|
||||
陪练场景准备服务
|
||||
|
||||
使用 Python 原生实现。
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
service = PracticeSceneService()
|
||||
result = await service.prepare_practice_knowledge(
|
||||
db=db_session,
|
||||
course_id=1
|
||||
)
|
||||
if result.success:
|
||||
print(result.scene.name)
|
||||
print(result.scene.objectives)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
self.ai_service = AIService(module_code="practice_scene")
|
||||
|
||||
async def prepare_practice_knowledge(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int
|
||||
) -> PracticeSceneResult:
|
||||
"""
|
||||
准备陪练所需的知识内容并生成场景
|
||||
|
||||
陪练知识准备的 Python 实现。
|
||||
|
||||
Args:
|
||||
db: 数据库会话(支持多租户,由调用方传入对应租户的数据库连接)
|
||||
course_id: 课程ID
|
||||
|
||||
Returns:
|
||||
PracticeSceneResult: 包含场景配置和元信息的结果对象
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始陪练知识准备 - course_id: {course_id}")
|
||||
|
||||
# 1. 查询知识点
|
||||
knowledge_points = await self._fetch_knowledge_points(db, course_id)
|
||||
|
||||
if not knowledge_points:
|
||||
logger.warning(f"课程没有知识点 - course_id: {course_id}")
|
||||
return PracticeSceneResult(
|
||||
success=False,
|
||||
error=f"课程 {course_id} 没有可用的知识点"
|
||||
)
|
||||
|
||||
logger.info(f"获取到 {len(knowledge_points)} 个知识点 - course_id: {course_id}")
|
||||
|
||||
# 2. 格式化知识点为文本
|
||||
knowledge_text = self._format_knowledge_points(knowledge_points)
|
||||
|
||||
# 3. 调用 AI 生成场景
|
||||
ai_response = await self._call_ai_generation(knowledge_text)
|
||||
|
||||
logger.info(
|
||||
f"AI 生成完成 - provider: {ai_response.provider}, "
|
||||
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
|
||||
)
|
||||
|
||||
# 4. 解析 JSON 结果
|
||||
scene_data = self._parse_scene_response(ai_response.content)
|
||||
|
||||
if not scene_data:
|
||||
logger.error(f"场景解析失败 - course_id: {course_id}")
|
||||
return PracticeSceneResult(
|
||||
success=False,
|
||||
raw_response={"ai_output": ai_response.content},
|
||||
ai_provider=ai_response.provider,
|
||||
ai_model=ai_response.model,
|
||||
ai_tokens=ai_response.total_tokens,
|
||||
ai_latency_ms=ai_response.latency_ms,
|
||||
knowledge_points_count=len(knowledge_points),
|
||||
error="AI 输出解析失败"
|
||||
)
|
||||
|
||||
# 5. 构建场景对象
|
||||
scene = self._build_scene_object(scene_data)
|
||||
|
||||
logger.info(
|
||||
f"陪练场景生成成功 - course_id: {course_id}, "
|
||||
f"scene_name: {scene.name}, type: {scene.type}"
|
||||
)
|
||||
|
||||
return PracticeSceneResult(
|
||||
success=True,
|
||||
scene=scene,
|
||||
raw_response=scene_data,
|
||||
ai_provider=ai_response.provider,
|
||||
ai_model=ai_response.model,
|
||||
ai_tokens=ai_response.total_tokens,
|
||||
ai_latency_ms=ai_response.latency_ms,
|
||||
knowledge_points_count=len(knowledge_points)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"陪练知识准备失败 - course_id: {course_id}, error: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return PracticeSceneResult(
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def _fetch_knowledge_points(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从数据库获取课程知识点
|
||||
|
||||
获取课程知识点
|
||||
"""
|
||||
# 知识点查询 SQL:
|
||||
# SELECT kp.name, kp.description
|
||||
# FROM knowledge_points kp
|
||||
# INNER JOIN course_materials cm ON kp.material_id = cm.id
|
||||
# WHERE kp.course_id = {course_id}
|
||||
# AND kp.is_deleted = 0
|
||||
# AND cm.is_deleted = 0
|
||||
# ORDER BY kp.id;
|
||||
|
||||
sql = text("""
|
||||
SELECT kp.name, kp.description
|
||||
FROM knowledge_points kp
|
||||
INNER JOIN course_materials cm ON kp.material_id = cm.id
|
||||
WHERE kp.course_id = :course_id
|
||||
AND kp.is_deleted = 0
|
||||
AND cm.is_deleted = 0
|
||||
ORDER BY kp.id
|
||||
""")
|
||||
|
||||
try:
|
||||
result = await db.execute(sql, {"course_id": course_id})
|
||||
rows = result.fetchall()
|
||||
|
||||
knowledge_points = []
|
||||
for row in rows:
|
||||
knowledge_points.append({
|
||||
"name": row[0],
|
||||
"description": row[1] or ""
|
||||
})
|
||||
|
||||
return knowledge_points
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询知识点失败: {e}")
|
||||
raise ExternalServiceError(f"数据库查询失败: {e}")
|
||||
|
||||
def _format_knowledge_points(self, knowledge_points: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
将知识点列表格式化为文本
|
||||
|
||||
Args:
|
||||
knowledge_points: 知识点列表
|
||||
|
||||
Returns:
|
||||
格式化后的文本
|
||||
"""
|
||||
lines = []
|
||||
for i, kp in enumerate(knowledge_points, 1):
|
||||
name = kp.get("name", "")
|
||||
description = kp.get("description", "")
|
||||
|
||||
if description:
|
||||
lines.append(f"{i}. {name}\n {description}")
|
||||
else:
|
||||
lines.append(f"{i}. {name}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
async def _call_ai_generation(self, knowledge_text: str) -> AIResponse:
|
||||
"""
|
||||
调用 AI 生成陪练场景
|
||||
|
||||
Args:
|
||||
knowledge_text: 格式化后的知识点文本
|
||||
|
||||
Returns:
|
||||
AI 响应对象
|
||||
"""
|
||||
# 构建用户消息
|
||||
user_message = USER_PROMPT.format(knowledge_points=knowledge_text)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
# 调用 AI(自动降级:4sapi.com → OpenRouter)
|
||||
response = await self.ai_service.chat(
|
||||
messages=messages,
|
||||
temperature=0.7, # 适中的创意性
|
||||
prompt_name="practice_scene_generation"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _parse_scene_response(self, ai_output: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析 AI 输出的场景 JSON
|
||||
|
||||
使用 LLM JSON Parser 进行多层兜底解析
|
||||
|
||||
Args:
|
||||
ai_output: AI 原始输出
|
||||
|
||||
Returns:
|
||||
解析后的字典,失败返回 None
|
||||
"""
|
||||
# 先清洗输出
|
||||
cleaned_output, rules = clean_llm_output(ai_output)
|
||||
if rules:
|
||||
logger.debug(f"AI 输出已清洗: {rules}")
|
||||
|
||||
# 使用带 Schema 校验的解析
|
||||
result = parse_with_fallback(
|
||||
cleaned_output,
|
||||
schema=PRACTICE_SCENE_SCHEMA,
|
||||
default=None,
|
||||
validate_schema=True,
|
||||
on_error="none"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _build_scene_object(self, scene_data: Dict[str, Any]) -> PracticeScene:
|
||||
"""
|
||||
从解析的字典构建场景对象
|
||||
|
||||
Args:
|
||||
scene_data: 解析后的场景数据
|
||||
|
||||
Returns:
|
||||
PracticeScene 对象
|
||||
"""
|
||||
# 提取 scene 字段(JSON 格式为 {"scene": {...}})
|
||||
scene = scene_data.get("scene", scene_data)
|
||||
|
||||
return PracticeScene(
|
||||
name=scene.get("name", "陪练场景"),
|
||||
description=scene.get("description", ""),
|
||||
background=scene.get("background", ""),
|
||||
ai_role=scene.get("ai_role", "AI扮演客户"),
|
||||
objectives=scene.get("objectives", []),
|
||||
keywords=scene.get("keywords", []),
|
||||
type=scene.get("type", DEFAULT_SCENE_TYPE),
|
||||
difficulty=scene.get("difficulty", DEFAULT_DIFFICULTY)
|
||||
)
|
||||
|
||||
def scene_to_dict(self, scene: PracticeScene) -> Dict[str, Any]:
|
||||
"""
|
||||
将场景对象转换为字典
|
||||
|
||||
便于 API 响应序列化
|
||||
|
||||
Args:
|
||||
scene: PracticeScene 对象
|
||||
|
||||
Returns:
|
||||
字典格式的场景数据
|
||||
"""
|
||||
return {
|
||||
"scene": {
|
||||
"name": scene.name,
|
||||
"description": scene.description,
|
||||
"background": scene.background,
|
||||
"ai_role": scene.ai_role,
|
||||
"objectives": scene.objectives,
|
||||
"keywords": scene.keywords,
|
||||
"type": scene.type,
|
||||
"difficulty": scene.difficulty
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 全局实例 ====================
|
||||
|
||||
practice_scene_service = PracticeSceneService()
|
||||
|
||||
|
||||
# ==================== 便捷函数 ====================
|
||||
|
||||
async def prepare_practice_knowledge(
|
||||
db: AsyncSession,
|
||||
course_id: int
|
||||
) -> PracticeSceneResult:
|
||||
"""
|
||||
准备陪练所需的知识内容(便捷函数)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
|
||||
Returns:
|
||||
PracticeSceneResult 结果对象
|
||||
"""
|
||||
return await practice_scene_service.prepare_practice_knowledge(db, course_id)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
57
backend/app/services/ai/prompts/__init__.py
Normal file
57
backend/app/services/ai/prompts/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
提示词模板模块
|
||||
|
||||
遵循瑞小美提示词规范
|
||||
"""
|
||||
|
||||
from .knowledge_analysis_prompts import (
|
||||
PROMPT_META as KNOWLEDGE_ANALYSIS_PROMPT_META,
|
||||
SYSTEM_PROMPT as KNOWLEDGE_ANALYSIS_SYSTEM_PROMPT,
|
||||
USER_PROMPT as KNOWLEDGE_ANALYSIS_USER_PROMPT,
|
||||
KNOWLEDGE_POINT_SCHEMA,
|
||||
)
|
||||
|
||||
from .exam_generator_prompts import (
|
||||
PROMPT_META as EXAM_GENERATOR_PROMPT_META,
|
||||
SYSTEM_PROMPT as EXAM_GENERATOR_SYSTEM_PROMPT,
|
||||
USER_PROMPT as EXAM_GENERATOR_USER_PROMPT,
|
||||
MISTAKE_REGEN_SYSTEM_PROMPT,
|
||||
MISTAKE_REGEN_USER_PROMPT,
|
||||
QUESTION_SCHEMA,
|
||||
QUESTION_TYPES,
|
||||
DEFAULT_QUESTION_COUNTS,
|
||||
DEFAULT_DIFFICULTY_LEVEL,
|
||||
)
|
||||
|
||||
from .ability_analysis_prompts import (
|
||||
PROMPT_META as ABILITY_ANALYSIS_PROMPT_META,
|
||||
SYSTEM_PROMPT as ABILITY_ANALYSIS_SYSTEM_PROMPT,
|
||||
USER_PROMPT as ABILITY_ANALYSIS_USER_PROMPT,
|
||||
ABILITY_ANALYSIS_SCHEMA,
|
||||
ABILITY_DIMENSIONS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Knowledge Analysis Prompts
|
||||
"KNOWLEDGE_ANALYSIS_PROMPT_META",
|
||||
"KNOWLEDGE_ANALYSIS_SYSTEM_PROMPT",
|
||||
"KNOWLEDGE_ANALYSIS_USER_PROMPT",
|
||||
"KNOWLEDGE_POINT_SCHEMA",
|
||||
# Exam Generator Prompts
|
||||
"EXAM_GENERATOR_PROMPT_META",
|
||||
"EXAM_GENERATOR_SYSTEM_PROMPT",
|
||||
"EXAM_GENERATOR_USER_PROMPT",
|
||||
"MISTAKE_REGEN_SYSTEM_PROMPT",
|
||||
"MISTAKE_REGEN_USER_PROMPT",
|
||||
"QUESTION_SCHEMA",
|
||||
"QUESTION_TYPES",
|
||||
"DEFAULT_QUESTION_COUNTS",
|
||||
"DEFAULT_DIFFICULTY_LEVEL",
|
||||
# Ability Analysis Prompts
|
||||
"ABILITY_ANALYSIS_PROMPT_META",
|
||||
"ABILITY_ANALYSIS_SYSTEM_PROMPT",
|
||||
"ABILITY_ANALYSIS_USER_PROMPT",
|
||||
"ABILITY_ANALYSIS_SCHEMA",
|
||||
"ABILITY_DIMENSIONS",
|
||||
]
|
||||
|
||||
215
backend/app/services/ai/prompts/ability_analysis_prompts.py
Normal file
215
backend/app/services/ai/prompts/ability_analysis_prompts.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
智能工牌能力分析与课程推荐提示词模板
|
||||
|
||||
功能:分析员工与顾客的对话记录,评估能力维度得分,并推荐适合的课程
|
||||
"""
|
||||
|
||||
# ==================== 元数据 ====================
|
||||
|
||||
PROMPT_META = {
|
||||
"name": "ability_analysis",
|
||||
"display_name": "智能工牌能力分析",
|
||||
"description": "分析员工与顾客对话,评估多维度能力得分,推荐个性化课程",
|
||||
"module": "kaopeilian",
|
||||
"variables": ["dialogue_history", "user_info", "courses"],
|
||||
"version": "1.0.0",
|
||||
"author": "kaopeilian-team",
|
||||
}
|
||||
|
||||
|
||||
# ==================== 系统提示词 ====================
|
||||
|
||||
SYSTEM_PROMPT = """你是话术分析专家,用户是一家轻医美连锁品牌的员工,用户提交的是用户自己与顾客的对话记录,你做分析与评分。并严格按照以下格式输出。并根据课程列表,为该用户提供选课建议。
|
||||
|
||||
输出标准:
|
||||
{
|
||||
"analysis": {
|
||||
"total_score": 82,
|
||||
"ability_dimensions": [
|
||||
{
|
||||
"name": "专业知识",
|
||||
"score": 88,
|
||||
"feedback": "产品知识扎实,能准确回答客户问题。建议:继续深化对新产品的了解。"
|
||||
},
|
||||
{
|
||||
"name": "沟通技巧",
|
||||
"score": 92,
|
||||
"feedback": "语言表达清晰流畅,善于倾听客户需求。建议:可以多使用开放式问题引导。"
|
||||
},
|
||||
{
|
||||
"name": "操作技能",
|
||||
"score": 85,
|
||||
"feedback": "基本操作熟练,流程规范。建议:提升复杂场景的应对速度。"
|
||||
},
|
||||
{
|
||||
"name": "客户服务",
|
||||
"score": 90,
|
||||
"feedback": "服务态度优秀,客户体验良好。建议:进一步提升个性化服务能力。"
|
||||
},
|
||||
{
|
||||
"name": "安全意识",
|
||||
"score": 79,
|
||||
"feedback": "基本安全规范掌握,但在细节提醒上还可加强。"
|
||||
},
|
||||
{
|
||||
"name": "应变能力",
|
||||
"score": 76,
|
||||
"feedback": "面对突发情况反应较快,但处理方式可以更灵活多样。"
|
||||
}
|
||||
],
|
||||
"course_recommendations": [
|
||||
{
|
||||
"course_id": 5,
|
||||
"course_name": "应变能力提升训练营",
|
||||
"recommendation_reason": "该课程专注于提升应变能力,包含大量实战案例分析和模拟演练,针对您当前的薄弱环节(应变能力76分)设计。通过学习可提升15分左右。",
|
||||
"priority": "high",
|
||||
"match_score": 95
|
||||
},
|
||||
{
|
||||
"course_id": 3,
|
||||
"course_name": "安全规范与操作标准",
|
||||
"recommendation_reason": "系统讲解安全规范和操作标准,通过案例教学帮助建立安全意识。当前您的安全意识得分为79分,通过本课程学习预计可提升12分。",
|
||||
"priority": "high",
|
||||
"match_score": 88
|
||||
},
|
||||
{
|
||||
"course_id": 7,
|
||||
"course_name": "高级销售技巧",
|
||||
"recommendation_reason": "进阶课程,帮助您将已有的沟通优势(92分)转化为更高级的销售技能,进一步巩固客户服务能力(90分)。",
|
||||
"priority": "medium",
|
||||
"match_score": 82
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
## 输出要求(严格执行)
|
||||
1. 直接输出纯净的 JSON,不要包含 Markdown 标记(如 ```json)
|
||||
2. 不要包含任何解释性文字
|
||||
3. 能力维度必须包含:专业知识、沟通技巧、操作技能、客户服务、安全意识、应变能力
|
||||
4. 课程推荐必须来自提供的课程列表,使用真实的 course_id
|
||||
5. 推荐课程数量:1-5个,优先推荐能补齐短板的课程
|
||||
6. priority 取值:high(得分<80的薄弱项)、medium(得分80-85)、low(锦上添花)
|
||||
|
||||
## 评分标准
|
||||
- 90-100:优秀
|
||||
- 80-89:良好
|
||||
- 70-79:一般
|
||||
- 60-69:需改进
|
||||
- <60:亟需提升"""
|
||||
|
||||
|
||||
# ==================== 用户提示词模板 ====================
|
||||
|
||||
USER_PROMPT = """对话记录:{dialogue_history}
|
||||
|
||||
---
|
||||
|
||||
用户的信息和岗位:{user_info}
|
||||
|
||||
---
|
||||
|
||||
所有可选课程:{courses}"""
|
||||
|
||||
|
||||
# ==================== JSON Schema ====================
|
||||
|
||||
ABILITY_ANALYSIS_SCHEMA = {
|
||||
"type": "object",
|
||||
"required": ["analysis"],
|
||||
"properties": {
|
||||
"analysis": {
|
||||
"type": "object",
|
||||
"required": ["total_score", "ability_dimensions", "course_recommendations"],
|
||||
"properties": {
|
||||
"total_score": {
|
||||
"type": "number",
|
||||
"description": "总体评分(0-100)",
|
||||
"minimum": 0,
|
||||
"maximum": 100
|
||||
},
|
||||
"ability_dimensions": {
|
||||
"type": "array",
|
||||
"description": "能力维度评分列表",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["name", "score", "feedback"],
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "能力维度名称"
|
||||
},
|
||||
"score": {
|
||||
"type": "number",
|
||||
"description": "该维度得分(0-100)",
|
||||
"minimum": 0,
|
||||
"maximum": 100
|
||||
},
|
||||
"feedback": {
|
||||
"type": "string",
|
||||
"description": "该维度的反馈和建议"
|
||||
}
|
||||
}
|
||||
},
|
||||
"minItems": 1
|
||||
},
|
||||
"course_recommendations": {
|
||||
"type": "array",
|
||||
"description": "课程推荐列表",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["course_id", "course_name", "recommendation_reason", "priority", "match_score"],
|
||||
"properties": {
|
||||
"course_id": {
|
||||
"type": "integer",
|
||||
"description": "课程ID"
|
||||
},
|
||||
"course_name": {
|
||||
"type": "string",
|
||||
"description": "课程名称"
|
||||
},
|
||||
"recommendation_reason": {
|
||||
"type": "string",
|
||||
"description": "推荐理由"
|
||||
},
|
||||
"priority": {
|
||||
"type": "string",
|
||||
"description": "推荐优先级",
|
||||
"enum": ["high", "medium", "low"]
|
||||
},
|
||||
"match_score": {
|
||||
"type": "number",
|
||||
"description": "匹配度得分(0-100)",
|
||||
"minimum": 0,
|
||||
"maximum": 100
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 能力维度常量 ====================
|
||||
|
||||
ABILITY_DIMENSIONS = [
|
||||
"专业知识",
|
||||
"沟通技巧",
|
||||
"操作技能",
|
||||
"客户服务",
|
||||
"安全意识",
|
||||
"应变能力",
|
||||
]
|
||||
|
||||
PRIORITY_LEVELS = ["high", "medium", "low"]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
48
backend/app/services/ai/prompts/answer_judge_prompts.py
Normal file
48
backend/app/services/ai/prompts/answer_judge_prompts.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
答案判断器提示词模板
|
||||
|
||||
功能:判断填空题与问答题是否回答正确
|
||||
"""
|
||||
|
||||
# ==================== 元数据 ====================
|
||||
|
||||
PROMPT_META = {
|
||||
"name": "answer_judge",
|
||||
"display_name": "答案判断器",
|
||||
"description": "判断填空题与问答题的答案是否正确",
|
||||
"module": "kaopeilian",
|
||||
"variables": ["question", "correct_answer", "user_answer", "analysis"],
|
||||
"version": "1.0.0",
|
||||
"author": "kaopeilian-team",
|
||||
}
|
||||
|
||||
|
||||
# ==================== 系统提示词 ====================
|
||||
|
||||
SYSTEM_PROMPT = """你是一个答案判断器,根据用户提交的答案,比对题目、答案、解析。给出正确或错误的判断。
|
||||
|
||||
注意:仅输出"正确"或"错误",无需更多字符和说明。"""
|
||||
|
||||
|
||||
# ==================== 用户提示词模板 ====================
|
||||
|
||||
USER_PROMPT = """题目:{question}。
|
||||
正确答案:{correct_answer}。
|
||||
解析:{analysis}。
|
||||
|
||||
考生的回答:{user_answer}。"""
|
||||
|
||||
|
||||
# ==================== 判断结果常量 ====================
|
||||
|
||||
CORRECT_KEYWORDS = ["正确", "correct", "true", "yes", "对", "是"]
|
||||
INCORRECT_KEYWORDS = ["错误", "incorrect", "false", "no", "wrong", "不正确", "错"]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
74
backend/app/services/ai/prompts/course_chat_prompts.py
Normal file
74
backend/app/services/ai/prompts/course_chat_prompts.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
课程对话提示词模板
|
||||
|
||||
功能:基于课程知识点进行智能问答
|
||||
"""
|
||||
|
||||
# ==================== 元数据 ====================
|
||||
|
||||
PROMPT_META = {
|
||||
"name": "course_chat",
|
||||
"display_name": "与课程对话",
|
||||
"description": "基于课程知识点内容,为用户提供智能问答服务",
|
||||
"module": "kaopeilian",
|
||||
"variables": ["knowledge_base", "query"],
|
||||
"version": "2.0.0",
|
||||
"author": "kaopeilian-team",
|
||||
}
|
||||
|
||||
|
||||
# ==================== 系统提示词 ====================
|
||||
|
||||
SYSTEM_PROMPT = """你是知识拆解专家,精通以下知识库(课程)内容。请根据用户的问题,从知识库中找到最相关的信息,进行深入分析后,用简洁清晰的语言回答用户。为用户提供与课程对话的服务。
|
||||
|
||||
回答要求:
|
||||
|
||||
1. 直接针对问题核心,避免冗长铺垫
|
||||
2. 使用通俗易懂的语言,必要时举例说明
|
||||
3. 突出关键要点,帮助用户快速理解
|
||||
4. 如果知识库中没有相关内容,请如实告知
|
||||
|
||||
知识库:
|
||||
{knowledge_base}"""
|
||||
|
||||
|
||||
# ==================== 用户提示词模板 ====================
|
||||
|
||||
USER_PROMPT = """{query}"""
|
||||
|
||||
|
||||
# ==================== 知识库格式模板 ====================
|
||||
|
||||
KNOWLEDGE_ITEM_TEMPLATE = """【{name}】
|
||||
{description}
|
||||
"""
|
||||
|
||||
|
||||
# ==================== 配置常量 ====================
|
||||
|
||||
# 会话历史窗口大小(保留最近 N 轮对话)
|
||||
CONVERSATION_WINDOW_SIZE = 10
|
||||
|
||||
# 会话 TTL(秒)- 30 分钟
|
||||
CONVERSATION_TTL = 1800
|
||||
|
||||
# 最大知识点数量
|
||||
MAX_KNOWLEDGE_POINTS = 50
|
||||
|
||||
# 知识库最大字符数
|
||||
MAX_KNOWLEDGE_BASE_LENGTH = 50000
|
||||
|
||||
# 默认模型
|
||||
DEFAULT_CHAT_MODEL = "gemini-3-flash-preview"
|
||||
|
||||
# 温度参数(对话场景使用较高温度)
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
300
backend/app/services/ai/prompts/exam_generator_prompts.py
Normal file
300
backend/app/services/ai/prompts/exam_generator_prompts.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
试题生成器提示词模板
|
||||
|
||||
功能:根据岗位和知识点动态生成考试题目
|
||||
"""
|
||||
|
||||
# ==================== 元数据 ====================
|
||||
|
||||
PROMPT_META = {
|
||||
"name": "exam_generator",
|
||||
"display_name": "试题生成器",
|
||||
"description": "根据课程知识点和岗位特征,动态生成考试题目(单选、多选、判断、填空、问答)",
|
||||
"module": "kaopeilian",
|
||||
"variables": [
|
||||
"total_count",
|
||||
"single_choice_count",
|
||||
"multiple_choice_count",
|
||||
"true_false_count",
|
||||
"fill_blank_count",
|
||||
"essay_count",
|
||||
"difficulty_level",
|
||||
"position_info",
|
||||
"knowledge_points",
|
||||
],
|
||||
"version": "2.0.0",
|
||||
"author": "kaopeilian-team",
|
||||
}
|
||||
|
||||
|
||||
# ==================== 系统提示词(第一轮出题) ====================
|
||||
|
||||
SYSTEM_PROMPT = """## 角色
|
||||
你是一位经验丰富的考试出题专家,能够依据用户提供的知识内容,结合用户的岗位特征,随机地生成{total_count}题考题。你会以专业、严谨且清晰的方式出题。
|
||||
|
||||
## 输出{single_choice_count}道单选题
|
||||
1、每道题目只能有 1 个正确答案。
|
||||
2、干扰项要具有合理性和迷惑性,且所有选项必须与主题相关。
|
||||
3、答案解析要简明扼要,说明选择理由。
|
||||
4、为每道题记录出题来源的知识点 id。
|
||||
5、请以 JSON 格式输出。
|
||||
6、为每道题输出一个序号。
|
||||
|
||||
### 输出结构:
|
||||
{{
|
||||
"num": "题号",
|
||||
"type": "single_choice",
|
||||
"topic": {{
|
||||
"title": "清晰完整的题目描述",
|
||||
"options": {{
|
||||
"opt1": "A:符合语境的选项",
|
||||
"opt2": "B:符合语境的选项",
|
||||
"opt3": "C:符合语境的选项",
|
||||
"opt4": "D:符合语境的选项"
|
||||
}}
|
||||
}},
|
||||
"knowledge_point_id": "出题来源知识点的id",
|
||||
"correct": "其中一个选项的全部原文",
|
||||
"analysis": "准确的答案解析,包含选择原因和知识点说明"
|
||||
}}
|
||||
|
||||
- 严格按照以上格式输出
|
||||
|
||||
## 输出{multiple_choice_count}道多选题
|
||||
1、每道题目有多个正确答案。
|
||||
2、"type": "multiple_choice"
|
||||
3、其它事项同单选题。
|
||||
|
||||
## 输出{true_false_count}道判断题
|
||||
1、每道题目只有 "正确" 或 "错误" 两种答案。
|
||||
2、题目表述应明确清晰,避免歧义。
|
||||
3、题目应直接陈述事实或观点,便于做出是非判断。
|
||||
4、其它事项同单选题。
|
||||
|
||||
### 输出结构:
|
||||
{{
|
||||
"num": "题号",
|
||||
"type": "true_false",
|
||||
"topic": {{
|
||||
"title": "清晰完整的题目描述"
|
||||
}},
|
||||
"knowledge_point_id": " 出题来源知识点的id",
|
||||
"correct": "正确",
|
||||
"analysis": "准确的答案解析,包含判断原因和知识点说明"
|
||||
}}
|
||||
|
||||
- 严格按照以上格式输出
|
||||
|
||||
## 输出{fill_blank_count}道填空题
|
||||
1. 题干应明确完整,空缺处需用横线"___"标示,且只能有一处空缺
|
||||
2. 答案应唯一且明确,避免开放性表述
|
||||
3. 空缺长度应与答案长度大致匹配
|
||||
4. 解析需说明答案依据及相关知识点
|
||||
5. 其余要求与单选题一致
|
||||
|
||||
### 输出结构:
|
||||
{{
|
||||
"num": "题号",
|
||||
"type": "fill_blank",
|
||||
"topic": {{
|
||||
"title": "包含___空缺的题目描述"
|
||||
}},
|
||||
"knowledge_point_id": "出题来源知识点的id",
|
||||
"correct": "准确的填空答案",
|
||||
"analysis": "解析答案的依据和相关知识点说明"
|
||||
}}
|
||||
|
||||
- 严格按照以上格式输出
|
||||
|
||||
### 输出{essay_count}道问答题
|
||||
1. 问题应具体明确,限定回答范围
|
||||
2. 答案需条理清晰,突出核心要点
|
||||
3. 解析可补充扩展说明或评分要点
|
||||
4. 避免过于宽泛或需要主观发挥的问题
|
||||
5. 其余要求同单选题
|
||||
|
||||
### 输出结构:
|
||||
{{
|
||||
"num": "题号",
|
||||
"type": "essay",
|
||||
"topic": {{
|
||||
"title": "需要详细回答的问题描述"
|
||||
}},
|
||||
"knowledge_point_id": "出题来源知识点的id",
|
||||
"correct": "完整准确的参考答案(分点或连贯表述)",
|
||||
"analysis": "对答案的补充说明、评分要点或相关知识点扩展"
|
||||
}}
|
||||
|
||||
## 特殊要求
|
||||
1. 题目难度:{difficulty_level}级(5 级为最难)
|
||||
2. 避免使用模棱两可的表述
|
||||
3. 选项内容要互斥,不能有重叠
|
||||
4. 每个选项长度尽量均衡
|
||||
5. 正确答案(A、B、C、D)分布要合理,避免规律性
|
||||
6. 正确答案必须使用其中一个选项中的全部原文,严禁修改
|
||||
7. knowledge_point_id 必须是唯一的,即每道题的知识点来源只允许填一个 id。
|
||||
|
||||
## 输出格式要求
|
||||
请直接输出一个纯净的 JSON 数组(Array),不要包含 Markdown 标记(如 ```json),也不要包含任何解释性文字。
|
||||
|
||||
请按以上要求生成题目,确保每道题目质量。"""
|
||||
|
||||
|
||||
# ==================== 用户提示词模板(第一轮出题) ====================
|
||||
|
||||
USER_PROMPT = """# 请针对岗位特征、待出题的知识点内容进行出题。
|
||||
|
||||
## 岗位信息:
|
||||
|
||||
{position_info}
|
||||
|
||||
---
|
||||
|
||||
## 知识点:
|
||||
|
||||
{knowledge_points}"""
|
||||
|
||||
|
||||
# ==================== 错题重出系统提示词 ====================
|
||||
|
||||
MISTAKE_REGEN_SYSTEM_PROMPT = """## 角色
|
||||
你是一位经验丰富的考试出题专家,能够依据用户提供的错题记录,重新为用户出题。你会为每道错题重新出一题,你会以专业、严谨且清晰的方式出题。
|
||||
|
||||
## 输出单选题
|
||||
1、每道题目只能有 1 个正确答案。
|
||||
2、干扰项要具有合理性和迷惑性,且所有选项必须与主题相关。
|
||||
3、答案解析要简明扼要,说明选择理由。
|
||||
4、为每道题记录出题来源的知识点 id。
|
||||
5、请以 JSON 格式输出。
|
||||
6、为每道题输出一个序号。
|
||||
|
||||
### 输出结构:
|
||||
{{
|
||||
"num": "题号",
|
||||
"type": "single_choice",
|
||||
"topic": {{
|
||||
"title": "清晰完整的题目描述",
|
||||
"options": {{
|
||||
"opt1": "A:符合语境的选项",
|
||||
"opt2": "B:符合语境的选项",
|
||||
"opt3": "C:符合语境的选项",
|
||||
"opt4": "D:符合语境的选项"
|
||||
}}
|
||||
}},
|
||||
"knowledge_point_id": "出题来源知识点的id",
|
||||
"correct": "其中一个选项的全部原文",
|
||||
"analysis": "准确的答案解析,包含选择原因和知识点说明"
|
||||
}}
|
||||
|
||||
- 严格按照以上格式输出
|
||||
|
||||
|
||||
## 特殊要求
|
||||
1. 题目难度:{difficulty_level}级(5 级为最难)
|
||||
2. 避免使用模棱两可的表述
|
||||
3. 选项内容要互斥,不能有重叠
|
||||
4. 每个选项长度尽量均衡
|
||||
5. 正确答案(A、B、C、D)分布要合理,避免规律性
|
||||
6. 正确答案必须使用其中一个选项中的全部原文,严禁修改
|
||||
7. knowledge_point_id 必须是唯一的,即每道题的知识点来源只允许填一个 id。
|
||||
|
||||
## 输出格式要求
|
||||
请直接输出一个纯净的 JSON 数组(Array),不要包含 Markdown 标记(如 ```json),也不要包含任何解释性文字。
|
||||
|
||||
请按以上要求生成题目,确保每道题目质量。"""
|
||||
|
||||
|
||||
# ==================== 错题重出用户提示词 ====================
|
||||
|
||||
MISTAKE_REGEN_USER_PROMPT = """## 错题记录:
|
||||
|
||||
{mistake_records}"""
|
||||
|
||||
|
||||
# ==================== JSON Schema ====================
|
||||
|
||||
QUESTION_SCHEMA = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["num", "type", "topic", "correct"],
|
||||
"properties": {
|
||||
"num": {
|
||||
"oneOf": [
|
||||
{"type": "integer"},
|
||||
{"type": "string"}
|
||||
],
|
||||
"description": "题号"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["single_choice", "multiple_choice", "true_false", "fill_blank", "essay"],
|
||||
"description": "题目类型"
|
||||
},
|
||||
"topic": {
|
||||
"type": "object",
|
||||
"required": ["title"],
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "题目标题"
|
||||
},
|
||||
"options": {
|
||||
"type": "object",
|
||||
"description": "选项(选择题必填)"
|
||||
}
|
||||
}
|
||||
},
|
||||
"knowledge_point_id": {
|
||||
"oneOf": [
|
||||
{"type": "integer"},
|
||||
{"type": "string"},
|
||||
{"type": "null"}
|
||||
],
|
||||
"description": "知识点ID"
|
||||
},
|
||||
"correct": {
|
||||
"type": "string",
|
||||
"description": "正确答案"
|
||||
},
|
||||
"analysis": {
|
||||
"type": "string",
|
||||
"description": "答案解析"
|
||||
}
|
||||
}
|
||||
},
|
||||
"minItems": 1,
|
||||
"maxItems": 50
|
||||
}
|
||||
|
||||
|
||||
# ==================== 题目类型常量 ====================
|
||||
|
||||
QUESTION_TYPES = {
|
||||
"single_choice": "单选题",
|
||||
"multiple_choice": "多选题",
|
||||
"true_false": "判断题",
|
||||
"fill_blank": "填空题",
|
||||
"essay": "问答题",
|
||||
}
|
||||
|
||||
# 默认题目数量配置
|
||||
DEFAULT_QUESTION_COUNTS = {
|
||||
"single_choice_count": 4,
|
||||
"multiple_choice_count": 2,
|
||||
"true_false_count": 1,
|
||||
"fill_blank_count": 2,
|
||||
"essay_count": 1,
|
||||
}
|
||||
|
||||
DEFAULT_DIFFICULTY_LEVEL = 3
|
||||
MAX_DIFFICULTY_LEVEL = 5
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
148
backend/app/services/ai/prompts/knowledge_analysis_prompts.py
Normal file
148
backend/app/services/ai/prompts/knowledge_analysis_prompts.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
知识点分析提示词模板
|
||||
|
||||
功能:从课程资料中提取知识点
|
||||
"""
|
||||
|
||||
# ==================== 元数据 ====================
|
||||
|
||||
PROMPT_META = {
|
||||
"name": "knowledge_analysis",
|
||||
"display_name": "知识点分析",
|
||||
"description": "从课程资料中提取和分析知识点,支持PDF/Word/文本等格式",
|
||||
"module": "kaopeilian",
|
||||
"variables": ["course_name", "content"],
|
||||
"version": "2.0.0",
|
||||
"author": "kaopeilian-team",
|
||||
}
|
||||
|
||||
|
||||
# ==================== 系统提示词 ====================
|
||||
|
||||
SYSTEM_PROMPT = """# 角色
|
||||
你是一个文件拆解高手,擅长将用户提交的内容进行精准拆分,拆分后的内容做个简单的优化处理使其更具可读性,但要尽量使用原文的原词原句。
|
||||
|
||||
## 技能
|
||||
### 技能 1: 内容拆分
|
||||
1. 当用户提交内容后,拆分为多段。
|
||||
2. 对拆分后的内容做简单优化,使其更具可读性,比如去掉奇怪符号(如换行符、乱码),若语句不通顺,或格式原因导致错位,则重新表达。用户可能会提交录音转文字的内容,因此可能是有错字的,注意修复这些小瑕疵。
|
||||
3. 优化过程中,尽量使用原文的原词原句,特别是话术类,必须保持原有的句式、保持原词原句,而不是重构。
|
||||
4. 注意是拆分而不是重写,不需要润色,尽量不做任何处理。
|
||||
5. 输出到 content。
|
||||
|
||||
### 技能 2: 为每一个选段概括一个标题
|
||||
1. 为每个拆分出来的选段概括一个标题,并输出到 title。
|
||||
|
||||
### 技能 3: 为每一个选段说明与主题的关联
|
||||
1. 详细说明这一段与全文核心主题的关联,并输出到 topic_relation。
|
||||
|
||||
### 技能 4: 为每一个选段打上一个类型标签
|
||||
1. 用户提交的内容很有可能是一个课程、一篇讲义、一个产品的说明书,通常是用户希望他公司的员工或高管学习的知识。
|
||||
2. 用户通常是医疗美容机构或轻医美、生活美容连锁品牌。
|
||||
3. 你要为每个选段打上一个知识类型的标签,最好是这几个类型中的一个:"理论知识", "诊断设计", "操作步骤", "沟通话术", "案例分析", "注意事项", "技巧方法", "客诉处理"。当然你也可以为这个选段匹配一个更适合的。
|
||||
|
||||
## 输出要求(严格按要求输出)
|
||||
请直接输出一个纯净的 JSON 数组(Array),不要包含 Markdown 标记(如 ```json),也不要包含任何解释性文字。格式如下:
|
||||
|
||||
[
|
||||
{
|
||||
"title": "知识点标题",
|
||||
"content": "知识点内容",
|
||||
"topic_relation": "知识点与主题的关系",
|
||||
"type": "知识点类型"
|
||||
},
|
||||
{
|
||||
"title": "第二个知识点标题",
|
||||
"content": "第二个知识点内容...",
|
||||
"topic_relation": "...",
|
||||
"type": "..."
|
||||
}
|
||||
]
|
||||
|
||||
## 限制
|
||||
- 仅围绕用户提交的内容进行拆分和关联标注,不涉及其他无关内容。
|
||||
- 拆分后的内容必须最大程度保持与原文一致。
|
||||
- 关联说明需清晰合理。
|
||||
- 不论如何,不要拆分超过 20 段!"""
|
||||
|
||||
|
||||
# ==================== 用户提示词模板 ====================
|
||||
|
||||
USER_PROMPT = """课程主题:{course_name}
|
||||
|
||||
## 用户提交的内容:
|
||||
|
||||
{content}
|
||||
|
||||
## 注意
|
||||
|
||||
- 以json的格式输出
|
||||
- 不论如何,不要拆分超过20 段!"""
|
||||
|
||||
|
||||
# ==================== JSON Schema ====================
|
||||
|
||||
KNOWLEDGE_POINT_SCHEMA = {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["title", "content", "type"],
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "知识点标题",
|
||||
"maxLength": 200
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "知识点内容"
|
||||
},
|
||||
"topic_relation": {
|
||||
"type": "string",
|
||||
"description": "与主题的关系描述"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "知识点类型",
|
||||
"enum": [
|
||||
"理论知识",
|
||||
"诊断设计",
|
||||
"操作步骤",
|
||||
"沟通话术",
|
||||
"案例分析",
|
||||
"注意事项",
|
||||
"技巧方法",
|
||||
"客诉处理",
|
||||
"其他"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"minItems": 1,
|
||||
"maxItems": 20
|
||||
}
|
||||
|
||||
|
||||
# ==================== 知识点类型常量 ====================
|
||||
|
||||
KNOWLEDGE_POINT_TYPES = [
|
||||
"理论知识",
|
||||
"诊断设计",
|
||||
"操作步骤",
|
||||
"沟通话术",
|
||||
"案例分析",
|
||||
"注意事项",
|
||||
"技巧方法",
|
||||
"客诉处理",
|
||||
]
|
||||
|
||||
DEFAULT_KNOWLEDGE_TYPE = "理论知识"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
193
backend/app/services/ai/prompts/practice_analysis_prompts.py
Normal file
193
backend/app/services/ai/prompts/practice_analysis_prompts.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
陪练分析报告提示词模板
|
||||
|
||||
功能:分析陪练对话,生成综合评分和改进建议
|
||||
"""
|
||||
|
||||
# ==================== 元数据 ====================
|
||||
|
||||
PROMPT_META = {
|
||||
"name": "practice_analysis",
|
||||
"display_name": "陪练分析报告",
|
||||
"description": "分析陪练对话,生成综合评分、能力维度评估、对话标注和改进建议",
|
||||
"module": "kaopeilian",
|
||||
"variables": ["dialogue_history"],
|
||||
"version": "1.0.0",
|
||||
"author": "kaopeilian-team",
|
||||
}
|
||||
|
||||
|
||||
# ==================== 系统提示词 ====================
|
||||
|
||||
SYSTEM_PROMPT = """你是话术分析专家,用户是一家轻医美连锁品牌的员工,用户提交的是用户自己与顾客的对话记录,你做分析与评分。并严格按照以下格式输出。
|
||||
|
||||
输出标准:
|
||||
{
|
||||
"analysis": {
|
||||
"total_score": 88,
|
||||
"score_breakdown": [
|
||||
{"name": "开场技巧", "score": 92, "description": "开场自然,快速建立信任"},
|
||||
{"name": "需求挖掘", "score": 90, "description": "能够有效识别客户需求"},
|
||||
{"name": "产品介绍", "score": 88, "description": "产品介绍清晰,重点突出"},
|
||||
{"name": "异议处理", "score": 85, "description": "处理客户异议还需加强"},
|
||||
{"name": "成交技巧", "score": 86, "description": "成交话术运用良好"}
|
||||
],
|
||||
"ability_dimensions": [
|
||||
{"name": "沟通表达", "score": 90, "feedback": "语言流畅,表达清晰,语调富有亲和力"},
|
||||
{"name": "倾听理解", "score": 92, "feedback": "能够准确理解客户意图,给予恰当回应"},
|
||||
{"name": "情绪控制", "score": 88, "feedback": "整体情绪稳定,面对异议时保持专业"},
|
||||
{"name": "专业知识", "score": 93, "feedback": "对医美项目知识掌握扎实"},
|
||||
{"name": "销售技巧", "score": 87, "feedback": "销售流程把控良好"},
|
||||
{"name": "应变能力", "score": 85, "feedback": "面对突发问题能够快速反应"}
|
||||
],
|
||||
"dialogue_annotations": [
|
||||
{"sequence": 1, "tags": ["亮点话术"], "comment": "开场专业,身份介绍清晰"},
|
||||
{"sequence": 3, "tags": ["金牌话术"], "comment": "巧妙引导,从客户角度出发"},
|
||||
{"sequence": 5, "tags": ["亮点话术"], "comment": "类比生动,让客户容易理解"},
|
||||
{"sequence": 7, "tags": ["金牌话术"], "comment": "专业解答,打消客户疑虑"}
|
||||
],
|
||||
"suggestions": [
|
||||
{"title": "控制语速", "content": "您的语速偏快,建议适当放慢,给客户更多思考时间", "example": "说完产品优势后,停顿2-3秒,观察客户反应"},
|
||||
{"title": "多用开放式问题", "content": "增加开放式问题的使用,更深入了解客户需求", "example": "您对未来的保障有什么期望?而不是您需要保险吗?"},
|
||||
{"title": "强化成交信号识别", "content": "客户已经表现出兴趣时,要及时推进成交", "example": "当客户问费用多少时,这是购买信号,应该立即报价并促成"}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
## 输出要求(严格执行)
|
||||
1. 直接输出纯净的 JSON,不要包含 Markdown 标记(如 ```json)
|
||||
2. 不要包含任何解释性文字
|
||||
3. score_breakdown 必须包含 5 项:开场技巧、需求挖掘、产品介绍、异议处理、成交技巧
|
||||
4. ability_dimensions 必须包含 6 项:沟通表达、倾听理解、情绪控制、专业知识、销售技巧、应变能力
|
||||
5. dialogue_annotations 标注有亮点或问题的对话轮次,tags 可选:亮点话术、金牌话术、待改进、问题话术
|
||||
6. suggestions 提供 2-4 条具体可操作的改进建议,每条包含 title、content、example
|
||||
|
||||
## 评分标准
|
||||
- 90-100:优秀
|
||||
- 80-89:良好
|
||||
- 70-79:一般
|
||||
- 60-69:需改进
|
||||
- <60:亟需提升"""
|
||||
|
||||
|
||||
# ==================== 用户提示词模板 ====================
|
||||
|
||||
USER_PROMPT = """{dialogue_history}"""
|
||||
|
||||
|
||||
# ==================== JSON Schema ====================
|
||||
|
||||
PRACTICE_ANALYSIS_SCHEMA = {
|
||||
"type": "object",
|
||||
"required": ["analysis"],
|
||||
"properties": {
|
||||
"analysis": {
|
||||
"type": "object",
|
||||
"required": ["total_score", "score_breakdown", "ability_dimensions", "dialogue_annotations", "suggestions"],
|
||||
"properties": {
|
||||
"total_score": {
|
||||
"type": "number",
|
||||
"description": "总体评分(0-100)",
|
||||
"minimum": 0,
|
||||
"maximum": 100
|
||||
},
|
||||
"score_breakdown": {
|
||||
"type": "array",
|
||||
"description": "分数细分(5项)",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["name", "score", "description"],
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "维度名称"},
|
||||
"score": {"type": "number", "description": "得分(0-100)"},
|
||||
"description": {"type": "string", "description": "评价描述"}
|
||||
}
|
||||
},
|
||||
"minItems": 5
|
||||
},
|
||||
"ability_dimensions": {
|
||||
"type": "array",
|
||||
"description": "能力维度评分(6项)",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["name", "score", "feedback"],
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "能力维度名称"},
|
||||
"score": {"type": "number", "description": "得分(0-100)"},
|
||||
"feedback": {"type": "string", "description": "反馈评语"}
|
||||
}
|
||||
},
|
||||
"minItems": 6
|
||||
},
|
||||
"dialogue_annotations": {
|
||||
"type": "array",
|
||||
"description": "对话标注",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["sequence", "tags", "comment"],
|
||||
"properties": {
|
||||
"sequence": {"type": "integer", "description": "对话轮次序号"},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"description": "标签列表",
|
||||
"items": {"type": "string"}
|
||||
},
|
||||
"comment": {"type": "string", "description": "点评内容"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"suggestions": {
|
||||
"type": "array",
|
||||
"description": "改进建议",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["title", "content", "example"],
|
||||
"properties": {
|
||||
"title": {"type": "string", "description": "建议标题"},
|
||||
"content": {"type": "string", "description": "建议内容"},
|
||||
"example": {"type": "string", "description": "示例"}
|
||||
}
|
||||
},
|
||||
"minItems": 2,
|
||||
"maxItems": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 常量定义 ====================
|
||||
|
||||
SCORE_BREAKDOWN_ITEMS = [
|
||||
"开场技巧",
|
||||
"需求挖掘",
|
||||
"产品介绍",
|
||||
"异议处理",
|
||||
"成交技巧",
|
||||
]
|
||||
|
||||
ABILITY_DIMENSIONS = [
|
||||
"沟通表达",
|
||||
"倾听理解",
|
||||
"情绪控制",
|
||||
"专业知识",
|
||||
"销售技巧",
|
||||
"应变能力",
|
||||
]
|
||||
|
||||
ANNOTATION_TAGS = [
|
||||
"亮点话术",
|
||||
"金牌话术",
|
||||
"待改进",
|
||||
"问题话术",
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
192
backend/app/services/ai/prompts/practice_scene_prompts.py
Normal file
192
backend/app/services/ai/prompts/practice_scene_prompts.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
陪练场景生成提示词模板
|
||||
|
||||
功能:根据课程知识点生成陪练场景配置
|
||||
"""
|
||||
|
||||
# ==================== 元数据 ====================
|
||||
|
||||
PROMPT_META = {
|
||||
"name": "practice_scene_generation",
|
||||
"display_name": "陪练场景生成",
|
||||
"description": "根据课程知识点生成 AI 陪练场景配置,包含场景名称、背景、AI 角色、练习目标等",
|
||||
"module": "kaopeilian",
|
||||
"variables": ["knowledge_points"],
|
||||
"version": "1.0.0",
|
||||
"author": "kaopeilian-team",
|
||||
}
|
||||
|
||||
|
||||
# ==================== 系统提示词 ====================
|
||||
|
||||
SYSTEM_PROMPT = """你是一个训练场景研究专家,能将用户提交的知识点,转变为一个模拟陪练的场景,并严格按照以下格式输出。
|
||||
|
||||
输出标准:
|
||||
|
||||
{
|
||||
"scene": {
|
||||
"name": "轻医美产品咨询陪练",
|
||||
"description": "模拟客户咨询轻医美产品的场景",
|
||||
"background": "客户对脸部抗衰项目感兴趣。",
|
||||
"ai_role": "AI扮演一位30岁女性客户",
|
||||
"objectives": ["了解客户需求", "介绍产品优势", "处理价格异议"],
|
||||
"keywords": ["抗衰", "玻尿酸", "价格"],
|
||||
"type": "product-intro",
|
||||
"difficulty": "intermediate"
|
||||
}
|
||||
}
|
||||
|
||||
## 字段说明
|
||||
|
||||
- **name**: 场景名称,简洁明了,体现陪练主题
|
||||
- **description**: 场景描述,说明这是什么样的模拟场景
|
||||
- **background**: 场景背景设定,描述客户的情况和需求
|
||||
- **ai_role**: AI 角色描述,说明 AI 扮演什么角色(通常是客户)
|
||||
- **objectives**: 练习目标数组,列出学员需要达成的目标
|
||||
- **keywords**: 关键词数组,从知识点中提取的核心关键词
|
||||
- **type**: 场景类型,可选值:
|
||||
- phone: 电话销售
|
||||
- face: 面对面销售
|
||||
- complaint: 客户投诉
|
||||
- after-sales: 售后服务
|
||||
- product-intro: 产品介绍
|
||||
- **difficulty**: 难度等级,可选值:
|
||||
- beginner: 入门
|
||||
- junior: 初级
|
||||
- intermediate: 中级
|
||||
- senior: 高级
|
||||
- expert: 专家
|
||||
|
||||
## 输出要求
|
||||
|
||||
1. 直接输出纯净的 JSON 对象,不要包含 Markdown 标记(如 ```json)
|
||||
2. 不要包含任何解释性文字
|
||||
3. 根据知识点内容合理设计场景,确保场景与知识点紧密相关
|
||||
4. objectives 至少包含 2-3 个具体可操作的目标
|
||||
5. keywords 提取 3-5 个核心关键词
|
||||
6. 根据知识点的复杂程度选择合适的 difficulty
|
||||
7. 根据知识点的应用场景选择合适的 type"""
|
||||
|
||||
|
||||
# ==================== 用户提示词模板 ====================
|
||||
|
||||
USER_PROMPT = """请根据以下知识点内容,生成一个模拟陪练场景:
|
||||
|
||||
## 知识点内容
|
||||
|
||||
{knowledge_points}
|
||||
|
||||
## 要求
|
||||
|
||||
- 以 JSON 格式输出
|
||||
- 场景要贴合知识点的实际应用场景
|
||||
- AI 角色要符合轻医美行业的客户特征
|
||||
- 练习目标要具体、可评估"""
|
||||
|
||||
|
||||
# ==================== JSON Schema ====================
|
||||
|
||||
PRACTICE_SCENE_SCHEMA = {
|
||||
"type": "object",
|
||||
"required": ["scene"],
|
||||
"properties": {
|
||||
"scene": {
|
||||
"type": "object",
|
||||
"required": ["name", "description", "background", "ai_role", "objectives", "keywords", "type", "difficulty"],
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "场景名称",
|
||||
"maxLength": 100
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "场景描述",
|
||||
"maxLength": 500
|
||||
},
|
||||
"background": {
|
||||
"type": "string",
|
||||
"description": "场景背景设定",
|
||||
"maxLength": 500
|
||||
},
|
||||
"ai_role": {
|
||||
"type": "string",
|
||||
"description": "AI 角色描述",
|
||||
"maxLength": 200
|
||||
},
|
||||
"objectives": {
|
||||
"type": "array",
|
||||
"description": "练习目标",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"minItems": 2,
|
||||
"maxItems": 5
|
||||
},
|
||||
"keywords": {
|
||||
"type": "array",
|
||||
"description": "关键词",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"minItems": 2,
|
||||
"maxItems": 8
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "场景类型",
|
||||
"enum": [
|
||||
"phone",
|
||||
"face",
|
||||
"complaint",
|
||||
"after-sales",
|
||||
"product-intro"
|
||||
]
|
||||
},
|
||||
"difficulty": {
|
||||
"type": "string",
|
||||
"description": "难度等级",
|
||||
"enum": [
|
||||
"beginner",
|
||||
"junior",
|
||||
"intermediate",
|
||||
"senior",
|
||||
"expert"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 场景类型常量 ====================
|
||||
|
||||
SCENE_TYPES = {
|
||||
"phone": "电话销售",
|
||||
"face": "面对面销售",
|
||||
"complaint": "客户投诉",
|
||||
"after-sales": "售后服务",
|
||||
"product-intro": "产品介绍",
|
||||
}
|
||||
|
||||
DIFFICULTY_LEVELS = {
|
||||
"beginner": "入门",
|
||||
"junior": "初级",
|
||||
"intermediate": "中级",
|
||||
"senior": "高级",
|
||||
"expert": "专家",
|
||||
}
|
||||
|
||||
# 默认值
|
||||
DEFAULT_SCENE_TYPE = "product-intro"
|
||||
DEFAULT_DIFFICULTY = "intermediate"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
141
backend/app/services/auth_service.py
Normal file
141
backend/app/services/auth_service.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
认证服务
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import UnauthorizedError
|
||||
from app.core.logger import logger
|
||||
from app.core.security import create_access_token, create_refresh_token, decode_token
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import Token
|
||||
from app.services.user_service import UserService
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""认证服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.user_service = UserService(db)
|
||||
|
||||
async def login(self, username: str, password: str) -> tuple[User, Token]:
|
||||
"""
|
||||
用户登录
|
||||
|
||||
Args:
|
||||
username: 用户名/邮箱/手机号
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
用户对象和令牌
|
||||
"""
|
||||
# 验证用户
|
||||
user = await self.user_service.authenticate(
|
||||
username=username, password=password
|
||||
)
|
||||
|
||||
if not user:
|
||||
logger.warning(
|
||||
"登录失败:用户名或密码错误",
|
||||
username=username,
|
||||
)
|
||||
raise UnauthorizedError("用户名或密码错误")
|
||||
|
||||
if not user.is_active:
|
||||
logger.warning(
|
||||
"登录失败:用户已被禁用",
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
)
|
||||
raise UnauthorizedError("用户已被禁用")
|
||||
|
||||
# 生成令牌
|
||||
access_token = create_access_token(subject=user.id)
|
||||
refresh_token = create_refresh_token(subject=user.id)
|
||||
|
||||
# 更新最后登录时间
|
||||
await self.user_service.update_last_login(user.id)
|
||||
|
||||
# 记录日志
|
||||
logger.info(
|
||||
"用户登录成功",
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
role=user.role,
|
||||
)
|
||||
|
||||
return user, Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> Token:
|
||||
"""
|
||||
刷新访问令牌
|
||||
|
||||
Args:
|
||||
refresh_token: 刷新令牌
|
||||
|
||||
Returns:
|
||||
新的令牌
|
||||
"""
|
||||
try:
|
||||
# 解码刷新令牌
|
||||
payload = decode_token(refresh_token)
|
||||
|
||||
# 验证令牌类型
|
||||
if payload.get("type") != "refresh":
|
||||
raise UnauthorizedError("无效的刷新令牌")
|
||||
|
||||
# 获取用户ID
|
||||
user_id = int(payload.get("sub"))
|
||||
|
||||
# 获取用户
|
||||
user = await self.user_service.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UnauthorizedError("用户不存在")
|
||||
|
||||
if not user.is_active:
|
||||
raise UnauthorizedError("用户已被禁用")
|
||||
|
||||
# 生成新的访问令牌
|
||||
access_token = create_access_token(subject=user.id)
|
||||
|
||||
logger.info(
|
||||
"令牌刷新成功",
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
)
|
||||
|
||||
return Token(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token, # 保持原刷新令牌
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"令牌刷新失败",
|
||||
error=str(e),
|
||||
)
|
||||
raise UnauthorizedError("无效的刷新令牌")
|
||||
|
||||
async def logout(self, user_id: int) -> None:
|
||||
"""
|
||||
用户登出
|
||||
|
||||
注意:JWT是无状态的,实际的登出需要在客户端删除令牌
|
||||
这里只是记录日志,如果需要可以将令牌加入黑名单
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
user = await self.user_service.get_by_id(user_id)
|
||||
if user:
|
||||
logger.info(
|
||||
"用户登出",
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
)
|
||||
112
backend/app/services/base_service.py
Normal file
112
backend/app/services/base_service.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""基础服务类"""
|
||||
from typing import TypeVar, Generic, Type, Optional, List, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseService(Generic[ModelType]):
|
||||
"""
|
||||
基础服务类,提供通用的CRUD操作
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[ModelType]):
|
||||
self.model = model
|
||||
|
||||
async def get(self, db: AsyncSession, id: int) -> Optional[ModelType]:
|
||||
"""根据ID获取单个对象"""
|
||||
result = await db.execute(select(self.model).where(self.model.id == id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_id(self, db: AsyncSession, id: int) -> Optional[ModelType]:
|
||||
"""别名:按ID获取对象(兼容旧代码)"""
|
||||
return await self.get(db, id)
|
||||
|
||||
async def get_multi(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
query: Optional[Select] = None,
|
||||
) -> List[ModelType]:
|
||||
"""获取多个对象"""
|
||||
if query is None:
|
||||
query = select(self.model)
|
||||
|
||||
result = await db.execute(query.offset(skip).limit(limit))
|
||||
return result.scalars().all()
|
||||
|
||||
async def count(self, db: AsyncSession, *, query: Optional[Select] = None) -> int:
|
||||
"""统计数量"""
|
||||
if query is None:
|
||||
query = select(func.count()).select_from(self.model)
|
||||
else:
|
||||
query = select(func.count()).select_from(query.subquery())
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalar_one()
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: Any, **kwargs) -> ModelType:
|
||||
"""创建对象"""
|
||||
if hasattr(obj_in, "model_dump"):
|
||||
create_data = obj_in.model_dump()
|
||||
else:
|
||||
create_data = obj_in
|
||||
|
||||
# 合并额外参数
|
||||
create_data.update(kwargs)
|
||||
|
||||
db_obj = self.model(**create_data)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
async def update(
|
||||
self, db: AsyncSession, *, db_obj: ModelType, obj_in: Any, **kwargs
|
||||
) -> ModelType:
|
||||
"""更新对象"""
|
||||
if hasattr(obj_in, "model_dump"):
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
else:
|
||||
update_data = obj_in
|
||||
|
||||
# 合并额外参数(如 updated_by 等审计字段)
|
||||
if kwargs:
|
||||
update_data.update(kwargs)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(db_obj, field, value)
|
||||
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
async def delete(self, db: AsyncSession, *, id: int) -> bool:
|
||||
"""删除对象"""
|
||||
obj = await self.get(db, id)
|
||||
if obj:
|
||||
await db.delete(obj)
|
||||
await db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def soft_delete(self, db: AsyncSession, *, id: int) -> bool:
|
||||
"""软删除对象"""
|
||||
from datetime import datetime
|
||||
|
||||
obj = await self.get(db, id)
|
||||
if obj and hasattr(obj, "is_deleted"):
|
||||
obj.is_deleted = True
|
||||
if hasattr(obj, "deleted_at"):
|
||||
obj.deleted_at = datetime.now()
|
||||
db.add(obj)
|
||||
await db.commit()
|
||||
return True
|
||||
return False
|
||||
137
backend/app/services/course_exam_service.py
Normal file
137
backend/app/services/course_exam_service.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
课程考试设置服务
|
||||
"""
|
||||
from typing import Optional
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logger import get_logger
|
||||
from app.models.course_exam_settings import CourseExamSettings
|
||||
from app.schemas.course import CourseExamSettingsCreate, CourseExamSettingsUpdate
|
||||
from app.services.base_service import BaseService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CourseExamService(BaseService[CourseExamSettings]):
|
||||
"""课程考试设置服务"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(CourseExamSettings)
|
||||
|
||||
async def get_by_course_id(self, db: AsyncSession, course_id: int) -> Optional[CourseExamSettings]:
|
||||
"""
|
||||
根据课程ID获取考试设置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
|
||||
Returns:
|
||||
考试设置实例或None
|
||||
"""
|
||||
stmt = select(CourseExamSettings).where(
|
||||
CourseExamSettings.course_id == course_id,
|
||||
CourseExamSettings.is_deleted == False
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_or_update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
settings_in: CourseExamSettingsCreate,
|
||||
user_id: int
|
||||
) -> CourseExamSettings:
|
||||
"""
|
||||
创建或更新课程考试设置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
settings_in: 考试设置数据
|
||||
user_id: 操作用户ID
|
||||
|
||||
Returns:
|
||||
考试设置实例
|
||||
"""
|
||||
# 检查是否已存在设置
|
||||
existing_settings = await self.get_by_course_id(db, course_id)
|
||||
|
||||
if existing_settings:
|
||||
# 更新现有设置
|
||||
update_data = settings_in.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(existing_settings, field, value)
|
||||
existing_settings.updated_by = user_id
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(existing_settings)
|
||||
|
||||
logger.info(f"更新课程考试设置成功", course_id=course_id, user_id=user_id)
|
||||
return existing_settings
|
||||
else:
|
||||
# 创建新设置
|
||||
create_data = settings_in.model_dump()
|
||||
create_data.update({
|
||||
"course_id": course_id,
|
||||
"created_by": user_id,
|
||||
"updated_by": user_id
|
||||
})
|
||||
|
||||
new_settings = CourseExamSettings(**create_data)
|
||||
db.add(new_settings)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(new_settings)
|
||||
|
||||
logger.info(f"创建课程考试设置成功", course_id=course_id, user_id=user_id)
|
||||
return new_settings
|
||||
|
||||
async def update(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
settings_in: CourseExamSettingsUpdate,
|
||||
user_id: int
|
||||
) -> CourseExamSettings:
|
||||
"""
|
||||
更新课程考试设置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
settings_in: 更新的考试设置数据
|
||||
user_id: 操作用户ID
|
||||
|
||||
Returns:
|
||||
更新后的考试设置实例
|
||||
"""
|
||||
# 获取现有设置
|
||||
settings = await self.get_by_course_id(db, course_id)
|
||||
if not settings:
|
||||
# 如果不存在,创建新的
|
||||
create_data = settings_in.model_dump(exclude_unset=True)
|
||||
return await self.create_or_update(
|
||||
db,
|
||||
course_id=course_id,
|
||||
settings_in=CourseExamSettingsCreate(**create_data),
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# 更新设置
|
||||
update_data = settings_in.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(settings, field, value)
|
||||
settings.updated_by = user_id
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(settings)
|
||||
|
||||
logger.info(f"更新课程考试设置成功", course_id=course_id, user_id=user_id)
|
||||
return settings
|
||||
|
||||
|
||||
# 创建服务实例
|
||||
course_exam_service = CourseExamService()
|
||||
194
backend/app/services/course_position_service.py
Normal file
194
backend/app/services/course_position_service.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
课程岗位分配服务
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import select, and_, delete, func
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logger import get_logger
|
||||
from app.models.position_course import PositionCourse
|
||||
from app.models.position import Position
|
||||
from app.models.position_member import PositionMember
|
||||
from app.schemas.course import CoursePositionAssignment, CoursePositionAssignmentInDB
|
||||
from app.services.base_service import BaseService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CoursePositionService(BaseService[PositionCourse]):
|
||||
"""课程岗位分配服务"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(PositionCourse)
|
||||
|
||||
async def get_course_positions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
course_type: Optional[str] = None
|
||||
) -> List[CoursePositionAssignmentInDB]:
|
||||
"""
|
||||
获取课程的岗位分配列表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
course_type: 课程类型筛选
|
||||
|
||||
Returns:
|
||||
岗位分配列表
|
||||
"""
|
||||
# 构建查询
|
||||
conditions = [
|
||||
PositionCourse.course_id == course_id,
|
||||
PositionCourse.is_deleted == False
|
||||
]
|
||||
|
||||
if course_type:
|
||||
conditions.append(PositionCourse.course_type == course_type)
|
||||
|
||||
stmt = (
|
||||
select(PositionCourse)
|
||||
.options(selectinload(PositionCourse.position))
|
||||
.where(and_(*conditions))
|
||||
.order_by(PositionCourse.priority, PositionCourse.id)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
assignments = result.scalars().all()
|
||||
|
||||
# 转换为返回格式,并查询每个岗位的成员数量
|
||||
result_list = []
|
||||
for assignment in assignments:
|
||||
# 查询岗位成员数量
|
||||
member_count = 0
|
||||
if assignment.position_id:
|
||||
member_count_result = await db.execute(
|
||||
select(func.count(PositionMember.id)).where(
|
||||
and_(
|
||||
PositionMember.position_id == assignment.position_id,
|
||||
PositionMember.is_deleted == False
|
||||
)
|
||||
)
|
||||
)
|
||||
member_count = member_count_result.scalar() or 0
|
||||
|
||||
result_list.append(
|
||||
CoursePositionAssignmentInDB(
|
||||
id=assignment.id,
|
||||
course_id=assignment.course_id,
|
||||
position_id=assignment.position_id,
|
||||
course_type=assignment.course_type,
|
||||
priority=assignment.priority,
|
||||
position_name=assignment.position.name if assignment.position else None,
|
||||
position_description=assignment.position.description if assignment.position else None,
|
||||
member_count=member_count
|
||||
)
|
||||
)
|
||||
|
||||
return result_list
|
||||
|
||||
async def batch_assign_positions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
assignments: List[CoursePositionAssignment],
|
||||
user_id: int
|
||||
) -> List[CoursePositionAssignmentInDB]:
|
||||
"""
|
||||
批量分配课程到岗位
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
assignments: 岗位分配列表
|
||||
user_id: 操作用户ID
|
||||
|
||||
Returns:
|
||||
分配结果列表
|
||||
"""
|
||||
created_assignments = []
|
||||
|
||||
for assignment in assignments:
|
||||
# 检查是否已存在(注意:Result 只能消费一次,需保存结果)
|
||||
result = await db.execute(
|
||||
select(PositionCourse).where(
|
||||
PositionCourse.course_id == course_id,
|
||||
PositionCourse.position_id == assignment.position_id,
|
||||
PositionCourse.is_deleted == False,
|
||||
)
|
||||
)
|
||||
existing_assignment = result.scalar_one_or_none()
|
||||
|
||||
if existing_assignment:
|
||||
# 已存在则更新类型与优先级
|
||||
existing_assignment.course_type = assignment.course_type
|
||||
existing_assignment.priority = assignment.priority
|
||||
# PositionCourse 未继承 AuditMixin,不强制写入审计字段
|
||||
created_assignments.append(existing_assignment)
|
||||
else:
|
||||
# 新建分配关系
|
||||
new_assignment = PositionCourse(
|
||||
course_id=course_id,
|
||||
position_id=assignment.position_id,
|
||||
course_type=assignment.course_type,
|
||||
priority=assignment.priority,
|
||||
)
|
||||
db.add(new_assignment)
|
||||
created_assignments.append(new_assignment)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 重新加载关联数据
|
||||
for obj in created_assignments:
|
||||
await db.refresh(obj)
|
||||
|
||||
logger.info("批量分配课程到岗位成功", course_id=course_id, count=len(assignments), user_id=user_id)
|
||||
|
||||
# 返回分配结果
|
||||
return await self.get_course_positions(db, course_id)
|
||||
|
||||
async def remove_position_assignment(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int,
|
||||
position_id: int,
|
||||
user_id: int
|
||||
) -> bool:
|
||||
"""
|
||||
移除课程的岗位分配
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
position_id: 岗位ID
|
||||
user_id: 操作用户ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
# 查找分配记录
|
||||
stmt = select(PositionCourse).where(
|
||||
PositionCourse.course_id == course_id,
|
||||
PositionCourse.position_id == position_id,
|
||||
PositionCourse.is_deleted == False
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
assignment = result.scalar_one_or_none()
|
||||
|
||||
if assignment:
|
||||
# 软删除
|
||||
assignment.is_deleted = True
|
||||
assignment.deleted_by = user_id
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"移除课程岗位分配成功", course_id=course_id, position_id=position_id, user_id=user_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# 创建服务实例
|
||||
course_position_service = CoursePositionService()
|
||||
837
backend/app/services/course_service.py
Normal file
837
backend/app/services/course_service.py
Normal file
@@ -0,0 +1,837 @@
|
||||
"""
|
||||
课程服务层
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select, or_, and_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.logger import get_logger
|
||||
from app.core.exceptions import NotFoundError, BadRequestError, ConflictError
|
||||
from app.models.course import (
|
||||
Course,
|
||||
CourseStatus,
|
||||
CourseMaterial,
|
||||
KnowledgePoint,
|
||||
GrowthPath,
|
||||
)
|
||||
from app.models.course_exam_settings import CourseExamSettings
|
||||
from app.models.position_member import PositionMember
|
||||
from app.models.position_course import PositionCourse
|
||||
from app.schemas.course import (
|
||||
CourseCreate,
|
||||
CourseUpdate,
|
||||
CourseList,
|
||||
CourseInDB,
|
||||
CourseMaterialCreate,
|
||||
KnowledgePointCreate,
|
||||
KnowledgePointUpdate,
|
||||
KnowledgePointInDB,
|
||||
GrowthPathCreate,
|
||||
)
|
||||
from app.schemas.base import PaginationParams, PaginatedResponse
|
||||
from app.services.base_service import BaseService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CourseService(BaseService[Course]):
|
||||
"""
|
||||
课程服务类
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(Course)
|
||||
|
||||
async def get_course_list(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
page_params: PaginationParams,
|
||||
filters: CourseList,
|
||||
user_id: Optional[int] = None,
|
||||
) -> PaginatedResponse[CourseInDB]:
|
||||
"""
|
||||
获取课程列表(支持筛选)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
page_params: 分页参数
|
||||
filters: 筛选条件
|
||||
user_id: 用户ID(用于记录访问日志)
|
||||
|
||||
Returns:
|
||||
分页的课程列表
|
||||
"""
|
||||
# 构建筛选条件
|
||||
filter_conditions = []
|
||||
|
||||
# 状态筛选(默认只显示已发布的课程)
|
||||
if filters.status is not None:
|
||||
filter_conditions.append(Course.status == filters.status)
|
||||
else:
|
||||
# 如果没有指定状态,默认只返回已发布的课程
|
||||
filter_conditions.append(Course.status == CourseStatus.PUBLISHED)
|
||||
|
||||
# 分类筛选
|
||||
if filters.category is not None:
|
||||
filter_conditions.append(Course.category == filters.category)
|
||||
|
||||
# 是否推荐筛选
|
||||
if filters.is_featured is not None:
|
||||
filter_conditions.append(Course.is_featured == filters.is_featured)
|
||||
|
||||
# 关键词搜索
|
||||
if filters.keyword:
|
||||
keyword = f"%{filters.keyword}%"
|
||||
filter_conditions.append(
|
||||
or_(Course.name.like(keyword), Course.description.like(keyword))
|
||||
)
|
||||
|
||||
# 记录查询日志
|
||||
logger.info(
|
||||
"查询课程列表",
|
||||
user_id=user_id,
|
||||
filters=filters.model_dump(exclude_none=True),
|
||||
page=page_params.page,
|
||||
size=page_params.page_size,
|
||||
)
|
||||
|
||||
# 执行分页查询
|
||||
query = select(Course).where(Course.is_deleted == False)
|
||||
|
||||
# 添加筛选条件
|
||||
if filter_conditions:
|
||||
query = query.where(and_(*filter_conditions))
|
||||
|
||||
# 添加排序:优先按sort_order升序,其次按创建时间降序(新课程优先)
|
||||
query = query.order_by(Course.sort_order.asc(), Course.created_at.desc())
|
||||
|
||||
# 获取总数
|
||||
count_query = (
|
||||
select(func.count()).select_from(Course).where(Course.is_deleted == False)
|
||||
)
|
||||
if filter_conditions:
|
||||
count_query = count_query.where(and_(*filter_conditions))
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# 分页
|
||||
query = query.offset(page_params.offset).limit(page_params.limit)
|
||||
query = query.options(selectinload(Course.materials))
|
||||
|
||||
# 执行查询
|
||||
result = await db.execute(query)
|
||||
courses = result.scalars().all()
|
||||
|
||||
# 获取用户所属的岗位ID列表
|
||||
user_position_ids = []
|
||||
if user_id:
|
||||
position_result = await db.execute(
|
||||
select(PositionMember.position_id).where(
|
||||
PositionMember.user_id == user_id,
|
||||
PositionMember.is_deleted == False
|
||||
)
|
||||
)
|
||||
user_position_ids = [row[0] for row in position_result.fetchall()]
|
||||
|
||||
# 批量查询课程的岗位分配信息
|
||||
course_ids = [c.id for c in courses]
|
||||
course_type_map = {}
|
||||
if course_ids and user_position_ids:
|
||||
position_course_result = await db.execute(
|
||||
select(PositionCourse.course_id, PositionCourse.course_type).where(
|
||||
PositionCourse.course_id.in_(course_ids),
|
||||
PositionCourse.position_id.in_(user_position_ids),
|
||||
PositionCourse.is_deleted == False
|
||||
)
|
||||
)
|
||||
# 构建课程类型映射:如果有多个岗位,优先取required
|
||||
for course_id, course_type in position_course_result.fetchall():
|
||||
if course_id not in course_type_map:
|
||||
course_type_map[course_id] = course_type
|
||||
elif course_type == 'required':
|
||||
course_type_map[course_id] = 'required'
|
||||
|
||||
# 转换为 Pydantic 模型,并附加课程类型
|
||||
course_list = []
|
||||
for course in courses:
|
||||
course_data = CourseInDB.model_validate(course)
|
||||
# 设置课程类型:如果用户有岗位分配则使用分配类型,否则为None
|
||||
course_data.course_type = course_type_map.get(course.id)
|
||||
course_list.append(course_data)
|
||||
|
||||
# 计算总页数
|
||||
pages = (total + page_params.page_size - 1) // page_params.page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=course_list,
|
||||
total=total,
|
||||
page=page_params.page,
|
||||
page_size=page_params.page_size,
|
||||
pages=pages,
|
||||
)
|
||||
|
||||
async def create_course(
|
||||
self, db: AsyncSession, *, course_in: CourseCreate, created_by: int
|
||||
) -> Course:
|
||||
"""
|
||||
创建课程
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_in: 课程创建数据
|
||||
created_by: 创建人ID
|
||||
|
||||
Returns:
|
||||
创建的课程
|
||||
"""
|
||||
# 检查名称是否重复
|
||||
existing = await db.execute(
|
||||
select(Course).where(
|
||||
and_(Course.name == course_in.name, Course.is_deleted == False)
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ConflictError(f"课程名称 '{course_in.name}' 已存在")
|
||||
|
||||
# 创建课程
|
||||
course_data = course_in.model_dump()
|
||||
course = await self.create(db, obj_in=course_data, created_by=created_by)
|
||||
|
||||
# 自动创建默认考试设置
|
||||
default_exam_settings = CourseExamSettings(
|
||||
course_id=course.id,
|
||||
created_by=created_by,
|
||||
updated_by=created_by
|
||||
# 其他字段使用模型定义的默认值:
|
||||
# single_choice_count=4, multiple_choice_count=2, true_false_count=1,
|
||||
# fill_blank_count=2, essay_count=1, duration_minutes=10, 等
|
||||
)
|
||||
db.add(default_exam_settings)
|
||||
await db.commit()
|
||||
await db.refresh(course)
|
||||
|
||||
logger.info(
|
||||
"创建课程", course_id=course.id, course_name=course.name, created_by=created_by
|
||||
)
|
||||
logger.info(
|
||||
"自动创建默认考试设置", course_id=course.id, exam_settings_id=default_exam_settings.id
|
||||
)
|
||||
|
||||
return course
|
||||
|
||||
async def update_course(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
course_id: int,
|
||||
course_in: CourseUpdate,
|
||||
updated_by: int,
|
||||
) -> Course:
|
||||
"""
|
||||
更新课程
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
course_in: 课程更新数据
|
||||
updated_by: 更新人ID
|
||||
|
||||
Returns:
|
||||
更新后的课程
|
||||
"""
|
||||
# 获取课程
|
||||
course = await self.get_by_id(db, course_id)
|
||||
if not course:
|
||||
raise NotFoundError(f"课程ID {course_id} 不存在")
|
||||
|
||||
# 检查名称是否重复(如果修改了名称)
|
||||
if course_in.name and course_in.name != course.name:
|
||||
existing = await db.execute(
|
||||
select(Course).where(
|
||||
and_(
|
||||
Course.name == course_in.name,
|
||||
Course.id != course_id,
|
||||
Course.is_deleted == False,
|
||||
)
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ConflictError(f"课程名称 '{course_in.name}' 已存在")
|
||||
|
||||
# 记录状态变更
|
||||
old_status = course.status
|
||||
|
||||
# 更新课程
|
||||
update_data = course_in.model_dump(exclude_unset=True)
|
||||
|
||||
# 如果状态变为已发布,记录发布时间
|
||||
if (
|
||||
update_data.get("status") == CourseStatus.PUBLISHED
|
||||
and old_status != CourseStatus.PUBLISHED
|
||||
):
|
||||
update_data["published_at"] = datetime.now()
|
||||
update_data["publisher_id"] = updated_by
|
||||
|
||||
course = await self.update(
|
||||
db, db_obj=course, obj_in=update_data, updated_by=updated_by
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"更新课程",
|
||||
course_id=course.id,
|
||||
course_name=course.name,
|
||||
old_status=old_status,
|
||||
new_status=course.status,
|
||||
updated_by=updated_by,
|
||||
)
|
||||
|
||||
return course
|
||||
|
||||
async def delete_course(
|
||||
self, db: AsyncSession, *, course_id: int, deleted_by: int
|
||||
) -> bool:
|
||||
"""
|
||||
删除课程(软删除 + 删除相关文件)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
deleted_by: 删除人ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from app.core.config import settings
|
||||
|
||||
course = await self.get_by_id(db, course_id)
|
||||
if not course:
|
||||
raise NotFoundError(f"课程ID {course_id} 不存在")
|
||||
|
||||
# 放开删除限制:任意状态均可软删除,由业务方自行控制
|
||||
|
||||
# 执行软删除(标记 is_deleted,记录删除时间),由审计日志记录操作者
|
||||
success = await self.soft_delete(db, id=course_id)
|
||||
|
||||
if success:
|
||||
# 删除课程文件夹及其所有内容
|
||||
course_folder = Path(settings.UPLOAD_PATH) / "courses" / str(course_id)
|
||||
if course_folder.exists() and course_folder.is_dir():
|
||||
try:
|
||||
shutil.rmtree(course_folder)
|
||||
logger.info(
|
||||
"删除课程文件夹成功",
|
||||
course_id=course_id,
|
||||
folder_path=str(course_folder),
|
||||
)
|
||||
except Exception as e:
|
||||
# 文件夹删除失败不影响业务流程,仅记录日志
|
||||
logger.error(
|
||||
"删除课程文件夹失败",
|
||||
course_id=course_id,
|
||||
folder_path=str(course_folder),
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"删除课程",
|
||||
course_id=course_id,
|
||||
course_name=course.name,
|
||||
deleted_by=deleted_by,
|
||||
folder_deleted=course_folder.exists(),
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
async def add_course_material(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
course_id: int,
|
||||
material_in: CourseMaterialCreate,
|
||||
created_by: int,
|
||||
) -> CourseMaterial:
|
||||
"""
|
||||
添加课程资料
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
material_in: 资料创建数据
|
||||
created_by: 创建人ID
|
||||
|
||||
Returns:
|
||||
创建的课程资料
|
||||
"""
|
||||
# 检查课程是否存在
|
||||
course = await self.get_by_id(db, course_id)
|
||||
if not course:
|
||||
raise NotFoundError(f"课程ID {course_id} 不存在")
|
||||
|
||||
# 创建资料
|
||||
material_data = material_in.model_dump()
|
||||
material_data.update({
|
||||
"course_id": course_id,
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by
|
||||
})
|
||||
|
||||
material = CourseMaterial(**material_data)
|
||||
db.add(material)
|
||||
await db.commit()
|
||||
await db.refresh(material)
|
||||
|
||||
logger.info(
|
||||
"添加课程资料",
|
||||
course_id=course_id,
|
||||
material_id=material.id,
|
||||
material_name=material.name,
|
||||
file_type=material.file_type,
|
||||
file_size=material.file_size,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
return material
|
||||
|
||||
async def get_course_materials(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
course_id: int,
|
||||
) -> List[CourseMaterial]:
|
||||
"""
|
||||
获取课程资料列表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
|
||||
Returns:
|
||||
课程资料列表
|
||||
"""
|
||||
# 确认课程存在
|
||||
course = await self.get_by_id(db, course_id)
|
||||
if not course:
|
||||
raise NotFoundError(f"课程ID {course_id} 不存在")
|
||||
|
||||
stmt = (
|
||||
select(CourseMaterial)
|
||||
.where(
|
||||
CourseMaterial.course_id == course_id,
|
||||
CourseMaterial.is_deleted == False,
|
||||
)
|
||||
.order_by(CourseMaterial.sort_order.asc(), CourseMaterial.id.asc())
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
materials = result.scalars().all()
|
||||
|
||||
logger.info(
|
||||
"查询课程资料列表", course_id=course_id, count=len(materials)
|
||||
)
|
||||
|
||||
return materials
|
||||
|
||||
async def delete_course_material(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
course_id: int,
|
||||
material_id: int,
|
||||
deleted_by: int,
|
||||
) -> bool:
|
||||
"""
|
||||
删除课程资料(软删除 + 删除物理文件)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
material_id: 资料ID
|
||||
deleted_by: 删除人ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from app.core.config import settings
|
||||
|
||||
# 先确认课程存在
|
||||
course = await self.get_by_id(db, course_id)
|
||||
if not course:
|
||||
raise NotFoundError(f"课程ID {course_id} 不存在")
|
||||
|
||||
# 查找资料并校验归属
|
||||
material_stmt = select(CourseMaterial).where(
|
||||
CourseMaterial.id == material_id,
|
||||
CourseMaterial.course_id == course_id,
|
||||
CourseMaterial.is_deleted == False,
|
||||
)
|
||||
result = await db.execute(material_stmt)
|
||||
material = result.scalar_one_or_none()
|
||||
if not material:
|
||||
raise NotFoundError(f"课程资料ID {material_id} 不存在或已删除")
|
||||
|
||||
# 获取文件路径信息用于删除物理文件
|
||||
file_url = material.file_url
|
||||
|
||||
# 软删除数据库记录
|
||||
material.is_deleted = True
|
||||
material.deleted_at = datetime.now()
|
||||
if hasattr(material, "deleted_by"):
|
||||
# 兼容存在该字段的表
|
||||
setattr(material, "deleted_by", deleted_by)
|
||||
|
||||
db.add(material)
|
||||
await db.commit()
|
||||
|
||||
# 删除物理文件
|
||||
if file_url and file_url.startswith("/static/uploads/"):
|
||||
try:
|
||||
# 从URL中提取相对路径
|
||||
relative_path = file_url.replace("/static/uploads/", "")
|
||||
file_path = Path(settings.UPLOAD_PATH) / relative_path
|
||||
|
||||
# 检查文件是否存在并删除
|
||||
if file_path.exists() and file_path.is_file():
|
||||
os.remove(file_path)
|
||||
logger.info(
|
||||
"删除物理文件成功",
|
||||
file_path=str(file_path),
|
||||
material_id=material_id,
|
||||
)
|
||||
except Exception as e:
|
||||
# 物理文件删除失败不影响业务流程,仅记录日志
|
||||
logger.error(
|
||||
"删除物理文件失败",
|
||||
file_url=file_url,
|
||||
material_id=material_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"删除课程资料",
|
||||
course_id=course_id,
|
||||
material_id=material_id,
|
||||
deleted_by=deleted_by,
|
||||
file_deleted=file_url is not None,
|
||||
)
|
||||
return True
|
||||
|
||||
async def get_material_knowledge_points(
|
||||
self, db: AsyncSession, material_id: int
|
||||
) -> List[KnowledgePointInDB]:
|
||||
"""获取资料关联的知识点列表"""
|
||||
|
||||
# 获取资料信息
|
||||
result = await db.execute(
|
||||
select(CourseMaterial).where(
|
||||
CourseMaterial.id == material_id,
|
||||
CourseMaterial.is_deleted == False
|
||||
)
|
||||
)
|
||||
material = result.scalar_one_or_none()
|
||||
|
||||
if not material:
|
||||
raise NotFoundError(f"资料ID {material_id} 不存在")
|
||||
|
||||
# 直接查询关联到该资料的知识点
|
||||
query = select(KnowledgePoint).where(
|
||||
KnowledgePoint.material_id == material_id,
|
||||
KnowledgePoint.is_deleted == False
|
||||
).order_by(KnowledgePoint.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
knowledge_points = result.scalars().all()
|
||||
|
||||
from app.schemas.course import KnowledgePointInDB
|
||||
return [KnowledgePointInDB.model_validate(kp) for kp in knowledge_points]
|
||||
|
||||
async def add_material_knowledge_points(
|
||||
self, db: AsyncSession, material_id: int, knowledge_point_ids: List[int]
|
||||
) -> List[KnowledgePointInDB]:
|
||||
"""
|
||||
为资料添加知识点关联
|
||||
|
||||
注意:自2025-09-27起,知识点直接通过material_id关联到资料,
|
||||
material_knowledge_points中间表已废弃。此方法将更新知识点的material_id字段。
|
||||
"""
|
||||
# 验证资料是否存在
|
||||
result = await db.execute(
|
||||
select(CourseMaterial).where(
|
||||
CourseMaterial.id == material_id,
|
||||
CourseMaterial.is_deleted == False
|
||||
)
|
||||
)
|
||||
material = result.scalar_one_or_none()
|
||||
|
||||
if not material:
|
||||
raise NotFoundError(f"资料ID {material_id} 不存在")
|
||||
|
||||
# 验证知识点是否存在且属于同一课程
|
||||
result = await db.execute(
|
||||
select(KnowledgePoint).where(
|
||||
KnowledgePoint.id.in_(knowledge_point_ids),
|
||||
KnowledgePoint.course_id == material.course_id,
|
||||
KnowledgePoint.is_deleted == False
|
||||
)
|
||||
)
|
||||
valid_knowledge_points = result.scalars().all()
|
||||
|
||||
if len(valid_knowledge_points) != len(knowledge_point_ids):
|
||||
raise BadRequestError("部分知识点不存在或不属于同一课程")
|
||||
|
||||
# 更新知识点的material_id字段
|
||||
added_knowledge_points = []
|
||||
for kp in valid_knowledge_points:
|
||||
# 更新知识点的资料关联
|
||||
kp.material_id = material_id
|
||||
added_knowledge_points.append(kp)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 刷新对象以获取更新后的数据
|
||||
for kp in added_knowledge_points:
|
||||
await db.refresh(kp)
|
||||
|
||||
from app.schemas.course import KnowledgePointInDB
|
||||
return [KnowledgePointInDB.model_validate(kp) for kp in added_knowledge_points]
|
||||
|
||||
async def remove_material_knowledge_point(
|
||||
self, db: AsyncSession, material_id: int, knowledge_point_id: int
|
||||
) -> bool:
|
||||
"""
|
||||
移除资料的知识点关联(软删除知识点)
|
||||
|
||||
注意:自2025-09-27起,知识点直接通过material_id关联到资料,
|
||||
material_knowledge_points中间表已废弃。此方法将软删除知识点。
|
||||
"""
|
||||
# 查找知识点并验证归属
|
||||
result = await db.execute(
|
||||
select(KnowledgePoint).where(
|
||||
KnowledgePoint.id == knowledge_point_id,
|
||||
KnowledgePoint.material_id == material_id,
|
||||
KnowledgePoint.is_deleted == False
|
||||
)
|
||||
)
|
||||
knowledge_point = result.scalar_one_or_none()
|
||||
|
||||
if not knowledge_point:
|
||||
raise NotFoundError(f"知识点ID {knowledge_point_id} 不存在或不属于该资料")
|
||||
|
||||
# 软删除知识点
|
||||
knowledge_point.is_deleted = True
|
||||
knowledge_point.deleted_at = datetime.now()
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
"移除资料知识点关联",
|
||||
material_id=material_id,
|
||||
knowledge_point_id=knowledge_point_id,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class KnowledgePointService(BaseService[KnowledgePoint]):
|
||||
"""
|
||||
知识点服务类
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(KnowledgePoint)
|
||||
|
||||
async def get_knowledge_points_by_course(
|
||||
self, db: AsyncSession, *, course_id: int, material_id: Optional[int] = None
|
||||
) -> List[KnowledgePoint]:
|
||||
"""
|
||||
获取课程的知识点列表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
material_id: 资料ID(可选,用于筛选特定资料的知识点)
|
||||
|
||||
Returns:
|
||||
知识点列表
|
||||
"""
|
||||
query = select(KnowledgePoint).where(
|
||||
and_(
|
||||
KnowledgePoint.course_id == course_id,
|
||||
KnowledgePoint.is_deleted == False,
|
||||
)
|
||||
)
|
||||
|
||||
if material_id is not None:
|
||||
query = query.where(KnowledgePoint.material_id == material_id)
|
||||
|
||||
query = query.order_by(KnowledgePoint.created_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_knowledge_point(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
course_id: int,
|
||||
point_in: KnowledgePointCreate,
|
||||
created_by: int,
|
||||
) -> KnowledgePoint:
|
||||
"""
|
||||
创建知识点
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
point_in: 知识点创建数据
|
||||
created_by: 创建人ID
|
||||
|
||||
Returns:
|
||||
创建的知识点
|
||||
"""
|
||||
# 检查课程是否存在
|
||||
course_service = CourseService()
|
||||
course = await course_service.get_by_id(db, course_id)
|
||||
if not course:
|
||||
raise NotFoundError(f"课程ID {course_id} 不存在")
|
||||
|
||||
# 创建知识点
|
||||
point_data = point_in.model_dump()
|
||||
point_data.update({"course_id": course_id})
|
||||
|
||||
knowledge_point = await self.create(
|
||||
db, obj_in=point_data, created_by=created_by
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"创建知识点",
|
||||
course_id=course_id,
|
||||
knowledge_point_id=knowledge_point.id,
|
||||
knowledge_point_name=knowledge_point.name,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
return knowledge_point
|
||||
|
||||
async def update_knowledge_point(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
point_id: int,
|
||||
point_in: KnowledgePointUpdate,
|
||||
updated_by: int,
|
||||
) -> KnowledgePoint:
|
||||
"""
|
||||
更新知识点
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
point_id: 知识点ID
|
||||
point_in: 知识点更新数据
|
||||
updated_by: 更新人ID
|
||||
|
||||
Returns:
|
||||
更新后的知识点
|
||||
"""
|
||||
knowledge_point = await self.get_by_id(db, point_id)
|
||||
if not knowledge_point:
|
||||
raise NotFoundError(f"知识点ID {point_id} 不存在")
|
||||
|
||||
# 验证关联资料是否存在
|
||||
if hasattr(point_in, 'material_id') and point_in.material_id:
|
||||
result = await db.execute(
|
||||
select(CourseMaterial).where(
|
||||
CourseMaterial.id == point_in.material_id,
|
||||
CourseMaterial.is_deleted == False
|
||||
)
|
||||
)
|
||||
material = result.scalar_one_or_none()
|
||||
if not material:
|
||||
raise NotFoundError(f"资料ID {point_in.material_id} 不存在")
|
||||
|
||||
# 更新知识点
|
||||
update_data = point_in.model_dump(exclude_unset=True)
|
||||
knowledge_point = await self.update(
|
||||
db, db_obj=knowledge_point, obj_in=update_data, updated_by=updated_by
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"更新知识点",
|
||||
knowledge_point_id=knowledge_point.id,
|
||||
knowledge_point_name=knowledge_point.name,
|
||||
updated_by=updated_by,
|
||||
)
|
||||
|
||||
return knowledge_point
|
||||
|
||||
|
||||
class GrowthPathService(BaseService[GrowthPath]):
|
||||
"""
|
||||
成长路径服务类
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(GrowthPath)
|
||||
|
||||
async def create_growth_path(
|
||||
self, db: AsyncSession, *, path_in: GrowthPathCreate, created_by: int
|
||||
) -> GrowthPath:
|
||||
"""
|
||||
创建成长路径
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
path_in: 成长路径创建数据
|
||||
created_by: 创建人ID
|
||||
|
||||
Returns:
|
||||
创建的成长路径
|
||||
"""
|
||||
# 检查名称是否重复
|
||||
existing = await db.execute(
|
||||
select(GrowthPath).where(
|
||||
and_(GrowthPath.name == path_in.name, GrowthPath.is_deleted == False)
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ConflictError(f"成长路径名称 '{path_in.name}' 已存在")
|
||||
|
||||
# 验证课程是否存在
|
||||
if path_in.courses:
|
||||
course_ids = [c.course_id for c in path_in.courses]
|
||||
course_service = CourseService()
|
||||
for course_id in course_ids:
|
||||
course = await course_service.get_by_id(db, course_id)
|
||||
if not course:
|
||||
raise NotFoundError(f"课程ID {course_id} 不存在")
|
||||
|
||||
# 创建成长路径
|
||||
path_data = path_in.model_dump()
|
||||
# 转换课程列表为JSON格式
|
||||
if path_data.get("courses"):
|
||||
path_data["courses"] = [c.model_dump() for c in path_in.courses]
|
||||
|
||||
growth_path = await self.create(db, obj_in=path_data, created_by=created_by)
|
||||
|
||||
logger.info(
|
||||
"创建成长路径",
|
||||
growth_path_id=growth_path.id,
|
||||
growth_path_name=growth_path.name,
|
||||
course_count=len(path_in.courses) if path_in.courses else 0,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
return growth_path
|
||||
|
||||
|
||||
# 创建服务实例
|
||||
course_service = CourseService()
|
||||
knowledge_point_service = KnowledgePointService()
|
||||
growth_path_service = GrowthPathService()
|
||||
65
backend/app/services/course_statistics_service.py
Normal file
65
backend/app/services/course_statistics_service.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
课程统计服务
|
||||
"""
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.models.exam import Exam
|
||||
from app.models.course import Course
|
||||
from app.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class CourseStatisticsService:
|
||||
"""课程统计服务类"""
|
||||
|
||||
async def update_course_student_count(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
course_id: int
|
||||
) -> int:
|
||||
"""
|
||||
更新课程学员数统计
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
course_id: 课程ID
|
||||
|
||||
Returns:
|
||||
更新后的学员数
|
||||
"""
|
||||
try:
|
||||
# 统计该课程的不同学员数(基于考试记录)
|
||||
stmt = select(func.count(func.distinct(Exam.user_id))).where(
|
||||
Exam.course_id == course_id,
|
||||
Exam.is_deleted == False
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
student_count = result.scalar_one() or 0
|
||||
|
||||
# 更新课程表
|
||||
course_stmt = select(Course).where(
|
||||
Course.id == course_id,
|
||||
Course.is_deleted == False
|
||||
)
|
||||
course_result = await db.execute(course_stmt)
|
||||
course = course_result.scalar_one_or_none()
|
||||
|
||||
if course:
|
||||
course.student_count = student_count
|
||||
await db.commit()
|
||||
logger.info(f"更新课程 {course_id} 学员数: {student_count}")
|
||||
return student_count
|
||||
else:
|
||||
logger.warning(f"课程 {course_id} 不存在")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新课程学员数失败: {str(e)}", exc_info=True)
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
course_statistics_service = CourseStatisticsService()
|
||||
|
||||
97
backend/app/services/coze_broadcast_service.py
Normal file
97
backend/app/services/coze_broadcast_service.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Coze 播课服务
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from cozepy.exception import CozeError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.ai.coze.client import get_coze_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CozeBroadcastService:
|
||||
"""Coze 播课服务"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化配置"""
|
||||
self.workflow_id = settings.COZE_BROADCAST_WORKFLOW_ID
|
||||
self.space_id = settings.COZE_BROADCAST_SPACE_ID
|
||||
|
||||
def _get_client(self):
|
||||
"""获取新的 Coze 客户端(每次调用都创建新认证,避免token过期)"""
|
||||
return get_coze_client(force_new=True)
|
||||
|
||||
async def trigger_workflow(self, course_id: int) -> None:
|
||||
"""
|
||||
触发播课生成工作流(不等待结果)
|
||||
|
||||
Coze工作流会:
|
||||
1. 生成播课音频
|
||||
2. 直接将结果写入数据库
|
||||
|
||||
Args:
|
||||
course_id: 课程ID
|
||||
|
||||
Raises:
|
||||
CozeError: Coze API 调用失败
|
||||
"""
|
||||
logger.info(
|
||||
f"触发播课生成工作流",
|
||||
extra={
|
||||
"course_id": course_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"bot_id": settings.COZE_BROADCAST_BOT_ID or settings.COZE_PRACTICE_BOT_ID # 关联到同一工作空间的Bot
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# 每次调用都获取新客户端(确保OAuth token有效)
|
||||
coze = self._get_client()
|
||||
|
||||
# 调用工作流(触发即返回,不等待结果)
|
||||
# 关键:添加bot_id参数,关联到OAuth应用下的Bot
|
||||
import asyncio
|
||||
result = await asyncio.to_thread(
|
||||
coze.workflows.runs.create,
|
||||
workflow_id=self.workflow_id,
|
||||
parameters={"course_id": str(course_id)},
|
||||
bot_id=settings.COZE_BROADCAST_BOT_ID or settings.COZE_PRACTICE_BOT_ID # 关联Bot,确保OAuth权限
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"播课生成工作流已触发",
|
||||
extra={
|
||||
"course_id": course_id,
|
||||
"execute_id": getattr(result, 'execute_id', None),
|
||||
"debug_url": getattr(result, 'debug_url', None)
|
||||
}
|
||||
)
|
||||
|
||||
except CozeError as e:
|
||||
logger.error(
|
||||
f"触发 Coze 工作流失败",
|
||||
extra={
|
||||
"course_id": course_id,
|
||||
"error": str(e),
|
||||
"error_code": getattr(e, 'code', None)
|
||||
}
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"触发播课生成工作流异常",
|
||||
extra={
|
||||
"course_id": course_id,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__
|
||||
}
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# 全局单例
|
||||
broadcast_service = CozeBroadcastService()
|
||||
199
backend/app/services/coze_service.py
Normal file
199
backend/app/services/coze_service.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Coze AI对话服务
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from cozepy import Coze, COZE_CN_BASE_URL, Message
|
||||
from cozepy.exception import CozeError, CozeAPIError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.ai.coze.client import get_auth_manager
|
||||
|
||||
# 注意:不再直接使用 TokenAuth,统一通过 get_auth_manager() 管理认证
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CozeService:
|
||||
"""Coze对话服务"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Coze客户端"""
|
||||
if not settings.COZE_PRACTICE_BOT_ID:
|
||||
raise ValueError("COZE_PRACTICE_BOT_ID 未配置")
|
||||
|
||||
self.bot_id = settings.COZE_PRACTICE_BOT_ID
|
||||
self._auth_manager = get_auth_manager()
|
||||
|
||||
logger.info(
|
||||
f"CozeService初始化成功,Bot ID={self.bot_id}, "
|
||||
f"Base URL={COZE_CN_BASE_URL}"
|
||||
)
|
||||
|
||||
@property
|
||||
def client(self) -> Coze:
|
||||
"""获取Coze客户端(每次获取确保OAuth token有效)"""
|
||||
return self._auth_manager.get_client(force_new=True)
|
||||
|
||||
def build_scene_prompt(
|
||||
self,
|
||||
scene_name: str,
|
||||
scene_background: str,
|
||||
scene_ai_role: str,
|
||||
scene_objectives: list,
|
||||
scene_keywords: Optional[list] = None,
|
||||
scene_description: Optional[str] = None,
|
||||
user_message: str = ""
|
||||
) -> str:
|
||||
"""
|
||||
构建场景提示词(Markdown格式)
|
||||
|
||||
参数:
|
||||
scene_name: 场景名称
|
||||
scene_background: 场景背景
|
||||
scene_ai_role: AI角色描述
|
||||
scene_objectives: 练习目标列表
|
||||
scene_keywords: 关键词列表
|
||||
scene_description: 场景描述(可选)
|
||||
user_message: 用户第一句话
|
||||
|
||||
返回:
|
||||
完整的场景提示词(Markdown格式)
|
||||
"""
|
||||
# 构建练习目标
|
||||
objectives_text = "\n".join(
|
||||
f"{i+1}. {obj}" for i, obj in enumerate(scene_objectives)
|
||||
)
|
||||
|
||||
# 构建关键词
|
||||
keywords_text = ", ".join(scene_keywords) if scene_keywords else ""
|
||||
|
||||
# 构建完整提示词
|
||||
prompt = f"""# 陪练场景设定
|
||||
|
||||
## 场景名称
|
||||
{scene_name}
|
||||
"""
|
||||
|
||||
# 添加场景描述(如果有)
|
||||
if scene_description:
|
||||
prompt += f"""
|
||||
## 场景描述
|
||||
{scene_description}
|
||||
"""
|
||||
|
||||
prompt += f"""
|
||||
## 场景背景
|
||||
{scene_background}
|
||||
|
||||
## AI角色要求
|
||||
{scene_ai_role}
|
||||
|
||||
## 练习目标
|
||||
{objectives_text}
|
||||
"""
|
||||
|
||||
# 添加关键词(如果有)
|
||||
if keywords_text:
|
||||
prompt += f"""
|
||||
## 关键词
|
||||
{keywords_text}
|
||||
"""
|
||||
|
||||
prompt += f"""
|
||||
---
|
||||
|
||||
现在开始陪练对话。请你严格按照上述场景设定扮演角色,与学员进行实战对话练习。
|
||||
不要提及"场景设定"或"角色扮演"等元信息,直接进入角色开始对话。
|
||||
|
||||
学员的第一句话:{user_message}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def create_stream_chat(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
创建流式对话
|
||||
|
||||
参数:
|
||||
user_id: 用户ID
|
||||
message: 消息内容
|
||||
conversation_id: 对话ID(续接对话时使用)
|
||||
|
||||
返回:
|
||||
Coze流式对话迭代器
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"创建Coze流式对话,user_id={user_id}, "
|
||||
f"conversation_id={conversation_id}, "
|
||||
f"message_length={len(message)}"
|
||||
)
|
||||
|
||||
stream = self.client.chat.stream(
|
||||
bot_id=self.bot_id,
|
||||
user_id=user_id,
|
||||
additional_messages=[Message.build_user_question_text(message)],
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
# 记录LogID用于排查问题
|
||||
if hasattr(stream, 'response') and hasattr(stream.response, 'logid'):
|
||||
logger.info(f"Coze对话创建成功,logid={stream.response.logid}")
|
||||
|
||||
return stream
|
||||
|
||||
except (CozeError, CozeAPIError) as e:
|
||||
logger.error(f"Coze API调用失败: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建Coze对话失败: {e}")
|
||||
raise
|
||||
|
||||
def cancel_chat(self, conversation_id: str, chat_id: str):
|
||||
"""
|
||||
中断对话
|
||||
|
||||
参数:
|
||||
conversation_id: 对话ID
|
||||
chat_id: 聊天ID
|
||||
"""
|
||||
try:
|
||||
logger.info(f"中断Coze对话,conversation_id={conversation_id}, chat_id={chat_id}")
|
||||
|
||||
result = self.client.chat.cancel(
|
||||
conversation_id=conversation_id,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
logger.info(f"对话中断成功")
|
||||
return result
|
||||
|
||||
except (CozeError, CozeAPIError) as e:
|
||||
logger.error(f"中断对话失败: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"中断对话异常: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# 单例模式
|
||||
_coze_service: Optional[CozeService] = None
|
||||
|
||||
|
||||
def get_coze_service() -> CozeService:
|
||||
"""
|
||||
获取CozeService单例
|
||||
|
||||
用于FastAPI依赖注入
|
||||
"""
|
||||
global _coze_service
|
||||
if _coze_service is None:
|
||||
_coze_service = CozeService()
|
||||
return _coze_service
|
||||
|
||||
305
backend/app/services/document_converter.py
Normal file
305
backend/app/services/document_converter.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
文档转换服务
|
||||
使用 LibreOffice 将 Office 文档转换为 PDF
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentConverterService:
|
||||
"""文档转换服务类"""
|
||||
|
||||
# 支持转换的文件格式
|
||||
SUPPORTED_FORMATS = {'.docx', '.doc', '.pptx', '.ppt', '.xlsx', '.xls'}
|
||||
|
||||
# Excel文件格式(需要特殊处理页面布局)
|
||||
EXCEL_FORMATS = {'.xlsx', '.xls'}
|
||||
|
||||
def __init__(self):
|
||||
"""初始化转换服务"""
|
||||
self.converted_path = Path(settings.UPLOAD_PATH) / "converted"
|
||||
self.converted_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_converted_file_path(self, course_id: int, material_id: int) -> Path:
|
||||
"""
|
||||
获取转换后的文件路径
|
||||
|
||||
Args:
|
||||
course_id: 课程ID
|
||||
material_id: 资料ID
|
||||
|
||||
Returns:
|
||||
转换后的PDF文件路径
|
||||
"""
|
||||
course_dir = self.converted_path / str(course_id)
|
||||
course_dir.mkdir(parents=True, exist_ok=True)
|
||||
return course_dir / f"{material_id}.pdf"
|
||||
|
||||
def need_convert(self, source_file: Path, converted_file: Path) -> bool:
|
||||
"""
|
||||
判断是否需要重新转换
|
||||
|
||||
Args:
|
||||
source_file: 源文件路径
|
||||
converted_file: 转换后的文件路径
|
||||
|
||||
Returns:
|
||||
是否需要转换
|
||||
"""
|
||||
# 如果转换文件不存在,需要转换
|
||||
if not converted_file.exists():
|
||||
return True
|
||||
|
||||
# 如果源文件不存在,不需要转换
|
||||
if not source_file.exists():
|
||||
return False
|
||||
|
||||
# 如果源文件修改时间晚于转换文件,需要重新转换
|
||||
source_mtime = source_file.stat().st_mtime
|
||||
converted_mtime = converted_file.stat().st_mtime
|
||||
|
||||
return source_mtime > converted_mtime
|
||||
|
||||
def convert_excel_to_html(
|
||||
self,
|
||||
source_file: str,
|
||||
course_id: int,
|
||||
material_id: int
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
将Excel文件转换为HTML(避免PDF分页问题)
|
||||
|
||||
Args:
|
||||
source_file: 源文件路径
|
||||
course_id: 课程ID
|
||||
material_id: 资料ID
|
||||
|
||||
Returns:
|
||||
转换后的HTML文件URL,失败返回None
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
import openpyxl
|
||||
from openpyxl.utils import get_column_letter
|
||||
except ImportError as ie:
|
||||
logger.error(f"Excel转换依赖缺失: openpyxl 未安装。请运行 pip install openpyxl 或重建Docker镜像。错误: {str(ie)}")
|
||||
return None
|
||||
|
||||
source_path = Path(source_file)
|
||||
logger.info(f"开始Excel转HTML: source={source_file}, course_id={course_id}, material_id={material_id}")
|
||||
|
||||
# 获取HTML输出路径
|
||||
course_dir = self.converted_path / str(course_id)
|
||||
course_dir.mkdir(parents=True, exist_ok=True)
|
||||
html_file = course_dir / f"{material_id}.html"
|
||||
|
||||
# 检查缓存
|
||||
if html_file.exists():
|
||||
source_mtime = source_path.stat().st_mtime
|
||||
html_mtime = html_file.stat().st_mtime
|
||||
if source_mtime <= html_mtime:
|
||||
logger.info(f"使用缓存的HTML文件: {html_file}")
|
||||
return f"/static/uploads/converted/{course_id}/{material_id}.html"
|
||||
|
||||
# 读取Excel文件
|
||||
wb = openpyxl.load_workbook(source_file, data_only=True)
|
||||
|
||||
# 构建HTML
|
||||
html_content = '''<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; padding: 20px; background: #f5f5f5; }
|
||||
.sheet-tabs { display: flex; gap: 10px; margin-bottom: 20px; flex-wrap: wrap; }
|
||||
.sheet-tab { padding: 8px 16px; background: #fff; border: 1px solid #ddd; border-radius: 4px; cursor: pointer; }
|
||||
.sheet-tab.active { background: #409eff; color: white; border-color: #409eff; }
|
||||
.sheet-content { display: none; }
|
||||
.sheet-content.active { display: block; }
|
||||
table { border-collapse: collapse; width: 100%; background: white; box-shadow: 0 1px 3px rgba(0,0,0,0.1); }
|
||||
th, td { border: 1px solid #e4e7ed; padding: 8px 12px; text-align: left; white-space: nowrap; }
|
||||
th { background: #f5f7fa; font-weight: 600; position: sticky; top: 0; }
|
||||
tr:nth-child(even) { background: #fafafa; }
|
||||
tr:hover { background: #ecf5ff; }
|
||||
.table-wrapper { overflow-x: auto; max-height: 80vh; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
'''
|
||||
|
||||
# 生成sheet选项卡
|
||||
sheet_names = wb.sheetnames
|
||||
html_content += '<div class="sheet-tabs">\n'
|
||||
for i, name in enumerate(sheet_names):
|
||||
active = 'active' if i == 0 else ''
|
||||
html_content += f'<div class="sheet-tab {active}" onclick="showSheet({i})">{name}</div>\n'
|
||||
html_content += '</div>\n'
|
||||
|
||||
# 生成每个sheet的表格
|
||||
for i, sheet_name in enumerate(sheet_names):
|
||||
ws = wb[sheet_name]
|
||||
active = 'active' if i == 0 else ''
|
||||
html_content += f'<div class="sheet-content {active}" id="sheet-{i}">\n'
|
||||
html_content += '<div class="table-wrapper"><table>\n'
|
||||
|
||||
# 获取有效数据范围
|
||||
max_row = ws.max_row or 1
|
||||
max_col = ws.max_column or 1
|
||||
|
||||
for row_idx in range(1, min(max_row + 1, 1001)): # 限制最多1000行
|
||||
html_content += '<tr>'
|
||||
for col_idx in range(1, min(max_col + 1, 51)): # 限制最多50列
|
||||
cell = ws.cell(row=row_idx, column=col_idx)
|
||||
value = cell.value if cell.value is not None else ''
|
||||
tag = 'th' if row_idx == 1 else 'td'
|
||||
# 转义HTML特殊字符
|
||||
if isinstance(value, str):
|
||||
value = value.replace('&', '&').replace('<', '<').replace('>', '>')
|
||||
html_content += f'<{tag}>{value}</{tag}>'
|
||||
html_content += '</tr>\n'
|
||||
|
||||
html_content += '</table></div></div>\n'
|
||||
|
||||
# 添加JavaScript
|
||||
html_content += '''
|
||||
<script>
|
||||
function showSheet(index) {
|
||||
document.querySelectorAll('.sheet-tab').forEach((tab, i) => {
|
||||
tab.classList.toggle('active', i === index);
|
||||
});
|
||||
document.querySelectorAll('.sheet-content').forEach((content, i) => {
|
||||
content.classList.toggle('active', i === index);
|
||||
});
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>'''
|
||||
|
||||
# 写入HTML文件
|
||||
with open(html_file, 'w', encoding='utf-8') as f:
|
||||
f.write(html_content)
|
||||
|
||||
logger.info(f"Excel转HTML成功: {html_file}")
|
||||
return f"/static/uploads/converted/{course_id}/{material_id}.html"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Excel转HTML失败: {source_file}, 错误: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def convert_to_pdf(
|
||||
self,
|
||||
source_file: str,
|
||||
course_id: int,
|
||||
material_id: int
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
将Office文档转换为PDF
|
||||
|
||||
Args:
|
||||
source_file: 源文件路径(绝对路径或相对路径)
|
||||
course_id: 课程ID
|
||||
material_id: 资料ID
|
||||
|
||||
Returns:
|
||||
转换后的PDF文件URL,失败返回None
|
||||
"""
|
||||
try:
|
||||
source_path = Path(source_file)
|
||||
|
||||
# 检查源文件是否存在
|
||||
if not source_path.exists():
|
||||
logger.error(f"源文件不存在: {source_file}")
|
||||
return None
|
||||
|
||||
# 检查文件格式是否支持
|
||||
file_ext = source_path.suffix.lower()
|
||||
if file_ext not in self.SUPPORTED_FORMATS:
|
||||
logger.error(f"不支持的文件格式: {file_ext}")
|
||||
return None
|
||||
|
||||
# Excel文件使用HTML预览(避免分页问题)
|
||||
if file_ext in self.EXCEL_FORMATS:
|
||||
return self.convert_excel_to_html(source_file, course_id, material_id)
|
||||
|
||||
# 获取转换后的文件路径
|
||||
converted_file = self.get_converted_file_path(course_id, material_id)
|
||||
|
||||
# 检查是否需要转换
|
||||
if not self.need_convert(source_path, converted_file):
|
||||
logger.info(f"使用缓存的转换文件: {converted_file}")
|
||||
return f"/static/uploads/converted/{course_id}/{material_id}.pdf"
|
||||
|
||||
# 执行转换
|
||||
logger.info(f"开始转换文档: {source_file} -> {converted_file}")
|
||||
|
||||
# 使用 LibreOffice 转换
|
||||
# --headless: 无界面模式
|
||||
# --convert-to pdf: 转换为PDF
|
||||
# --outdir: 输出目录
|
||||
output_dir = converted_file.parent
|
||||
|
||||
cmd = [
|
||||
'libreoffice',
|
||||
'--headless',
|
||||
'--convert-to', 'pdf',
|
||||
'--outdir', str(output_dir),
|
||||
str(source_path)
|
||||
]
|
||||
|
||||
# 执行转换命令(设置超时时间为60秒)
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
check=True
|
||||
)
|
||||
|
||||
# LibreOffice 转换后的文件名是源文件名.pdf
|
||||
# 需要重命名为 material_id.pdf
|
||||
temp_converted = output_dir / f"{source_path.stem}.pdf"
|
||||
if temp_converted.exists() and temp_converted != converted_file:
|
||||
temp_converted.rename(converted_file)
|
||||
|
||||
# 检查转换结果
|
||||
if converted_file.exists():
|
||||
logger.info(f"文档转换成功: {converted_file}")
|
||||
return f"/static/uploads/converted/{course_id}/{material_id}.pdf"
|
||||
else:
|
||||
logger.error(f"文档转换失败,输出文件不存在: {converted_file}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"文档转换超时: {source_file}")
|
||||
return None
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"文档转换失败: {source_file}, 错误: {e.stderr}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"文档转换异常: {source_file}, 错误: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def is_convertible(self, file_ext: str) -> bool:
|
||||
"""
|
||||
判断文件格式是否可转换
|
||||
|
||||
Args:
|
||||
file_ext: 文件扩展名(带点,如 .docx)
|
||||
|
||||
Returns:
|
||||
是否可转换
|
||||
"""
|
||||
return file_ext.lower() in self.SUPPORTED_FORMATS
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
document_converter = DocumentConverterService()
|
||||
|
||||
739
backend/app/services/employee_sync_service.py
Normal file
739
backend/app/services/employee_sync_service.py
Normal file
@@ -0,0 +1,739 @@
|
||||
"""
|
||||
员工同步服务
|
||||
从外部钉钉员工表同步员工数据到考培练系统
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import selectinload
|
||||
import asyncio
|
||||
|
||||
from app.core.logger import get_logger
|
||||
from app.core.security import get_password_hash
|
||||
from app.models.user import User, Team
|
||||
from app.models.position import Position
|
||||
from app.models.position_member import PositionMember
|
||||
from app.schemas.user import UserCreate
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EmployeeSyncService:
|
||||
"""员工同步服务"""
|
||||
|
||||
# 外部数据库连接配置
|
||||
EXTERNAL_DB_URL = "mysql+aiomysql://neuron_new:NWxGM6CQoMLKyEszXhfuLBIIo1QbeK@120.77.144.233:29613/neuron_new?charset=utf8mb4"
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.external_engine = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
self.external_engine = create_async_engine(
|
||||
self.EXTERNAL_DB_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
if self.external_engine:
|
||||
await self.external_engine.dispose()
|
||||
|
||||
async def fetch_employees_from_dingtalk(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从钉钉员工表获取在职员工数据
|
||||
|
||||
Returns:
|
||||
员工数据列表
|
||||
"""
|
||||
logger.info("开始从钉钉员工表获取数据...")
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
员工姓名,
|
||||
手机号,
|
||||
邮箱,
|
||||
所属部门,
|
||||
职位,
|
||||
工号,
|
||||
是否领导,
|
||||
是否在职,
|
||||
钉钉用户ID,
|
||||
入职日期,
|
||||
工作地点
|
||||
FROM v_钉钉员工表
|
||||
WHERE 是否在职 = 1
|
||||
ORDER BY 员工姓名
|
||||
"""
|
||||
|
||||
async with self.external_engine.connect() as conn:
|
||||
result = await conn.execute(text(query))
|
||||
rows = result.fetchall()
|
||||
|
||||
employees = []
|
||||
for row in rows:
|
||||
employees.append({
|
||||
'full_name': row[0],
|
||||
'phone': row[1],
|
||||
'email': row[2],
|
||||
'department': row[3],
|
||||
'position': row[4],
|
||||
'employee_no': row[5],
|
||||
'is_leader': bool(row[6]),
|
||||
'is_active': bool(row[7]),
|
||||
'dingtalk_id': row[8],
|
||||
'join_date': row[9],
|
||||
'work_location': row[10]
|
||||
})
|
||||
|
||||
logger.info(f"获取到 {len(employees)} 条在职员工数据")
|
||||
return employees
|
||||
|
||||
def generate_email(self, phone: str, original_email: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
生成邮箱地址
|
||||
如果原始邮箱为空或格式无效,生成 {手机号}@rxm.com
|
||||
|
||||
Args:
|
||||
phone: 手机号
|
||||
original_email: 原始邮箱
|
||||
|
||||
Returns:
|
||||
邮箱地址
|
||||
"""
|
||||
if original_email and original_email.strip():
|
||||
email = original_email.strip()
|
||||
# 验证邮箱格式:检查是否有@后直接跟点号等无效格式
|
||||
if '@' in email and not email.startswith('@') and not email.endswith('@'):
|
||||
# 检查@后面是否直接是点号
|
||||
at_index = email.index('@')
|
||||
if at_index + 1 < len(email) and email[at_index + 1] != '.':
|
||||
# 检查是否有域名部分
|
||||
domain = email[at_index + 1:]
|
||||
if '.' in domain and len(domain) > 2:
|
||||
return email
|
||||
|
||||
# 如果邮箱无效或为空,使用手机号生成
|
||||
if phone:
|
||||
return f"{phone}@rxm.com"
|
||||
|
||||
return None
|
||||
|
||||
def determine_role(self, is_leader: bool) -> str:
|
||||
"""
|
||||
确定用户角色
|
||||
|
||||
Args:
|
||||
is_leader: 是否领导
|
||||
|
||||
Returns:
|
||||
角色: manager 或 trainee
|
||||
"""
|
||||
return 'manager' if is_leader else 'trainee'
|
||||
|
||||
async def create_or_get_team(self, department_name: str, leader_id: Optional[int] = None) -> Team:
|
||||
"""
|
||||
创建或获取部门团队
|
||||
|
||||
Args:
|
||||
department_name: 部门名称
|
||||
leader_id: 负责人ID
|
||||
|
||||
Returns:
|
||||
团队对象
|
||||
"""
|
||||
if not department_name or department_name.strip() == '':
|
||||
return None
|
||||
|
||||
department_name = department_name.strip()
|
||||
|
||||
# 检查团队是否已存在
|
||||
stmt = select(Team).where(
|
||||
Team.name == department_name,
|
||||
Team.is_deleted == False
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
team = result.scalar_one_or_none()
|
||||
|
||||
if team:
|
||||
# 更新负责人
|
||||
if leader_id and not team.leader_id:
|
||||
team.leader_id = leader_id
|
||||
logger.info(f"更新团队 {department_name} 的负责人")
|
||||
return team
|
||||
|
||||
# 创建新团队
|
||||
# 生成团队代码:使用部门名称的拼音首字母或简化处理
|
||||
team_code = f"DEPT_{hash(department_name) % 100000:05d}"
|
||||
|
||||
team = Team(
|
||||
name=department_name,
|
||||
code=team_code,
|
||||
description=f"{department_name}",
|
||||
team_type='department',
|
||||
is_active=True,
|
||||
leader_id=leader_id
|
||||
)
|
||||
|
||||
self.db.add(team)
|
||||
await self.db.flush() # 获取ID但不提交
|
||||
|
||||
logger.info(f"创建团队: {department_name} (ID: {team.id})")
|
||||
return team
|
||||
|
||||
async def create_or_get_position(self, position_name: str) -> Optional[Position]:
|
||||
"""
|
||||
创建或获取岗位
|
||||
|
||||
Args:
|
||||
position_name: 岗位名称
|
||||
|
||||
Returns:
|
||||
岗位对象
|
||||
"""
|
||||
if not position_name or position_name.strip() == '':
|
||||
return None
|
||||
|
||||
position_name = position_name.strip()
|
||||
|
||||
# 检查岗位是否已存在
|
||||
stmt = select(Position).where(
|
||||
Position.name == position_name,
|
||||
Position.is_deleted == False
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
position = result.scalar_one_or_none()
|
||||
|
||||
if position:
|
||||
return position
|
||||
|
||||
# 创建新岗位
|
||||
position_code = f"POS_{hash(position_name) % 100000:05d}"
|
||||
|
||||
position = Position(
|
||||
name=position_name,
|
||||
code=position_code,
|
||||
description=f"{position_name}",
|
||||
status='active'
|
||||
)
|
||||
|
||||
self.db.add(position)
|
||||
await self.db.flush()
|
||||
|
||||
logger.info(f"创建岗位: {position_name} (ID: {position.id})")
|
||||
return position
|
||||
|
||||
async def create_user(self, employee_data: Dict[str, Any]) -> Optional[User]:
|
||||
"""
|
||||
创建用户
|
||||
|
||||
Args:
|
||||
employee_data: 员工数据
|
||||
|
||||
Returns:
|
||||
用户对象或None(如果创建失败)
|
||||
"""
|
||||
phone = employee_data.get('phone')
|
||||
full_name = employee_data.get('full_name')
|
||||
|
||||
if not phone:
|
||||
logger.warning(f"员工 {full_name} 没有手机号,跳过")
|
||||
return None
|
||||
|
||||
# 检查用户是否已存在(通过手机号)
|
||||
stmt = select(User).where(
|
||||
User.phone == phone,
|
||||
User.is_deleted == False
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
existing_user = result.scalar_one_or_none()
|
||||
|
||||
if existing_user:
|
||||
logger.info(f"用户已存在: {phone} ({full_name})")
|
||||
return existing_user
|
||||
|
||||
# 生成邮箱
|
||||
email = self.generate_email(phone, employee_data.get('email'))
|
||||
|
||||
# 检查邮箱是否已被其他用户使用(避免唯一索引冲突)
|
||||
if email:
|
||||
email_check_stmt = select(User).where(
|
||||
User.email == email,
|
||||
User.is_deleted == False
|
||||
)
|
||||
email_result = await self.db.execute(email_check_stmt)
|
||||
if email_result.scalar_one_or_none():
|
||||
# 邮箱已存在,使用手机号生成唯一邮箱
|
||||
email = f"{phone}@rxm.com"
|
||||
logger.warning(f"邮箱 {employee_data.get('email')} 已被使用,为员工 {full_name} 生成新邮箱: {email}")
|
||||
|
||||
# 确定角色
|
||||
role = self.determine_role(employee_data.get('is_leader', False))
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username=phone, # 使用手机号作为用户名
|
||||
email=email,
|
||||
phone=phone,
|
||||
hashed_password=get_password_hash('123456'), # 初始密码
|
||||
full_name=full_name,
|
||||
role=role,
|
||||
is_active=True,
|
||||
is_verified=True
|
||||
)
|
||||
|
||||
self.db.add(user)
|
||||
await self.db.flush()
|
||||
|
||||
logger.info(f"创建用户: {phone} ({full_name}) - 角色: {role}")
|
||||
return user
|
||||
|
||||
async def sync_employees(self) -> Dict[str, Any]:
|
||||
"""
|
||||
执行完整的员工同步流程
|
||||
|
||||
Returns:
|
||||
同步结果统计
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始员工同步")
|
||||
logger.info("=" * 60)
|
||||
|
||||
stats = {
|
||||
'total_employees': 0,
|
||||
'users_created': 0,
|
||||
'users_skipped': 0,
|
||||
'teams_created': 0,
|
||||
'positions_created': 0,
|
||||
'errors': [],
|
||||
'start_time': datetime.now()
|
||||
}
|
||||
|
||||
try:
|
||||
# 1. 获取员工数据
|
||||
employees = await self.fetch_employees_from_dingtalk()
|
||||
stats['total_employees'] = len(employees)
|
||||
|
||||
if not employees:
|
||||
logger.warning("没有获取到员工数据")
|
||||
return stats
|
||||
|
||||
# 2. 创建用户和相关数据
|
||||
for employee in employees:
|
||||
try:
|
||||
# 创建用户
|
||||
user = await self.create_user(employee)
|
||||
if not user:
|
||||
stats['users_skipped'] += 1
|
||||
continue
|
||||
|
||||
stats['users_created'] += 1
|
||||
|
||||
# 创建部门团队
|
||||
department = employee.get('department')
|
||||
if department:
|
||||
team = await self.create_or_get_team(
|
||||
department,
|
||||
leader_id=user.id if employee.get('is_leader') else None
|
||||
)
|
||||
if team:
|
||||
# 用SQL直接插入user_teams关联表(避免懒加载问题)
|
||||
await self._add_user_to_team(user.id, team.id)
|
||||
logger.info(f"关联用户 {user.full_name} 到团队 {team.name}")
|
||||
|
||||
# 创建岗位
|
||||
position_name = employee.get('position')
|
||||
if position_name:
|
||||
position = await self.create_or_get_position(position_name)
|
||||
if position:
|
||||
# 检查是否已经关联
|
||||
stmt = select(PositionMember).where(
|
||||
PositionMember.position_id == position.id,
|
||||
PositionMember.user_id == user.id,
|
||||
PositionMember.is_deleted == False
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
existing_member = result.scalar_one_or_none()
|
||||
|
||||
if not existing_member:
|
||||
# 创建岗位成员关联
|
||||
position_member = PositionMember(
|
||||
position_id=position.id,
|
||||
user_id=user.id,
|
||||
role='member'
|
||||
)
|
||||
self.db.add(position_member)
|
||||
logger.info(f"关联用户 {user.full_name} 到岗位 {position.name}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"处理员工 {employee.get('full_name')} 时出错: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
stats['errors'].append(error_msg)
|
||||
continue
|
||||
|
||||
# 3. 提交所有更改
|
||||
await self.db.commit()
|
||||
logger.info("✅ 数据库事务已提交")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"员工同步失败: {str(e)}")
|
||||
await self.db.rollback()
|
||||
stats['errors'].append(str(e))
|
||||
raise
|
||||
|
||||
finally:
|
||||
stats['end_time'] = datetime.now()
|
||||
stats['duration'] = (stats['end_time'] - stats['start_time']).total_seconds()
|
||||
|
||||
# 4. 输出统计信息
|
||||
logger.info("=" * 60)
|
||||
logger.info("同步完成统计")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"总员工数: {stats['total_employees']}")
|
||||
logger.info(f"创建用户: {stats['users_created']}")
|
||||
logger.info(f"跳过用户: {stats['users_skipped']}")
|
||||
logger.info(f"耗时: {stats['duration']:.2f}秒")
|
||||
|
||||
if stats['errors']:
|
||||
logger.warning(f"错误数量: {len(stats['errors'])}")
|
||||
for error in stats['errors']:
|
||||
logger.warning(f" - {error}")
|
||||
|
||||
return stats
|
||||
|
||||
async def preview_sync_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
预览待同步的员工数据(不执行实际同步)
|
||||
|
||||
Returns:
|
||||
预览数据
|
||||
"""
|
||||
logger.info("预览待同步员工数据...")
|
||||
|
||||
employees = await self.fetch_employees_from_dingtalk()
|
||||
|
||||
preview = {
|
||||
'total_count': len(employees),
|
||||
'employees': [],
|
||||
'departments': set(),
|
||||
'positions': set(),
|
||||
'leaders_count': 0,
|
||||
'trainees_count': 0
|
||||
}
|
||||
|
||||
for emp in employees:
|
||||
role = self.determine_role(emp.get('is_leader', False))
|
||||
email = self.generate_email(emp.get('phone'), emp.get('email'))
|
||||
|
||||
preview['employees'].append({
|
||||
'full_name': emp.get('full_name'),
|
||||
'phone': emp.get('phone'),
|
||||
'email': email,
|
||||
'department': emp.get('department'),
|
||||
'position': emp.get('position'),
|
||||
'role': role,
|
||||
'is_leader': emp.get('is_leader')
|
||||
})
|
||||
|
||||
if emp.get('department'):
|
||||
preview['departments'].add(emp.get('department'))
|
||||
if emp.get('position'):
|
||||
preview['positions'].add(emp.get('position'))
|
||||
|
||||
if role == 'manager':
|
||||
preview['leaders_count'] += 1
|
||||
else:
|
||||
preview['trainees_count'] += 1
|
||||
|
||||
preview['departments'] = list(preview['departments'])
|
||||
preview['positions'] = list(preview['positions'])
|
||||
|
||||
return preview
|
||||
|
||||
async def _add_user_to_team(self, user_id: int, team_id: int) -> None:
|
||||
"""
|
||||
将用户添加到团队(直接SQL操作,避免懒加载问题)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
team_id: 团队ID
|
||||
"""
|
||||
# 先检查是否已存在关联
|
||||
check_result = await self.db.execute(
|
||||
text("SELECT 1 FROM user_teams WHERE user_id = :user_id AND team_id = :team_id"),
|
||||
{"user_id": user_id, "team_id": team_id}
|
||||
)
|
||||
if check_result.scalar() is None:
|
||||
# 不存在则插入
|
||||
await self.db.execute(
|
||||
text("INSERT INTO user_teams (user_id, team_id, role) VALUES (:user_id, :team_id, 'member')"),
|
||||
{"user_id": user_id, "team_id": team_id}
|
||||
)
|
||||
|
||||
async def _cleanup_user_related_data(self, user_id: int) -> None:
|
||||
"""
|
||||
清理用户关联数据(用于删除用户前)
|
||||
|
||||
Args:
|
||||
user_id: 要清理的用户ID
|
||||
"""
|
||||
logger.info(f"清理用户 {user_id} 的关联数据...")
|
||||
|
||||
# 删除用户的考试记录
|
||||
await self.db.execute(
|
||||
text("DELETE FROM exam_results WHERE exam_id IN (SELECT id FROM exams WHERE user_id = :user_id)"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
await self.db.execute(
|
||||
text("DELETE FROM exams WHERE user_id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
# 删除用户的错题记录
|
||||
await self.db.execute(
|
||||
text("DELETE FROM exam_mistakes WHERE user_id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
# 删除用户的能力评估记录
|
||||
await self.db.execute(
|
||||
text("DELETE FROM ability_assessments WHERE user_id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
# 删除用户的岗位关联
|
||||
await self.db.execute(
|
||||
text("DELETE FROM position_members WHERE user_id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
# 删除用户的团队关联
|
||||
await self.db.execute(
|
||||
text("DELETE FROM user_teams WHERE user_id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
# 删除用户的陪练会话
|
||||
await self.db.execute(
|
||||
text("DELETE FROM practice_sessions WHERE user_id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
# 删除用户的任务分配
|
||||
await self.db.execute(
|
||||
text("DELETE FROM task_assignments WHERE user_id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
# 删除用户创建的任务的分配记录
|
||||
await self.db.execute(
|
||||
text("DELETE FROM task_assignments WHERE task_id IN (SELECT id FROM tasks WHERE creator_id = :user_id)"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
# 删除用户创建的任务
|
||||
await self.db.execute(
|
||||
text("DELETE FROM tasks WHERE creator_id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
# 将用户作为负责人的团队的leader_id设为NULL
|
||||
await self.db.execute(
|
||||
text("UPDATE teams SET leader_id = NULL WHERE leader_id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
logger.info(f"用户 {user_id} 的关联数据清理完成")
|
||||
|
||||
async def incremental_sync_employees(self) -> Dict[str, Any]:
|
||||
"""
|
||||
增量同步员工数据
|
||||
- 新增钉钉有但系统没有的员工
|
||||
- 删除系统有但钉钉没有的员工(物理删除)
|
||||
- 跳过两边都存在的员工(不做任何修改)
|
||||
|
||||
Returns:
|
||||
同步结果统计
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始增量员工同步")
|
||||
logger.info("=" * 60)
|
||||
|
||||
stats = {
|
||||
'added_count': 0,
|
||||
'deleted_count': 0,
|
||||
'skipped_count': 0,
|
||||
'added_users': [],
|
||||
'deleted_users': [],
|
||||
'errors': [],
|
||||
'start_time': datetime.now()
|
||||
}
|
||||
|
||||
try:
|
||||
# 1. 获取钉钉在职员工数据
|
||||
dingtalk_employees = await self.fetch_employees_from_dingtalk()
|
||||
dingtalk_phones = {emp.get('phone') for emp in dingtalk_employees if emp.get('phone')}
|
||||
logger.info(f"钉钉在职员工数量: {len(dingtalk_phones)}")
|
||||
|
||||
# 2. 获取系统现有用户(排除admin和已软删除的)
|
||||
stmt = select(User).where(
|
||||
User.is_deleted == False,
|
||||
User.username != 'admin'
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
system_users = result.scalars().all()
|
||||
system_phones = {user.phone for user in system_users if user.phone}
|
||||
logger.info(f"系统现有员工数量(排除admin): {len(system_phones)}")
|
||||
|
||||
# 3. 计算需要新增、删除、跳过的员工
|
||||
phones_to_add = dingtalk_phones - system_phones
|
||||
phones_to_delete = system_phones - dingtalk_phones
|
||||
phones_to_skip = dingtalk_phones & system_phones
|
||||
|
||||
logger.info(f"待新增: {len(phones_to_add)}, 待删除: {len(phones_to_delete)}, 跳过: {len(phones_to_skip)}")
|
||||
|
||||
stats['skipped_count'] = len(phones_to_skip)
|
||||
|
||||
# 4. 新增员工
|
||||
for employee in dingtalk_employees:
|
||||
phone = employee.get('phone')
|
||||
if not phone or phone not in phones_to_add:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 创建用户
|
||||
user = await self.create_user(employee)
|
||||
if not user:
|
||||
continue
|
||||
|
||||
stats['added_count'] += 1
|
||||
stats['added_users'].append({
|
||||
'full_name': user.full_name,
|
||||
'phone': user.phone,
|
||||
'role': user.role
|
||||
})
|
||||
|
||||
# 创建部门团队
|
||||
department = employee.get('department')
|
||||
if department:
|
||||
team = await self.create_or_get_team(
|
||||
department,
|
||||
leader_id=user.id if employee.get('is_leader') else None
|
||||
)
|
||||
if team:
|
||||
# 用SQL直接插入user_teams关联表(避免懒加载问题)
|
||||
await self._add_user_to_team(user.id, team.id)
|
||||
logger.info(f"关联用户 {user.full_name} 到团队 {team.name}")
|
||||
|
||||
# 创建岗位
|
||||
position_name = employee.get('position')
|
||||
if position_name:
|
||||
position = await self.create_or_get_position(position_name)
|
||||
if position:
|
||||
# 检查是否已经关联
|
||||
stmt = select(PositionMember).where(
|
||||
PositionMember.position_id == position.id,
|
||||
PositionMember.user_id == user.id,
|
||||
PositionMember.is_deleted == False
|
||||
)
|
||||
result = await self.db.execute(stmt)
|
||||
existing_member = result.scalar_one_or_none()
|
||||
|
||||
if not existing_member:
|
||||
position_member = PositionMember(
|
||||
position_id=position.id,
|
||||
user_id=user.id,
|
||||
role='member'
|
||||
)
|
||||
self.db.add(position_member)
|
||||
logger.info(f"关联用户 {user.full_name} 到岗位 {position.name}")
|
||||
|
||||
logger.info(f"✅ 新增员工: {user.full_name} ({user.phone})")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"新增员工 {employee.get('full_name')} 失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
stats['errors'].append(error_msg)
|
||||
continue
|
||||
|
||||
# 5. 删除离职员工(物理删除)
|
||||
# 先flush之前的新增操作,避免与删除操作冲突
|
||||
await self.db.flush()
|
||||
|
||||
# 收集需要删除的用户ID
|
||||
users_to_delete = []
|
||||
for user in system_users:
|
||||
if user.phone and user.phone in phones_to_delete:
|
||||
# 双重保护:确保不删除admin
|
||||
if user.username == 'admin' or user.role == 'admin':
|
||||
logger.warning(f"⚠️ 跳过删除管理员账户: {user.username}")
|
||||
continue
|
||||
|
||||
users_to_delete.append({
|
||||
'id': user.id,
|
||||
'full_name': user.full_name,
|
||||
'phone': user.phone,
|
||||
'username': user.username
|
||||
})
|
||||
|
||||
# 批量删除用户及其关联数据
|
||||
for user_info in users_to_delete:
|
||||
try:
|
||||
user_id = user_info['id']
|
||||
|
||||
# 先清理关联数据(外键约束)
|
||||
await self._cleanup_user_related_data(user_id)
|
||||
|
||||
# 用SQL直接删除用户(避免ORM的级联操作冲突)
|
||||
await self.db.execute(
|
||||
text("DELETE FROM users WHERE id = :user_id"),
|
||||
{"user_id": user_id}
|
||||
)
|
||||
|
||||
stats['deleted_users'].append({
|
||||
'full_name': user_info['full_name'],
|
||||
'phone': user_info['phone'],
|
||||
'username': user_info['username']
|
||||
})
|
||||
stats['deleted_count'] += 1
|
||||
logger.info(f"🗑️ 删除离职员工: {user_info['full_name']} ({user_info['phone']})")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"删除员工 {user_info['full_name']} 失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
stats['errors'].append(error_msg)
|
||||
continue
|
||||
|
||||
# 6. 提交所有更改
|
||||
await self.db.commit()
|
||||
logger.info("✅ 数据库事务已提交")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"增量同步失败: {str(e)}")
|
||||
await self.db.rollback()
|
||||
stats['errors'].append(str(e))
|
||||
raise
|
||||
|
||||
finally:
|
||||
stats['end_time'] = datetime.now()
|
||||
stats['duration'] = (stats['end_time'] - stats['start_time']).total_seconds()
|
||||
|
||||
# 7. 输出统计信息
|
||||
logger.info("=" * 60)
|
||||
logger.info("增量同步完成统计")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"新增员工: {stats['added_count']}")
|
||||
logger.info(f"删除员工: {stats['deleted_count']}")
|
||||
logger.info(f"跳过员工: {stats['skipped_count']}")
|
||||
logger.info(f"耗时: {stats['duration']:.2f}秒")
|
||||
|
||||
if stats['errors']:
|
||||
logger.warning(f"错误数量: {len(stats['errors'])}")
|
||||
|
||||
return stats
|
||||
|
||||
486
backend/app/services/exam_report_service.py
Normal file
486
backend/app/services/exam_report_service.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
考试报告和错题统计服务
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, or_, desc, case, text
|
||||
from app.models.exam import Exam
|
||||
from app.models.exam_mistake import ExamMistake
|
||||
from app.models.course import Course, KnowledgePoint
|
||||
from app.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ExamReportService:
|
||||
"""考试报告服务类"""
|
||||
|
||||
@staticmethod
|
||||
async def get_exam_report(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取成绩报告汇总数据
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
start_date: 开始日期(YYYY-MM-DD)
|
||||
end_date: 结束日期(YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Dict: 包含overview、trends、subjects、recent_exams的完整报告数据
|
||||
"""
|
||||
logger.info(f"获取成绩报告 - user_id: {user_id}, start_date: {start_date}, end_date: {end_date}")
|
||||
|
||||
# 构建基础查询条件
|
||||
conditions = [Exam.user_id == user_id]
|
||||
|
||||
# 添加时间范围条件
|
||||
if start_date:
|
||||
conditions.append(Exam.start_time >= start_date)
|
||||
if end_date:
|
||||
# 结束日期包含当天全部
|
||||
end_datetime = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
|
||||
conditions.append(Exam.start_time < end_datetime)
|
||||
|
||||
# 1. 获取概览数据
|
||||
overview = await ExamReportService._get_overview(db, conditions)
|
||||
|
||||
# 2. 获取趋势数据(最近30天)
|
||||
trends = await ExamReportService._get_trends(db, user_id, conditions)
|
||||
|
||||
# 3. 获取科目分析
|
||||
subjects = await ExamReportService._get_subjects(db, conditions)
|
||||
|
||||
# 4. 获取最近考试记录
|
||||
recent_exams = await ExamReportService._get_recent_exams(db, conditions)
|
||||
|
||||
return {
|
||||
"overview": overview,
|
||||
"trends": trends,
|
||||
"subjects": subjects,
|
||||
"recent_exams": recent_exams
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _get_overview(db: AsyncSession, conditions: List) -> Dict[str, Any]:
|
||||
"""获取概览数据"""
|
||||
# 查询统计数据(使用round1_score作为主要成绩)
|
||||
stmt = select(
|
||||
func.count(Exam.id).label("total_exams"),
|
||||
func.avg(Exam.round1_score).label("avg_score"),
|
||||
func.sum(Exam.question_count).label("total_questions"),
|
||||
func.count(case((Exam.is_passed == True, 1))).label("passed_exams")
|
||||
).where(
|
||||
and_(*conditions, Exam.round1_score.isnot(None))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
stats = result.one()
|
||||
|
||||
total_exams = stats.total_exams or 0
|
||||
passed_exams = stats.passed_exams or 0
|
||||
|
||||
return {
|
||||
"avg_score": round(float(stats.avg_score or 0), 1),
|
||||
"total_exams": total_exams,
|
||||
"pass_rate": round((passed_exams / total_exams * 100) if total_exams > 0 else 0, 1),
|
||||
"total_questions": stats.total_questions or 0
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _get_trends(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
base_conditions: List
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取成绩趋势(最近30天)"""
|
||||
# 计算30天前的日期
|
||||
thirty_days_ago = datetime.now() - timedelta(days=30)
|
||||
|
||||
# 查询最近30天的考试数据,按日期分组
|
||||
stmt = select(
|
||||
func.date(Exam.start_time).label("exam_date"),
|
||||
func.avg(Exam.round1_score).label("avg_score")
|
||||
).where(
|
||||
and_(
|
||||
Exam.user_id == user_id,
|
||||
Exam.start_time >= thirty_days_ago,
|
||||
Exam.round1_score.isnot(None)
|
||||
)
|
||||
).group_by(
|
||||
func.date(Exam.start_time)
|
||||
).order_by(
|
||||
func.date(Exam.start_time)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
trend_data = result.all()
|
||||
|
||||
# 转换为前端需要的格式
|
||||
trends = []
|
||||
for row in trend_data:
|
||||
trends.append({
|
||||
"date": row.exam_date.strftime("%Y-%m-%d") if row.exam_date else "",
|
||||
"avg_score": round(float(row.avg_score or 0), 1)
|
||||
})
|
||||
|
||||
return trends
|
||||
|
||||
@staticmethod
|
||||
async def _get_subjects(db: AsyncSession, conditions: List) -> List[Dict[str, Any]]:
|
||||
"""获取科目分析"""
|
||||
# 关联course表,按课程分组统计
|
||||
stmt = select(
|
||||
Exam.course_id,
|
||||
Course.name.label("course_name"),
|
||||
func.avg(Exam.round1_score).label("avg_score"),
|
||||
func.count(Exam.id).label("exam_count"),
|
||||
func.max(Exam.round1_score).label("max_score"),
|
||||
func.min(Exam.round1_score).label("min_score"),
|
||||
func.count(case((Exam.is_passed == True, 1))).label("passed_count")
|
||||
).join(
|
||||
Course, Exam.course_id == Course.id
|
||||
).where(
|
||||
and_(*conditions, Exam.round1_score.isnot(None))
|
||||
).group_by(
|
||||
Exam.course_id, Course.name
|
||||
).order_by(
|
||||
desc(func.count(Exam.id))
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
subject_data = result.all()
|
||||
|
||||
subjects = []
|
||||
for row in subject_data:
|
||||
exam_count = row.exam_count or 0
|
||||
passed_count = row.passed_count or 0
|
||||
|
||||
subjects.append({
|
||||
"course_id": row.course_id,
|
||||
"course_name": row.course_name,
|
||||
"avg_score": round(float(row.avg_score or 0), 1),
|
||||
"exam_count": exam_count,
|
||||
"max_score": round(float(row.max_score or 0), 1),
|
||||
"min_score": round(float(row.min_score or 0), 1),
|
||||
"pass_rate": round((passed_count / exam_count * 100) if exam_count > 0 else 0, 1)
|
||||
})
|
||||
|
||||
return subjects
|
||||
|
||||
@staticmethod
|
||||
async def _get_recent_exams(db: AsyncSession, conditions: List) -> List[Dict[str, Any]]:
|
||||
"""获取最近10次考试记录"""
|
||||
# 查询最近10次考试,包含三轮得分
|
||||
stmt = select(
|
||||
Exam.id,
|
||||
Exam.course_id,
|
||||
Course.name.label("course_name"),
|
||||
Exam.score,
|
||||
Exam.total_score,
|
||||
Exam.is_passed,
|
||||
Exam.start_time,
|
||||
Exam.end_time,
|
||||
Exam.round1_score,
|
||||
Exam.round2_score,
|
||||
Exam.round3_score
|
||||
).join(
|
||||
Course, Exam.course_id == Course.id
|
||||
).where(
|
||||
and_(*conditions)
|
||||
).order_by(
|
||||
desc(Exam.created_at) # 改为按创建时间排序,避免start_time为NULL的问题
|
||||
).limit(10)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
exam_data = result.all()
|
||||
|
||||
recent_exams = []
|
||||
for row in exam_data:
|
||||
# 计算考试用时
|
||||
duration_seconds = None
|
||||
if row.start_time and row.end_time:
|
||||
duration_seconds = int((row.end_time - row.start_time).total_seconds())
|
||||
|
||||
recent_exams.append({
|
||||
"id": row.id,
|
||||
"course_id": row.course_id,
|
||||
"course_name": row.course_name,
|
||||
"score": round(float(row.score), 1) if row.score else None,
|
||||
"total_score": round(float(row.total_score or 100), 1),
|
||||
"is_passed": row.is_passed,
|
||||
"duration_seconds": duration_seconds,
|
||||
"start_time": row.start_time.isoformat() if row.start_time else None,
|
||||
"end_time": row.end_time.isoformat() if row.end_time else None,
|
||||
"round_scores": {
|
||||
"round1": round(float(row.round1_score), 1) if row.round1_score else None,
|
||||
"round2": round(float(row.round2_score), 1) if row.round2_score else None,
|
||||
"round3": round(float(row.round3_score), 1) if row.round3_score else None
|
||||
}
|
||||
})
|
||||
|
||||
return recent_exams
|
||||
|
||||
|
||||
class MistakeService:
|
||||
"""错题服务类"""
|
||||
|
||||
@staticmethod
|
||||
async def get_mistakes_list(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
exam_id: Optional[int] = None,
|
||||
course_id: Optional[int] = None,
|
||||
question_type: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
page: int = 1,
|
||||
size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取错题列表(支持多维度筛选)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
exam_id: 考试ID(可选)
|
||||
course_id: 课程ID(可选)
|
||||
question_type: 题型(可选)
|
||||
search: 关键词搜索(可选)
|
||||
start_date: 开始日期(可选)
|
||||
end_date: 结束日期(可选)
|
||||
page: 页码
|
||||
size: 每页数量
|
||||
|
||||
Returns:
|
||||
Dict: 包含items、total、page、size、pages的分页数据
|
||||
"""
|
||||
logger.info(f"获取错题列表 - user_id: {user_id}, exam_id: {exam_id}, course_id: {course_id}")
|
||||
|
||||
# 构建查询条件
|
||||
conditions = [ExamMistake.user_id == user_id]
|
||||
|
||||
if exam_id:
|
||||
conditions.append(ExamMistake.exam_id == exam_id)
|
||||
|
||||
if question_type:
|
||||
conditions.append(ExamMistake.question_type == question_type)
|
||||
|
||||
if search:
|
||||
conditions.append(ExamMistake.question_content.like(f"%{search}%"))
|
||||
|
||||
if start_date:
|
||||
conditions.append(ExamMistake.created_at >= start_date)
|
||||
|
||||
if end_date:
|
||||
end_datetime = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
|
||||
conditions.append(ExamMistake.created_at < end_datetime)
|
||||
|
||||
# 如果指定了course_id,需要通过exam关联
|
||||
if course_id:
|
||||
conditions.append(Exam.course_id == course_id)
|
||||
|
||||
# 查询总数
|
||||
count_stmt = select(func.count(ExamMistake.id)).select_from(ExamMistake).join(
|
||||
Exam, ExamMistake.exam_id == Exam.id
|
||||
).where(and_(*conditions))
|
||||
|
||||
total_result = await db.execute(count_stmt)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# 查询分页数据
|
||||
offset = (page - 1) * size
|
||||
|
||||
stmt = select(
|
||||
ExamMistake.id,
|
||||
ExamMistake.exam_id,
|
||||
Exam.course_id,
|
||||
Course.name.label("course_name"),
|
||||
ExamMistake.question_content,
|
||||
ExamMistake.correct_answer,
|
||||
ExamMistake.user_answer,
|
||||
ExamMistake.question_type,
|
||||
ExamMistake.knowledge_point_id,
|
||||
KnowledgePoint.name.label("knowledge_point_name"),
|
||||
ExamMistake.created_at
|
||||
).select_from(ExamMistake).join(
|
||||
Exam, ExamMistake.exam_id == Exam.id
|
||||
).join(
|
||||
Course, Exam.course_id == Course.id
|
||||
).outerjoin(
|
||||
KnowledgePoint, ExamMistake.knowledge_point_id == KnowledgePoint.id
|
||||
).where(
|
||||
and_(*conditions)
|
||||
).order_by(
|
||||
desc(ExamMistake.created_at)
|
||||
).offset(offset).limit(size)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
mistakes = result.all()
|
||||
|
||||
# 构建返回数据
|
||||
items = []
|
||||
for row in mistakes:
|
||||
items.append({
|
||||
"id": row.id,
|
||||
"exam_id": row.exam_id,
|
||||
"course_id": row.course_id,
|
||||
"course_name": row.course_name,
|
||||
"question_content": row.question_content,
|
||||
"correct_answer": row.correct_answer,
|
||||
"user_answer": row.user_answer,
|
||||
"question_type": row.question_type,
|
||||
"knowledge_point_id": row.knowledge_point_id,
|
||||
"knowledge_point_name": row.knowledge_point_name,
|
||||
"created_at": row.created_at
|
||||
})
|
||||
|
||||
pages = (total + size - 1) // size
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"size": size,
|
||||
"pages": pages
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def get_mistakes_statistics(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
course_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取错题统计数据
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict: 包含total、by_course、by_type、by_time的统计数据
|
||||
"""
|
||||
logger.info(f"获取错题统计 - user_id: {user_id}, course_id: {course_id}")
|
||||
|
||||
# 基础条件
|
||||
base_conditions = [ExamMistake.user_id == user_id]
|
||||
if course_id:
|
||||
base_conditions.append(Exam.course_id == course_id)
|
||||
|
||||
# 1. 总数统计
|
||||
count_stmt = select(func.count(ExamMistake.id)).select_from(ExamMistake).join(
|
||||
Exam, ExamMistake.exam_id == Exam.id
|
||||
).where(and_(*base_conditions))
|
||||
|
||||
total_result = await db.execute(count_stmt)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# 2. 按课程统计
|
||||
by_course_stmt = select(
|
||||
Exam.course_id,
|
||||
Course.name.label("course_name"),
|
||||
func.count(ExamMistake.id).label("count")
|
||||
).select_from(ExamMistake).join(
|
||||
Exam, ExamMistake.exam_id == Exam.id
|
||||
).join(
|
||||
Course, Exam.course_id == Course.id
|
||||
).where(
|
||||
ExamMistake.user_id == user_id
|
||||
).group_by(
|
||||
Exam.course_id, Course.name
|
||||
).order_by(
|
||||
desc(func.count(ExamMistake.id))
|
||||
)
|
||||
|
||||
by_course_result = await db.execute(by_course_stmt)
|
||||
by_course_data = by_course_result.all()
|
||||
|
||||
by_course = [
|
||||
{
|
||||
"course_id": row.course_id,
|
||||
"course_name": row.course_name,
|
||||
"count": row.count
|
||||
}
|
||||
for row in by_course_data
|
||||
]
|
||||
|
||||
# 3. 按题型统计
|
||||
by_type_stmt = select(
|
||||
ExamMistake.question_type,
|
||||
func.count(ExamMistake.id).label("count")
|
||||
).where(
|
||||
and_(ExamMistake.user_id == user_id, ExamMistake.question_type.isnot(None))
|
||||
).group_by(
|
||||
ExamMistake.question_type
|
||||
)
|
||||
|
||||
by_type_result = await db.execute(by_type_stmt)
|
||||
by_type_data = by_type_result.all()
|
||||
|
||||
# 题型名称映射
|
||||
type_names = {
|
||||
"single": "单选题",
|
||||
"multiple": "多选题",
|
||||
"judge": "判断题",
|
||||
"blank": "填空题",
|
||||
"essay": "问答题"
|
||||
}
|
||||
|
||||
by_type = [
|
||||
{
|
||||
"type": row.question_type,
|
||||
"type_name": type_names.get(row.question_type, "未知类型"),
|
||||
"count": row.count
|
||||
}
|
||||
for row in by_type_data
|
||||
]
|
||||
|
||||
# 4. 按时间统计
|
||||
now = datetime.now()
|
||||
week_ago = now - timedelta(days=7)
|
||||
month_ago = now - timedelta(days=30)
|
||||
quarter_ago = now - timedelta(days=90)
|
||||
|
||||
# 最近一周
|
||||
week_stmt = select(func.count(ExamMistake.id)).where(
|
||||
and_(ExamMistake.user_id == user_id, ExamMistake.created_at >= week_ago)
|
||||
)
|
||||
week_result = await db.execute(week_stmt)
|
||||
week_count = week_result.scalar() or 0
|
||||
|
||||
# 最近一月
|
||||
month_stmt = select(func.count(ExamMistake.id)).where(
|
||||
and_(ExamMistake.user_id == user_id, ExamMistake.created_at >= month_ago)
|
||||
)
|
||||
month_result = await db.execute(month_stmt)
|
||||
month_count = month_result.scalar() or 0
|
||||
|
||||
# 最近三月
|
||||
quarter_stmt = select(func.count(ExamMistake.id)).where(
|
||||
and_(ExamMistake.user_id == user_id, ExamMistake.created_at >= quarter_ago)
|
||||
)
|
||||
quarter_result = await db.execute(quarter_stmt)
|
||||
quarter_count = quarter_result.scalar() or 0
|
||||
|
||||
by_time = {
|
||||
"week": week_count,
|
||||
"month": month_count,
|
||||
"quarter": quarter_count
|
||||
}
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"by_course": by_course,
|
||||
"by_type": by_type,
|
||||
"by_time": by_time
|
||||
}
|
||||
|
||||
439
backend/app/services/exam_service.py
Normal file
439
backend/app/services/exam_service.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
考试服务层
|
||||
"""
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, or_, desc
|
||||
from app.models.exam import Exam, Question, ExamResult
|
||||
from app.models.exam_mistake import ExamMistake
|
||||
from app.models.course import Course, KnowledgePoint
|
||||
from app.core.exceptions import BusinessException, ErrorCode
|
||||
|
||||
|
||||
class ExamService:
|
||||
"""考试服务类"""
|
||||
|
||||
@staticmethod
|
||||
async def start_exam(
|
||||
db: AsyncSession, user_id: int, course_id: int, question_count: int = 10
|
||||
) -> Exam:
|
||||
"""
|
||||
开始考试
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID
|
||||
question_count: 题目数量
|
||||
|
||||
Returns:
|
||||
Exam: 考试实例
|
||||
"""
|
||||
# 检查课程是否存在
|
||||
course = await db.get(Course, course_id)
|
||||
if not course:
|
||||
raise BusinessException(error_code=ErrorCode.NOT_FOUND, message="课程不存在")
|
||||
|
||||
# 获取该课程的所有可用题目
|
||||
stmt = select(Question).where(
|
||||
and_(Question.course_id == course_id, Question.is_active == True)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
all_questions = result.scalars().all()
|
||||
|
||||
if not all_questions:
|
||||
raise BusinessException(
|
||||
error_code=ErrorCode.VALIDATION_ERROR, message="该课程暂无题目"
|
||||
)
|
||||
|
||||
# 随机选择题目
|
||||
selected_questions = random.sample(
|
||||
all_questions, min(question_count, len(all_questions))
|
||||
)
|
||||
|
||||
# 构建题目数据
|
||||
questions_data = []
|
||||
for q in selected_questions:
|
||||
question_data = {
|
||||
"id": str(q.id),
|
||||
"type": q.question_type,
|
||||
"title": q.title,
|
||||
"content": q.content,
|
||||
"options": q.options,
|
||||
"score": q.score,
|
||||
}
|
||||
questions_data.append(question_data)
|
||||
|
||||
# 创建考试记录
|
||||
exam = Exam(
|
||||
user_id=user_id,
|
||||
course_id=course_id,
|
||||
exam_name=f"{course.name} - 随机测试",
|
||||
question_count=len(selected_questions),
|
||||
total_score=sum(q.score for q in selected_questions),
|
||||
pass_score=sum(q.score for q in selected_questions) * 0.6,
|
||||
duration_minutes=60,
|
||||
status="started",
|
||||
questions={"questions": questions_data},
|
||||
)
|
||||
|
||||
db.add(exam)
|
||||
await db.commit()
|
||||
await db.refresh(exam)
|
||||
|
||||
return exam
|
||||
|
||||
@staticmethod
|
||||
async def submit_exam(
|
||||
db: AsyncSession, user_id: int, exam_id: int, answers: List[Dict[str, str]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
提交考试答案
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
exam_id: 考试ID
|
||||
answers: 答案列表
|
||||
|
||||
Returns:
|
||||
Dict: 考试结果
|
||||
"""
|
||||
# 获取考试记录
|
||||
stmt = select(Exam).where(and_(Exam.id == exam_id, Exam.user_id == user_id))
|
||||
result = await db.execute(stmt)
|
||||
exam = result.scalar_one_or_none()
|
||||
|
||||
if not exam:
|
||||
raise BusinessException(error_code=ErrorCode.NOT_FOUND, message="考试记录不存在")
|
||||
|
||||
if exam.status != "started":
|
||||
raise BusinessException(
|
||||
error_code=ErrorCode.VALIDATION_ERROR, message="考试已结束或已提交"
|
||||
)
|
||||
|
||||
# 检查考试是否超时
|
||||
if datetime.now() > exam.start_time + timedelta(
|
||||
minutes=exam.duration_minutes
|
||||
):
|
||||
exam.status = "timeout"
|
||||
await db.commit()
|
||||
raise BusinessException(
|
||||
error_code=ErrorCode.VALIDATION_ERROR, message="考试已超时"
|
||||
)
|
||||
|
||||
# 处理答案
|
||||
answers_dict = {ans["question_id"]: ans["answer"] for ans in answers}
|
||||
total_score = 0.0
|
||||
correct_count = 0
|
||||
|
||||
# 批量获取题目
|
||||
question_ids = [int(ans["question_id"]) for ans in answers]
|
||||
stmt = select(Question).where(Question.id.in_(question_ids))
|
||||
result = await db.execute(stmt)
|
||||
questions_map = {str(q.id): q for q in result.scalars().all()}
|
||||
|
||||
# 创建答题结果记录
|
||||
for question_data in exam.questions["questions"]:
|
||||
question_id = question_data["id"]
|
||||
question = questions_map.get(question_id)
|
||||
|
||||
if not question:
|
||||
continue
|
||||
|
||||
user_answer = answers_dict.get(question_id, "")
|
||||
is_correct = user_answer == question.correct_answer
|
||||
|
||||
if is_correct:
|
||||
total_score += question.score
|
||||
correct_count += 1
|
||||
|
||||
# 创建答题结果记录
|
||||
exam_result = ExamResult(
|
||||
exam_id=exam_id,
|
||||
question_id=int(question_id),
|
||||
user_answer=user_answer,
|
||||
is_correct=is_correct,
|
||||
score=question.score if is_correct else 0.0,
|
||||
)
|
||||
db.add(exam_result)
|
||||
|
||||
# 更新题目使用统计
|
||||
question.usage_count += 1
|
||||
if is_correct:
|
||||
question.correct_count += 1
|
||||
|
||||
# 更新考试记录
|
||||
exam.end_time = datetime.now()
|
||||
exam.score = total_score
|
||||
exam.is_passed = total_score >= exam.pass_score
|
||||
exam.status = "submitted"
|
||||
exam.answers = {"answers": answers}
|
||||
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"exam_id": exam_id,
|
||||
"total_score": total_score,
|
||||
"pass_score": exam.pass_score,
|
||||
"is_passed": exam.is_passed,
|
||||
"correct_count": correct_count,
|
||||
"total_count": exam.question_count,
|
||||
"accuracy": correct_count / exam.question_count
|
||||
if exam.question_count > 0
|
||||
else 0,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def get_exam_detail(
|
||||
db: AsyncSession, user_id: int, exam_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取考试详情
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
exam_id: 考试ID
|
||||
|
||||
Returns:
|
||||
Dict: 考试详情
|
||||
"""
|
||||
# 获取考试记录
|
||||
stmt = select(Exam).where(and_(Exam.id == exam_id, Exam.user_id == user_id))
|
||||
result = await db.execute(stmt)
|
||||
exam = result.scalar_one_or_none()
|
||||
|
||||
if not exam:
|
||||
raise BusinessException(error_code=ErrorCode.NOT_FOUND, message="考试记录不存在")
|
||||
|
||||
# 构建返回数据
|
||||
exam_data = {
|
||||
"id": exam.id,
|
||||
"course_id": exam.course_id,
|
||||
"exam_name": exam.exam_name,
|
||||
"question_count": exam.question_count,
|
||||
"total_score": exam.total_score,
|
||||
"pass_score": exam.pass_score,
|
||||
"start_time": exam.start_time.isoformat() if exam.start_time else None,
|
||||
"end_time": exam.end_time.isoformat() if exam.end_time else None,
|
||||
"duration_minutes": exam.duration_minutes,
|
||||
"status": exam.status,
|
||||
"score": exam.score,
|
||||
"is_passed": exam.is_passed,
|
||||
"questions": exam.questions,
|
||||
}
|
||||
|
||||
# 如果考试已提交,获取答题详情
|
||||
if exam.status == "submitted" and exam.answers:
|
||||
stmt = select(ExamResult).where(ExamResult.exam_id == exam_id)
|
||||
result = await db.execute(stmt)
|
||||
results = result.scalars().all()
|
||||
|
||||
results_data = []
|
||||
for r in results:
|
||||
results_data.append(
|
||||
{
|
||||
"question_id": r.question_id,
|
||||
"user_answer": r.user_answer,
|
||||
"is_correct": r.is_correct,
|
||||
"score": r.score,
|
||||
}
|
||||
)
|
||||
|
||||
exam_data["results"] = results_data
|
||||
exam_data["answers"] = exam.answers
|
||||
|
||||
return exam_data
|
||||
|
||||
@staticmethod
|
||||
async def get_exam_records(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
page: int = 1,
|
||||
size: int = 10,
|
||||
course_id: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取考试记录列表(包含统计数据)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
page: 页码
|
||||
size: 每页数量
|
||||
course_id: 课程ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict: 考试记录列表(包含统计信息)
|
||||
"""
|
||||
# 构建查询条件
|
||||
conditions = [Exam.user_id == user_id]
|
||||
if course_id:
|
||||
conditions.append(Exam.course_id == course_id)
|
||||
|
||||
# 查询总数
|
||||
count_stmt = select(func.count(Exam.id)).where(and_(*conditions))
|
||||
total = await db.scalar(count_stmt)
|
||||
|
||||
# 查询考试数据(JOIN courses表获取课程名称)
|
||||
offset = (page - 1) * size
|
||||
stmt = (
|
||||
select(Exam, Course.name.label("course_name"))
|
||||
.join(Course, Exam.course_id == Course.id)
|
||||
.where(and_(*conditions))
|
||||
.order_by(Exam.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(size)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
# 构建返回数据
|
||||
items = []
|
||||
for exam, course_name in rows:
|
||||
# 1. 计算用时
|
||||
duration_seconds = None
|
||||
if exam.start_time and exam.end_time:
|
||||
duration_seconds = int((exam.end_time - exam.start_time).total_seconds())
|
||||
|
||||
# 2. 统计错题数
|
||||
mistakes_stmt = select(func.count(ExamMistake.id)).where(
|
||||
ExamMistake.exam_id == exam.id
|
||||
)
|
||||
wrong_count = await db.scalar(mistakes_stmt) or 0
|
||||
|
||||
# 3. 计算正确数和正确率
|
||||
correct_count = exam.question_count - wrong_count if exam.question_count else 0
|
||||
accuracy = None
|
||||
if exam.question_count and exam.question_count > 0:
|
||||
accuracy = round((correct_count / exam.question_count) * 100, 1)
|
||||
|
||||
# 4. 分题型统计
|
||||
question_type_stats = []
|
||||
if exam.questions:
|
||||
try:
|
||||
# 解析questions JSON,统计每种题型的总数
|
||||
questions_data = json.loads(exam.questions) if isinstance(exam.questions, str) else exam.questions
|
||||
type_totals = {}
|
||||
type_scores = {} # 存储每种题型的总分
|
||||
|
||||
for q in questions_data:
|
||||
q_type = q.get("type", "unknown")
|
||||
q_score = q.get("score", 0)
|
||||
type_totals[q_type] = type_totals.get(q_type, 0) + 1
|
||||
type_scores[q_type] = type_scores.get(q_type, 0) + q_score
|
||||
|
||||
# 查询错题按题型统计
|
||||
mistakes_by_type_stmt = (
|
||||
select(ExamMistake.question_type, func.count(ExamMistake.id))
|
||||
.where(ExamMistake.exam_id == exam.id)
|
||||
.group_by(ExamMistake.question_type)
|
||||
)
|
||||
mistakes_by_type_result = await db.execute(mistakes_by_type_stmt)
|
||||
mistakes_by_type = dict(mistakes_by_type_result.all())
|
||||
|
||||
# 题型名称映射
|
||||
type_name_map = {
|
||||
"single": "单选题",
|
||||
"multiple": "多选题",
|
||||
"judge": "判断题",
|
||||
"blank": "填空题",
|
||||
"essay": "问答题"
|
||||
}
|
||||
|
||||
# 组装分题型统计
|
||||
for q_type, total in type_totals.items():
|
||||
wrong = mistakes_by_type.get(q_type, 0)
|
||||
correct = total - wrong
|
||||
type_accuracy = round((correct / total) * 100, 1) if total > 0 else 0
|
||||
|
||||
question_type_stats.append({
|
||||
"type": type_name_map.get(q_type, q_type),
|
||||
"type_code": q_type,
|
||||
"total": total,
|
||||
"correct": correct,
|
||||
"wrong": wrong,
|
||||
"accuracy": type_accuracy,
|
||||
"total_score": type_scores.get(q_type, 0)
|
||||
})
|
||||
except (json.JSONDecodeError, TypeError, KeyError) as e:
|
||||
# 如果JSON解析失败,返回空统计
|
||||
question_type_stats = []
|
||||
|
||||
items.append(
|
||||
{
|
||||
"id": exam.id,
|
||||
"course_id": exam.course_id,
|
||||
"course_name": course_name,
|
||||
"exam_name": exam.exam_name,
|
||||
"question_count": exam.question_count,
|
||||
"total_score": exam.total_score,
|
||||
"score": exam.score,
|
||||
"is_passed": exam.is_passed,
|
||||
"status": exam.status,
|
||||
"start_time": exam.start_time.isoformat() if exam.start_time else None,
|
||||
"end_time": exam.end_time.isoformat() if exam.end_time else None,
|
||||
"created_at": exam.created_at.isoformat(),
|
||||
# 统计字段
|
||||
"accuracy": accuracy,
|
||||
"correct_count": correct_count,
|
||||
"wrong_count": wrong_count,
|
||||
"duration_seconds": duration_seconds,
|
||||
"question_type_stats": question_type_stats,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"size": size,
|
||||
"pages": (total + size - 1) // size,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def get_exam_statistics(
|
||||
db: AsyncSession, user_id: int, course_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取考试统计信息
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict: 统计信息
|
||||
"""
|
||||
# 构建查询条件
|
||||
conditions = [Exam.user_id == user_id, Exam.status == "submitted"]
|
||||
if course_id:
|
||||
conditions.append(Exam.course_id == course_id)
|
||||
|
||||
# 查询统计数据
|
||||
stmt = select(
|
||||
func.count(Exam.id).label("total_exams"),
|
||||
func.count(func.nullif(Exam.is_passed, False)).label("passed_exams"),
|
||||
func.avg(Exam.score).label("avg_score"),
|
||||
func.max(Exam.score).label("max_score"),
|
||||
func.min(Exam.score).label("min_score"),
|
||||
).where(and_(*conditions))
|
||||
|
||||
result = await db.execute(stmt)
|
||||
stats = result.one()
|
||||
|
||||
return {
|
||||
"total_exams": stats.total_exams or 0,
|
||||
"passed_exams": stats.passed_exams or 0,
|
||||
"pass_rate": (stats.passed_exams / stats.total_exams * 100)
|
||||
if stats.total_exams > 0
|
||||
else 0,
|
||||
"avg_score": float(stats.avg_score or 0),
|
||||
"max_score": float(stats.max_score or 0),
|
||||
"min_score": float(stats.min_score or 0),
|
||||
}
|
||||
0
backend/app/services/external/__init__.py
vendored
Normal file
0
backend/app/services/external/__init__.py
vendored
Normal file
330
backend/app/services/notification_service.py
Normal file
330
backend/app/services/notification_service.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
站内消息通知服务
|
||||
提供通知的CRUD操作和业务逻辑
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
from sqlalchemy import select, and_, desc, func, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.logger import get_logger
|
||||
from app.models.notification import Notification
|
||||
from app.models.user import User
|
||||
from app.schemas.notification import (
|
||||
NotificationCreate,
|
||||
NotificationBatchCreate,
|
||||
NotificationResponse,
|
||||
NotificationType,
|
||||
)
|
||||
from app.services.base_service import BaseService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class NotificationService(BaseService[Notification]):
|
||||
"""
|
||||
站内消息通知服务
|
||||
|
||||
提供通知的创建、查询、标记已读等功能
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(Notification)
|
||||
|
||||
async def create_notification(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
notification_in: NotificationCreate
|
||||
) -> Notification:
|
||||
"""
|
||||
创建单个通知
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
notification_in: 通知创建数据
|
||||
|
||||
Returns:
|
||||
创建的通知对象
|
||||
"""
|
||||
notification = Notification(
|
||||
user_id=notification_in.user_id,
|
||||
title=notification_in.title,
|
||||
content=notification_in.content,
|
||||
type=notification_in.type.value if isinstance(notification_in.type, NotificationType) else notification_in.type,
|
||||
related_id=notification_in.related_id,
|
||||
related_type=notification_in.related_type,
|
||||
sender_id=notification_in.sender_id,
|
||||
is_read=False
|
||||
)
|
||||
|
||||
db.add(notification)
|
||||
await db.commit()
|
||||
await db.refresh(notification)
|
||||
|
||||
logger.info(
|
||||
"创建通知成功",
|
||||
notification_id=notification.id,
|
||||
user_id=notification_in.user_id,
|
||||
type=notification_in.type
|
||||
)
|
||||
|
||||
return notification
|
||||
|
||||
async def batch_create_notifications(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
batch_in: NotificationBatchCreate
|
||||
) -> List[Notification]:
|
||||
"""
|
||||
批量创建通知(发送给多个用户)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
batch_in: 批量通知创建数据
|
||||
|
||||
Returns:
|
||||
创建的通知列表
|
||||
"""
|
||||
notifications = []
|
||||
notification_type = batch_in.type.value if isinstance(batch_in.type, NotificationType) else batch_in.type
|
||||
|
||||
for user_id in batch_in.user_ids:
|
||||
notification = Notification(
|
||||
user_id=user_id,
|
||||
title=batch_in.title,
|
||||
content=batch_in.content,
|
||||
type=notification_type,
|
||||
related_id=batch_in.related_id,
|
||||
related_type=batch_in.related_type,
|
||||
sender_id=batch_in.sender_id,
|
||||
is_read=False
|
||||
)
|
||||
notifications.append(notification)
|
||||
db.add(notification)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# 刷新所有对象
|
||||
for notification in notifications:
|
||||
await db.refresh(notification)
|
||||
|
||||
logger.info(
|
||||
"批量创建通知成功",
|
||||
count=len(notifications),
|
||||
user_ids=batch_in.user_ids,
|
||||
type=batch_in.type
|
||||
)
|
||||
|
||||
return notifications
|
||||
|
||||
async def get_user_notifications(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
is_read: Optional[bool] = None,
|
||||
notification_type: Optional[str] = None
|
||||
) -> Tuple[List[NotificationResponse], int, int]:
|
||||
"""
|
||||
获取用户的通知列表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
skip: 跳过数量
|
||||
limit: 返回数量
|
||||
is_read: 是否已读筛选
|
||||
notification_type: 通知类型筛选
|
||||
|
||||
Returns:
|
||||
(通知列表, 总数, 未读数)
|
||||
"""
|
||||
# 构建基础查询条件
|
||||
conditions = [Notification.user_id == user_id]
|
||||
|
||||
if is_read is not None:
|
||||
conditions.append(Notification.is_read == is_read)
|
||||
|
||||
if notification_type:
|
||||
conditions.append(Notification.type == notification_type)
|
||||
|
||||
# 查询通知列表(带发送者信息)
|
||||
stmt = (
|
||||
select(Notification)
|
||||
.where(and_(*conditions))
|
||||
.order_by(desc(Notification.created_at))
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
notifications = result.scalars().all()
|
||||
|
||||
# 统计总数
|
||||
count_stmt = select(func.count()).select_from(Notification).where(and_(*conditions))
|
||||
total_result = await db.execute(count_stmt)
|
||||
total = total_result.scalar_one()
|
||||
|
||||
# 统计未读数
|
||||
unread_stmt = (
|
||||
select(func.count())
|
||||
.select_from(Notification)
|
||||
.where(and_(Notification.user_id == user_id, Notification.is_read == False))
|
||||
)
|
||||
unread_result = await db.execute(unread_stmt)
|
||||
unread_count = unread_result.scalar_one()
|
||||
|
||||
# 获取发送者信息
|
||||
sender_ids = [n.sender_id for n in notifications if n.sender_id]
|
||||
sender_names = {}
|
||||
if sender_ids:
|
||||
sender_stmt = select(User.id, User.full_name).where(User.id.in_(sender_ids))
|
||||
sender_result = await db.execute(sender_stmt)
|
||||
sender_names = {row[0]: row[1] for row in sender_result.fetchall()}
|
||||
|
||||
# 构建响应
|
||||
responses = []
|
||||
for notification in notifications:
|
||||
response = NotificationResponse(
|
||||
id=notification.id,
|
||||
user_id=notification.user_id,
|
||||
title=notification.title,
|
||||
content=notification.content,
|
||||
type=notification.type,
|
||||
is_read=notification.is_read,
|
||||
related_id=notification.related_id,
|
||||
related_type=notification.related_type,
|
||||
sender_id=notification.sender_id,
|
||||
sender_name=sender_names.get(notification.sender_id) if notification.sender_id else None,
|
||||
created_at=notification.created_at,
|
||||
updated_at=notification.updated_at
|
||||
)
|
||||
responses.append(response)
|
||||
|
||||
return responses, total, unread_count
|
||||
|
||||
async def get_unread_count(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: int
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
获取用户未读通知数量
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
(未读数, 总数)
|
||||
"""
|
||||
# 统计未读数
|
||||
unread_stmt = (
|
||||
select(func.count())
|
||||
.select_from(Notification)
|
||||
.where(and_(Notification.user_id == user_id, Notification.is_read == False))
|
||||
)
|
||||
unread_result = await db.execute(unread_stmt)
|
||||
unread_count = unread_result.scalar_one()
|
||||
|
||||
# 统计总数
|
||||
total_stmt = (
|
||||
select(func.count())
|
||||
.select_from(Notification)
|
||||
.where(Notification.user_id == user_id)
|
||||
)
|
||||
total_result = await db.execute(total_stmt)
|
||||
total = total_result.scalar_one()
|
||||
|
||||
return unread_count, total
|
||||
|
||||
async def mark_as_read(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
notification_ids: Optional[List[int]] = None
|
||||
) -> int:
|
||||
"""
|
||||
标记通知为已读
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
notification_ids: 通知ID列表,为空则标记全部
|
||||
|
||||
Returns:
|
||||
更新的数量
|
||||
"""
|
||||
conditions = [
|
||||
Notification.user_id == user_id,
|
||||
Notification.is_read == False
|
||||
]
|
||||
|
||||
if notification_ids:
|
||||
conditions.append(Notification.id.in_(notification_ids))
|
||||
|
||||
stmt = (
|
||||
update(Notification)
|
||||
.where(and_(*conditions))
|
||||
.values(is_read=True)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
|
||||
logger.info(
|
||||
"标记通知已读",
|
||||
user_id=user_id,
|
||||
notification_ids=notification_ids,
|
||||
updated_count=updated_count
|
||||
)
|
||||
|
||||
return updated_count
|
||||
|
||||
async def delete_notification(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
notification_id: int
|
||||
) -> bool:
|
||||
"""
|
||||
删除通知
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
notification_id: 通知ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
stmt = select(Notification).where(
|
||||
and_(
|
||||
Notification.id == notification_id,
|
||||
Notification.user_id == user_id
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
notification = result.scalar_one_or_none()
|
||||
|
||||
if notification:
|
||||
await db.delete(notification)
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
"删除通知成功",
|
||||
notification_id=notification_id,
|
||||
user_id=user_id
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# 创建服务实例
|
||||
notification_service = NotificationService()
|
||||
|
||||
356
backend/app/services/scrm_service.py
Normal file
356
backend/app/services/scrm_service.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""
|
||||
SCRM 系统对接服务
|
||||
|
||||
提供给 SCRM 系统调用的数据查询服务
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.position import Position
|
||||
from app.models.position_member import PositionMember
|
||||
from app.models.position_course import PositionCourse
|
||||
from app.models.course import Course, KnowledgePoint, CourseMaterial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SCRMService:
|
||||
"""SCRM 系统数据查询服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_employee_position(
|
||||
self,
|
||||
userid: Optional[str] = None,
|
||||
name: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据企微 userid 或员工姓名获取员工岗位信息
|
||||
|
||||
Args:
|
||||
userid: 企微员工 userid(可选)
|
||||
name: 员工姓名(可选,支持模糊匹配)
|
||||
|
||||
Returns:
|
||||
员工岗位信息字典,包含 employee_id, userid, name, positions
|
||||
如果员工不存在返回 None
|
||||
如果按姓名搜索有多个结果,返回列表
|
||||
"""
|
||||
query = (
|
||||
select(User)
|
||||
.options(selectinload(User.position_memberships).selectinload(PositionMember.position))
|
||||
.where(User.is_deleted.is_(False))
|
||||
)
|
||||
|
||||
# 优先按 wework_userid 精确匹配
|
||||
if userid:
|
||||
query = query.where(User.wework_userid == userid)
|
||||
result = await self.db.execute(query)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user:
|
||||
return self._build_employee_position_data(user)
|
||||
|
||||
# 其次按姓名匹配(支持精确匹配和模糊匹配)
|
||||
if name:
|
||||
# 先尝试精确匹配
|
||||
exact_query = query.where(User.full_name == name)
|
||||
result = await self.db.execute(exact_query)
|
||||
users = result.scalars().all()
|
||||
|
||||
# 如果精确匹配没有结果,尝试模糊匹配
|
||||
if not users:
|
||||
fuzzy_query = query.where(User.full_name.ilike(f"%{name}%"))
|
||||
result = await self.db.execute(fuzzy_query)
|
||||
users = result.scalars().all()
|
||||
|
||||
if len(users) == 1:
|
||||
return self._build_employee_position_data(users[0])
|
||||
elif len(users) > 1:
|
||||
# 多个匹配结果,返回列表供选择
|
||||
return {
|
||||
"multiple_matches": True,
|
||||
"count": len(users),
|
||||
"employees": [
|
||||
{
|
||||
"employee_id": u.id,
|
||||
"userid": u.wework_userid,
|
||||
"name": u.full_name or u.username,
|
||||
"phone": u.phone[-4:] if u.phone else None # 只显示手机号后4位
|
||||
}
|
||||
for u in users
|
||||
]
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _build_employee_position_data(self, user: User) -> Dict[str, Any]:
|
||||
"""构建员工岗位数据"""
|
||||
positions = []
|
||||
for i, pm in enumerate(user.position_memberships):
|
||||
if pm.is_deleted or pm.position.is_deleted:
|
||||
continue
|
||||
positions.append({
|
||||
"position_id": pm.position.id,
|
||||
"position_name": pm.position.name,
|
||||
"is_primary": i == 0, # 第一个为主岗位
|
||||
"joined_at": pm.joined_at.strftime("%Y-%m-%d") if pm.joined_at else None
|
||||
})
|
||||
|
||||
return {
|
||||
"employee_id": user.id,
|
||||
"userid": user.wework_userid,
|
||||
"name": user.full_name or user.username,
|
||||
"positions": positions
|
||||
}
|
||||
|
||||
async def get_employee_position_by_id(self, employee_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
根据员工ID获取岗位信息
|
||||
|
||||
Args:
|
||||
employee_id: 员工ID(users表主键)
|
||||
|
||||
Returns:
|
||||
员工岗位信息字典
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(User)
|
||||
.options(selectinload(User.position_memberships).selectinload(PositionMember.position))
|
||||
.where(User.id == employee_id, User.is_deleted == False)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
return self._build_employee_position_data(user)
|
||||
|
||||
async def get_position_courses(
|
||||
self,
|
||||
position_id: int,
|
||||
course_type: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定岗位的课程列表
|
||||
|
||||
Args:
|
||||
position_id: 岗位ID
|
||||
course_type: 课程类型筛选(required/optional/all)
|
||||
|
||||
Returns:
|
||||
岗位课程信息字典,包含 position_id, position_name, courses
|
||||
如果岗位不存在返回 None
|
||||
"""
|
||||
# 查询岗位
|
||||
position_result = await self.db.execute(
|
||||
select(Position).where(Position.id == position_id, Position.is_deleted.is_(False))
|
||||
)
|
||||
position = position_result.scalar_one_or_none()
|
||||
|
||||
if not position:
|
||||
return None
|
||||
|
||||
# 查询岗位课程关联
|
||||
query = (
|
||||
select(PositionCourse, Course)
|
||||
.join(Course, PositionCourse.course_id == Course.id)
|
||||
.where(
|
||||
PositionCourse.position_id == position_id,
|
||||
PositionCourse.is_deleted.is_(False),
|
||||
Course.is_deleted.is_(False),
|
||||
Course.status == "published" # 只返回已发布的课程
|
||||
)
|
||||
.order_by(PositionCourse.priority.desc())
|
||||
)
|
||||
|
||||
# 课程类型筛选
|
||||
if course_type and course_type != "all":
|
||||
query = query.where(PositionCourse.course_type == course_type)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
pc_courses = result.all()
|
||||
|
||||
# 构建课程列表,并统计知识点数量
|
||||
courses = []
|
||||
for pc, course in pc_courses:
|
||||
# 统计该课程的知识点数量
|
||||
kp_count_result = await self.db.execute(
|
||||
select(func.count(KnowledgePoint.id))
|
||||
.where(
|
||||
KnowledgePoint.course_id == course.id,
|
||||
KnowledgePoint.is_deleted.is_(False)
|
||||
)
|
||||
)
|
||||
kp_count = kp_count_result.scalar() or 0
|
||||
|
||||
courses.append({
|
||||
"course_id": course.id,
|
||||
"course_name": course.name,
|
||||
"course_type": pc.course_type,
|
||||
"priority": pc.priority,
|
||||
"knowledge_point_count": kp_count
|
||||
})
|
||||
|
||||
return {
|
||||
"position_id": position.id,
|
||||
"position_name": position.name,
|
||||
"courses": courses
|
||||
}
|
||||
|
||||
async def search_knowledge_points(
|
||||
self,
|
||||
keywords: List[str],
|
||||
position_id: Optional[int] = None,
|
||||
course_ids: Optional[List[int]] = None,
|
||||
knowledge_type: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
搜索知识点
|
||||
|
||||
Args:
|
||||
keywords: 搜索关键词列表
|
||||
position_id: 岗位ID(用于优先排序)
|
||||
course_ids: 限定课程范围
|
||||
knowledge_type: 知识点类型筛选
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
搜索结果字典,包含 total 和 items
|
||||
"""
|
||||
# 基础查询
|
||||
query = (
|
||||
select(KnowledgePoint, Course)
|
||||
.join(Course, KnowledgePoint.course_id == Course.id)
|
||||
.where(
|
||||
KnowledgePoint.is_deleted.is_(False),
|
||||
Course.is_deleted.is_(False),
|
||||
Course.status == "published"
|
||||
)
|
||||
)
|
||||
|
||||
# 关键词搜索条件(在名称和描述中搜索)
|
||||
keyword_conditions = []
|
||||
for keyword in keywords:
|
||||
keyword_conditions.append(
|
||||
or_(
|
||||
KnowledgePoint.name.ilike(f"%{keyword}%"),
|
||||
KnowledgePoint.description.ilike(f"%{keyword}%")
|
||||
)
|
||||
)
|
||||
if keyword_conditions:
|
||||
query = query.where(or_(*keyword_conditions))
|
||||
|
||||
# 课程范围筛选
|
||||
if course_ids:
|
||||
query = query.where(KnowledgePoint.course_id.in_(course_ids))
|
||||
|
||||
# 知识点类型筛选
|
||||
if knowledge_type:
|
||||
query = query.where(KnowledgePoint.type == knowledge_type)
|
||||
|
||||
# 如果指定了岗位,优先返回该岗位相关课程的知识点
|
||||
if position_id:
|
||||
# 获取该岗位的课程ID列表
|
||||
pos_course_result = await self.db.execute(
|
||||
select(PositionCourse.course_id)
|
||||
.where(
|
||||
PositionCourse.position_id == position_id,
|
||||
PositionCourse.is_deleted.is_(False)
|
||||
)
|
||||
)
|
||||
pos_course_ids = [row[0] for row in pos_course_result.all()]
|
||||
|
||||
if pos_course_ids:
|
||||
# 使用 CASE WHEN 进行排序:岗位相关课程优先
|
||||
from sqlalchemy import case
|
||||
priority_order = case(
|
||||
(KnowledgePoint.course_id.in_(pos_course_ids), 0),
|
||||
else_=1
|
||||
)
|
||||
query = query.order_by(priority_order, KnowledgePoint.id.desc())
|
||||
else:
|
||||
query = query.order_by(KnowledgePoint.id.desc())
|
||||
else:
|
||||
query = query.order_by(KnowledgePoint.id.desc())
|
||||
|
||||
# 执行查询
|
||||
result = await self.db.execute(query.limit(limit))
|
||||
kp_courses = result.all()
|
||||
|
||||
# 计算相关度分数(简单实现:匹配的关键词越多分数越高)
|
||||
def calc_relevance(kp: KnowledgePoint) -> float:
|
||||
text = f"{kp.name} {kp.description or ''}"
|
||||
matched = sum(1 for kw in keywords if kw.lower() in text.lower())
|
||||
return round(matched / len(keywords), 2) if keywords else 1.0
|
||||
|
||||
# 构建结果
|
||||
items = []
|
||||
for kp, course in kp_courses:
|
||||
items.append({
|
||||
"knowledge_point_id": kp.id,
|
||||
"name": kp.name,
|
||||
"course_id": course.id,
|
||||
"course_name": course.name,
|
||||
"type": kp.type,
|
||||
"relevance_score": calc_relevance(kp)
|
||||
})
|
||||
|
||||
# 按相关度分数排序
|
||||
items.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||||
|
||||
return {
|
||||
"total": len(items),
|
||||
"items": items
|
||||
}
|
||||
|
||||
async def get_knowledge_point_detail(self, kp_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取知识点详情
|
||||
|
||||
Args:
|
||||
kp_id: 知识点ID
|
||||
|
||||
Returns:
|
||||
知识点详情字典
|
||||
如果知识点不存在返回 None
|
||||
"""
|
||||
# 查询知识点及关联的课程和资料
|
||||
result = await self.db.execute(
|
||||
select(KnowledgePoint, Course, CourseMaterial)
|
||||
.join(Course, KnowledgePoint.course_id == Course.id)
|
||||
.outerjoin(CourseMaterial, KnowledgePoint.material_id == CourseMaterial.id)
|
||||
.where(
|
||||
KnowledgePoint.id == kp_id,
|
||||
KnowledgePoint.is_deleted.is_(False)
|
||||
)
|
||||
)
|
||||
row = result.one_or_none()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
kp, course, material = row
|
||||
|
||||
return {
|
||||
"knowledge_point_id": kp.id,
|
||||
"name": kp.name,
|
||||
"course_id": course.id,
|
||||
"course_name": course.name,
|
||||
"type": kp.type,
|
||||
"content": kp.description or "", # description 作为知识点内容
|
||||
"material_id": material.id if material else None,
|
||||
"material_type": material.file_type if material else None,
|
||||
"material_url": material.file_url if material else None,
|
||||
"topic_relation": kp.topic_relation,
|
||||
"source": kp.source,
|
||||
"created_at": kp.created_at.strftime("%Y-%m-%d %H:%M:%S") if kp.created_at else None
|
||||
}
|
||||
|
||||
708
backend/app/services/statistics_service.py
Normal file
708
backend/app/services/statistics_service.py
Normal file
@@ -0,0 +1,708 @@
|
||||
"""
|
||||
统计分析服务
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_, or_, case, desc, distinct
|
||||
from app.models.exam import Exam, Question
|
||||
from app.models.exam_mistake import ExamMistake
|
||||
from app.models.course import Course, KnowledgePoint
|
||||
from app.models.practice import PracticeSession
|
||||
from app.models.training import TrainingSession
|
||||
from app.core.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StatisticsService:
|
||||
"""统计分析服务类"""
|
||||
|
||||
@staticmethod
|
||||
def _get_date_range(period: str) -> Tuple[datetime, datetime]:
|
||||
"""
|
||||
根据period返回开始和结束日期
|
||||
|
||||
Args:
|
||||
period: 时间范围 (week/month/quarter/halfYear/year)
|
||||
|
||||
Returns:
|
||||
Tuple[datetime, datetime]: (开始日期, 结束日期)
|
||||
"""
|
||||
end_date = datetime.now()
|
||||
|
||||
if period == "week":
|
||||
start_date = end_date - timedelta(days=7)
|
||||
elif period == "month":
|
||||
start_date = end_date - timedelta(days=30)
|
||||
elif period == "quarter":
|
||||
start_date = end_date - timedelta(days=90)
|
||||
elif period == "halfYear":
|
||||
start_date = end_date - timedelta(days=180)
|
||||
elif period == "year":
|
||||
start_date = end_date - timedelta(days=365)
|
||||
else:
|
||||
# 默认一个月
|
||||
start_date = end_date - timedelta(days=30)
|
||||
|
||||
return start_date, end_date
|
||||
|
||||
@staticmethod
|
||||
def _calculate_change_rate(current: float, previous: float) -> float:
|
||||
"""
|
||||
计算环比变化率
|
||||
|
||||
Args:
|
||||
current: 当前值
|
||||
previous: 上期值
|
||||
|
||||
Returns:
|
||||
float: 变化率(百分比)
|
||||
"""
|
||||
if previous == 0:
|
||||
return 0 if current == 0 else 100
|
||||
return round(((current - previous) / previous) * 100, 1)
|
||||
|
||||
@staticmethod
|
||||
async def get_key_metrics(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
course_id: Optional[int] = None,
|
||||
period: str = "month"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取关键指标
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID(可选)
|
||||
period: 时间范围
|
||||
|
||||
Returns:
|
||||
Dict: 包含学习效率、知识覆盖率、平均用时、进步速度的指标数据
|
||||
"""
|
||||
logger.info(f"获取关键指标 - user_id: {user_id}, course_id: {course_id}, period: {period}")
|
||||
|
||||
start_date, end_date = StatisticsService._get_date_range(period)
|
||||
|
||||
# 构建基础查询条件
|
||||
exam_conditions = [
|
||||
Exam.user_id == user_id,
|
||||
Exam.start_time >= start_date,
|
||||
Exam.start_time <= end_date,
|
||||
Exam.round1_score.isnot(None)
|
||||
]
|
||||
if course_id:
|
||||
exam_conditions.append(Exam.course_id == course_id)
|
||||
|
||||
# 1. 学习效率 = (总题数 - 错题数) / 总题数
|
||||
# 获取总题数
|
||||
total_questions_stmt = select(
|
||||
func.coalesce(func.sum(Exam.question_count), 0)
|
||||
).where(and_(*exam_conditions))
|
||||
total_questions = await db.scalar(total_questions_stmt) or 0
|
||||
|
||||
# 获取错题数
|
||||
mistake_conditions = [ExamMistake.user_id == user_id]
|
||||
if course_id:
|
||||
mistake_conditions.append(
|
||||
ExamMistake.exam_id.in_(
|
||||
select(Exam.id).where(Exam.course_id == course_id)
|
||||
)
|
||||
)
|
||||
mistake_stmt = select(func.count(ExamMistake.id)).where(
|
||||
and_(*mistake_conditions)
|
||||
)
|
||||
mistake_count = await db.scalar(mistake_stmt) or 0
|
||||
|
||||
# 计算学习效率
|
||||
learning_efficiency = 0.0
|
||||
if total_questions > 0:
|
||||
correct_questions = total_questions - mistake_count
|
||||
learning_efficiency = round((correct_questions / total_questions) * 100, 1)
|
||||
|
||||
# 计算上期学习效率(用于环比)
|
||||
prev_start_date = start_date - (end_date - start_date)
|
||||
prev_exam_conditions = [
|
||||
Exam.user_id == user_id,
|
||||
Exam.start_time >= prev_start_date,
|
||||
Exam.start_time < start_date,
|
||||
Exam.round1_score.isnot(None)
|
||||
]
|
||||
if course_id:
|
||||
prev_exam_conditions.append(Exam.course_id == course_id)
|
||||
|
||||
prev_total_questions = await db.scalar(
|
||||
select(func.coalesce(func.sum(Exam.question_count), 0)).where(
|
||||
and_(*prev_exam_conditions)
|
||||
)
|
||||
) or 0
|
||||
|
||||
prev_mistake_count = await db.scalar(
|
||||
select(func.count(ExamMistake.id)).where(
|
||||
and_(
|
||||
ExamMistake.user_id == user_id,
|
||||
ExamMistake.exam_id.in_(
|
||||
select(Exam.id).where(and_(*prev_exam_conditions))
|
||||
)
|
||||
)
|
||||
)
|
||||
) or 0
|
||||
|
||||
prev_efficiency = 0.0
|
||||
if prev_total_questions > 0:
|
||||
prev_correct = prev_total_questions - prev_mistake_count
|
||||
prev_efficiency = (prev_correct / prev_total_questions) * 100
|
||||
|
||||
efficiency_change = StatisticsService._calculate_change_rate(
|
||||
learning_efficiency, prev_efficiency
|
||||
)
|
||||
|
||||
# 2. 知识覆盖率 = 已掌握知识点数 / 总知识点数
|
||||
# 获取总知识点数
|
||||
kp_conditions = []
|
||||
if course_id:
|
||||
kp_conditions.append(KnowledgePoint.course_id == course_id)
|
||||
|
||||
total_kp_stmt = select(func.count(KnowledgePoint.id)).where(
|
||||
and_(KnowledgePoint.is_deleted == False, *kp_conditions)
|
||||
)
|
||||
total_kp = await db.scalar(total_kp_stmt) or 0
|
||||
|
||||
# 获取错误的知识点数(至少错过一次的)
|
||||
mistake_kp_stmt = select(
|
||||
func.count(distinct(ExamMistake.knowledge_point_id))
|
||||
).where(
|
||||
and_(
|
||||
ExamMistake.user_id == user_id,
|
||||
ExamMistake.knowledge_point_id.isnot(None),
|
||||
*([ExamMistake.exam_id.in_(
|
||||
select(Exam.id).where(Exam.course_id == course_id)
|
||||
)] if course_id else [])
|
||||
)
|
||||
)
|
||||
mistake_kp = await db.scalar(mistake_kp_stmt) or 0
|
||||
|
||||
# 计算知识覆盖率(掌握的知识点 = 总知识点 - 错误知识点)
|
||||
knowledge_coverage = 0.0
|
||||
if total_kp > 0:
|
||||
mastered_kp = max(0, total_kp - mistake_kp)
|
||||
knowledge_coverage = round((mastered_kp / total_kp) * 100, 1)
|
||||
|
||||
# 上期知识覆盖率(简化:假设知识点总数不变)
|
||||
prev_mistake_kp = await db.scalar(
|
||||
select(func.count(distinct(ExamMistake.knowledge_point_id))).where(
|
||||
and_(
|
||||
ExamMistake.user_id == user_id,
|
||||
ExamMistake.knowledge_point_id.isnot(None),
|
||||
ExamMistake.exam_id.in_(
|
||||
select(Exam.id).where(and_(*prev_exam_conditions))
|
||||
)
|
||||
)
|
||||
)
|
||||
) or 0
|
||||
|
||||
prev_coverage = 0.0
|
||||
if total_kp > 0:
|
||||
prev_mastered = max(0, total_kp - prev_mistake_kp)
|
||||
prev_coverage = (prev_mastered / total_kp) * 100
|
||||
|
||||
coverage_change = StatisticsService._calculate_change_rate(
|
||||
knowledge_coverage, prev_coverage
|
||||
)
|
||||
|
||||
# 3. 平均用时 = 总考试时长 / 总题数
|
||||
total_duration_stmt = select(
|
||||
func.coalesce(func.sum(Exam.duration_minutes), 0)
|
||||
).where(and_(*exam_conditions))
|
||||
total_duration = await db.scalar(total_duration_stmt) or 0
|
||||
|
||||
avg_time_per_question = 0.0
|
||||
if total_questions > 0:
|
||||
avg_time_per_question = round((total_duration / total_questions), 1)
|
||||
|
||||
# 上期平均用时
|
||||
prev_total_duration = await db.scalar(
|
||||
select(func.coalesce(func.sum(Exam.duration_minutes), 0)).where(
|
||||
and_(*prev_exam_conditions)
|
||||
)
|
||||
) or 0
|
||||
|
||||
prev_avg_time = 0.0
|
||||
if prev_total_questions > 0:
|
||||
prev_avg_time = prev_total_duration / prev_total_questions
|
||||
|
||||
# 平均用时的环比是负增长表示好(时间减少)
|
||||
time_change = StatisticsService._calculate_change_rate(
|
||||
avg_time_per_question, prev_avg_time
|
||||
)
|
||||
|
||||
# 4. 进步速度 = (本期平均分 - 上期平均分) / 上期平均分
|
||||
avg_score_stmt = select(func.avg(Exam.round1_score)).where(
|
||||
and_(*exam_conditions)
|
||||
)
|
||||
avg_score = await db.scalar(avg_score_stmt) or 0
|
||||
|
||||
prev_avg_score_stmt = select(func.avg(Exam.round1_score)).where(
|
||||
and_(*prev_exam_conditions)
|
||||
)
|
||||
prev_avg_score = await db.scalar(prev_avg_score_stmt) or 0
|
||||
|
||||
progress_speed = StatisticsService._calculate_change_rate(
|
||||
float(avg_score), float(prev_avg_score)
|
||||
)
|
||||
|
||||
return {
|
||||
"learningEfficiency": {
|
||||
"value": learning_efficiency,
|
||||
"unit": "%",
|
||||
"change": efficiency_change,
|
||||
"description": "正确题数/总练习题数"
|
||||
},
|
||||
"knowledgeCoverage": {
|
||||
"value": knowledge_coverage,
|
||||
"unit": "%",
|
||||
"change": coverage_change,
|
||||
"description": "已掌握知识点/总知识点"
|
||||
},
|
||||
"avgTimePerQuestion": {
|
||||
"value": avg_time_per_question,
|
||||
"unit": "分/题",
|
||||
"change": time_change,
|
||||
"description": "平均每道题的答题时间"
|
||||
},
|
||||
"progressSpeed": {
|
||||
"value": abs(progress_speed),
|
||||
"unit": "%",
|
||||
"change": progress_speed,
|
||||
"description": "成绩提升速度"
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def get_score_distribution(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
course_id: Optional[int] = None,
|
||||
period: str = "month"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取成绩分布统计
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID(可选)
|
||||
period: 时间范围
|
||||
|
||||
Returns:
|
||||
Dict: 成绩分布数据(优秀、良好、中等、及格、不及格)
|
||||
"""
|
||||
logger.info(f"获取成绩分布 - user_id: {user_id}, course_id: {course_id}, period: {period}")
|
||||
|
||||
start_date, end_date = StatisticsService._get_date_range(period)
|
||||
|
||||
# 构建查询条件
|
||||
conditions = [
|
||||
Exam.user_id == user_id,
|
||||
Exam.start_time >= start_date,
|
||||
Exam.start_time <= end_date,
|
||||
Exam.round1_score.isnot(None)
|
||||
]
|
||||
if course_id:
|
||||
conditions.append(Exam.course_id == course_id)
|
||||
|
||||
# 使用case when统计各分数段的数量
|
||||
stmt = select(
|
||||
func.count(case((Exam.round1_score >= 90, 1))).label("excellent"),
|
||||
func.count(case((and_(Exam.round1_score >= 80, Exam.round1_score < 90), 1))).label("good"),
|
||||
func.count(case((and_(Exam.round1_score >= 70, Exam.round1_score < 80), 1))).label("medium"),
|
||||
func.count(case((and_(Exam.round1_score >= 60, Exam.round1_score < 70), 1))).label("pass_count"),
|
||||
func.count(case((Exam.round1_score < 60, 1))).label("fail")
|
||||
).where(and_(*conditions))
|
||||
|
||||
result = await db.execute(stmt)
|
||||
row = result.one()
|
||||
|
||||
return {
|
||||
"excellent": row.excellent or 0, # 优秀(90-100)
|
||||
"good": row.good or 0, # 良好(80-89)
|
||||
"medium": row.medium or 0, # 中等(70-79)
|
||||
"pass": row.pass_count or 0, # 及格(60-69)
|
||||
"fail": row.fail or 0 # 不及格(<60)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def get_difficulty_analysis(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
course_id: Optional[int] = None,
|
||||
period: str = "month"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取题目难度分析
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID(可选)
|
||||
period: 时间范围
|
||||
|
||||
Returns:
|
||||
Dict: 各难度题目的正确率统计
|
||||
"""
|
||||
logger.info(f"获取难度分析 - user_id: {user_id}, course_id: {course_id}, period: {period}")
|
||||
|
||||
start_date, end_date = StatisticsService._get_date_range(period)
|
||||
|
||||
# 获取用户在时间范围内的考试
|
||||
exam_conditions = [
|
||||
Exam.user_id == user_id,
|
||||
Exam.start_time >= start_date,
|
||||
Exam.start_time <= end_date
|
||||
]
|
||||
if course_id:
|
||||
exam_conditions.append(Exam.course_id == course_id)
|
||||
|
||||
exam_ids_stmt = select(Exam.id).where(and_(*exam_conditions))
|
||||
result = await db.execute(exam_ids_stmt)
|
||||
exam_ids = [row[0] for row in result.all()]
|
||||
|
||||
if not exam_ids:
|
||||
# 没有考试数据,返回默认值
|
||||
return {
|
||||
"easy": 100.0,
|
||||
"medium": 100.0,
|
||||
"hard": 100.0,
|
||||
"综合题": 100.0,
|
||||
"应用题": 100.0
|
||||
}
|
||||
|
||||
# 统计各难度的总题数和错题数
|
||||
difficulty_stats = {}
|
||||
|
||||
for difficulty in ["easy", "medium", "hard"]:
|
||||
# 总题数:从exams的questions字段中统计(这里简化处理)
|
||||
# 由于questions字段是JSON,我们通过question_count估算
|
||||
# 实际应用中可以解析JSON或通过exam_results表统计
|
||||
|
||||
# 错题数:从exam_mistakes通过question_id关联查询
|
||||
mistake_stmt = select(func.count(ExamMistake.id)).select_from(
|
||||
ExamMistake
|
||||
).join(
|
||||
Question, ExamMistake.question_id == Question.id
|
||||
).where(
|
||||
and_(
|
||||
ExamMistake.user_id == user_id,
|
||||
ExamMistake.exam_id.in_(exam_ids),
|
||||
Question.difficulty == difficulty
|
||||
)
|
||||
)
|
||||
|
||||
mistake_count = await db.scalar(mistake_stmt) or 0
|
||||
|
||||
# 总题数:该难度的题目在用户考试中出现的次数
|
||||
# 简化处理:假设每次考试平均包含该难度题目的比例
|
||||
total_questions_stmt = select(
|
||||
func.coalesce(func.sum(Exam.question_count), 0)
|
||||
).where(and_(*exam_conditions))
|
||||
total_count = await db.scalar(total_questions_stmt) or 0
|
||||
total_count = int(total_count) # 转换为int避免Decimal类型问题
|
||||
|
||||
# 简化算法:假设easy:medium:hard = 3:2:1
|
||||
if difficulty == "easy":
|
||||
estimated_count = int(total_count * 0.5)
|
||||
elif difficulty == "medium":
|
||||
estimated_count = int(total_count * 0.3)
|
||||
else: # hard
|
||||
estimated_count = int(total_count * 0.2)
|
||||
|
||||
# 计算正确率
|
||||
if estimated_count > 0:
|
||||
correct_count = max(0, estimated_count - mistake_count)
|
||||
accuracy = round((correct_count / estimated_count) * 100, 1)
|
||||
else:
|
||||
accuracy = 100.0
|
||||
|
||||
difficulty_stats[difficulty] = accuracy
|
||||
|
||||
# 综合题和应用题使用中等和困难题的平均值
|
||||
difficulty_stats["综合题"] = round((difficulty_stats["medium"] + difficulty_stats["hard"]) / 2, 1)
|
||||
difficulty_stats["应用题"] = round((difficulty_stats["medium"] + difficulty_stats["hard"]) / 2, 1)
|
||||
|
||||
return {
|
||||
"简单题": difficulty_stats["easy"],
|
||||
"中等题": difficulty_stats["medium"],
|
||||
"困难题": difficulty_stats["hard"],
|
||||
"综合题": difficulty_stats["综合题"],
|
||||
"应用题": difficulty_stats["应用题"]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def get_knowledge_mastery(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
course_id: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取知识点掌握度
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID(可选)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 知识点掌握度列表
|
||||
"""
|
||||
logger.info(f"获取知识点掌握度 - user_id: {user_id}, course_id: {course_id}")
|
||||
|
||||
# 获取知识点列表
|
||||
kp_conditions = [KnowledgePoint.is_deleted == False]
|
||||
if course_id:
|
||||
kp_conditions.append(KnowledgePoint.course_id == course_id)
|
||||
|
||||
kp_stmt = select(KnowledgePoint).where(and_(*kp_conditions)).limit(10)
|
||||
result = await db.execute(kp_stmt)
|
||||
knowledge_points = result.scalars().all()
|
||||
|
||||
if not knowledge_points:
|
||||
# 没有知识点数据,返回默认数据
|
||||
return [
|
||||
{"name": "基础概念", "mastery": 85.0},
|
||||
{"name": "核心知识", "mastery": 72.0},
|
||||
{"name": "实践应用", "mastery": 68.0},
|
||||
{"name": "综合运用", "mastery": 58.0},
|
||||
{"name": "高级技巧", "mastery": 75.0},
|
||||
{"name": "案例分析", "mastery": 62.0}
|
||||
]
|
||||
|
||||
mastery_list = []
|
||||
|
||||
for kp in knowledge_points:
|
||||
# 统计该知识点的错误次数
|
||||
mistake_stmt = select(func.count(ExamMistake.id)).where(
|
||||
and_(
|
||||
ExamMistake.user_id == user_id,
|
||||
ExamMistake.knowledge_point_id == kp.id
|
||||
)
|
||||
)
|
||||
mistake_count = await db.scalar(mistake_stmt) or 0
|
||||
|
||||
# 假设每个知识点平均被考查10次(简化处理)
|
||||
estimated_total = 10
|
||||
|
||||
# 计算掌握度
|
||||
if estimated_total > 0:
|
||||
correct_count = max(0, estimated_total - mistake_count)
|
||||
mastery = round((correct_count / estimated_total) * 100, 1)
|
||||
else:
|
||||
mastery = 100.0
|
||||
|
||||
mastery_list.append({
|
||||
"name": kp.name[:10], # 限制名称长度
|
||||
"mastery": mastery
|
||||
})
|
||||
|
||||
return mastery_list[:6] # 最多返回6个知识点
|
||||
|
||||
@staticmethod
|
||||
async def get_study_time_stats(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
course_id: Optional[int] = None,
|
||||
period: str = "month"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取学习时长统计
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID(可选)
|
||||
period: 时间范围
|
||||
|
||||
Returns:
|
||||
Dict: 学习时长和练习时长的日期分布数据
|
||||
"""
|
||||
logger.info(f"获取学习时长统计 - user_id: {user_id}, course_id: {course_id}, period: {period}")
|
||||
|
||||
start_date, end_date = StatisticsService._get_date_range(period)
|
||||
|
||||
# 获取天数
|
||||
days = (end_date - start_date).days
|
||||
if days > 30:
|
||||
days = 30 # 最多显示30天
|
||||
|
||||
# 生成日期列表
|
||||
date_list = []
|
||||
for i in range(days):
|
||||
date = end_date - timedelta(days=days - i - 1)
|
||||
date_list.append(date.date())
|
||||
|
||||
# 初始化数据
|
||||
study_time_data = {str(d): 0.0 for d in date_list}
|
||||
practice_time_data = {str(d): 0.0 for d in date_list}
|
||||
|
||||
# 统计考试时长(学习时长)
|
||||
exam_conditions = [
|
||||
Exam.user_id == user_id,
|
||||
Exam.start_time >= start_date,
|
||||
Exam.start_time <= end_date
|
||||
]
|
||||
if course_id:
|
||||
exam_conditions.append(Exam.course_id == course_id)
|
||||
|
||||
exam_stmt = select(
|
||||
func.date(Exam.start_time).label("date"),
|
||||
func.sum(Exam.duration_minutes).label("total_minutes")
|
||||
).where(
|
||||
and_(*exam_conditions)
|
||||
).group_by(
|
||||
func.date(Exam.start_time)
|
||||
)
|
||||
|
||||
exam_result = await db.execute(exam_stmt)
|
||||
for row in exam_result.all():
|
||||
date_str = str(row.date)
|
||||
if date_str in study_time_data:
|
||||
study_time_data[date_str] = round(float(row.total_minutes) / 60, 1)
|
||||
|
||||
# 统计陪练时长(练习时长)
|
||||
practice_conditions = [
|
||||
PracticeSession.user_id == user_id,
|
||||
PracticeSession.start_time >= start_date,
|
||||
PracticeSession.start_time <= end_date,
|
||||
PracticeSession.status == "completed"
|
||||
]
|
||||
|
||||
practice_stmt = select(
|
||||
func.date(PracticeSession.start_time).label("date"),
|
||||
func.sum(PracticeSession.duration_seconds).label("total_seconds")
|
||||
).where(
|
||||
and_(*practice_conditions)
|
||||
).group_by(
|
||||
func.date(PracticeSession.start_time)
|
||||
)
|
||||
|
||||
practice_result = await db.execute(practice_stmt)
|
||||
for row in practice_result.all():
|
||||
date_str = str(row.date)
|
||||
if date_str in practice_time_data:
|
||||
practice_time_data[date_str] = round(float(row.total_seconds) / 3600, 1)
|
||||
|
||||
# 如果period是week,返回星期几标签
|
||||
if period == "week":
|
||||
weekday_labels = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
|
||||
labels = weekday_labels[:len(date_list)]
|
||||
else:
|
||||
# 其他情况返回日期
|
||||
labels = [d.strftime("%m-%d") for d in date_list]
|
||||
|
||||
study_values = [study_time_data[str(d)] for d in date_list]
|
||||
practice_values = [practice_time_data[str(d)] for d in date_list]
|
||||
|
||||
return {
|
||||
"labels": labels,
|
||||
"studyTime": study_values,
|
||||
"practiceTime": practice_values
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def get_detail_data(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
course_id: Optional[int] = None,
|
||||
period: str = "month"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取详细统计数据(按日期)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
course_id: 课程ID(可选)
|
||||
period: 时间范围
|
||||
|
||||
Returns:
|
||||
List[Dict]: 每日详细统计数据
|
||||
"""
|
||||
logger.info(f"获取详细数据 - user_id: {user_id}, course_id: {course_id}, period: {period}")
|
||||
|
||||
start_date, end_date = StatisticsService._get_date_range(period)
|
||||
|
||||
# 构建查询条件
|
||||
exam_conditions = [
|
||||
Exam.user_id == user_id,
|
||||
Exam.start_time >= start_date,
|
||||
Exam.start_time <= end_date,
|
||||
Exam.round1_score.isnot(None)
|
||||
]
|
||||
if course_id:
|
||||
exam_conditions.append(Exam.course_id == course_id)
|
||||
|
||||
# 按日期分组统计
|
||||
stmt = select(
|
||||
func.date(Exam.start_time).label("date"),
|
||||
func.count(Exam.id).label("exam_count"),
|
||||
func.avg(Exam.round1_score).label("avg_score"),
|
||||
func.sum(Exam.duration_minutes).label("total_duration"),
|
||||
func.sum(Exam.question_count).label("total_questions")
|
||||
).where(
|
||||
and_(*exam_conditions)
|
||||
).group_by(
|
||||
func.date(Exam.start_time)
|
||||
).order_by(
|
||||
desc(func.date(Exam.start_time))
|
||||
).limit(10) # 最多返回10条
|
||||
|
||||
result = await db.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
detail_list = []
|
||||
|
||||
for row in rows:
|
||||
date_str = row.date.strftime("%Y-%m-%d")
|
||||
exam_count = row.exam_count or 0
|
||||
avg_score = round(float(row.avg_score or 0), 1)
|
||||
study_time = round(float(row.total_duration or 0) / 60, 1)
|
||||
question_count = row.total_questions or 0
|
||||
|
||||
# 统计当天的错题数
|
||||
mistake_stmt = select(func.count(ExamMistake.id)).where(
|
||||
and_(
|
||||
ExamMistake.user_id == user_id,
|
||||
ExamMistake.exam_id.in_(
|
||||
select(Exam.id).where(
|
||||
and_(
|
||||
Exam.user_id == user_id,
|
||||
func.date(Exam.start_time) == row.date
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
mistake_count = await db.scalar(mistake_stmt) or 0
|
||||
|
||||
# 计算正确率
|
||||
accuracy = 0.0
|
||||
if question_count > 0:
|
||||
correct_count = question_count - mistake_count
|
||||
accuracy = round((correct_count / question_count) * 100, 1)
|
||||
|
||||
# 计算进步指数(基于平均分)
|
||||
improvement = min(100, max(0, int(avg_score)))
|
||||
|
||||
detail_list.append({
|
||||
"date": date_str,
|
||||
"examCount": exam_count,
|
||||
"avgScore": avg_score,
|
||||
"studyTime": study_time,
|
||||
"questionCount": question_count,
|
||||
"accuracy": accuracy,
|
||||
"improvement": improvement
|
||||
})
|
||||
|
||||
return detail_list
|
||||
|
||||
170
backend/app/services/system_log_service.py
Normal file
170
backend/app/services/system_log_service.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
系统日志服务
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.system_log import SystemLog
|
||||
from app.schemas.system_log import SystemLogCreate, SystemLogQuery
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SystemLogService:
|
||||
"""系统日志服务类"""
|
||||
|
||||
async def create_log(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
log_data: SystemLogCreate
|
||||
) -> SystemLog:
|
||||
"""
|
||||
创建系统日志
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
log_data: 日志数据
|
||||
|
||||
Returns:
|
||||
创建的日志对象
|
||||
"""
|
||||
try:
|
||||
log = SystemLog(**log_data.model_dump())
|
||||
db.add(log)
|
||||
await db.commit()
|
||||
await db.refresh(log)
|
||||
return log
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"创建系统日志失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_logs(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
query_params: SystemLogQuery
|
||||
) -> tuple[list[SystemLog], int]:
|
||||
"""
|
||||
查询系统日志列表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
query_params: 查询参数
|
||||
|
||||
Returns:
|
||||
(日志列表, 总数)
|
||||
"""
|
||||
try:
|
||||
# 构建基础查询
|
||||
stmt = select(SystemLog)
|
||||
count_stmt = select(func.count(SystemLog.id))
|
||||
|
||||
# 应用筛选条件
|
||||
filters = []
|
||||
|
||||
if query_params.level:
|
||||
filters.append(SystemLog.level == query_params.level)
|
||||
|
||||
if query_params.type:
|
||||
filters.append(SystemLog.type == query_params.type)
|
||||
|
||||
if query_params.user:
|
||||
filters.append(SystemLog.user == query_params.user)
|
||||
|
||||
if query_params.keyword:
|
||||
filters.append(SystemLog.message.like(f"%{query_params.keyword}%"))
|
||||
|
||||
if query_params.start_date:
|
||||
filters.append(SystemLog.created_at >= query_params.start_date)
|
||||
|
||||
if query_params.end_date:
|
||||
filters.append(SystemLog.created_at <= query_params.end_date)
|
||||
|
||||
# 应用所有筛选条件
|
||||
if filters:
|
||||
stmt = stmt.where(*filters)
|
||||
count_stmt = count_stmt.where(*filters)
|
||||
|
||||
# 获取总数
|
||||
result = await db.execute(count_stmt)
|
||||
total = result.scalar_one()
|
||||
|
||||
# 应用排序和分页
|
||||
stmt = stmt.order_by(SystemLog.created_at.desc())
|
||||
stmt = stmt.offset((query_params.page - 1) * query_params.page_size)
|
||||
stmt = stmt.limit(query_params.page_size)
|
||||
|
||||
# 执行查询
|
||||
result = await db.execute(stmt)
|
||||
logs = result.scalars().all()
|
||||
|
||||
return list(logs), total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询系统日志失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_log_by_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
log_id: int
|
||||
) -> Optional[SystemLog]:
|
||||
"""
|
||||
根据ID获取日志详情
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
log_id: 日志ID
|
||||
|
||||
Returns:
|
||||
日志对象或None
|
||||
"""
|
||||
try:
|
||||
stmt = select(SystemLog).where(SystemLog.id == log_id)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"获取日志详情失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def delete_logs_before_date(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
before_date: datetime
|
||||
) -> int:
|
||||
"""
|
||||
删除指定日期之前的日志(用于日志清理)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
before_date: 截止日期
|
||||
|
||||
Returns:
|
||||
删除的日志数量
|
||||
"""
|
||||
try:
|
||||
stmt = select(SystemLog).where(SystemLog.created_at < before_date)
|
||||
result = await db.execute(stmt)
|
||||
logs = result.scalars().all()
|
||||
|
||||
count = len(logs)
|
||||
for log in logs:
|
||||
await db.delete(log)
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"已删除 {count} 条日志记录")
|
||||
return count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"删除日志失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
system_log_service = SystemLogService()
|
||||
|
||||
|
||||
|
||||
214
backend/app/services/task_service.py
Normal file
214
backend/app/services/task_service.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
任务服务
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from app.models.task import Task, TaskCourse, TaskAssignment, TaskStatus, AssignmentStatus
|
||||
from app.models.course import Course
|
||||
from app.schemas.task import TaskCreate, TaskUpdate, TaskStatsResponse
|
||||
from app.services.base_service import BaseService
|
||||
|
||||
|
||||
class TaskService(BaseService[Task]):
|
||||
"""任务服务"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(Task)
|
||||
|
||||
async def create_task(self, db: AsyncSession, task_in: TaskCreate, creator_id: int) -> Task:
|
||||
"""创建任务"""
|
||||
# 创建任务
|
||||
task = Task(
|
||||
title=task_in.title,
|
||||
description=task_in.description,
|
||||
priority=task_in.priority,
|
||||
deadline=task_in.deadline,
|
||||
requirements=task_in.requirements,
|
||||
creator_id=creator_id,
|
||||
status=TaskStatus.PENDING
|
||||
)
|
||||
db.add(task)
|
||||
await db.flush()
|
||||
|
||||
# 关联课程
|
||||
for course_id in task_in.course_ids:
|
||||
task_course = TaskCourse(task_id=task.id, course_id=course_id)
|
||||
db.add(task_course)
|
||||
|
||||
# 分配用户
|
||||
for user_id in task_in.user_ids:
|
||||
assignment = TaskAssignment(task_id=task.id, user_id=user_id)
|
||||
db.add(assignment)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return task
|
||||
|
||||
async def get_tasks(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
status: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20
|
||||
) -> (List[Task], int):
|
||||
"""获取任务列表"""
|
||||
stmt = select(Task).where(Task.is_deleted == False)
|
||||
|
||||
if status:
|
||||
stmt = stmt.where(Task.status == status)
|
||||
|
||||
stmt = stmt.order_by(Task.created_at.desc())
|
||||
|
||||
# 获取总数
|
||||
count_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False)
|
||||
if status:
|
||||
count_stmt = count_stmt.where(Task.status == status)
|
||||
total = (await db.execute(count_stmt)).scalar_one()
|
||||
|
||||
# 分页
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
result = await db.execute(stmt)
|
||||
tasks = result.scalars().all()
|
||||
|
||||
return tasks, total
|
||||
|
||||
async def get_task_detail(self, db: AsyncSession, task_id: int) -> Optional[Task]:
|
||||
"""获取任务详情"""
|
||||
stmt = select(Task).where(
|
||||
and_(Task.id == task_id, Task.is_deleted == False)
|
||||
).options(
|
||||
joinedload(Task.course_links).joinedload(TaskCourse.course),
|
||||
joinedload(Task.assignments)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.unique().scalar_one_or_none()
|
||||
|
||||
async def update_task(self, db: AsyncSession, task_id: int, task_in: TaskUpdate) -> Optional[Task]:
|
||||
"""更新任务"""
|
||||
stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False))
|
||||
result = await db.execute(stmt)
|
||||
task = result.scalar_one_or_none()
|
||||
|
||||
if not task:
|
||||
return None
|
||||
|
||||
update_data = task_in.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(task, field, value)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(task)
|
||||
return task
|
||||
|
||||
async def delete_task(self, db: AsyncSession, task_id: int) -> bool:
|
||||
"""删除任务(软删除)"""
|
||||
stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False))
|
||||
result = await db.execute(stmt)
|
||||
task = result.scalar_one_or_none()
|
||||
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.is_deleted = True
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
async def get_task_stats(self, db: AsyncSession) -> TaskStatsResponse:
|
||||
"""获取任务统计"""
|
||||
# 总任务数
|
||||
total_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False)
|
||||
total = (await db.execute(total_stmt)).scalar_one()
|
||||
|
||||
# 各状态任务数
|
||||
status_stmt = select(
|
||||
Task.status,
|
||||
func.count(Task.id)
|
||||
).where(Task.is_deleted == False).group_by(Task.status)
|
||||
status_result = await db.execute(status_stmt)
|
||||
status_counts = dict(status_result.all())
|
||||
|
||||
# 平均完成率
|
||||
avg_stmt = select(func.avg(Task.progress)).where(
|
||||
and_(Task.is_deleted == False, Task.status != TaskStatus.EXPIRED)
|
||||
)
|
||||
avg_completion = (await db.execute(avg_stmt)).scalar_one() or 0.0
|
||||
|
||||
return TaskStatsResponse(
|
||||
total=total,
|
||||
ongoing=status_counts.get(TaskStatus.ONGOING.value, 0),
|
||||
completed=status_counts.get(TaskStatus.COMPLETED.value, 0),
|
||||
expired=status_counts.get(TaskStatus.EXPIRED.value, 0),
|
||||
avg_completion_rate=round(avg_completion, 1)
|
||||
)
|
||||
|
||||
async def update_task_progress(self, db: AsyncSession, task_id: int) -> int:
|
||||
"""
|
||||
更新任务进度
|
||||
|
||||
计算已完成的分配数占总分配数的百分比
|
||||
"""
|
||||
# 统计总分配数和完成数
|
||||
stmt = select(
|
||||
func.count(TaskAssignment.id).label('total'),
|
||||
func.sum(
|
||||
func.case(
|
||||
(TaskAssignment.status == AssignmentStatus.COMPLETED, 1),
|
||||
else_=0
|
||||
)
|
||||
).label('completed')
|
||||
).where(TaskAssignment.task_id == task_id)
|
||||
|
||||
result = (await db.execute(stmt)).first()
|
||||
total = result.total or 0
|
||||
completed = result.completed or 0
|
||||
|
||||
if total == 0:
|
||||
progress = 0
|
||||
else:
|
||||
progress = int((completed / total) * 100)
|
||||
|
||||
# 更新任务进度
|
||||
task_stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False))
|
||||
task_result = await db.execute(task_stmt)
|
||||
task = task_result.scalar_one_or_none()
|
||||
|
||||
if task:
|
||||
task.progress = progress
|
||||
await db.commit()
|
||||
|
||||
return progress
|
||||
|
||||
async def update_task_status(self, db: AsyncSession, task_id: int):
|
||||
"""
|
||||
更新任务状态
|
||||
|
||||
根据进度和截止时间自动更新任务状态
|
||||
"""
|
||||
task = await self.get_task_detail(db, task_id)
|
||||
if not task:
|
||||
return
|
||||
|
||||
# 计算并更新进度
|
||||
progress = await self.update_task_progress(db, task_id)
|
||||
|
||||
# 自动更新状态
|
||||
now = datetime.now()
|
||||
|
||||
if progress == 100:
|
||||
# 完全完成
|
||||
task.status = TaskStatus.COMPLETED
|
||||
elif task.deadline and now > task.deadline and task.status != TaskStatus.COMPLETED:
|
||||
# 已过期且未完成
|
||||
task.status = TaskStatus.EXPIRED
|
||||
elif progress > 0 and task.status == TaskStatus.PENDING:
|
||||
# 已开始但未完成
|
||||
task.status = TaskStatus.ONGOING
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
task_service = TaskService()
|
||||
|
||||
372
backend/app/services/training_service.py
Normal file
372
backend/app/services/training_service.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""陪练服务层"""
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.models.training import (
|
||||
TrainingScene,
|
||||
TrainingSession,
|
||||
TrainingMessage,
|
||||
TrainingReport,
|
||||
TrainingSceneStatus,
|
||||
TrainingSessionStatus,
|
||||
MessageRole,
|
||||
MessageType,
|
||||
)
|
||||
from app.schemas.training import (
|
||||
TrainingSceneCreate,
|
||||
TrainingSceneUpdate,
|
||||
TrainingSessionCreate,
|
||||
TrainingSessionUpdate,
|
||||
TrainingMessageCreate,
|
||||
TrainingReportCreate,
|
||||
StartTrainingRequest,
|
||||
StartTrainingResponse,
|
||||
EndTrainingRequest,
|
||||
EndTrainingResponse,
|
||||
)
|
||||
from app.services.base_service import BaseService
|
||||
|
||||
# from app.services.ai.coze.client import CozeClient
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class TrainingSceneService(BaseService[TrainingScene]):
|
||||
"""陪练场景服务"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(TrainingScene)
|
||||
|
||||
async def get_active_scenes(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
category: Optional[str] = None,
|
||||
is_public: Optional[bool] = None,
|
||||
user_level: Optional[int] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
) -> List[TrainingScene]:
|
||||
"""获取激活的陪练场景列表"""
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.status == TrainingSceneStatus.ACTIVE,
|
||||
self.model.is_deleted == False,
|
||||
)
|
||||
)
|
||||
|
||||
if category:
|
||||
query = query.where(self.model.category == category)
|
||||
|
||||
if is_public is not None:
|
||||
query = query.where(self.model.is_public == is_public)
|
||||
|
||||
if user_level is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
self.model.required_level == None,
|
||||
self.model.required_level <= user_level,
|
||||
)
|
||||
)
|
||||
|
||||
return await self.get_multi(db, skip=skip, limit=limit, query=query)
|
||||
|
||||
async def create_scene(
|
||||
self, db: AsyncSession, *, scene_in: TrainingSceneCreate, created_by: int
|
||||
) -> TrainingScene:
|
||||
"""创建陪练场景"""
|
||||
return await self.create(
|
||||
db, obj_in=scene_in, created_by=created_by, updated_by=created_by
|
||||
)
|
||||
|
||||
async def update_scene(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
scene_id: int,
|
||||
scene_in: TrainingSceneUpdate,
|
||||
updated_by: int,
|
||||
) -> Optional[TrainingScene]:
|
||||
"""更新陪练场景"""
|
||||
scene = await self.get(db, scene_id)
|
||||
if not scene or scene.is_deleted:
|
||||
return None
|
||||
|
||||
scene.updated_by = updated_by
|
||||
return await self.update(db, db_obj=scene, obj_in=scene_in)
|
||||
|
||||
|
||||
class TrainingSessionService(BaseService[TrainingSession]):
|
||||
"""陪练会话服务"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(TrainingSession)
|
||||
self.scene_service = TrainingSceneService()
|
||||
self.message_service = TrainingMessageService()
|
||||
self.report_service = TrainingReportService()
|
||||
# TODO: 等Coze网关模块实现后替换
|
||||
self._coze_client = None
|
||||
|
||||
@property
|
||||
def coze_client(self):
|
||||
"""延迟初始化Coze客户端"""
|
||||
if self._coze_client is None:
|
||||
try:
|
||||
# from app.services.ai.coze.client import CozeClient
|
||||
# self._coze_client = CozeClient()
|
||||
logger.warning("Coze客户端暂未实现,使用模拟模式")
|
||||
self._coze_client = None
|
||||
except ImportError:
|
||||
logger.warning("Coze客户端未实现,使用模拟模式")
|
||||
return self._coze_client
|
||||
|
||||
async def start_training(
|
||||
self, db: AsyncSession, *, request: StartTrainingRequest, user_id: int
|
||||
) -> StartTrainingResponse:
|
||||
"""开始陪练会话"""
|
||||
# 验证场景
|
||||
scene = await self.scene_service.get(db, request.scene_id)
|
||||
if not scene or scene.is_deleted or scene.status != TrainingSceneStatus.ACTIVE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在或未激活"
|
||||
)
|
||||
|
||||
# 检查用户等级
|
||||
# TODO: 从User服务获取用户等级
|
||||
user_level = 1 # 临时模拟
|
||||
if scene.required_level and user_level < scene.required_level:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户等级不足")
|
||||
|
||||
# 创建会话
|
||||
session_data = TrainingSessionCreate(
|
||||
scene_id=request.scene_id, session_config=request.config
|
||||
)
|
||||
|
||||
session = await self.create(
|
||||
db, obj_in=session_data, user_id=user_id, created_by=user_id
|
||||
)
|
||||
|
||||
# 初始化Coze会话
|
||||
coze_conversation_id = None
|
||||
if self.coze_client and scene.ai_config:
|
||||
try:
|
||||
bot_id = scene.ai_config.get("bot_id", settings.coze_training_bot_id)
|
||||
if bot_id:
|
||||
# 创建Coze会话
|
||||
coze_result = await self.coze_client.create_conversation(
|
||||
bot_id=bot_id,
|
||||
user_id=str(user_id),
|
||||
meta_data={
|
||||
"scene_id": scene.id,
|
||||
"scene_name": scene.name,
|
||||
"session_id": session.id,
|
||||
},
|
||||
)
|
||||
coze_conversation_id = coze_result.get("conversation_id")
|
||||
|
||||
# 更新会话的Coze ID
|
||||
session.coze_conversation_id = coze_conversation_id
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"创建Coze会话失败: {e}")
|
||||
|
||||
# 加载场景信息
|
||||
await db.refresh(session, ["scene"])
|
||||
|
||||
# 构造WebSocket URL(如果需要)
|
||||
websocket_url = None
|
||||
if coze_conversation_id:
|
||||
websocket_url = f"ws://localhost:8000/ws/v1/training/{session.id}"
|
||||
|
||||
return StartTrainingResponse(
|
||||
session_id=session.id,
|
||||
coze_conversation_id=coze_conversation_id,
|
||||
scene=scene,
|
||||
websocket_url=websocket_url,
|
||||
)
|
||||
|
||||
async def end_training(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session_id: int,
|
||||
request: EndTrainingRequest,
|
||||
user_id: int,
|
||||
) -> EndTrainingResponse:
|
||||
"""结束陪练会话"""
|
||||
# 获取会话
|
||||
session = await self.get(db, session_id)
|
||||
if not session or session.user_id != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="会话不存在")
|
||||
|
||||
if session.status in [
|
||||
TrainingSessionStatus.COMPLETED,
|
||||
TrainingSessionStatus.CANCELLED,
|
||||
]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="会话已结束")
|
||||
|
||||
# 计算持续时间
|
||||
end_time = datetime.now()
|
||||
duration_seconds = int((end_time - session.start_time).total_seconds())
|
||||
|
||||
# 更新会话状态
|
||||
update_data = TrainingSessionUpdate(
|
||||
status=TrainingSessionStatus.COMPLETED,
|
||||
end_time=end_time,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
session = await self.update(db, db_obj=session, obj_in=update_data)
|
||||
|
||||
# 生成报告
|
||||
report = None
|
||||
if request.generate_report:
|
||||
report = await self._generate_report(
|
||||
db, session_id=session_id, user_id=user_id
|
||||
)
|
||||
|
||||
# 加载关联数据
|
||||
await db.refresh(session, ["scene"])
|
||||
if report:
|
||||
await db.refresh(report, ["session"])
|
||||
|
||||
return EndTrainingResponse(session=session, report=report)
|
||||
|
||||
async def get_user_sessions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
user_id: int,
|
||||
scene_id: Optional[int] = None,
|
||||
status: Optional[TrainingSessionStatus] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
) -> List[TrainingSession]:
|
||||
"""获取用户的陪练会话列表"""
|
||||
query = select(self.model).where(self.model.user_id == user_id)
|
||||
|
||||
if scene_id:
|
||||
query = query.where(self.model.scene_id == scene_id)
|
||||
|
||||
if status:
|
||||
query = query.where(self.model.status == status)
|
||||
|
||||
query = query.order_by(self.model.created_at.desc())
|
||||
|
||||
return await self.get_multi(db, skip=skip, limit=limit, query=query)
|
||||
|
||||
async def _generate_report(
|
||||
self, db: AsyncSession, *, session_id: int, user_id: int
|
||||
) -> Optional[TrainingReport]:
|
||||
"""生成陪练报告(内部方法)"""
|
||||
# 获取会话消息
|
||||
messages = await self.message_service.get_session_messages(
|
||||
db, session_id=session_id
|
||||
)
|
||||
|
||||
# TODO: 调用AI分析服务生成报告
|
||||
# 这里先生成模拟报告
|
||||
report_data = TrainingReportCreate(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
overall_score=85.5,
|
||||
dimension_scores={"表达能力": 88.0, "逻辑思维": 85.0, "专业知识": 82.0, "应变能力": 87.0},
|
||||
strengths=["表达清晰,语言流畅", "能够快速理解问题并作出回应", "展现了良好的专业素养"],
|
||||
weaknesses=["部分专业术语使用不够准确", "回答有时过于冗长,需要更加精炼"],
|
||||
suggestions=["加强专业知识的学习,特别是术语的准确使用", "练习更加简洁有力的表达方式", "增加实际案例的积累,丰富回答内容"],
|
||||
detailed_analysis="整体表现良好,展现了扎实的基础知识和良好的沟通能力...",
|
||||
statistics={
|
||||
"total_messages": len(messages),
|
||||
"user_messages": len(
|
||||
[m for m in messages if m.role == MessageRole.USER]
|
||||
),
|
||||
"avg_response_time": 2.5,
|
||||
"total_words": 1500,
|
||||
},
|
||||
)
|
||||
|
||||
return await self.report_service.create(
|
||||
db, obj_in=report_data, created_by=user_id
|
||||
)
|
||||
|
||||
|
||||
class TrainingMessageService(BaseService[TrainingMessage]):
|
||||
"""陪练消息服务"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(TrainingMessage)
|
||||
|
||||
async def create_message(
|
||||
self, db: AsyncSession, *, message_in: TrainingMessageCreate
|
||||
) -> TrainingMessage:
|
||||
"""创建消息"""
|
||||
return await self.create(db, obj_in=message_in)
|
||||
|
||||
async def get_session_messages(
|
||||
self, db: AsyncSession, *, session_id: int, skip: int = 0, limit: int = 100
|
||||
) -> List[TrainingMessage]:
|
||||
"""获取会话的所有消息"""
|
||||
query = (
|
||||
select(self.model)
|
||||
.where(self.model.session_id == session_id)
|
||||
.order_by(self.model.created_at)
|
||||
)
|
||||
|
||||
return await self.get_multi(db, skip=skip, limit=limit, query=query)
|
||||
|
||||
async def save_voice_message(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
session_id: int,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
voice_url: str,
|
||||
voice_duration: float,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> TrainingMessage:
|
||||
"""保存语音消息"""
|
||||
message_data = TrainingMessageCreate(
|
||||
session_id=session_id,
|
||||
role=role,
|
||||
type=MessageType.VOICE,
|
||||
content=content,
|
||||
voice_url=voice_url,
|
||||
voice_duration=voice_duration,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return await self.create(db, obj_in=message_data)
|
||||
|
||||
|
||||
class TrainingReportService(BaseService[TrainingReport]):
|
||||
"""陪练报告服务"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(TrainingReport)
|
||||
|
||||
async def get_by_session(
|
||||
self, db: AsyncSession, *, session_id: int
|
||||
) -> Optional[TrainingReport]:
|
||||
"""根据会话ID获取报告"""
|
||||
result = await db.execute(
|
||||
select(self.model).where(self.model.session_id == session_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_user_reports(
|
||||
self, db: AsyncSession, *, user_id: int, skip: int = 0, limit: int = 20
|
||||
) -> List[TrainingReport]:
|
||||
"""获取用户的所有报告"""
|
||||
query = (
|
||||
select(self.model)
|
||||
.where(self.model.user_id == user_id)
|
||||
.order_by(self.model.created_at.desc())
|
||||
)
|
||||
|
||||
return await self.get_multi(db, skip=skip, limit=limit, query=query)
|
||||
423
backend/app/services/user_service.py
Normal file
423
backend/app/services/user_service.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
用户服务
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import and_, or_, select, func
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.exceptions import ConflictError, NotFoundError
|
||||
from app.core.logger import logger
|
||||
from app.core.security import get_password_hash, verify_password
|
||||
from app.models.user import Team, User, user_teams
|
||||
from app.schemas.user import UserCreate, UserFilter, UserUpdate
|
||||
from app.services.base_service import BaseService
|
||||
|
||||
|
||||
class UserService(BaseService[User]):
|
||||
"""用户服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
super().__init__(User)
|
||||
self.db = db
|
||||
|
||||
async def get_by_id(self, user_id: int) -> Optional[User]:
|
||||
"""根据ID获取用户"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.id == user_id, User.is_deleted == False)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_username(self, username: str) -> Optional[User]:
|
||||
"""根据用户名获取用户"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(
|
||||
User.username == username,
|
||||
User.is_deleted == False,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[User]:
|
||||
"""根据邮箱获取用户"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(
|
||||
User.email == email,
|
||||
User.is_deleted == False,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_phone(self, phone: str) -> Optional[User]:
|
||||
"""根据手机号获取用户"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(
|
||||
User.phone == phone,
|
||||
User.is_deleted == False,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _check_username_exists_all(self, username: str) -> Optional[User]:
|
||||
"""
|
||||
检查用户名是否已存在(包括已删除的用户)
|
||||
用于创建用户时检查唯一性约束
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.username == username)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _check_email_exists_all(self, email: str) -> Optional[User]:
|
||||
"""
|
||||
检查邮箱是否已存在(包括已删除的用户)
|
||||
用于创建用户时检查唯一性约束
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _check_phone_exists_all(self, phone: str) -> Optional[User]:
|
||||
"""
|
||||
检查手机号是否已存在(包括已删除的用户)
|
||||
用于创建用户时检查唯一性约束
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(User).where(User.phone == phone)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_user(
|
||||
self,
|
||||
*,
|
||||
obj_in: UserCreate,
|
||||
created_by: Optional[int] = None,
|
||||
) -> User:
|
||||
"""创建用户"""
|
||||
# 检查用户名是否已存在(包括已删除的用户,防止唯一键冲突)
|
||||
existing_user = await self._check_username_exists_all(obj_in.username)
|
||||
if existing_user:
|
||||
if existing_user.is_deleted:
|
||||
raise ConflictError(f"用户名 {obj_in.username} 已被使用(历史用户),请更换其他用户名")
|
||||
else:
|
||||
raise ConflictError(f"用户名 {obj_in.username} 已存在")
|
||||
|
||||
# 检查邮箱是否已存在(包括已删除的用户)
|
||||
if obj_in.email:
|
||||
existing_email = await self._check_email_exists_all(obj_in.email)
|
||||
if existing_email:
|
||||
if existing_email.is_deleted:
|
||||
raise ConflictError(f"邮箱 {obj_in.email} 已被使用(历史用户),请更换其他邮箱")
|
||||
else:
|
||||
raise ConflictError(f"邮箱 {obj_in.email} 已存在")
|
||||
|
||||
# 检查手机号是否已存在(包括已删除的用户)
|
||||
if obj_in.phone:
|
||||
existing_phone = await self._check_phone_exists_all(obj_in.phone)
|
||||
if existing_phone:
|
||||
if existing_phone.is_deleted:
|
||||
raise ConflictError(f"手机号 {obj_in.phone} 已被使用(历史用户),请更换其他手机号")
|
||||
else:
|
||||
raise ConflictError(f"手机号 {obj_in.phone} 已存在")
|
||||
|
||||
# 创建用户数据
|
||||
user_data = obj_in.model_dump(exclude={"password"})
|
||||
user_data["hashed_password"] = get_password_hash(obj_in.password)
|
||||
# 注意:User模型不包含created_by字段,该信息记录在日志中
|
||||
# user_data["created_by"] = created_by
|
||||
|
||||
try:
|
||||
# 创建用户
|
||||
user = await self.create(db=self.db, obj_in=user_data)
|
||||
except IntegrityError as e:
|
||||
# 捕获数据库唯一键冲突异常,返回友好错误信息
|
||||
await self.db.rollback()
|
||||
error_msg = str(e.orig) if e.orig else str(e)
|
||||
logger.warning(
|
||||
"创建用户时发生唯一键冲突",
|
||||
username=obj_in.username,
|
||||
email=obj_in.email,
|
||||
error=error_msg,
|
||||
)
|
||||
if "username" in error_msg.lower():
|
||||
raise ConflictError(f"用户名 {obj_in.username} 已被占用,请更换其他用户名")
|
||||
elif "email" in error_msg.lower():
|
||||
raise ConflictError(f"邮箱 {obj_in.email} 已被占用,请更换其他邮箱")
|
||||
elif "phone" in error_msg.lower():
|
||||
raise ConflictError(f"手机号 {obj_in.phone} 已被占用,请更换其他手机号")
|
||||
else:
|
||||
raise ConflictError(f"创建用户失败:数据冲突,请检查用户名、邮箱或手机号是否重复")
|
||||
|
||||
# 记录日志
|
||||
logger.info(
|
||||
"用户创建成功",
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
role=user.role,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
async def update_user(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
obj_in: UserUpdate,
|
||||
updated_by: Optional[int] = None,
|
||||
) -> User:
|
||||
"""更新用户"""
|
||||
user = await self.get_by_id(user_id)
|
||||
if not user:
|
||||
raise NotFoundError("用户不存在")
|
||||
|
||||
# 如果更新邮箱,检查是否已存在
|
||||
if obj_in.email and obj_in.email != user.email:
|
||||
if await self.get_by_email(obj_in.email):
|
||||
raise ConflictError(f"邮箱 {obj_in.email} 已存在")
|
||||
|
||||
# 如果更新手机号,检查是否已存在
|
||||
if obj_in.phone and obj_in.phone != user.phone:
|
||||
if await self.get_by_phone(obj_in.phone):
|
||||
raise ConflictError(f"手机号 {obj_in.phone} 已存在")
|
||||
|
||||
# 更新用户数据
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
update_data["updated_by"] = updated_by
|
||||
|
||||
user = await self.update(db=self.db, db_obj=user, obj_in=update_data)
|
||||
|
||||
# 记录日志
|
||||
logger.info(
|
||||
"用户更新成功",
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
updated_fields=list(update_data.keys()),
|
||||
updated_by=updated_by,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
async def update_password(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
old_password: str,
|
||||
new_password: str,
|
||||
) -> User:
|
||||
"""更新密码"""
|
||||
user = await self.get_by_id(user_id)
|
||||
if not user:
|
||||
raise NotFoundError("用户不存在")
|
||||
|
||||
# 验证旧密码
|
||||
if not verify_password(old_password, user.hashed_password):
|
||||
raise ConflictError("旧密码错误")
|
||||
|
||||
# 更新密码
|
||||
update_data = {
|
||||
"hashed_password": get_password_hash(new_password),
|
||||
"password_changed_at": datetime.now(),
|
||||
}
|
||||
user = await self.update(db=self.db, db_obj=user, obj_in=update_data)
|
||||
|
||||
# 记录日志
|
||||
logger.info(
|
||||
"用户密码更新成功",
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
async def update_last_login(self, user_id: int) -> None:
|
||||
"""更新最后登录时间"""
|
||||
user = await self.get_by_id(user_id)
|
||||
if user:
|
||||
await self.update(
|
||||
db=self.db,
|
||||
db_obj=user,
|
||||
obj_in={"last_login_at": datetime.now()},
|
||||
)
|
||||
|
||||
async def get_users_with_filter(
|
||||
self,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
filter_params: UserFilter,
|
||||
) -> tuple[List[User], int]:
|
||||
"""根据筛选条件获取用户列表"""
|
||||
# 构建筛选条件
|
||||
filters = [User.is_deleted == False]
|
||||
|
||||
if filter_params.role:
|
||||
filters.append(User.role == filter_params.role)
|
||||
|
||||
if filter_params.is_active is not None:
|
||||
filters.append(User.is_active == filter_params.is_active)
|
||||
|
||||
if filter_params.keyword:
|
||||
keyword = f"%{filter_params.keyword}%"
|
||||
filters.append(
|
||||
or_(
|
||||
User.username.like(keyword),
|
||||
User.email.like(keyword),
|
||||
User.full_name.like(keyword),
|
||||
)
|
||||
)
|
||||
|
||||
if filter_params.team_id:
|
||||
# 通过团队ID筛选用户
|
||||
subquery = select(user_teams.c.user_id).where(
|
||||
user_teams.c.team_id == filter_params.team_id
|
||||
)
|
||||
filters.append(User.id.in_(subquery))
|
||||
|
||||
# 构建查询
|
||||
query = select(User).where(and_(*filters))
|
||||
|
||||
# 获取用户列表
|
||||
users = await self.get_multi(self.db, skip=skip, limit=limit, query=query)
|
||||
|
||||
# 获取总数
|
||||
count_query = select(func.count(User.id)).where(and_(*filters))
|
||||
count_result = await self.db.execute(count_query)
|
||||
total = count_result.scalar()
|
||||
|
||||
return users, total
|
||||
|
||||
async def add_user_to_team(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
team_id: int,
|
||||
role: str = "member",
|
||||
) -> None:
|
||||
"""将用户添加到团队"""
|
||||
# 检查用户是否存在
|
||||
user = await self.get_by_id(user_id)
|
||||
if not user:
|
||||
raise NotFoundError("用户不存在")
|
||||
|
||||
# 检查团队是否存在
|
||||
team_result = await self.db.execute(
|
||||
select(Team).where(Team.id == team_id, Team.is_deleted == False)
|
||||
)
|
||||
team = team_result.scalar_one_or_none()
|
||||
if not team:
|
||||
raise NotFoundError("团队不存在")
|
||||
|
||||
# 检查是否已在团队中
|
||||
existing = await self.db.execute(
|
||||
select(user_teams).where(
|
||||
user_teams.c.user_id == user_id,
|
||||
user_teams.c.team_id == team_id,
|
||||
)
|
||||
)
|
||||
if existing.first():
|
||||
raise ConflictError("用户已在该团队中")
|
||||
|
||||
# 添加到团队
|
||||
await self.db.execute(
|
||||
user_teams.insert().values(
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
role=role,
|
||||
joined_at=datetime.now(),
|
||||
)
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
# 记录日志
|
||||
logger.info(
|
||||
"用户加入团队",
|
||||
user_id=user_id,
|
||||
username=user.username,
|
||||
team_id=team_id,
|
||||
team_name=team.name,
|
||||
role=role,
|
||||
)
|
||||
|
||||
async def remove_user_from_team(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
team_id: int,
|
||||
) -> None:
|
||||
"""从团队中移除用户"""
|
||||
# 删除关联
|
||||
result = await self.db.execute(
|
||||
user_teams.delete().where(
|
||||
user_teams.c.user_id == user_id,
|
||||
user_teams.c.team_id == team_id,
|
||||
)
|
||||
)
|
||||
|
||||
if result.rowcount == 0:
|
||||
raise NotFoundError("用户不在该团队中")
|
||||
|
||||
await self.db.commit()
|
||||
|
||||
# 记录日志
|
||||
logger.info(
|
||||
"用户离开团队",
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
async def soft_delete(self, *, db_obj: User) -> User:
|
||||
"""
|
||||
软删除用户
|
||||
|
||||
Args:
|
||||
db_obj: 用户对象
|
||||
|
||||
Returns:
|
||||
软删除后的用户对象
|
||||
"""
|
||||
db_obj.is_deleted = True
|
||||
db_obj.deleted_at = datetime.now()
|
||||
self.db.add(db_obj)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(db_obj)
|
||||
|
||||
logger.info(
|
||||
"用户软删除成功",
|
||||
user_id=db_obj.id,
|
||||
username=db_obj.username,
|
||||
)
|
||||
|
||||
return db_obj
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
*,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> Optional[User]:
|
||||
"""用户认证"""
|
||||
# 尝试用户名登录
|
||||
user = await self.get_by_username(username)
|
||||
|
||||
# 尝试邮箱登录
|
||||
if not user:
|
||||
user = await self.get_by_email(username)
|
||||
|
||||
# 尝试手机号登录
|
||||
if not user:
|
||||
user = await self.get_by_phone(username)
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
|
||||
return user
|
||||
510
backend/app/services/yanji_service.py
Normal file
510
backend/app/services/yanji_service.py
Normal file
@@ -0,0 +1,510 @@
|
||||
"""
|
||||
言迹智能工牌API服务
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class YanjiService:
|
||||
"""言迹智能工牌API服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = settings.YANJI_API_BASE
|
||||
self.client_id = settings.YANJI_CLIENT_ID
|
||||
self.client_secret = settings.YANJI_CLIENT_SECRET
|
||||
self.tenant_id = settings.YANJI_TENANT_ID
|
||||
self.estate_id = int(settings.YANJI_ESTATE_ID)
|
||||
|
||||
# Token缓存
|
||||
self._access_token: Optional[str] = None
|
||||
self._token_expires_at: Optional[datetime] = None
|
||||
|
||||
async def get_access_token(self) -> str:
|
||||
"""
|
||||
获取或刷新access_token
|
||||
|
||||
Returns:
|
||||
access_token字符串
|
||||
"""
|
||||
# 检查缓存的token是否仍然有效(提前5分钟刷新)
|
||||
if self._access_token and self._token_expires_at:
|
||||
if datetime.now() < self._token_expires_at - timedelta(minutes=5):
|
||||
return self._access_token
|
||||
|
||||
# 获取新的token
|
||||
url = f"{self.base_url}/oauth/token"
|
||||
params = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, params=params, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
self._access_token = data["access_token"]
|
||||
expires_in = data.get("expires_in", 3600) # 默认1小时
|
||||
self._token_expires_at = datetime.now() + timedelta(seconds=expires_in)
|
||||
|
||||
logger.info(f"言迹API token获取成功,有效期至: {self._token_expires_at}")
|
||||
return self._access_token
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: Optional[Dict] = None,
|
||||
json_data: Optional[Dict] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
统一的HTTP请求方法
|
||||
|
||||
Args:
|
||||
method: HTTP方法(GET/POST等)
|
||||
path: API路径
|
||||
params: Query参数
|
||||
json_data: Body参数(JSON)
|
||||
|
||||
Returns:
|
||||
响应数据(data字段)
|
||||
|
||||
Raises:
|
||||
Exception: API调用失败
|
||||
"""
|
||||
token = await self.get_access_token()
|
||||
url = f"{self.base_url}{path}"
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
timeout=60.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# 言迹API: code='0'或code=0表示成功
|
||||
code = result.get("code")
|
||||
if str(code) != '0':
|
||||
error_msg = result.get("msg", "Unknown error")
|
||||
logger.error(f"言迹API调用失败: {error_msg}, result={result}")
|
||||
raise Exception(f"言迹API错误: {error_msg}")
|
||||
|
||||
# data可能为None,返回空字典或空列表由调用方判断
|
||||
return result.get("data")
|
||||
|
||||
async def get_visit_audios(
|
||||
self, external_visit_ids: List[str]
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
根据来访单ID获取录音信息
|
||||
|
||||
Args:
|
||||
external_visit_ids: 三方来访ID列表(最多10个)
|
||||
|
||||
Returns:
|
||||
录音信息列表
|
||||
"""
|
||||
if not external_visit_ids:
|
||||
return []
|
||||
|
||||
if len(external_visit_ids) > 10:
|
||||
logger.warning(f"来访单ID数量超过限制,截取前10个")
|
||||
external_visit_ids = external_visit_ids[:10]
|
||||
|
||||
data = await self._request(
|
||||
method="POST",
|
||||
path="/api/beauty/v1/visit/audios",
|
||||
json_data={
|
||||
"estateId": self.estate_id,
|
||||
"externalVisitIds": external_visit_ids,
|
||||
},
|
||||
)
|
||||
|
||||
if data is None:
|
||||
logger.info(f"获取来访录音信息: 无数据")
|
||||
return []
|
||||
|
||||
records = data.get("records", [])
|
||||
logger.info(f"获取来访录音信息成功: {len(records)}条")
|
||||
return records
|
||||
|
||||
async def get_audio_asr_result(self, audio_id: int) -> Dict:
|
||||
"""
|
||||
获取录音的ASR分析结果(对话文本)
|
||||
|
||||
Args:
|
||||
audio_id: 录音ID
|
||||
|
||||
Returns:
|
||||
ASR分析结果,包含对话文本数组
|
||||
"""
|
||||
data = await self._request(
|
||||
method="GET",
|
||||
path="/api/beauty/v1/audio/asr-analysed",
|
||||
params={"estateId": self.estate_id, "audioId": audio_id},
|
||||
)
|
||||
|
||||
# 检查data是否为None
|
||||
if data is None:
|
||||
logger.warning(f"录音ASR结果为None: audio_id={audio_id}")
|
||||
return {}
|
||||
|
||||
# data是一个数组,取第一个元素
|
||||
if isinstance(data, list) and len(data) > 0:
|
||||
result = data[0]
|
||||
logger.info(
|
||||
f"获取录音ASR结果成功: audio_id={audio_id}, "
|
||||
f"对话数={len(result.get('result', []))}"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"录音ASR结果为空: audio_id={audio_id}")
|
||||
return {}
|
||||
|
||||
async def get_recent_conversations(
|
||||
self, consultant_phone: str, limit: int = 10
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
获取员工最近N条对话记录
|
||||
|
||||
业务逻辑:
|
||||
1. 通过员工手机号获取录音列表(目前使用模拟数据)
|
||||
2. 对每个录音获取ASR分析结果
|
||||
3. 组合返回完整的对话记录
|
||||
|
||||
Args:
|
||||
consultant_phone: 员工手机号
|
||||
limit: 获取数量,默认10条
|
||||
|
||||
Returns:
|
||||
对话记录列表,格式:
|
||||
[{
|
||||
"audio_id": 123,
|
||||
"visit_id": "xxx",
|
||||
"start_time": "2025-01-15 10:30:00",
|
||||
"duration": 120000,
|
||||
"consultant_name": "张三",
|
||||
"consultant_phone": "13800138000",
|
||||
"conversation": [
|
||||
{"role": "consultant", "text": "您好..."},
|
||||
{"role": "customer", "text": "你好..."}
|
||||
]
|
||||
}]
|
||||
"""
|
||||
# TODO: 目前言迹API没有直接通过手机号查询录音的接口
|
||||
# 需要先获取来访单列表,再获取录音
|
||||
# 这里暂时返回空列表,后续根据实际业务需求补充
|
||||
|
||||
logger.warning(
|
||||
f"获取员工对话记录功能需要额外的业务逻辑支持 "
|
||||
f"(consultant_phone={consultant_phone}, limit={limit})"
|
||||
)
|
||||
|
||||
# 返回空列表,表示暂未实现
|
||||
return []
|
||||
|
||||
async def get_conversations_by_visit_ids(
|
||||
self, external_visit_ids: List[str]
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
根据来访单ID列表获取对话记录
|
||||
|
||||
Args:
|
||||
external_visit_ids: 三方来访ID列表
|
||||
|
||||
Returns:
|
||||
对话记录列表
|
||||
"""
|
||||
if not external_visit_ids:
|
||||
return []
|
||||
|
||||
# 1. 获取录音信息
|
||||
audio_records = await self.get_visit_audios(external_visit_ids)
|
||||
|
||||
if not audio_records:
|
||||
logger.info("没有找到录音记录")
|
||||
return []
|
||||
|
||||
# 2. 对每个录音获取ASR分析结果
|
||||
conversations = []
|
||||
for audio in audio_records:
|
||||
audio_id = audio.get("id")
|
||||
if not audio_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
asr_result = await self.get_audio_asr_result(audio_id)
|
||||
|
||||
# 解析对话文本
|
||||
conversation_messages = []
|
||||
for item in asr_result.get("result", []):
|
||||
role = "consultant" if item.get("role") == -1 else "customer"
|
||||
conversation_messages.append({
|
||||
"role": role,
|
||||
"text": item.get("text", ""),
|
||||
"begin_time": item.get("beginTime"),
|
||||
"end_time": item.get("endTime"),
|
||||
})
|
||||
|
||||
# 组合完整对话记录
|
||||
conversations.append({
|
||||
"audio_id": audio_id,
|
||||
"visit_id": audio.get("externalVisitId", ""),
|
||||
"start_time": audio.get("startTime", ""),
|
||||
"duration": audio.get("duration", 0),
|
||||
"consultant_name": audio.get("consultantName", ""),
|
||||
"consultant_phone": audio.get("consultantPhone", ""),
|
||||
"conversation": conversation_messages,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取录音ASR结果失败: audio_id={audio_id}, error={e}")
|
||||
continue
|
||||
|
||||
logger.info(f"成功获取{len(conversations)}条对话记录")
|
||||
return conversations
|
||||
|
||||
async def get_audio_list(self, phone: str) -> List[Dict]:
|
||||
"""
|
||||
获取员工的录音列表(模拟)
|
||||
|
||||
注意:言迹API暂时没有提供通过手机号直接查询录音列表的接口
|
||||
这里使用模拟数据,返回假想的录音列表
|
||||
|
||||
Args:
|
||||
phone: 员工手机号
|
||||
|
||||
Returns:
|
||||
录音信息列表
|
||||
"""
|
||||
logger.info(f"获取员工录音列表(模拟): phone={phone}")
|
||||
|
||||
# 模拟返回10条录音记录
|
||||
mock_audios = []
|
||||
base_time = datetime.now()
|
||||
|
||||
for i in range(10):
|
||||
# 模拟不同时长的录音
|
||||
durations = [25000, 45000, 180000, 240000, 120000, 90000, 60000, 300000, 420000, 150000]
|
||||
|
||||
mock_audios.append({
|
||||
"id": f"mock_audio_{i+1}",
|
||||
"externalVisitId": f"visit_{i+1}",
|
||||
"startTime": (base_time - timedelta(days=i)).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"duration": durations[i], # 毫秒
|
||||
"consultantName": "模拟员工",
|
||||
"consultantPhone": phone
|
||||
})
|
||||
|
||||
return mock_audios
|
||||
|
||||
async def get_employee_conversations_for_analysis(
|
||||
self,
|
||||
phone: str,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取员工最近N条录音的模拟对话数据(用于能力分析)
|
||||
|
||||
Args:
|
||||
phone: 员工手机号
|
||||
limit: 获取数量,默认10条
|
||||
|
||||
Returns:
|
||||
对话数据列表,格式:
|
||||
[{
|
||||
"audio_id": "mock_audio_1",
|
||||
"duration_seconds": 25,
|
||||
"start_time": "2025-10-15 10:30:00",
|
||||
"dialogue_history": [
|
||||
{"speaker": "consultant", "content": "您好..."},
|
||||
{"speaker": "customer", "content": "你好..."}
|
||||
]
|
||||
}]
|
||||
"""
|
||||
# 1. 获取录音列表
|
||||
audios = await self.get_audio_list(phone)
|
||||
|
||||
if not audios:
|
||||
logger.warning(f"未找到员工的录音记录: phone={phone}")
|
||||
return []
|
||||
|
||||
# 2. 筛选前limit条
|
||||
selected_audios = audios[:limit]
|
||||
|
||||
# 3. 为每条录音生成模拟对话
|
||||
conversations = []
|
||||
for audio in selected_audios:
|
||||
conversation = self._generate_mock_conversation(audio)
|
||||
conversations.append(conversation)
|
||||
|
||||
logger.info(f"生成模拟对话数据: phone={phone}, count={len(conversations)}")
|
||||
return conversations
|
||||
|
||||
def _generate_mock_conversation(self, audio: Dict) -> Dict:
|
||||
"""
|
||||
为录音生成模拟对话数据
|
||||
|
||||
根据录音时长选择不同复杂度的对话模板:
|
||||
- <30秒: 短对话(4-6轮)
|
||||
- 30秒-5分钟: 中等对话(8-12轮)
|
||||
- >5分钟: 长对话(15-20轮,完整销售流程)
|
||||
|
||||
Args:
|
||||
audio: 录音信息字典
|
||||
|
||||
Returns:
|
||||
对话数据字典
|
||||
"""
|
||||
duration = int(audio.get('duration', 60000)) // 1000 # 转换为秒
|
||||
|
||||
# 根据时长选择对话模板
|
||||
if duration < 30:
|
||||
dialogue = self._short_conversation_template()
|
||||
elif duration < 300:
|
||||
dialogue = self._medium_conversation_template()
|
||||
else:
|
||||
dialogue = self._long_conversation_template()
|
||||
|
||||
return {
|
||||
"audio_id": audio.get('id'),
|
||||
"duration_seconds": duration,
|
||||
"start_time": audio.get('startTime'),
|
||||
"dialogue_history": dialogue
|
||||
}
|
||||
|
||||
def _short_conversation_template(self) -> List[Dict]:
|
||||
"""短对话模板(<30秒)- 4-6轮对话"""
|
||||
templates = [
|
||||
[
|
||||
{"speaker": "consultant", "content": "您好,欢迎光临曼尼斐绮,请问有什么可以帮到您?"},
|
||||
{"speaker": "customer", "content": "你好,我想了解一下面部护理项目"},
|
||||
{"speaker": "consultant", "content": "好的,我们有多种面部护理方案,请问您主要关注哪方面呢?"},
|
||||
{"speaker": "customer", "content": "主要是想改善皮肤暗沉"},
|
||||
{"speaker": "consultant", "content": "明白了,针对皮肤暗沉,我推荐我们的美白焕肤套餐"}
|
||||
],
|
||||
[
|
||||
{"speaker": "consultant", "content": "您好,请问需要什么帮助吗?"},
|
||||
{"speaker": "customer", "content": "我想咨询一下祛斑项目"},
|
||||
{"speaker": "consultant", "content": "好的,请问您主要是哪种类型的斑点呢?"},
|
||||
{"speaker": "customer", "content": "脸颊两侧有些黄褐斑"},
|
||||
{"speaker": "consultant", "content": "了解,我们有专门针对黄褐斑的光子嫩肤项目,效果很不错"}
|
||||
],
|
||||
[
|
||||
{"speaker": "consultant", "content": "欢迎光临,有什么可以帮您的吗?"},
|
||||
{"speaker": "customer", "content": "我想预约做个面部护理"},
|
||||
{"speaker": "consultant", "content": "好的,请问您之前做过我们的项目吗?"},
|
||||
{"speaker": "customer", "content": "没有,第一次来"},
|
||||
{"speaker": "consultant", "content": "那我建议您先做个免费的皮肤检测,帮您制定个性化方案"},
|
||||
{"speaker": "customer", "content": "好的,那现在可以吗?"}
|
||||
]
|
||||
]
|
||||
return random.choice(templates)
|
||||
|
||||
def _medium_conversation_template(self) -> List[Dict]:
|
||||
"""中等对话模板(30秒-5分钟)- 8-12轮对话"""
|
||||
templates = [
|
||||
[
|
||||
{"speaker": "consultant", "content": "您好,欢迎光临曼尼斐绮,我是美容顾问小王,请问怎么称呼您?"},
|
||||
{"speaker": "customer", "content": "你好,我姓李"},
|
||||
{"speaker": "consultant", "content": "李女士您好,请问今天是第一次了解我们的项目吗?"},
|
||||
{"speaker": "customer", "content": "是的,之前在网上看到你们的介绍"},
|
||||
{"speaker": "consultant", "content": "好的,您对哪方面的美容项目比较感兴趣呢?"},
|
||||
{"speaker": "customer", "content": "我想改善面部松弛的问题,最近感觉皮肤没有以前紧致了"},
|
||||
{"speaker": "consultant", "content": "我理解您的困扰。请问您多大年龄?平时有做面部护理吗?"},
|
||||
{"speaker": "customer", "content": "我35岁,平时就是用护肤品,没做过专业护理"},
|
||||
{"speaker": "consultant", "content": "明白了。35岁开始注重抗衰是很及时的。我们有几种方案,比如射频紧肤、超声刀提拉,还有胶原蛋白再生项目"},
|
||||
{"speaker": "customer", "content": "这几种有什么区别吗?"},
|
||||
{"speaker": "consultant", "content": "射频主要是刺激胶原蛋白增生,效果温和持久。超声刀作用更深层,提拉效果更明显但价格稍高。我建议您先做个皮肤检测,看具体适合哪种"},
|
||||
{"speaker": "customer", "content": "好的,那先做个检测吧"}
|
||||
],
|
||||
[
|
||||
{"speaker": "consultant", "content": "您好,欢迎光临,我是美容顾问晓雯,请问您是第一次来吗?"},
|
||||
{"speaker": "customer", "content": "是的,朋友推荐过来看看"},
|
||||
{"speaker": "consultant", "content": "太好了,请问您朋友是做的什么项目呢?"},
|
||||
{"speaker": "customer", "content": "她做的好像是什么水光针"},
|
||||
{"speaker": "consultant", "content": "水光针确实是我们很受欢迎的项目。请问您今天主要想了解哪方面呢?"},
|
||||
{"speaker": "customer", "content": "我主要是皮肤有点粗糙,毛孔也大"},
|
||||
{"speaker": "consultant", "content": "嗯,针对毛孔粗大和皮肤粗糙,水光针确实有不错的效果。不过我建议先看看您的具体情况"},
|
||||
{"speaker": "customer", "content": "需要检查吗?"},
|
||||
{"speaker": "consultant", "content": "是的,我们有专业的皮肤检测仪,可以看到肉眼看不到的皮肤问题,这样制定方案更精准"},
|
||||
{"speaker": "customer", "content": "好的,那检查一下吧"},
|
||||
{"speaker": "consultant", "content": "好的,请这边来,检查大概需要5分钟"}
|
||||
]
|
||||
]
|
||||
return random.choice(templates)
|
||||
|
||||
def _long_conversation_template(self) -> List[Dict]:
|
||||
"""长对话模板(>5分钟)- 15-20轮对话,完整销售流程"""
|
||||
templates = [
|
||||
[
|
||||
{"speaker": "consultant", "content": "您好,欢迎光临曼尼斐绮,我是资深美容顾问晓雯,请问怎么称呼您?"},
|
||||
{"speaker": "customer", "content": "你好,我姓陈"},
|
||||
{"speaker": "consultant", "content": "陈女士您好,看您气色很好,平时应该很注重保养吧?"},
|
||||
{"speaker": "customer", "content": "还好吧,基本的护肤品会用"},
|
||||
{"speaker": "consultant", "content": "这样啊。那今天是专程过来了解我们的项目,还是朋友推荐的呢?"},
|
||||
{"speaker": "customer", "content": "我闺蜜在你们这做过,说效果不错,所以想来看看"},
|
||||
{"speaker": "consultant", "content": "太好了,请问您闺蜜做的是什么项目呢?"},
|
||||
{"speaker": "customer", "content": "好像是什么光子嫩肤"},
|
||||
{"speaker": "consultant", "content": "明白了,光子嫩肤确实是我们的明星项目。不过每个人的皮肤状况不同,我先帮您做个详细的皮肤检测,看看最适合您的方案好吗?"},
|
||||
{"speaker": "customer", "content": "好的"},
|
||||
{"speaker": "consultant", "content": "陈女士,通过检测我看到您的皮肤主要有三个问题:一是T区毛孔粗大,二是两颊有轻微色斑,三是皮肤缺水。您平时有感觉到这些问题吗?"},
|
||||
{"speaker": "customer", "content": "对,毛孔确实有点大,色斑是最近才发现的"},
|
||||
{"speaker": "consultant", "content": "嗯,这些问题如果不及时处理会越来越明显。针对您的情况,我建议做一个综合性的美白嫩肤方案"},
|
||||
{"speaker": "customer", "content": "具体是怎么做的?"},
|
||||
{"speaker": "consultant", "content": "我们采用光子嫩肤配合水光针的组合疗程。光子嫩肤主要解决色斑和毛孔问题,水光针补水锁水,效果相辅相成"},
|
||||
{"speaker": "customer", "content": "听起来不错,大概需要多少钱?"},
|
||||
{"speaker": "consultant", "content": "我们现在正好有活动,光子嫩肤单次原价3800,水光针单次2600,组合套餐优惠后只要5800,相当于打了九折"},
|
||||
{"speaker": "customer", "content": "嗯...还是有点贵"},
|
||||
{"speaker": "consultant", "content": "我理解您的顾虑。但是陈女士,您想想,这个价格是一次性投入,效果却能维持3-6个月。平均下来每天不到30块钱,换来的是皮肤的明显改善"},
|
||||
{"speaker": "customer", "content": "这倒也是..."},
|
||||
{"speaker": "consultant", "content": "而且这个活动就到本月底,下个月恢复原价的话就要6400了。您今天如果确定的话,我还可以帮您申请赠送一次基础补水护理"},
|
||||
{"speaker": "customer", "content": "那行吧,今天就定了"},
|
||||
{"speaker": "consultant", "content": "太好了!陈女士您做了个很明智的决定。我现在帮您预约最近的时间,您看周三下午方便吗?"}
|
||||
],
|
||||
[
|
||||
{"speaker": "consultant", "content": "您好,欢迎光临,我是美容顾问小张,请问您贵姓?"},
|
||||
{"speaker": "customer", "content": "我姓王"},
|
||||
{"speaker": "consultant", "content": "王女士您好,请坐。今天想了解什么项目呢?"},
|
||||
{"speaker": "customer", "content": "我想做个面部提升,感觉脸有点下垂了"},
|
||||
{"speaker": "consultant", "content": "嗯,我看得出来您平时很注重保养。请问您今年多大年龄?"},
|
||||
{"speaker": "customer", "content": "我42了"},
|
||||
{"speaker": "consultant", "content": "42岁这个年龄段,确实容易出现轻微松弛。您之前有做过抗衰项目吗?"},
|
||||
{"speaker": "customer", "content": "做过几次普通的面部护理,但感觉效果不明显"},
|
||||
{"speaker": "consultant", "content": "普通护理主要是表层保养,对于松弛问题作用有限。您的情况需要更深层的治疗"},
|
||||
{"speaker": "customer", "content": "那有什么好的方案吗?"},
|
||||
{"speaker": "consultant", "content": "针对您的情况,我推荐热玛吉或者超声刀。这两种都是通过热能刺激深层胶原蛋白重组,达到紧致提升的效果"},
|
||||
{"speaker": "customer", "content": "这两种有什么区别?"},
|
||||
{"speaker": "consultant", "content": "热玛吉作用在真皮层,效果更自然持久,适合轻中度松弛。超声刀能到达筋膜层,提拉力度更强,适合松弛比较明显的情况"},
|
||||
{"speaker": "customer", "content": "我的情况适合哪种?"},
|
||||
{"speaker": "consultant", "content": "从您的面部状况来看,我建议选择热玛吉。您的松弛程度属于轻度,热玛吉的效果会更自然,恢复期也更短"},
|
||||
{"speaker": "customer", "content": "费用大概多少?"},
|
||||
{"speaker": "consultant", "content": "热玛吉全脸的话,我们的价格是28800元。不过您今天来的时机很好,我们正在做周年庆活动,可以优惠到23800"},
|
||||
{"speaker": "customer", "content": "还是挺贵的啊"},
|
||||
{"speaker": "consultant", "content": "王女士,我理解您的感受。但是热玛吉一次治疗效果可以维持2-3年,平均每天只要20多块钱。而且这是一次性投入,不需要反复做"},
|
||||
{"speaker": "customer", "content": "效果真的能维持那么久吗?"},
|
||||
{"speaker": "consultant", "content": "这是有科学依据的。热玛吉刺激的是您自身的胶原蛋白再生,不是外来填充,所以效果持久自然。我们有很多客户都做过,反馈都很好"},
|
||||
{"speaker": "customer", "content": "那我考虑一下吧"},
|
||||
{"speaker": "consultant", "content": "可以的。不过这个活动优惠就到本周日,下周就恢复原价了。而且名额有限,您要是确定的话最好尽快预约"},
|
||||
{"speaker": "customer", "content": "好吧,那我今天就定下来吧"}
|
||||
]
|
||||
]
|
||||
return random.choice(templates)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user