feat: 初始化考培练系统项目

- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
This commit is contained in:
111
2026-01-24 19:33:28 +08:00
commit 998211c483
1197 changed files with 228429 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""业务逻辑服务包"""

View File

@@ -0,0 +1,272 @@
"""
能力评估服务
用于分析用户对话数据,生成能力评估报告和课程推荐
使用 Python 原生实现
"""
import json
import logging
from typing import Dict, Any, List, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.ability import AbilityAssessment
from app.services.ai import ability_analysis_service
logger = logging.getLogger(__name__)
class AbilityAssessmentService:
"""能力评估服务类"""
async def analyze_yanji_conversations(
self,
user_id: int,
phone: str,
db: AsyncSession,
yanji_service,
engine: Literal["v2"] = "v2"
) -> Dict[str, Any]:
"""
分析言迹对话并生成能力评估及课程推荐
Args:
user_id: 用户ID
phone: 用户手机号(用于获取言迹数据)
db: 数据库会话
yanji_service: 言迹服务实例
engine: 引擎类型v2=Python原生
Returns:
评估结果字典,包含:
- assessment_id: 评估记录ID
- total_score: 综合评分
- dimensions: 能力维度列表
- recommended_courses: 推荐课程列表
- conversation_count: 分析的对话数量
Raises:
ValueError: 未找到员工的录音记录
Exception: API调用失败或其他错误
"""
logger.info(f"开始分析言迹对话: user_id={user_id}, phone={phone}, engine={engine}")
# 1. 获取员工对话数据最多10条录音
conversations = await yanji_service.get_employee_conversations_for_analysis(
phone=phone,
limit=10
)
if not conversations:
logger.warning(f"未找到员工的录音记录: user_id={user_id}, phone={phone}")
raise ValueError("未找到该员工的录音记录")
# 2. 合并所有对话历史
all_dialogues = []
for conv in conversations:
all_dialogues.extend(conv['dialogue_history'])
logger.info(
f"准备分析: user_id={user_id}, "
f"对话数={len(conversations)}, "
f"总轮次={len(all_dialogues)}"
)
used_engine = "v2"
# Python 原生实现
logger.info(f"调用原生能力分析服务")
# 将对话历史格式化为文本
dialogue_text = self._format_dialogues_for_analysis(all_dialogues)
# 调用原生服务
result = await ability_analysis_service.analyze(
db=db,
user_id=user_id,
dialogue_history=dialogue_text
)
if not result.success:
raise Exception(f"能力分析失败: {result.error}")
# 转换为兼容格式
analysis_result = {
"analysis": {
"total_score": result.total_score,
"ability_dimensions": [
{"name": d.name, "score": d.score, "feedback": d.feedback}
for d in result.ability_dimensions
],
"course_recommendations": [
{
"course_id": c.course_id,
"course_name": c.course_name,
"recommendation_reason": c.recommendation_reason,
"priority": c.priority,
"match_score": c.match_score,
}
for c in result.course_recommendations
]
}
}
logger.info(
f"能力分析完成 - total_score: {result.total_score}, "
f"provider: {result.ai_provider}, latency: {result.ai_latency_ms}ms"
)
# 4. 提取结果
analysis = analysis_result.get('analysis', {})
ability_dims = analysis.get('ability_dimensions', [])
course_recs = analysis.get('course_recommendations', [])
total_score = analysis.get('total_score')
logger.info(
f"分析完成 (engine={used_engine}): total_score={total_score}, "
f"dimensions={len(ability_dims)}, courses={len(course_recs)}"
)
# 5. 保存能力评估记录到数据库
assessment = AbilityAssessment(
user_id=user_id,
source_type='yanji_badge',
source_id=','.join([str(c['audio_id']) for c in conversations]),
total_score=total_score,
ability_dimensions=ability_dims,
recommended_courses=course_recs,
conversation_count=len(conversations)
)
db.add(assessment)
await db.commit()
await db.refresh(assessment)
logger.info(
f"评估记录已保存: assessment_id={assessment.id}, "
f"user_id={user_id}, total_score={total_score}"
)
# 6. 返回评估结果
return {
"assessment_id": assessment.id,
"total_score": total_score,
"dimensions": ability_dims,
"recommended_courses": course_recs,
"conversation_count": len(conversations),
"analyzed_at": assessment.analyzed_at,
"engine": used_engine,
}
def _format_dialogues_for_analysis(self, dialogues: List[Dict[str, Any]]) -> str:
"""
将对话历史列表格式化为文本
Args:
dialogues: 对话历史列表,每项包含 speaker, content 等字段
Returns:
格式化后的对话文本
"""
lines = []
for i, d in enumerate(dialogues, 1):
speaker = d.get('speaker', 'unknown')
content = d.get('content', '')
# 统一说话者标识
if speaker in ['consultant', 'employee', 'user', '员工']:
speaker_label = '员工'
elif speaker in ['customer', 'client', '顾客', '客户']:
speaker_label = '顾客'
else:
speaker_label = speaker
lines.append(f"[{i}] {speaker_label}: {content}")
return '\n'.join(lines)
async def get_user_assessment_history(
self,
user_id: int,
db: AsyncSession,
limit: int = 10
) -> List[Dict[str, Any]]:
"""
获取用户的能力评估历史记录
Args:
user_id: 用户ID
db: 数据库会话
limit: 返回记录数量限制
Returns:
评估历史记录列表
"""
stmt = (
select(AbilityAssessment)
.where(AbilityAssessment.user_id == user_id)
.order_by(AbilityAssessment.analyzed_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
assessments = result.scalars().all()
history = []
for assessment in assessments:
history.append({
"id": assessment.id,
"source_type": assessment.source_type,
"total_score": assessment.total_score,
"ability_dimensions": assessment.ability_dimensions,
"recommended_courses": assessment.recommended_courses,
"conversation_count": assessment.conversation_count,
"analyzed_at": assessment.analyzed_at.isoformat() if assessment.analyzed_at else None,
"created_at": assessment.created_at.isoformat() if assessment.created_at else None
})
logger.info(f"获取评估历史: user_id={user_id}, count={len(history)}")
return history
async def get_assessment_detail(
self,
assessment_id: int,
db: AsyncSession
) -> Dict[str, Any]:
"""
获取单个评估记录的详细信息
Args:
assessment_id: 评估记录ID
db: 数据库会话
Returns:
评估详细信息
Raises:
ValueError: 评估记录不存在
"""
stmt = select(AbilityAssessment).where(AbilityAssessment.id == assessment_id)
result = await db.execute(stmt)
assessment = result.scalar_one_or_none()
if not assessment:
raise ValueError(f"评估记录不存在: assessment_id={assessment_id}")
return {
"id": assessment.id,
"user_id": assessment.user_id,
"source_type": assessment.source_type,
"source_id": assessment.source_id,
"total_score": assessment.total_score,
"ability_dimensions": assessment.ability_dimensions,
"recommended_courses": assessment.recommended_courses,
"conversation_count": assessment.conversation_count,
"analyzed_at": assessment.analyzed_at.isoformat() if assessment.analyzed_at else None,
"created_at": assessment.created_at.isoformat() if assessment.created_at else None
}
def get_ability_assessment_service() -> AbilityAssessmentService:
"""获取能力评估服务实例(依赖注入)"""
return AbilityAssessmentService()

View File

@@ -0,0 +1,151 @@
"""
AI 服务模块
包含:
- AIService: 本地 AI 服务(支持 4sapi + OpenRouter 降级)
- LLM JSON Parser: 大模型 JSON 输出解析器
- KnowledgeAnalysisServiceV2: 知识点分析服务Python 原生实现)
- ExamGeneratorService: 试题生成服务Python 原生实现)
- CourseChatServiceV2: 课程对话服务Python 原生实现)
- PracticeSceneService: 陪练场景准备服务Python 原生实现)
- AbilityAnalysisService: 智能工牌能力分析服务Python 原生实现)
- AnswerJudgeService: 答案判断服务Python 原生实现)
- PracticeAnalysisService: 陪练分析报告服务Python 原生实现)
"""
from .ai_service import (
AIService,
AIResponse,
AIConfig,
AIServiceError,
AIProvider,
DEFAULT_MODEL,
MODEL_ANALYSIS,
MODEL_CREATIVE,
MODEL_IMAGE_GEN,
quick_chat,
)
from .llm_json_parser import (
parse_llm_json,
parse_with_fallback,
safe_json_loads,
clean_llm_output,
diagnose_json_error,
validate_json_schema,
ParseResult,
JSONParseError,
JSONUnrecoverableError,
)
from .knowledge_analysis_v2 import (
KnowledgeAnalysisServiceV2,
knowledge_analysis_service_v2,
)
from .exam_generator_service import (
ExamGeneratorService,
ExamGeneratorConfig,
exam_generator_service,
generate_exam,
)
from .course_chat_service import (
CourseChatServiceV2,
course_chat_service_v2,
)
from .practice_scene_service import (
PracticeSceneService,
PracticeScene,
PracticeSceneResult,
practice_scene_service,
prepare_practice_knowledge,
)
from .ability_analysis_service import (
AbilityAnalysisService,
AbilityAnalysisResult,
AbilityDimension,
CourseRecommendation,
ability_analysis_service,
)
from .answer_judge_service import (
AnswerJudgeService,
JudgeResult,
answer_judge_service,
judge_answer,
)
from .practice_analysis_service import (
PracticeAnalysisService,
PracticeAnalysisResult,
ScoreBreakdownItem,
AbilityDimensionItem,
DialogueAnnotation,
Suggestion,
practice_analysis_service,
analyze_practice_session,
)
__all__ = [
# AI Service
"AIService",
"AIResponse",
"AIConfig",
"AIServiceError",
"AIProvider",
"DEFAULT_MODEL",
"MODEL_ANALYSIS",
"MODEL_CREATIVE",
"MODEL_IMAGE_GEN",
"quick_chat",
# JSON Parser
"parse_llm_json",
"parse_with_fallback",
"safe_json_loads",
"clean_llm_output",
"diagnose_json_error",
"validate_json_schema",
"ParseResult",
"JSONParseError",
"JSONUnrecoverableError",
# Knowledge Analysis V2
"KnowledgeAnalysisServiceV2",
"knowledge_analysis_service_v2",
# Exam Generator V2
"ExamGeneratorService",
"ExamGeneratorConfig",
"exam_generator_service",
"generate_exam",
# Course Chat V2
"CourseChatServiceV2",
"course_chat_service_v2",
# Practice Scene V2
"PracticeSceneService",
"PracticeScene",
"PracticeSceneResult",
"practice_scene_service",
"prepare_practice_knowledge",
# Ability Analysis V2
"AbilityAnalysisService",
"AbilityAnalysisResult",
"AbilityDimension",
"CourseRecommendation",
"ability_analysis_service",
# Answer Judge V2
"AnswerJudgeService",
"JudgeResult",
"answer_judge_service",
"judge_answer",
# Practice Analysis V2
"PracticeAnalysisService",
"PracticeAnalysisResult",
"ScoreBreakdownItem",
"AbilityDimensionItem",
"DialogueAnnotation",
"Suggestion",
"practice_analysis_service",
"analyze_practice_session",
]

View File

@@ -0,0 +1,479 @@
"""
智能工牌能力分析与课程推荐服务 - Python 原生实现
功能:
- 分析员工与顾客的对话记录
- 评估多维度能力得分
- 基于能力短板推荐课程
提供稳定可靠的能力分析和课程推荐能力。
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ExternalServiceError
from .ai_service import AIService, AIResponse
from .llm_json_parser import parse_with_fallback, clean_llm_output
from .prompts.ability_analysis_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
ABILITY_ANALYSIS_SCHEMA,
ABILITY_DIMENSIONS,
)
logger = logging.getLogger(__name__)
# ==================== 数据结构 ====================
@dataclass
class AbilityDimension:
"""能力维度评分"""
name: str
score: float
feedback: str
@dataclass
class CourseRecommendation:
"""课程推荐"""
course_id: int
course_name: str
recommendation_reason: str
priority: str # high, medium, low
match_score: float
@dataclass
class AbilityAnalysisResult:
"""能力分析结果"""
success: bool
total_score: float = 0.0
ability_dimensions: List[AbilityDimension] = field(default_factory=list)
course_recommendations: List[CourseRecommendation] = field(default_factory=list)
ai_provider: str = ""
ai_model: str = ""
ai_tokens: int = 0
ai_latency_ms: int = 0
error: str = ""
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"success": self.success,
"total_score": self.total_score,
"ability_dimensions": [
{"name": d.name, "score": d.score, "feedback": d.feedback}
for d in self.ability_dimensions
],
"course_recommendations": [
{
"course_id": c.course_id,
"course_name": c.course_name,
"recommendation_reason": c.recommendation_reason,
"priority": c.priority,
"match_score": c.match_score,
}
for c in self.course_recommendations
],
"ai_provider": self.ai_provider,
"ai_model": self.ai_model,
"ai_tokens": self.ai_tokens,
"ai_latency_ms": self.ai_latency_ms,
"error": self.error,
}
@dataclass
class UserPositionInfo:
"""用户岗位信息"""
position_id: int
position_name: str
code: str
description: str
skills: Optional[Dict[str, Any]]
level: str
status: str
@dataclass
class CourseInfo:
"""课程信息"""
id: int
name: str
description: str
category: str
tags: Optional[List[str]]
difficulty_level: int
duration_hours: float
# ==================== 服务类 ====================
class AbilityAnalysisService:
"""
智能工牌能力分析服务
使用 Python 原生实现。
使用示例:
```python
service = AbilityAnalysisService()
result = await service.analyze(
db=db_session,
user_id=1,
dialogue_history="顾客:你好,我想了解一下你们的服务..."
)
print(result.total_score)
print(result.course_recommendations)
```
"""
def __init__(self):
"""初始化服务"""
self.ai_service = AIService(module_code="ability_analysis")
async def analyze(
self,
db: AsyncSession,
user_id: int,
dialogue_history: str
) -> AbilityAnalysisResult:
"""
分析员工能力并推荐课程
Args:
db: 数据库会话(支持多租户,每个租户传入各自的会话)
user_id: 用户ID
dialogue_history: 对话记录
Returns:
AbilityAnalysisResult 分析结果
"""
try:
logger.info(f"开始能力分析 - user_id: {user_id}")
# 1. 验证输入
if not dialogue_history or not dialogue_history.strip():
return AbilityAnalysisResult(
success=False,
error="对话记录不能为空"
)
# 2. 查询用户岗位信息
user_positions = await self._get_user_positions(db, user_id)
user_info_str = self._format_user_info(user_positions)
logger.info(f"用户岗位信息: {len(user_positions)} 个岗位")
# 3. 查询所有可选课程
courses = await self._get_published_courses(db)
courses_str = self._format_courses(courses)
logger.info(f"可选课程: {len(courses)}")
# 4. 调用 AI 分析
ai_response = await self._call_ai_analysis(
dialogue_history=dialogue_history,
user_info=user_info_str,
courses=courses_str
)
logger.info(
f"AI 分析完成 - provider: {ai_response.provider}, "
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
)
# 5. 解析 JSON 结果
analysis_data = self._parse_analysis_result(ai_response.content, courses)
# 6. 构建返回结果
result = AbilityAnalysisResult(
success=True,
total_score=analysis_data.get("total_score", 0),
ability_dimensions=[
AbilityDimension(
name=d.get("name", ""),
score=d.get("score", 0),
feedback=d.get("feedback", "")
)
for d in analysis_data.get("ability_dimensions", [])
],
course_recommendations=[
CourseRecommendation(
course_id=c.get("course_id", 0),
course_name=c.get("course_name", ""),
recommendation_reason=c.get("recommendation_reason", ""),
priority=c.get("priority", "medium"),
match_score=c.get("match_score", 0)
)
for c in analysis_data.get("course_recommendations", [])
],
ai_provider=ai_response.provider,
ai_model=ai_response.model,
ai_tokens=ai_response.total_tokens,
ai_latency_ms=ai_response.latency_ms,
)
logger.info(
f"能力分析完成 - user_id: {user_id}, total_score: {result.total_score}, "
f"recommendations: {len(result.course_recommendations)}"
)
return result
except Exception as e:
logger.error(
f"能力分析失败 - user_id: {user_id}, error: {e}",
exc_info=True
)
return AbilityAnalysisResult(
success=False,
error=str(e)
)
async def _get_user_positions(
self,
db: AsyncSession,
user_id: int
) -> List[UserPositionInfo]:
"""
查询用户的岗位信息
获取用户基本信息
"""
query = text("""
SELECT
p.id as position_id,
p.name as position_name,
p.code,
p.description,
p.skills,
p.level,
p.status
FROM positions p
INNER JOIN position_members pm ON p.id = pm.position_id
WHERE pm.user_id = :user_id
AND pm.is_deleted = 0
AND p.is_deleted = 0
""")
result = await db.execute(query, {"user_id": user_id})
rows = result.fetchall()
positions = []
for row in rows:
# 解析 skills JSON
skills = None
if row.skills:
if isinstance(row.skills, str):
try:
skills = json.loads(row.skills)
except json.JSONDecodeError:
skills = None
else:
skills = row.skills
positions.append(UserPositionInfo(
position_id=row.position_id,
position_name=row.position_name,
code=row.code or "",
description=row.description or "",
skills=skills,
level=row.level or "",
status=row.status or ""
))
return positions
async def _get_published_courses(self, db: AsyncSession) -> List[CourseInfo]:
"""
查询所有已发布的课程
获取所有课程列表
"""
query = text("""
SELECT
id,
name,
description,
category,
tags,
difficulty_level,
duration_hours
FROM courses
WHERE status = 'published'
AND is_deleted = FALSE
ORDER BY sort_order
""")
result = await db.execute(query)
rows = result.fetchall()
courses = []
for row in rows:
# 解析 tags JSON
tags = None
if row.tags:
if isinstance(row.tags, str):
try:
tags = json.loads(row.tags)
except json.JSONDecodeError:
tags = None
else:
tags = row.tags
courses.append(CourseInfo(
id=row.id,
name=row.name,
description=row.description or "",
category=row.category or "",
tags=tags,
difficulty_level=row.difficulty_level or 3,
duration_hours=row.duration_hours or 0
))
return courses
def _format_user_info(self, positions: List[UserPositionInfo]) -> str:
"""格式化用户岗位信息为文本"""
if not positions:
return "暂无岗位信息"
lines = []
for p in positions:
info = f"- 岗位:{p.position_name}{p.code}"
if p.level:
info += f",级别:{p.level}"
if p.description:
info += f"\n 描述:{p.description}"
if p.skills:
skills_str = json.dumps(p.skills, ensure_ascii=False)
info += f"\n 核心技能:{skills_str}"
lines.append(info)
return "\n".join(lines)
def _format_courses(self, courses: List[CourseInfo]) -> str:
"""格式化课程列表为文本"""
if not courses:
return "暂无可选课程"
lines = []
for c in courses:
info = f"- ID: {c.id}, 课程名称: {c.name}"
if c.category:
info += f", 分类: {c.category}"
if c.difficulty_level:
info += f", 难度: {c.difficulty_level}"
if c.duration_hours:
info += f", 时长: {c.duration_hours}小时"
if c.description:
# 截断过长的描述
desc = c.description[:100] + "..." if len(c.description) > 100 else c.description
info += f"\n 描述: {desc}"
lines.append(info)
return "\n".join(lines)
async def _call_ai_analysis(
self,
dialogue_history: str,
user_info: str,
courses: str
) -> AIResponse:
"""调用 AI 进行能力分析"""
# 构建用户消息
user_message = USER_PROMPT.format(
dialogue_history=dialogue_history,
user_info=user_info,
courses=courses
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_message}
]
# 调用 AI自动支持 4sapi → OpenRouter 降级)
response = await self.ai_service.chat(
messages=messages,
temperature=0.7, # 保持一定创意性
prompt_name="ability_analysis"
)
return response
def _parse_analysis_result(
self,
ai_output: str,
courses: List[CourseInfo]
) -> Dict[str, Any]:
"""
解析 AI 输出的分析结果 JSON
使用 LLM JSON Parser 进行多层兜底解析
"""
# 先清洗输出
cleaned_output, rules = clean_llm_output(ai_output)
if rules:
logger.debug(f"AI 输出已清洗: {rules}")
# 使用带 Schema 校验的解析
parsed = parse_with_fallback(
cleaned_output,
schema=ABILITY_ANALYSIS_SCHEMA,
default={"analysis": {}},
validate_schema=True,
on_error="default"
)
# 提取 analysis 部分
analysis = parsed.get("analysis", {})
# 后处理:验证课程推荐的有效性
valid_course_ids = {c.id for c in courses}
valid_recommendations = []
for rec in analysis.get("course_recommendations", []):
course_id = rec.get("course_id")
if course_id in valid_course_ids:
valid_recommendations.append(rec)
else:
logger.warning(f"推荐的课程ID不存在: {course_id}")
analysis["course_recommendations"] = valid_recommendations
# 确保能力维度完整
existing_dims = {d.get("name") for d in analysis.get("ability_dimensions", [])}
for dim_name in ABILITY_DIMENSIONS:
if dim_name not in existing_dims:
logger.warning(f"缺少能力维度: {dim_name},使用默认值")
analysis.setdefault("ability_dimensions", []).append({
"name": dim_name,
"score": 70,
"feedback": "暂无具体评价"
})
return analysis
# ==================== 全局实例 ====================
ability_analysis_service = AbilityAnalysisService()

View File

@@ -0,0 +1,747 @@
"""
本地 AI 服务 - 遵循瑞小美 AI 接入规范
功能:
- 支持 4sapi.com首选和 OpenRouter备选自动降级
- 统一的请求/响应格式
- 调用日志记录
"""
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from enum import Enum
import httpx
logger = logging.getLogger(__name__)
class AIProvider(Enum):
"""AI 服务商"""
PRIMARY = "4sapi" # 首选4sapi.com
FALLBACK = "openrouter" # 备选OpenRouter
@dataclass
class AIResponse:
"""AI 响应结果"""
content: str # AI 回复内容
model: str = "" # 使用的模型
provider: str = "" # 实际使用的服务商
input_tokens: int = 0 # 输入 token 数
output_tokens: int = 0 # 输出 token 数
total_tokens: int = 0 # 总 token 数
cost: float = 0.0 # 费用(美元)
latency_ms: int = 0 # 响应延迟(毫秒)
raw_response: Dict[str, Any] = field(default_factory=dict) # 原始响应
images: List[str] = field(default_factory=list) # 图像生成结果
annotations: Dict[str, Any] = field(default_factory=dict) # PDF 解析注释
@dataclass
class AIConfig:
"""AI 服务配置"""
primary_api_key: str # 通用 KeyGemini/DeepSeek 等)
anthropic_api_key: str = "" # Claude 专属 Key
primary_base_url: str = "https://4sapi.com/v1"
fallback_api_key: str = ""
fallback_base_url: str = "https://openrouter.ai/api/v1"
default_model: str = "claude-opus-4-5-20251101-thinking" # 默认使用最强模型
timeout: float = 120.0
max_retries: int = 2
# Claude 模型列表(需要使用 anthropic_api_key
CLAUDE_MODELS = [
"claude-opus-4-5-20251101-thinking",
"claude-opus-4-5-20251101",
"claude-sonnet-4-20250514",
"claude-3-opus",
"claude-3-sonnet",
"claude-3-haiku",
]
def is_claude_model(model: str) -> bool:
"""判断是否为 Claude 模型"""
model_lower = model.lower()
return any(claude in model_lower for claude in ["claude", "anthropic"])
# 模型名称映射4sapi -> OpenRouter
MODEL_MAPPING = {
# 4sapi 使用简短名称OpenRouter 使用完整路径
"gemini-3-flash-preview": "google/gemini-3-flash-preview",
"gemini-3-pro-preview": "google/gemini-3-pro-preview",
"claude-opus-4-5-20251101-thinking": "anthropic/claude-opus-4.5",
"gemini-2.5-flash-image-preview": "google/gemini-2.0-flash-exp:free",
}
# 反向映射OpenRouter -> 4sapi
MODEL_MAPPING_REVERSE = {v: k for k, v in MODEL_MAPPING.items()}
class AIServiceError(Exception):
"""AI 服务错误"""
def __init__(self, message: str, provider: str = "", status_code: int = 0):
super().__init__(message)
self.provider = provider
self.status_code = status_code
class AIService:
"""
本地 AI 服务
遵循瑞小美 AI 接入规范:
- 首选 4sapi.com失败自动降级到 OpenRouter
- 统一的响应格式
- 自动模型名称转换
使用示例:
```python
ai = AIService(module_code="knowledge_analysis")
response = await ai.chat(
messages=[
{"role": "system", "content": "你是助手"},
{"role": "user", "content": "你好"}
],
prompt_name="greeting"
)
print(response.content)
```
"""
def __init__(
self,
module_code: str = "default",
config: Optional[AIConfig] = None,
db_session: Any = None
):
"""
初始化 AI 服务
配置加载优先级(遵循瑞小美 AI 接入规范):
1. 显式传入的 config 参数
2. 数据库 ai_config 表(推荐)
3. 环境变量fallback
Args:
module_code: 模块标识,用于统计
config: AI 配置None 则从数据库/环境变量读取
db_session: 数据库会话,用于记录调用日志和读取配置
"""
self.module_code = module_code
self.db_session = db_session
self.config = config or self._load_config(db_session)
logger.info(f"AIService 初始化: module={module_code}, primary={self.config.primary_base_url}")
def _load_config(self, db_session: Any) -> AIConfig:
"""
加载配置
配置加载优先级(遵循瑞小美 AI 接入规范):
1. 管理库 tenant_configs 表(推荐,通过 DynamicConfig
2. 环境变量fallback
Args:
db_session: 数据库会话(可选,用于日志记录)
Returns:
AIConfig 配置对象
"""
# 优先从管理库加载(同步方式)
try:
config = self._load_config_from_admin_db()
if config:
logger.info("✅ AI 配置已从管理库tenant_configs加载")
return config
except Exception as e:
logger.debug(f"从管理库加载 AI 配置失败: {e}")
# Fallback 到环境变量
logger.info("AI 配置从环境变量加载")
return self._load_config_from_env()
def _load_config_from_admin_db(self) -> Optional[AIConfig]:
"""
从管理库 tenant_configs 表加载配置
使用同步方式直接查询 kaopeilian_admin.tenant_configs 表
Returns:
AIConfig 配置对象,如果无数据则返回 None
"""
import os
# 获取当前租户编码
tenant_code = os.getenv("TENANT_CODE", "demo")
# 获取管理库连接信息
admin_db_host = os.getenv("ADMIN_DB_HOST", "prod-mysql")
admin_db_port = int(os.getenv("ADMIN_DB_PORT", "3306"))
admin_db_user = os.getenv("ADMIN_DB_USER", "root")
admin_db_password = os.getenv("ADMIN_DB_PASSWORD", "")
admin_db_name = os.getenv("ADMIN_DB_NAME", "kaopeilian_admin")
if not admin_db_password:
logger.debug("ADMIN_DB_PASSWORD 未配置,跳过管理库配置加载")
return None
try:
from sqlalchemy import create_engine, text
import urllib.parse
# 构建连接 URL
encoded_password = urllib.parse.quote_plus(admin_db_password)
admin_db_url = f"mysql+pymysql://{admin_db_user}:{encoded_password}@{admin_db_host}:{admin_db_port}/{admin_db_name}?charset=utf8mb4"
engine = create_engine(admin_db_url, pool_pre_ping=True)
with engine.connect() as conn:
# 1. 获取租户 ID
result = conn.execute(
text("SELECT id FROM tenants WHERE code = :code AND status = 'active'"),
{"code": tenant_code}
)
row = result.fetchone()
if not row:
logger.debug(f"租户 {tenant_code} 不存在或未激活")
engine.dispose()
return None
tenant_id = row[0]
# 2. 获取 AI 配置
result = conn.execute(
text("""
SELECT config_key, config_value
FROM tenant_configs
WHERE tenant_id = :tenant_id AND config_group = 'ai'
"""),
{"tenant_id": tenant_id}
)
rows = result.fetchall()
engine.dispose()
if not rows:
logger.debug(f"租户 {tenant_code} 无 AI 配置")
return None
# 转换为字典
config_dict = {row[0]: row[1] for row in rows}
# 检查必要的配置是否存在
primary_key = config_dict.get("AI_PRIMARY_API_KEY", "")
if not primary_key:
logger.warning(f"租户 {tenant_code} 的 AI_PRIMARY_API_KEY 为空")
return None
logger.info(f"✅ 从管理库加载租户 {tenant_code} 的 AI 配置成功")
return AIConfig(
primary_api_key=primary_key,
anthropic_api_key=config_dict.get("AI_ANTHROPIC_API_KEY", ""),
primary_base_url=config_dict.get("AI_PRIMARY_BASE_URL", "https://4sapi.com/v1"),
fallback_api_key=config_dict.get("AI_FALLBACK_API_KEY", ""),
fallback_base_url=config_dict.get("AI_FALLBACK_BASE_URL", "https://openrouter.ai/api/v1"),
default_model=config_dict.get("AI_DEFAULT_MODEL", "claude-opus-4-5-20251101-thinking"),
timeout=float(config_dict.get("AI_TIMEOUT", "120")),
)
except Exception as e:
logger.debug(f"从管理库读取 AI 配置异常: {e}")
return None
def _load_config_from_env(self) -> AIConfig:
"""
从环境变量加载配置
⚠️ 强制要求(遵循瑞小美 AI 接入规范):
- 禁止在代码中硬编码 API Key
- 必须通过环境变量配置 Key
必须配置的环境变量:
- AI_PRIMARY_API_KEY: 通用 Key用于 Gemini/DeepSeek 等)
- AI_ANTHROPIC_API_KEY: Claude 专属 Key
"""
import os
primary_api_key = os.getenv("AI_PRIMARY_API_KEY", "")
anthropic_api_key = os.getenv("AI_ANTHROPIC_API_KEY", "")
# 检查必要的 Key 是否已配置
if not primary_api_key:
logger.warning("⚠️ AI_PRIMARY_API_KEY 未配置AI 服务可能无法正常工作")
if not anthropic_api_key:
logger.warning("⚠️ AI_ANTHROPIC_API_KEY 未配置Claude 模型调用将失败")
return AIConfig(
# 通用 KeyGemini/DeepSeek 等非 Anthropic 模型)
primary_api_key=primary_api_key,
# Claude 专属 Key
anthropic_api_key=anthropic_api_key,
primary_base_url=os.getenv("AI_PRIMARY_BASE_URL", "https://4sapi.com/v1"),
fallback_api_key=os.getenv("AI_FALLBACK_API_KEY", ""),
fallback_base_url=os.getenv("AI_FALLBACK_BASE_URL", "https://openrouter.ai/api/v1"),
# 默认模型:遵循"优先最强"原则,使用 Claude Opus 4.5
default_model=os.getenv("AI_DEFAULT_MODEL", "claude-opus-4-5-20251101-thinking"),
timeout=float(os.getenv("AI_TIMEOUT", "120")),
)
def _convert_model_name(self, model: str, provider: AIProvider) -> str:
"""
转换模型名称以匹配服务商格式
Args:
model: 原始模型名称
provider: 目标服务商
Returns:
转换后的模型名称
"""
if provider == AIProvider.FALLBACK:
# 4sapi -> OpenRouter
return MODEL_MAPPING.get(model, f"google/{model}" if "/" not in model else model)
else:
# OpenRouter -> 4sapi
return MODEL_MAPPING_REVERSE.get(model, model.split("/")[-1] if "/" in model else model)
async def chat(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
prompt_name: str = "default",
**kwargs
) -> AIResponse:
"""
文本聊天
Args:
messages: 消息列表 [{"role": "system/user/assistant", "content": "..."}]
model: 模型名称None 使用默认模型
temperature: 温度参数
max_tokens: 最大输出 token 数
prompt_name: 提示词名称,用于统计
**kwargs: 其他参数
Returns:
AIResponse 响应对象
"""
model = model or self.config.default_model
# 构建请求体
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
payload["max_tokens"] = max_tokens
# 首选服务商
try:
return await self._call_provider(
provider=AIProvider.PRIMARY,
endpoint="/chat/completions",
payload=payload,
prompt_name=prompt_name
)
except AIServiceError as e:
logger.warning(f"首选服务商调用失败: {e}, 尝试降级到备选服务商")
# 如果没有备选 API Key直接抛出异常
if not self.config.fallback_api_key:
raise
# 降级到备选服务商
# 转换模型名称
fallback_model = self._convert_model_name(model, AIProvider.FALLBACK)
payload["model"] = fallback_model
return await self._call_provider(
provider=AIProvider.FALLBACK,
endpoint="/chat/completions",
payload=payload,
prompt_name=prompt_name
)
async def chat_stream(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
prompt_name: str = "default",
**kwargs
) -> AsyncGenerator[str, None]:
"""
流式文本聊天
Args:
messages: 消息列表 [{"role": "system/user/assistant", "content": "..."}]
model: 模型名称None 使用默认模型
temperature: 温度参数
max_tokens: 最大输出 token 数
prompt_name: 提示词名称,用于统计
**kwargs: 其他参数
Yields:
str: 文本块(逐字返回)
"""
model = model or self.config.default_model
# 构建请求体
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"stream": True,
}
if max_tokens:
payload["max_tokens"] = max_tokens
# 首选服务商
try:
async for chunk in self._call_provider_stream(
provider=AIProvider.PRIMARY,
endpoint="/chat/completions",
payload=payload,
prompt_name=prompt_name
):
yield chunk
return
except AIServiceError as e:
logger.warning(f"首选服务商流式调用失败: {e}, 尝试降级到备选服务商")
# 如果没有备选 API Key直接抛出异常
if not self.config.fallback_api_key:
raise
# 降级到备选服务商
# 转换模型名称
fallback_model = self._convert_model_name(model, AIProvider.FALLBACK)
payload["model"] = fallback_model
async for chunk in self._call_provider_stream(
provider=AIProvider.FALLBACK,
endpoint="/chat/completions",
payload=payload,
prompt_name=prompt_name
):
yield chunk
async def _call_provider_stream(
self,
provider: AIProvider,
endpoint: str,
payload: Dict[str, Any],
prompt_name: str
) -> AsyncGenerator[str, None]:
"""
流式调用指定服务商
Args:
provider: 服务商
endpoint: API 端点
payload: 请求体
prompt_name: 提示词名称
Yields:
str: 文本块
"""
# 获取配置
if provider == AIProvider.PRIMARY:
base_url = self.config.primary_base_url
# 根据模型选择 API KeyClaude 用专属 Key其他用通用 Key
model = payload.get("model", "")
if is_claude_model(model) and self.config.anthropic_api_key:
api_key = self.config.anthropic_api_key
logger.debug(f"[Stream] 使用 Claude 专属 Key 调用模型: {model}")
else:
api_key = self.config.primary_api_key
else:
api_key = self.config.fallback_api_key
base_url = self.config.fallback_base_url
url = f"{base_url.rstrip('/')}{endpoint}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# OpenRouter 需要额外的 header
if provider == AIProvider.FALLBACK:
headers["HTTP-Referer"] = "https://kaopeilian.ireborn.com.cn"
headers["X-Title"] = "KaoPeiLian"
start_time = time.time()
try:
timeout = httpx.Timeout(self.config.timeout, connect=10.0)
async with httpx.AsyncClient(timeout=timeout) as client:
logger.info(f"流式调用 AI 服务: provider={provider.value}, model={payload.get('model')}")
async with client.stream("POST", url, json=payload, headers=headers) as response:
# 检查响应状态
if response.status_code != 200:
error_text = await response.aread()
logger.error(f"AI 服务流式返回错误: status={response.status_code}, body={error_text[:500]}")
raise AIServiceError(
f"API 流式请求失败: HTTP {response.status_code}",
provider=provider.value,
status_code=response.status_code
)
# 处理 SSE 流
async for line in response.aiter_lines():
if not line or not line.strip():
continue
# 解析 SSE 数据行
if line.startswith("data: "):
data_str = line[6:] # 移除 "data: " 前缀
# 检查是否是结束标记
if data_str.strip() == "[DONE]":
logger.info(f"流式响应完成: provider={provider.value}")
return
try:
event_data = json.loads(data_str)
# 提取 delta 内容
choices = event_data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError as e:
logger.debug(f"解析流式数据失败: {e} - 数据: {data_str[:100]}")
continue
latency_ms = int((time.time() - start_time) * 1000)
logger.info(f"流式调用完成: provider={provider.value}, latency={latency_ms}ms")
except httpx.TimeoutException:
latency_ms = int((time.time() - start_time) * 1000)
logger.error(f"AI 服务流式超时: provider={provider.value}, latency={latency_ms}ms")
raise AIServiceError(f"流式请求超时({self.config.timeout}秒)", provider=provider.value)
except httpx.RequestError as e:
logger.error(f"AI 服务流式网络错误: provider={provider.value}, error={e}")
raise AIServiceError(f"流式网络错误: {e}", provider=provider.value)
async def _call_provider(
self,
provider: AIProvider,
endpoint: str,
payload: Dict[str, Any],
prompt_name: str
) -> AIResponse:
"""
调用指定服务商
Args:
provider: 服务商
endpoint: API 端点
payload: 请求体
prompt_name: 提示词名称
Returns:
AIResponse 响应对象
"""
# 获取配置
if provider == AIProvider.PRIMARY:
base_url = self.config.primary_base_url
# 根据模型选择 API KeyClaude 用专属 Key其他用通用 Key
model = payload.get("model", "")
if is_claude_model(model) and self.config.anthropic_api_key:
api_key = self.config.anthropic_api_key
logger.debug(f"使用 Claude 专属 Key 调用模型: {model}")
else:
api_key = self.config.primary_api_key
else:
api_key = self.config.fallback_api_key
base_url = self.config.fallback_base_url
url = f"{base_url.rstrip('/')}{endpoint}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# OpenRouter 需要额外的 header
if provider == AIProvider.FALLBACK:
headers["HTTP-Referer"] = "https://kaopeilian.ireborn.com.cn"
headers["X-Title"] = "KaoPeiLian"
start_time = time.time()
try:
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
logger.info(f"调用 AI 服务: provider={provider.value}, model={payload.get('model')}")
response = await client.post(url, json=payload, headers=headers)
latency_ms = int((time.time() - start_time) * 1000)
# 检查响应状态
if response.status_code != 200:
error_text = response.text
logger.error(f"AI 服务返回错误: status={response.status_code}, body={error_text[:500]}")
raise AIServiceError(
f"API 请求失败: HTTP {response.status_code}",
provider=provider.value,
status_code=response.status_code
)
data = response.json()
# 解析响应
ai_response = self._parse_response(data, provider, latency_ms)
# 记录日志
logger.info(
f"AI 调用成功: provider={provider.value}, model={ai_response.model}, "
f"tokens={ai_response.total_tokens}, latency={latency_ms}ms"
)
# 保存到数据库(如果有 session
await self._log_call(prompt_name, ai_response)
return ai_response
except httpx.TimeoutException:
latency_ms = int((time.time() - start_time) * 1000)
logger.error(f"AI 服务超时: provider={provider.value}, latency={latency_ms}ms")
raise AIServiceError(f"请求超时({self.config.timeout}秒)", provider=provider.value)
except httpx.RequestError as e:
logger.error(f"AI 服务网络错误: provider={provider.value}, error={e}")
raise AIServiceError(f"网络错误: {e}", provider=provider.value)
def _parse_response(
self,
data: Dict[str, Any],
provider: AIProvider,
latency_ms: int
) -> AIResponse:
"""解析 API 响应"""
# 提取内容
choices = data.get("choices", [])
if not choices:
raise AIServiceError("响应中没有 choices")
message = choices[0].get("message", {})
content = message.get("content", "")
# 提取 usage
usage = data.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", input_tokens + output_tokens)
# 提取费用(如果有)
cost = usage.get("total_cost", 0.0)
return AIResponse(
content=content,
model=data.get("model", ""),
provider=provider.value,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
cost=cost,
latency_ms=latency_ms,
raw_response=data
)
async def _log_call(self, prompt_name: str, response: AIResponse) -> None:
"""记录调用日志到数据库"""
if not self.db_session:
return
try:
# TODO: 实现调用日志记录
# 可以参考 ai_call_logs 表结构
pass
except Exception as e:
logger.warning(f"记录 AI 调用日志失败: {e}")
async def analyze_document(
self,
content: str,
prompt: str,
model: Optional[str] = None,
prompt_name: str = "document_analysis"
) -> AIResponse:
"""
分析文档内容
Args:
content: 文档内容
prompt: 分析提示词
model: 模型名称
prompt_name: 提示词名称
Returns:
AIResponse 响应对象
"""
messages = [
{"role": "user", "content": f"{prompt}\n\n文档内容:\n{content}"}
]
return await self.chat(
messages=messages,
model=model,
temperature=0.1, # 文档分析使用低温度
prompt_name=prompt_name
)
# 便捷函数
async def quick_chat(
messages: List[Dict[str, str]],
model: Optional[str] = None,
module_code: str = "quick"
) -> str:
"""
快速聊天,返回纯文本
Args:
messages: 消息列表
model: 模型名称
module_code: 模块标识
Returns:
AI 回复的文本内容
"""
ai = AIService(module_code=module_code)
response = await ai.chat(messages, model=model)
return response.content
# 模型常量(遵循瑞小美 AI 接入规范)
# 按优先级排序:首选 > 标准 > 快速
MODEL_PRIMARY = "claude-opus-4-5-20251101-thinking" # 🥇 首选:所有任务首先尝试
MODEL_STANDARD = "gemini-3-pro-preview" # 🥈 标准Claude 失败后降级
MODEL_FAST = "gemini-3-flash-preview" # 🥉 快速:最终保底
MODEL_IMAGE = "gemini-2.5-flash-image-preview" # 🖼️ 图像生成专用
MODEL_VIDEO = "veo3.1-pro" # 🎬 视频生成专用
# 兼容旧代码的别名
DEFAULT_MODEL = MODEL_PRIMARY # 默认使用最强模型
MODEL_ANALYSIS = MODEL_PRIMARY
MODEL_CREATIVE = MODEL_STANDARD
MODEL_IMAGE_GEN = MODEL_IMAGE

View File

@@ -0,0 +1,197 @@
"""
答案判断服务 - Python 原生实现
功能:
- 判断填空题与问答题的答案是否正确
- 通过 AI 语义理解比对用户答案与标准答案
提供稳定可靠的答案判断能力。
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional
from .ai_service import AIService, AIResponse
from .prompts.answer_judge_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
CORRECT_KEYWORDS,
INCORRECT_KEYWORDS,
)
logger = logging.getLogger(__name__)
@dataclass
class JudgeResult:
"""判断结果"""
is_correct: bool
raw_response: str
ai_provider: str = ""
ai_model: str = ""
ai_tokens: int = 0
ai_latency_ms: int = 0
class AnswerJudgeService:
"""
答案判断服务
使用 Python 原生实现。
使用示例:
```python
service = AnswerJudgeService()
result = await service.judge(
db=db_session, # 传入 db_session 用于记录调用日志
question="玻尿酸的主要作用是什么?",
correct_answer="补水保湿、填充塑形",
user_answer="保湿和塑形",
analysis="玻尿酸具有补水保湿和填充塑形两大功能"
)
print(result.is_correct) # True
```
"""
MODULE_CODE = "answer_judge"
async def judge(
self,
question: str,
correct_answer: str,
user_answer: str,
analysis: str = "",
db: Any = None # 数据库会话,用于记录 AI 调用日志
) -> JudgeResult:
"""
判断答案是否正确
Args:
question: 题目内容
correct_answer: 标准答案
user_answer: 用户答案
analysis: 答案解析(可选)
db: 数据库会话,用于记录调用日志(符合 AI 接入规范)
Returns:
JudgeResult 判断结果
"""
try:
logger.info(
f"开始判断答案 - question: {question[:50]}..., "
f"user_answer: {user_answer[:50]}..."
)
# 创建 AIService 实例(传入 db_session 用于记录调用日志)
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
# 构建提示词
user_prompt = USER_PROMPT.format(
question=question,
correct_answer=correct_answer,
user_answer=user_answer,
analysis=analysis or ""
)
# 调用 AI
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt}
]
ai_response = await ai_service.chat(
messages=messages,
temperature=0.1, # 低温度,确保输出稳定
prompt_name="answer_judge"
)
logger.info(
f"AI 判断完成 - provider: {ai_response.provider}, "
f"response: {ai_response.content}, "
f"latency: {ai_response.latency_ms}ms"
)
# 解析 AI 输出
is_correct = self._parse_judge_result(ai_response.content)
logger.info(f"答案判断结果: {is_correct}")
return JudgeResult(
is_correct=is_correct,
raw_response=ai_response.content,
ai_provider=ai_response.provider,
ai_model=ai_response.model,
ai_tokens=ai_response.total_tokens,
ai_latency_ms=ai_response.latency_ms,
)
except Exception as e:
logger.error(f"答案判断失败: {e}", exc_info=True)
# 出错时默认返回错误,保守处理
return JudgeResult(
is_correct=False,
raw_response=f"判断失败: {e}",
)
def _parse_judge_result(self, ai_output: str) -> bool:
"""
解析 AI 输出的判断结果
Args:
ai_output: AI 返回的文本
Returns:
bool: True 表示正确False 表示错误
"""
# 清洗输出
output = ai_output.strip().lower()
# 检查是否包含正确关键词
for keyword in CORRECT_KEYWORDS:
if keyword.lower() in output:
return True
# 检查是否包含错误关键词
for keyword in INCORRECT_KEYWORDS:
if keyword.lower() in output:
return False
# 无法识别时,默认返回错误(保守处理)
logger.warning(f"无法解析判断结果,默认返回错误: {ai_output}")
return False
# ==================== 全局实例 ====================
answer_judge_service = AnswerJudgeService()
# ==================== 便捷函数 ====================
async def judge_answer(
question: str,
correct_answer: str,
user_answer: str,
analysis: str = ""
) -> bool:
"""
便捷函数:判断答案是否正确
Args:
question: 题目内容
correct_answer: 标准答案
user_answer: 用户答案
analysis: 答案解析
Returns:
bool: True 表示正确False 表示错误
"""
result = await answer_judge_service.judge(
question=question,
correct_answer=correct_answer,
user_answer=user_answer,
analysis=analysis
)
return result.is_correct

View File

@@ -0,0 +1,757 @@
"""
课程对话服务 V2 - Python 原生实现
功能:
- 查询课程知识点作为知识库
- 调用 AI 进行对话
- 支持流式输出
- 多轮对话历史管理Redis 缓存)
提供稳定可靠的课程对话能力。
"""
import json
import logging
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ExternalServiceError
from .ai_service import AIService
from .prompts.course_chat_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
KNOWLEDGE_ITEM_TEMPLATE,
CONVERSATION_WINDOW_SIZE,
CONVERSATION_TTL,
MAX_KNOWLEDGE_POINTS,
MAX_KNOWLEDGE_BASE_LENGTH,
DEFAULT_CHAT_MODEL,
DEFAULT_TEMPERATURE,
)
logger = logging.getLogger(__name__)
# 会话索引 Redis key 前缀/后缀
CONVERSATION_INDEX_PREFIX = "course_chat:user:"
CONVERSATION_INDEX_SUFFIX = ":conversations"
# 会话元数据 key 前缀
CONVERSATION_META_PREFIX = "course_chat:meta:"
# 会话索引过期时间(与会话数据一致)
CONVERSATION_INDEX_TTL = CONVERSATION_TTL
class CourseChatServiceV2:
"""
课程对话服务 V2
使用 Python 原生实现。
使用示例:
```python
service = CourseChatServiceV2()
# 非流式对话
response = await service.chat(
db=db_session,
course_id=1,
query="什么是玻尿酸?",
user_id=1,
conversation_id=None
)
# 流式对话
async for chunk in service.chat_stream(
db=db_session,
course_id=1,
query="什么是玻尿酸?",
user_id=1,
conversation_id=None
):
print(chunk, end="", flush=True)
```
"""
# Redis key 前缀
CONVERSATION_KEY_PREFIX = "course_chat:conversation:"
# 模块标识
MODULE_CODE = "course_chat"
def __init__(self):
"""初始化服务AIService 在方法中动态创建,以传入 db_session"""
pass
async def chat(
self,
db: AsyncSession,
course_id: int,
query: str,
user_id: int,
conversation_id: Optional[str] = None
) -> Dict[str, Any]:
"""
与课程对话(非流式)
Args:
db: 数据库会话
course_id: 课程ID
query: 用户问题
user_id: 用户ID
conversation_id: 会话ID续接对话时传入
Returns:
包含 answer、conversation_id 等字段的字典
"""
try:
logger.info(
f"开始课程对话 V2 - course_id: {course_id}, user_id: {user_id}, "
f"conversation_id: {conversation_id}"
)
# 1. 获取课程知识点
knowledge_base = await self._get_course_knowledge(db, course_id)
if not knowledge_base:
logger.warning(f"课程 {course_id} 没有知识点,使用空知识库")
knowledge_base = "(该课程暂无知识点内容)"
# 2. 获取或创建会话ID
is_new_conversation = False
if not conversation_id:
conversation_id = self._generate_conversation_id(user_id, course_id)
is_new_conversation = True
logger.info(f"创建新会话: {conversation_id}")
# 3. 构建消息列表
messages = await self._build_messages(
knowledge_base=knowledge_base,
query=query,
user_id=user_id,
conversation_id=conversation_id
)
# 4. 创建 AIService 并调用(传入 db_session 以记录调用日志)
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
response = await ai_service.chat(
messages=messages,
model=DEFAULT_CHAT_MODEL,
temperature=DEFAULT_TEMPERATURE,
prompt_name="course_chat"
)
answer = response.content
# 5. 保存对话历史
await self._save_conversation_history(
conversation_id=conversation_id,
user_message=query,
assistant_message=answer
)
# 6. 更新会话索引
if is_new_conversation:
await self._add_to_conversation_index(user_id, conversation_id, course_id)
else:
await self._update_conversation_index(user_id, conversation_id)
logger.info(
f"课程对话完成 - course_id: {course_id}, conversation_id: {conversation_id}, "
f"provider: {response.provider}, tokens: {response.total_tokens}"
)
return {
"success": True,
"answer": answer,
"conversation_id": conversation_id,
"ai_provider": response.provider,
"ai_model": response.model,
"ai_tokens": response.total_tokens,
"ai_latency_ms": response.latency_ms,
}
except Exception as e:
logger.error(
f"课程对话失败 - course_id: {course_id}, user_id: {user_id}, error: {e}",
exc_info=True
)
raise ExternalServiceError(f"课程对话失败: {e}")
async def chat_stream(
self,
db: AsyncSession,
course_id: int,
query: str,
user_id: int,
conversation_id: Optional[str] = None
) -> AsyncGenerator[Tuple[str, Optional[str]], None]:
"""
与课程对话(流式输出)
Args:
db: 数据库会话
course_id: 课程ID
query: 用户问题
user_id: 用户ID
conversation_id: 会话ID续接对话时传入
Yields:
Tuple[str, Optional[str]]: (事件类型, 数据)
- ("conversation_started", conversation_id): 会话开始
- ("chunk", text): 文本块
- ("end", None): 结束
- ("error", message): 错误
"""
full_answer = ""
try:
logger.info(
f"开始流式课程对话 V2 - course_id: {course_id}, user_id: {user_id}, "
f"conversation_id: {conversation_id}"
)
# 1. 获取课程知识点
knowledge_base = await self._get_course_knowledge(db, course_id)
if not knowledge_base:
logger.warning(f"课程 {course_id} 没有知识点,使用空知识库")
knowledge_base = "(该课程暂无知识点内容)"
# 2. 获取或创建会话ID
is_new_conversation = False
if not conversation_id:
conversation_id = self._generate_conversation_id(user_id, course_id)
is_new_conversation = True
logger.info(f"创建新会话: {conversation_id}")
# 3. 发送会话开始事件(如果是新会话)
if is_new_conversation:
yield ("conversation_started", conversation_id)
# 4. 构建消息列表
messages = await self._build_messages(
knowledge_base=knowledge_base,
query=query,
user_id=user_id,
conversation_id=conversation_id
)
# 5. 创建 AIService 并流式调用(传入 db_session 以记录调用日志)
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
async for chunk in ai_service.chat_stream(
messages=messages,
model=DEFAULT_CHAT_MODEL,
temperature=DEFAULT_TEMPERATURE,
prompt_name="course_chat"
):
full_answer += chunk
yield ("chunk", chunk)
# 6. 发送结束事件
yield ("end", None)
# 7. 保存对话历史
await self._save_conversation_history(
conversation_id=conversation_id,
user_message=query,
assistant_message=full_answer
)
# 8. 更新会话索引
if is_new_conversation:
await self._add_to_conversation_index(user_id, conversation_id, course_id)
else:
await self._update_conversation_index(user_id, conversation_id)
logger.info(
f"流式课程对话完成 - course_id: {course_id}, conversation_id: {conversation_id}, "
f"answer_length: {len(full_answer)}"
)
except Exception as e:
logger.error(
f"流式课程对话失败 - course_id: {course_id}, user_id: {user_id}, error: {e}",
exc_info=True
)
yield ("error", str(e))
async def _get_course_knowledge(
self,
db: AsyncSession,
course_id: int
) -> str:
"""
获取课程知识点,构建知识库文本
Args:
db: 数据库会话
course_id: 课程ID
Returns:
知识库文本
"""
try:
# 查询知识点(课程知识点查询)
query = text("""
SELECT kp.name, kp.description
FROM knowledge_points kp
INNER JOIN course_materials cm ON kp.material_id = cm.id
WHERE kp.course_id = :course_id
AND kp.is_deleted = 0
AND cm.is_deleted = 0
ORDER BY kp.id
LIMIT :limit
""")
result = await db.execute(
query,
{"course_id": course_id, "limit": MAX_KNOWLEDGE_POINTS}
)
rows = result.fetchall()
if not rows:
logger.warning(f"课程 {course_id} 没有关联的知识点")
return ""
# 构建知识库文本
knowledge_items = []
total_length = 0
for row in rows:
name = row[0] or ""
description = row[1] or ""
item = KNOWLEDGE_ITEM_TEMPLATE.format(
name=name,
description=description
)
# 检查是否超过长度限制
if total_length + len(item) > MAX_KNOWLEDGE_BASE_LENGTH:
logger.warning(
f"知识库文本已达到最大长度限制 {MAX_KNOWLEDGE_BASE_LENGTH}"
f"停止添加更多知识点"
)
break
knowledge_items.append(item)
total_length += len(item)
knowledge_base = "\n".join(knowledge_items)
logger.info(
f"获取课程知识点成功 - course_id: {course_id}, "
f"count: {len(knowledge_items)}, length: {len(knowledge_base)}"
)
return knowledge_base
except Exception as e:
logger.error(f"获取课程知识点失败: {e}")
raise
async def _build_messages(
self,
knowledge_base: str,
query: str,
user_id: int,
conversation_id: str
) -> List[Dict[str, str]]:
"""
构建消息列表(包含历史对话)
Args:
knowledge_base: 知识库文本
query: 当前用户问题
user_id: 用户ID
conversation_id: 会话ID
Returns:
消息列表
"""
messages = []
# 1. 系统提示词
system_content = SYSTEM_PROMPT.format(knowledge_base=knowledge_base)
messages.append({"role": "system", "content": system_content})
# 2. 获取历史对话
history = await self._get_conversation_history(conversation_id)
# 限制历史窗口大小
if len(history) > CONVERSATION_WINDOW_SIZE * 2:
history = history[-(CONVERSATION_WINDOW_SIZE * 2):]
# 添加历史消息
messages.extend(history)
# 3. 当前用户问题
user_content = USER_PROMPT.format(query=query)
messages.append({"role": "user", "content": user_content})
logger.debug(
f"构建消息列表 - total: {len(messages)}, history: {len(history)}"
)
return messages
def _generate_conversation_id(self, user_id: int, course_id: int) -> str:
"""生成会话ID"""
unique_id = uuid.uuid4().hex[:8]
return f"conv_{user_id}_{course_id}_{unique_id}"
async def _get_conversation_history(
self,
conversation_id: str
) -> List[Dict[str, str]]:
"""
从 Redis 获取会话历史
Args:
conversation_id: 会话ID
Returns:
消息列表 [{"role": "user/assistant", "content": "..."}]
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
key = f"{self.CONVERSATION_KEY_PREFIX}{conversation_id}"
data = await redis.get(key)
if not data:
return []
history = json.loads(data)
return history
except RuntimeError:
# Redis 未初始化,返回空历史
logger.warning("Redis 未初始化,无法获取会话历史")
return []
except Exception as e:
logger.warning(f"获取会话历史失败: {e}")
return []
async def _save_conversation_history(
self,
conversation_id: str,
user_message: str,
assistant_message: str
) -> None:
"""
保存对话历史到 Redis
Args:
conversation_id: 会话ID
user_message: 用户消息
assistant_message: AI 回复
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
key = f"{self.CONVERSATION_KEY_PREFIX}{conversation_id}"
# 获取现有历史
history = await self._get_conversation_history(conversation_id)
# 添加新消息
history.append({"role": "user", "content": user_message})
history.append({"role": "assistant", "content": assistant_message})
# 限制历史长度
max_messages = CONVERSATION_WINDOW_SIZE * 2
if len(history) > max_messages:
history = history[-max_messages:]
# 保存到 Redis
await redis.setex(
key,
CONVERSATION_TTL,
json.dumps(history, ensure_ascii=False)
)
logger.debug(
f"保存会话历史成功 - conversation_id: {conversation_id}, "
f"messages: {len(history)}"
)
except RuntimeError:
# Redis 未初始化,跳过保存
logger.warning("Redis 未初始化,无法保存会话历史")
except Exception as e:
logger.warning(f"保存会话历史失败: {e}")
async def get_conversation_messages(
self,
conversation_id: str,
user_id: int
) -> List[Dict[str, Any]]:
"""
获取会话的历史消息
Args:
conversation_id: 会话ID
user_id: 用户ID用于权限验证
Returns:
消息列表
"""
# 验证会话ID是否属于该用户
if not conversation_id.startswith(f"conv_{user_id}_"):
logger.warning(
f"用户 {user_id} 尝试访问不属于自己的会话: {conversation_id}"
)
return []
history = await self._get_conversation_history(conversation_id)
# 格式化返回数据
messages = []
for i, msg in enumerate(history):
messages.append({
"id": i,
"role": msg["role"],
"content": msg["content"],
})
return messages
async def _add_to_conversation_index(
self,
user_id: int,
conversation_id: str,
course_id: int
) -> None:
"""
将会话添加到用户索引
Args:
user_id: 用户ID
conversation_id: 会话ID
course_id: 课程ID
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
# 1. 添加到用户的会话索引Sorted Setscore 为时间戳)
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
timestamp = time.time()
await redis.zadd(index_key, {conversation_id: timestamp})
await redis.expire(index_key, CONVERSATION_INDEX_TTL)
# 2. 保存会话元数据
meta_key = f"{CONVERSATION_META_PREFIX}{conversation_id}"
meta_data = {
"conversation_id": conversation_id,
"user_id": user_id,
"course_id": course_id,
"created_at": timestamp,
"updated_at": timestamp,
}
await redis.setex(
meta_key,
CONVERSATION_INDEX_TTL,
json.dumps(meta_data, ensure_ascii=False)
)
logger.debug(
f"会话已添加到索引 - user_id: {user_id}, conversation_id: {conversation_id}"
)
except RuntimeError:
logger.warning("Redis 未初始化,无法添加会话索引")
except Exception as e:
logger.warning(f"添加会话索引失败: {e}")
async def _update_conversation_index(
self,
user_id: int,
conversation_id: str
) -> None:
"""
更新会话的最后活跃时间
Args:
user_id: 用户ID
conversation_id: 会话ID
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
# 更新索引中的时间戳
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
timestamp = time.time()
await redis.zadd(index_key, {conversation_id: timestamp})
await redis.expire(index_key, CONVERSATION_INDEX_TTL)
# 更新元数据中的 updated_at
meta_key = f"{CONVERSATION_META_PREFIX}{conversation_id}"
meta_data = await redis.get(meta_key)
if meta_data:
meta = json.loads(meta_data)
meta["updated_at"] = timestamp
await redis.setex(
meta_key,
CONVERSATION_INDEX_TTL,
json.dumps(meta, ensure_ascii=False)
)
logger.debug(
f"会话索引已更新 - user_id: {user_id}, conversation_id: {conversation_id}"
)
except RuntimeError:
logger.warning("Redis 未初始化,无法更新会话索引")
except Exception as e:
logger.warning(f"更新会话索引失败: {e}")
async def list_user_conversations(
self,
user_id: int,
limit: int = 20
) -> List[Dict[str, Any]]:
"""
获取用户的会话列表
Args:
user_id: 用户ID
limit: 返回数量限制
Returns:
会话列表,按更新时间倒序
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
# 1. 从索引获取最近的会话ID列表倒序
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
conversation_ids = await redis.zrevrange(index_key, 0, limit - 1)
if not conversation_ids:
logger.debug(f"用户 {user_id} 没有会话记录")
return []
# 2. 获取每个会话的元数据和最后消息
conversations = []
for conv_id in conversation_ids:
# 确保是字符串
if isinstance(conv_id, bytes):
conv_id = conv_id.decode('utf-8')
# 获取元数据
meta_key = f"{CONVERSATION_META_PREFIX}{conv_id}"
meta_data = await redis.get(meta_key)
if meta_data:
if isinstance(meta_data, bytes):
meta_data = meta_data.decode('utf-8')
meta = json.loads(meta_data)
else:
# 从 conversation_id 解析 course_id
# 格式: conv_{user_id}_{course_id}_{uuid}
parts = conv_id.split('_')
course_id = int(parts[2]) if len(parts) >= 3 else 0
meta = {
"conversation_id": conv_id,
"user_id": user_id,
"course_id": course_id,
"created_at": time.time(),
"updated_at": time.time(),
}
# 获取最后一条消息作为预览
history = await self._get_conversation_history(conv_id)
last_message = ""
if history:
# 获取最后一条 assistant 消息
for msg in reversed(history):
if msg["role"] == "assistant":
last_message = msg["content"][:100] # 截取前100字符
if len(msg["content"]) > 100:
last_message += "..."
break
conversations.append({
"id": conv_id,
"course_id": meta.get("course_id"),
"created_at": meta.get("created_at"),
"updated_at": meta.get("updated_at"),
"last_message": last_message,
"message_count": len(history),
})
logger.info(f"获取用户会话列表 - user_id: {user_id}, count: {len(conversations)}")
return conversations
except RuntimeError:
logger.warning("Redis 未初始化,无法获取会话列表")
return []
except Exception as e:
logger.warning(f"获取会话列表失败: {e}")
return []
# 别名方法,供 API 层调用
async def get_conversations(
self,
user_id: int,
course_id: Optional[int] = None,
limit: int = 20
) -> List[Dict[str, Any]]:
"""
获取用户的会话列表(别名方法)
Args:
user_id: 用户ID
course_id: 课程ID可选用于过滤
limit: 返回数量限制
Returns:
会话列表
"""
conversations = await self.list_user_conversations(user_id, limit)
# 如果指定了 course_id进行过滤
if course_id is not None:
conversations = [
c for c in conversations
if c.get("course_id") == course_id
]
return conversations
async def get_messages(
self,
conversation_id: str,
user_id: int,
limit: int = 50
) -> List[Dict[str, Any]]:
"""
获取会话历史消息(别名方法)
Args:
conversation_id: 会话ID
user_id: 用户ID用于权限验证
limit: 返回数量限制
Returns:
消息列表
"""
messages = await self.get_conversation_messages(conversation_id, limit)
return messages
# 创建全局实例
course_chat_service_v2 = CourseChatServiceV2()

View File

@@ -0,0 +1,61 @@
"""
Coze AI 服务模块
"""
from .client import get_coze_client, get_auth_manager, get_bot_config, get_workspace_id
from .service import get_coze_service, CozeService
from .models import (
SessionType,
MessageRole,
ContentType,
StreamEventType,
CozeSession,
CozeMessage,
StreamEvent,
CreateSessionRequest,
CreateSessionResponse,
SendMessageRequest,
EndSessionRequest,
EndSessionResponse,
)
from .exceptions import (
CozeException,
CozeAuthError,
CozeAPIError,
CozeRateLimitError,
CozeTimeoutError,
CozeStreamError,
map_coze_error_to_exception,
)
__all__ = [
# Client
"get_coze_client",
"get_auth_manager",
"get_bot_config",
"get_workspace_id",
# Service
"get_coze_service",
"CozeService",
# Models
"SessionType",
"MessageRole",
"ContentType",
"StreamEventType",
"CozeSession",
"CozeMessage",
"StreamEvent",
"CreateSessionRequest",
"CreateSessionResponse",
"SendMessageRequest",
"EndSessionRequest",
"EndSessionResponse",
# Exceptions
"CozeException",
"CozeAuthError",
"CozeAPIError",
"CozeRateLimitError",
"CozeTimeoutError",
"CozeStreamError",
"map_coze_error_to_exception",
]

View File

@@ -0,0 +1,203 @@
"""
Coze AI 客户端管理
负责管理 Coze API 的认证和客户端实例
"""
from functools import lru_cache
from typing import Optional, Dict, Any
import logging
from pathlib import Path
from cozepy import Coze, TokenAuth, JWTAuth, COZE_CN_BASE_URL
from app.core.config import get_settings
logger = logging.getLogger(__name__)
class CozeAuthManager:
"""Coze 认证管理器"""
def __init__(self):
self.settings = get_settings()
self._client: Optional[Coze] = None
def _create_pat_auth(self) -> TokenAuth:
"""创建个人访问令牌认证"""
if not self.settings.COZE_API_TOKEN:
raise ValueError("COZE_API_TOKEN 未配置")
return TokenAuth(token=self.settings.COZE_API_TOKEN)
def _create_oauth_auth(self) -> JWTAuth:
"""创建 OAuth 认证"""
if not all(
[
self.settings.COZE_OAUTH_CLIENT_ID,
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
]
):
raise ValueError("OAuth 配置不完整")
# 读取私钥
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
if not private_key_path.exists():
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
with open(private_key_path, "r") as f:
private_key = f.read()
try:
return JWTAuth(
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
private_key=private_key,
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
)
except Exception as e:
logger.error(f"创建 OAuth 认证失败: {e}")
raise
def get_client(self, force_new: bool = False) -> Coze:
"""
获取 Coze 客户端实例
Args:
force_new: 是否强制创建新客户端用于长时间运行的请求避免token过期
认证优先级:
1. OAuth推荐配置完整时使用自动刷新token
2. PAT仅当OAuth未配置时使用注意PAT会过期
"""
if self._client is not None and not force_new:
return self._client
auth = None
auth_type = None
# 检查 OAuth 配置是否完整
oauth_configured = all([
self.settings.COZE_OAUTH_CLIENT_ID,
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
])
if oauth_configured:
# OAuth 配置完整,必须使用 OAuth不fallback到PAT
try:
auth = self._create_oauth_auth()
auth_type = "OAuth"
logger.info("使用 OAuth 认证")
except Exception as e:
# OAuth 配置完整但创建失败直接抛出异常不fallback到可能过期的PAT
logger.error(f"OAuth 认证创建失败: {e}")
raise ValueError(f"OAuth 认证失败,请检查私钥文件和配置: {e}")
else:
# OAuth 未配置,使用 PAT
if self.settings.COZE_API_TOKEN:
auth = self._create_pat_auth()
auth_type = "PAT"
logger.warning("使用 PAT 认证注意PAT会过期建议配置OAuth")
else:
raise ValueError("Coze 认证未配置:需要配置 OAuth 或 PAT Token")
# 创建客户端
client = Coze(
auth=auth, base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL
)
logger.debug(f"Coze客户端创建成功认证方式: {auth_type}, force_new: {force_new}")
# 只有非强制创建时才缓存
if not force_new:
self._client = client
return client
def reset(self):
"""重置客户端实例"""
self._client = None
def get_oauth_token(self) -> str:
"""
获取OAuth JWT Token用于前端直连
Returns:
JWT token字符串
"""
if not all([
self.settings.COZE_OAUTH_CLIENT_ID,
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
]):
raise ValueError("OAuth 配置不完整")
# 读取私钥
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
if not private_key_path.exists():
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
with open(private_key_path, "r") as f:
private_key = f.read()
# 创建JWTAuth实例必须指定中国区base_url
jwt_auth = JWTAuth(
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
private_key=private_key,
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
)
# 获取tokenJWTAuth内部会自动生成
# JWTAuth.token属性返回已签名的JWT
return jwt_auth.token
@lru_cache()
def get_auth_manager() -> CozeAuthManager:
"""获取认证管理器单例"""
return CozeAuthManager()
def get_coze_client(force_new: bool = False) -> Coze:
"""
获取 Coze 客户端
Args:
force_new: 是否强制创建新客户端(用于工作流等长时间运行的请求)
"""
return get_auth_manager().get_client(force_new=force_new)
def get_workspace_id() -> str:
"""获取工作空间 ID"""
settings = get_settings()
if not settings.COZE_WORKSPACE_ID:
raise ValueError("COZE_WORKSPACE_ID 未配置")
return settings.COZE_WORKSPACE_ID
def get_bot_config(session_type: str) -> Dict[str, Any]:
"""
根据会话类型获取 Bot 配置
Args:
session_type: 会话类型 (course_chat 或 training)
Returns:
包含 bot_id 等配置的字典
"""
settings = get_settings()
if session_type == "course_chat":
bot_id = settings.COZE_CHAT_BOT_ID
if not bot_id:
raise ValueError("COZE_CHAT_BOT_ID 未配置")
elif session_type == "training":
bot_id = settings.COZE_TRAINING_BOT_ID
if not bot_id:
raise ValueError("COZE_TRAINING_BOT_ID 未配置")
else:
raise ValueError(f"不支持的会话类型: {session_type}")
return {"bot_id": bot_id, "workspace_id": settings.COZE_WORKSPACE_ID}

View File

@@ -0,0 +1,44 @@
"""Coze客户端临时模拟等Agent-Coze实现后替换"""
import logging
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
class CozeClient:
"""
Coze客户端模拟类
TODO: 等Agent-Coze模块实现后这个类将被真实的Coze网关客户端替换
"""
async def create_conversation(
self, bot_id: str, user_id: str, meta_data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""创建会话(模拟)"""
logger.info(f"模拟创建Coze会话: bot_id={bot_id}, user_id={user_id}")
# 返回模拟的会话信息
return {
"conversation_id": f"mock_conversation_{user_id}_{bot_id[:8]}",
"bot_id": bot_id,
"status": "active",
}
async def send_message(
self, conversation_id: str, content: str, message_type: str = "text"
) -> Dict[str, Any]:
"""发送消息(模拟)"""
logger.info(f"模拟发送消息到会话 {conversation_id}: {content[:50]}...")
# 返回模拟的消息响应
return {
"message_id": f"mock_msg_{conversation_id[:8]}",
"content": f"这是对'{content[:30]}...'的模拟回复",
"role": "assistant",
}
async def end_conversation(self, conversation_id: str) -> Dict[str, Any]:
"""结束会话(模拟)"""
logger.info(f"模拟结束会话: {conversation_id}")
return {"status": "completed", "conversation_id": conversation_id}

View File

@@ -0,0 +1,101 @@
"""
Coze 服务异常定义
"""
from typing import Optional, Dict, Any
class CozeException(Exception):
"""Coze 服务基础异常"""
def __init__(
self,
message: str,
code: Optional[str] = None,
status_code: Optional[int] = None,
details: Optional[Dict[str, Any]] = None,
):
super().__init__(message)
self.message = message
self.code = code
self.status_code = status_code
self.details = details or {}
class CozeAuthError(CozeException):
"""认证异常"""
pass
class CozeAPIError(CozeException):
"""API 调用异常"""
pass
class CozeRateLimitError(CozeException):
"""速率限制异常"""
pass
class CozeTimeoutError(CozeException):
"""超时异常"""
pass
class CozeStreamError(CozeException):
"""流式响应异常"""
pass
def map_coze_error_to_exception(error: Exception) -> CozeException:
"""
将 Coze SDK 错误映射为统一异常
Args:
error: 原始异常
Returns:
CozeException: 映射后的异常
"""
error_message = str(error)
# 根据错误消息判断错误类型
if (
"authentication" in error_message.lower()
or "unauthorized" in error_message.lower()
):
return CozeAuthError(
message="Coze 认证失败",
code="COZE_AUTH_ERROR",
status_code=401,
details={"original_error": error_message},
)
if "rate limit" in error_message.lower():
return CozeRateLimitError(
message="Coze API 速率限制",
code="COZE_RATE_LIMIT",
status_code=429,
details={"original_error": error_message},
)
if "timeout" in error_message.lower():
return CozeTimeoutError(
message="Coze API 调用超时",
code="COZE_TIMEOUT",
status_code=504,
details={"original_error": error_message},
)
# 默认映射为 API 错误
return CozeAPIError(
message="Coze API 调用失败",
code="COZE_API_ERROR",
status_code=500,
details={"original_error": error_message},
)

View File

@@ -0,0 +1,136 @@
"""
Coze 服务数据模型
"""
from typing import Optional, List, Dict, Any, Literal
from datetime import datetime
from pydantic import BaseModel, Field
from enum import Enum
class SessionType(str, Enum):
"""会话类型"""
COURSE_CHAT = "course_chat" # 课程对话
TRAINING = "training" # 陪练会话
EXAM = "exam" # 考试会话
class MessageRole(str, Enum):
"""消息角色"""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class ContentType(str, Enum):
"""内容类型"""
TEXT = "text"
CARD = "card"
IMAGE = "image"
FILE = "file"
class StreamEventType(str, Enum):
"""流式事件类型"""
MESSAGE_START = "conversation.message.start"
MESSAGE_DELTA = "conversation.message.delta"
MESSAGE_COMPLETED = "conversation.message.completed"
ERROR = "error"
DONE = "done"
class CozeSession(BaseModel):
"""Coze 会话模型"""
session_id: str = Field(..., description="会话ID")
conversation_id: str = Field(..., description="Coze对话ID")
session_type: SessionType = Field(..., description="会话类型")
user_id: str = Field(..., description="用户ID")
bot_id: str = Field(..., description="Bot ID")
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
ended_at: Optional[datetime] = Field(None, description="结束时间")
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
class Config:
json_encoders = {datetime: lambda v: v.isoformat()}
class CozeMessage(BaseModel):
"""Coze 消息模型"""
message_id: str = Field(..., description="消息ID")
session_id: str = Field(..., description="会话ID")
role: MessageRole = Field(..., description="消息角色")
content: str = Field(..., description="消息内容")
content_type: ContentType = Field(ContentType.TEXT, description="内容类型")
created_at: datetime = Field(default_factory=datetime.now, description="创建时间")
metadata: Dict[str, Any] = Field(default_factory=dict, description="元数据")
class Config:
json_encoders = {datetime: lambda v: v.isoformat()}
class StreamEvent(BaseModel):
"""流式事件模型"""
event: StreamEventType = Field(..., description="事件类型")
data: Dict[str, Any] = Field(..., description="事件数据")
message_id: Optional[str] = Field(None, description="消息ID")
content: Optional[str] = Field(None, description="内容")
content_type: Optional[ContentType] = Field(None, description="内容类型")
role: Optional[MessageRole] = Field(None, description="角色")
error: Optional[str] = Field(None, description="错误信息")
class CreateSessionRequest(BaseModel):
"""创建会话请求"""
session_type: SessionType = Field(..., description="会话类型")
user_id: str = Field(..., description="用户ID")
course_id: Optional[str] = Field(None, description="课程ID (课程对话时必需)")
training_topic: Optional[str] = Field(None, description="陪练主题 (陪练时可选)")
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
class CreateSessionResponse(BaseModel):
"""创建会话响应"""
session_id: str = Field(..., description="会话ID")
conversation_id: str = Field(..., description="Coze对话ID")
bot_id: str = Field(..., description="Bot ID")
created_at: datetime = Field(..., description="创建时间")
class Config:
json_encoders = {datetime: lambda v: v.isoformat()}
class SendMessageRequest(BaseModel):
"""发送消息请求"""
session_id: str = Field(..., description="会话ID")
content: str = Field(..., description="消息内容")
file_ids: List[str] = Field(default_factory=list, description="附件ID列表")
stream: bool = Field(True, description="是否流式响应")
class EndSessionRequest(BaseModel):
"""结束会话请求"""
reason: Optional[str] = Field(None, description="结束原因")
feedback: Optional[Dict[str, Any]] = Field(None, description="用户反馈")
class EndSessionResponse(BaseModel):
"""结束会话响应"""
session_id: str = Field(..., description="会话ID")
ended_at: datetime = Field(..., description="结束时间")
duration_seconds: int = Field(..., description="会话时长(秒)")
message_count: int = Field(..., description="消息数量")
class Config:
json_encoders = {datetime: lambda v: v.isoformat()}

View File

@@ -0,0 +1,335 @@
"""
Coze 服务层实现
处理会话管理、消息发送、流式响应等核心功能
"""
import asyncio
import json
import logging
import uuid
from typing import AsyncIterator, Dict, Any, List, Optional
from datetime import datetime
from cozepy import ChatEventType, Message, MessageContentType
from .client import get_coze_client, get_bot_config, get_workspace_id
from .models import (
CozeSession,
CozeMessage,
StreamEvent,
SessionType,
MessageRole,
ContentType,
StreamEventType,
CreateSessionRequest,
CreateSessionResponse,
SendMessageRequest,
EndSessionRequest,
EndSessionResponse,
)
from .exceptions import (
CozeAPIError,
CozeStreamError,
CozeTimeoutError,
map_coze_error_to_exception,
)
logger = logging.getLogger(__name__)
class CozeService:
"""Coze 服务类"""
def __init__(self):
self.client = get_coze_client()
self.bot_config = get_bot_config()
self.workspace_id = get_workspace_id()
# 内存中的会话存储(生产环境应使用 Redis
self._sessions: Dict[str, CozeSession] = {}
self._messages: Dict[str, List[CozeMessage]] = {}
async def create_session(
self, request: CreateSessionRequest
) -> CreateSessionResponse:
"""
创建新会话
Args:
request: 创建会话请求
Returns:
CreateSessionResponse: 会话信息
"""
try:
# 根据会话类型选择 Bot
bot_id = self._get_bot_id_by_type(request.session_type)
# 创建 Coze 对话
conversation = await asyncio.to_thread(
self.client.conversations.create, bot_id=bot_id
)
# 创建本地会话记录
session = CozeSession(
session_id=str(uuid.uuid4()),
conversation_id=conversation.id,
session_type=request.session_type,
user_id=request.user_id,
bot_id=bot_id,
metadata=request.metadata,
)
# 保存会话
self._sessions[session.session_id] = session
self._messages[session.session_id] = []
logger.info(
f"创建会话成功",
extra={
"session_id": session.session_id,
"conversation_id": conversation.id,
"session_type": request.session_type.value,
"user_id": request.user_id,
},
)
return CreateSessionResponse(
session_id=session.session_id,
conversation_id=session.conversation_id,
bot_id=session.bot_id,
created_at=session.created_at,
)
except Exception as e:
logger.error(f"创建会话失败: {e}", exc_info=True)
raise map_coze_error_to_exception(e)
async def send_message(
self, request: SendMessageRequest
) -> AsyncIterator[StreamEvent]:
"""
发送消息并处理流式响应
Args:
request: 发送消息请求
Yields:
StreamEvent: 流式事件
"""
session = self._get_session(request.session_id)
if not session:
raise CozeAPIError(f"会话不存在: {request.session_id}")
# 记录用户消息
user_message = CozeMessage(
message_id=str(uuid.uuid4()),
session_id=session.session_id,
role=MessageRole.USER,
content=request.content,
)
self._messages[session.session_id].append(user_message)
try:
# 构建消息历史
messages = self._build_message_history(session.session_id)
# 调用 Coze API
stream = await asyncio.to_thread(
self.client.chat.stream,
bot_id=session.bot_id,
conversation_id=session.conversation_id,
additional_messages=messages,
auto_save_history=True,
)
# 处理流式响应
async for event in self._process_stream(stream, session.session_id):
yield event
except asyncio.TimeoutError:
logger.error(f"消息发送超时: session_id={request.session_id}")
raise CozeTimeoutError("消息处理超时")
except Exception as e:
logger.error(f"发送消息失败: {e}", exc_info=True)
raise map_coze_error_to_exception(e)
async def end_session(
self, session_id: str, request: EndSessionRequest
) -> EndSessionResponse:
"""
结束会话
Args:
session_id: 会话ID
request: 结束会话请求
Returns:
EndSessionResponse: 结束会话响应
"""
session = self._get_session(session_id)
if not session:
raise CozeAPIError(f"会话不存在: {session_id}")
# 更新会话状态
session.ended_at = datetime.now()
# 计算会话统计
duration_seconds = int((session.ended_at - session.created_at).total_seconds())
message_count = len(self._messages.get(session_id, []))
# 记录结束原因和反馈
if request.reason:
session.metadata["end_reason"] = request.reason
if request.feedback:
session.metadata["feedback"] = request.feedback
logger.info(
f"会话结束",
extra={
"session_id": session_id,
"duration_seconds": duration_seconds,
"message_count": message_count,
"reason": request.reason,
},
)
return EndSessionResponse(
session_id=session_id,
ended_at=session.ended_at,
duration_seconds=duration_seconds,
message_count=message_count,
)
async def get_session_messages(
self, session_id: str, limit: int = 50, offset: int = 0
) -> List[CozeMessage]:
"""获取会话消息历史"""
messages = self._messages.get(session_id, [])
return messages[offset : offset + limit]
def _get_bot_id_by_type(self, session_type: SessionType) -> str:
"""根据会话类型获取 Bot ID"""
mapping = {
SessionType.COURSE_CHAT: self.bot_config["course_chat"],
SessionType.TRAINING: self.bot_config["training"],
SessionType.EXAM: self.bot_config["exam"],
}
return mapping.get(session_type, self.bot_config["training"])
def _get_session(self, session_id: str) -> Optional[CozeSession]:
"""获取会话"""
return self._sessions.get(session_id)
def _build_message_history(self, session_id: str) -> List[Message]:
"""构建消息历史"""
messages = self._messages.get(session_id, [])
history = []
for msg in messages[-10:]: # 只发送最近10条消息作为上下文
history.append(
Message(
role=msg.role.value,
content=msg.content,
content_type=MessageContentType.TEXT,
)
)
return history
async def _process_stream(
self, stream, session_id: str
) -> AsyncIterator[StreamEvent]:
"""处理流式响应"""
assistant_message_id = str(uuid.uuid4())
accumulated_content = []
content_type = ContentType.TEXT
try:
for event in stream:
if event.event == ChatEventType.CONVERSATION_MESSAGE_DELTA:
# 消息片段
content = event.message.content
accumulated_content.append(content)
# 检测卡片类型
if (
hasattr(event.message, "content_type")
and event.message.content_type == "card"
):
content_type = ContentType.CARD
yield StreamEvent(
event=StreamEventType.MESSAGE_DELTA,
data={
"conversation_id": event.conversation_id,
"message_id": assistant_message_id,
"content": content,
"content_type": content_type.value,
},
message_id=assistant_message_id,
content=content,
content_type=content_type,
role=MessageRole.ASSISTANT,
)
elif event.event == ChatEventType.CONVERSATION_MESSAGE_COMPLETED:
# 消息完成
full_content = "".join(accumulated_content)
# 保存助手消息
assistant_message = CozeMessage(
message_id=assistant_message_id,
session_id=session_id,
role=MessageRole.ASSISTANT,
content=full_content,
content_type=content_type,
)
self._messages[session_id].append(assistant_message)
yield StreamEvent(
event=StreamEventType.MESSAGE_COMPLETED,
data={
"conversation_id": event.conversation_id,
"message_id": assistant_message_id,
"content": full_content,
"content_type": content_type.value,
"usage": getattr(event, "usage", {}),
},
message_id=assistant_message_id,
content=full_content,
content_type=content_type,
role=MessageRole.ASSISTANT,
)
elif event.event == ChatEventType.ERROR:
# 错误事件
yield StreamEvent(
event=StreamEventType.ERROR,
data={"error": str(event)},
error=str(event),
)
except Exception as e:
logger.error(f"流式处理错误: {e}", exc_info=True)
yield StreamEvent(
event=StreamEventType.ERROR, data={"error": str(e)}, error=str(e)
)
finally:
# 发送结束事件
yield StreamEvent(
event=StreamEventType.DONE, data={"session_id": session_id}
)
# 全局服务实例
_service: Optional[CozeService] = None
def get_coze_service() -> CozeService:
"""获取 Coze 服务单例"""
global _service
if _service is None:
_service = CozeService()
return _service

View File

@@ -0,0 +1,512 @@
"""
试题生成服务 V2 - Python 原生实现
功能:
- 根据岗位和知识点动态生成考试题目
- 支持错题重出模式
- 调用 AI 生成并解析 JSON 结果
提供稳定可靠的试题生成能力。
"""
import json
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ExternalServiceError
from .ai_service import AIService, AIResponse
from .llm_json_parser import parse_with_fallback, clean_llm_output
from .prompts.exam_generator_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
MISTAKE_REGEN_SYSTEM_PROMPT,
MISTAKE_REGEN_USER_PROMPT,
QUESTION_SCHEMA,
DEFAULT_QUESTION_COUNTS,
DEFAULT_DIFFICULTY_LEVEL,
)
logger = logging.getLogger(__name__)
@dataclass
class ExamGeneratorConfig:
"""考试生成配置"""
course_id: int
position_id: int
single_choice_count: int = DEFAULT_QUESTION_COUNTS["single_choice_count"]
multiple_choice_count: int = DEFAULT_QUESTION_COUNTS["multiple_choice_count"]
true_false_count: int = DEFAULT_QUESTION_COUNTS["true_false_count"]
fill_blank_count: int = DEFAULT_QUESTION_COUNTS["fill_blank_count"]
essay_count: int = DEFAULT_QUESTION_COUNTS["essay_count"]
difficulty_level: int = DEFAULT_DIFFICULTY_LEVEL
mistake_records: str = ""
@property
def total_count(self) -> int:
"""计算总题量"""
return (
self.single_choice_count +
self.multiple_choice_count +
self.true_false_count +
self.fill_blank_count +
self.essay_count
)
@property
def has_mistakes(self) -> bool:
"""是否有错题记录"""
return bool(self.mistake_records and self.mistake_records.strip())
class ExamGeneratorService:
"""
试题生成服务 V2
使用 Python 原生实现。
使用示例:
```python
service = ExamGeneratorService()
result = await service.generate_exam(
db=db_session,
config=ExamGeneratorConfig(
course_id=1,
position_id=1,
single_choice_count=5,
multiple_choice_count=3,
difficulty_level=3
)
)
```
"""
def __init__(self):
"""初始化服务"""
self.ai_service = AIService(module_code="exam_generator")
async def generate_exam(
self,
db: AsyncSession,
config: ExamGeneratorConfig
) -> Dict[str, Any]:
"""
生成考试题目(主入口)
Args:
db: 数据库会话
config: 考试生成配置
Returns:
生成结果,包含 success、questions、total_count 等字段
"""
try:
logger.info(
f"开始生成试题 - course_id: {config.course_id}, position_id: {config.position_id}, "
f"total_count: {config.total_count}, has_mistakes: {config.has_mistakes}"
)
# 根据是否有错题记录,走不同分支
if config.has_mistakes:
return await self._regenerate_from_mistakes(db, config)
else:
return await self._generate_from_knowledge(db, config)
except ExternalServiceError:
raise
except Exception as e:
logger.error(
f"试题生成失败 - course_id: {config.course_id}, error: {e}",
exc_info=True
)
raise ExternalServiceError(f"试题生成失败: {e}")
async def _generate_from_knowledge(
self,
db: AsyncSession,
config: ExamGeneratorConfig
) -> Dict[str, Any]:
"""
基于知识点生成题目(无错题模式)
流程:
1. 查询岗位信息
2. 随机查询知识点
3. 调用 AI 生成题目
4. 解析并返回结果
"""
# 1. 查询岗位信息
position_info = await self._query_position(db, config.position_id)
if not position_info:
raise ExternalServiceError(f"岗位不存在: position_id={config.position_id}")
logger.info(f"岗位信息: {position_info.get('name', 'unknown')}")
# 2. 随机查询知识点
knowledge_points = await self._query_knowledge_points(
db,
config.course_id,
config.total_count
)
if not knowledge_points:
raise ExternalServiceError(
f"课程没有可用的知识点: course_id={config.course_id}"
)
logger.info(f"查询到 {len(knowledge_points)} 个知识点")
# 3. 构建提示词
system_prompt = SYSTEM_PROMPT.format(
total_count=config.total_count,
single_choice_count=config.single_choice_count,
multiple_choice_count=config.multiple_choice_count,
true_false_count=config.true_false_count,
fill_blank_count=config.fill_blank_count,
essay_count=config.essay_count,
difficulty_level=config.difficulty_level,
)
user_prompt = USER_PROMPT.format(
position_info=self._format_position_info(position_info),
knowledge_points=self._format_knowledge_points(knowledge_points),
)
# 4. 调用 AI 生成
ai_response = await self._call_ai_generate(system_prompt, user_prompt)
logger.info(
f"AI 生成完成 - provider: {ai_response.provider}, "
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
)
# 5. 解析题目
questions = self._parse_questions(ai_response.content)
logger.info(f"试题解析成功,数量: {len(questions)}")
return {
"success": True,
"questions": questions,
"total_count": len(questions),
"mode": "knowledge_based",
"ai_provider": ai_response.provider,
"ai_model": ai_response.model,
"ai_tokens": ai_response.total_tokens,
"ai_latency_ms": ai_response.latency_ms,
}
async def _regenerate_from_mistakes(
self,
db: AsyncSession,
config: ExamGeneratorConfig
) -> Dict[str, Any]:
"""
错题重出模式
流程:
1. 构建错题重出提示词
2. 调用 AI 生成新题
3. 解析并返回结果
"""
logger.info("进入错题重出模式")
# 1. 构建提示词
system_prompt = MISTAKE_REGEN_SYSTEM_PROMPT.format(
difficulty_level=config.difficulty_level,
)
user_prompt = MISTAKE_REGEN_USER_PROMPT.format(
mistake_records=config.mistake_records,
)
# 2. 调用 AI 生成
ai_response = await self._call_ai_generate(system_prompt, user_prompt)
logger.info(
f"错题重出完成 - provider: {ai_response.provider}, "
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
)
# 3. 解析题目
questions = self._parse_questions(ai_response.content)
logger.info(f"错题重出解析成功,数量: {len(questions)}")
return {
"success": True,
"questions": questions,
"total_count": len(questions),
"mode": "mistake_regen",
"ai_provider": ai_response.provider,
"ai_model": ai_response.model,
"ai_tokens": ai_response.total_tokens,
"ai_latency_ms": ai_response.latency_ms,
}
async def _query_position(
self,
db: AsyncSession,
position_id: int
) -> Optional[Dict[str, Any]]:
"""
查询岗位信息
SQLSELECT id, name, description, skills, level FROM positions
WHERE id = :id AND is_deleted = FALSE
"""
try:
result = await db.execute(
text("""
SELECT id, name, description, skills, level
FROM positions
WHERE id = :position_id AND is_deleted = FALSE
"""),
{"position_id": position_id}
)
row = result.fetchone()
if not row:
return None
# 将 Row 转换为字典
return {
"id": row[0],
"name": row[1],
"description": row[2],
"skills": row[3], # JSON 字段
"level": row[4],
}
except Exception as e:
logger.error(f"查询岗位信息失败: {e}")
raise ExternalServiceError(f"查询岗位信息失败: {e}")
async def _query_knowledge_points(
self,
db: AsyncSession,
course_id: int,
limit: int
) -> List[Dict[str, Any]]:
"""
随机查询知识点
SQLSELECT kp.id, kp.name, kp.description, kp.topic_relation
FROM knowledge_points kp
INNER JOIN course_materials cm ON kp.material_id = cm.id
WHERE kp.course_id = :course_id
AND kp.is_deleted = FALSE
AND cm.is_deleted = FALSE
ORDER BY RAND()
LIMIT :limit
"""
try:
result = await db.execute(
text("""
SELECT kp.id, kp.name, kp.description, kp.topic_relation
FROM knowledge_points kp
INNER JOIN course_materials cm ON kp.material_id = cm.id
WHERE kp.course_id = :course_id
AND kp.is_deleted = FALSE
AND cm.is_deleted = FALSE
ORDER BY RAND()
LIMIT :limit
"""),
{"course_id": course_id, "limit": limit}
)
rows = result.fetchall()
return [
{
"id": row[0],
"name": row[1],
"description": row[2],
"topic_relation": row[3],
}
for row in rows
]
except Exception as e:
logger.error(f"查询知识点失败: {e}")
raise ExternalServiceError(f"查询知识点失败: {e}")
async def _call_ai_generate(
self,
system_prompt: str,
user_prompt: str
) -> AIResponse:
"""调用 AI 生成题目"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
response = await self.ai_service.chat(
messages=messages,
temperature=0.7, # 适当的创造性
prompt_name="exam_generator"
)
return response
def _parse_questions(self, ai_output: str) -> List[Dict[str, Any]]:
"""
解析 AI 输出的题目 JSON
使用 LLM JSON Parser 进行多层兜底解析
"""
# 先清洗输出
cleaned_output, rules = clean_llm_output(ai_output)
if rules:
logger.debug(f"AI 输出已清洗: {rules}")
# 使用带 Schema 校验的解析
questions = parse_with_fallback(
cleaned_output,
schema=QUESTION_SCHEMA,
default=[],
validate_schema=True,
on_error="default"
)
# 后处理:确保每个题目有必要字段
processed_questions = []
for i, q in enumerate(questions):
if isinstance(q, dict):
# 确保有 num 字段
if "num" not in q:
q["num"] = i + 1
# 确保 num 是整数
try:
q["num"] = int(q["num"])
except (ValueError, TypeError):
q["num"] = i + 1
# 确保有 type 字段
if "type" not in q:
# 根据是否有 options 推断类型
if q.get("topic", {}).get("options"):
q["type"] = "single_choice"
else:
q["type"] = "essay"
# 确保 knowledge_point_id 是整数或 None
kp_id = q.get("knowledge_point_id")
if kp_id is not None:
try:
q["knowledge_point_id"] = int(kp_id)
except (ValueError, TypeError):
q["knowledge_point_id"] = None
# 验证必要字段
if q.get("topic") and q.get("correct"):
processed_questions.append(q)
else:
logger.warning(f"题目缺少必要字段,已跳过: {q}")
if not processed_questions:
logger.warning("未能解析出有效的题目")
return processed_questions
def _format_position_info(self, position: Dict[str, Any]) -> str:
"""格式化岗位信息为文本"""
lines = [
f"岗位名称: {position.get('name', '未知')}",
f"岗位等级: {position.get('level', '未设置')}",
]
if position.get('description'):
lines.append(f"岗位描述: {position['description']}")
skills = position.get('skills')
if skills:
# skills 可能是 JSON 字符串或列表
if isinstance(skills, str):
try:
skills = json.loads(skills)
except json.JSONDecodeError:
skills = [skills]
if isinstance(skills, list) and skills:
lines.append(f"核心技能: {', '.join(str(s) for s in skills)}")
return '\n'.join(lines)
def _format_knowledge_points(self, knowledge_points: List[Dict[str, Any]]) -> str:
"""格式化知识点列表为文本"""
lines = []
for kp in knowledge_points:
kp_text = f"【知识点 ID: {kp['id']}{kp['name']}"
if kp.get('description'):
kp_text += f"\n{kp['description']}"
if kp.get('topic_relation'):
kp_text += f"\n关系描述: {kp['topic_relation']}"
lines.append(kp_text)
return '\n\n'.join(lines)
# 创建全局实例
exam_generator_service = ExamGeneratorService()
# ==================== 便捷函数 ====================
async def generate_exam(
db: AsyncSession,
course_id: int,
position_id: int,
single_choice_count: int = 4,
multiple_choice_count: int = 2,
true_false_count: int = 1,
fill_blank_count: int = 2,
essay_count: int = 1,
difficulty_level: int = 3,
mistake_records: str = ""
) -> Dict[str, Any]:
"""
便捷函数:生成考试题目
Args:
db: 数据库会话
course_id: 课程ID
position_id: 岗位ID
single_choice_count: 单选题数量
multiple_choice_count: 多选题数量
true_false_count: 判断题数量
fill_blank_count: 填空题数量
essay_count: 问答题数量
difficulty_level: 难度等级(1-5)
mistake_records: 错题记录JSON字符串
Returns:
生成结果
"""
config = ExamGeneratorConfig(
course_id=course_id,
position_id=position_id,
single_choice_count=single_choice_count,
multiple_choice_count=multiple_choice_count,
true_false_count=true_false_count,
fill_blank_count=fill_blank_count,
essay_count=essay_count,
difficulty_level=difficulty_level,
mistake_records=mistake_records,
)
return await exam_generator_service.generate_exam(db, config)

View File

@@ -0,0 +1,548 @@
"""
知识点分析服务 V2 - Python 原生实现
功能:
- 读取文档内容PDF/Word/TXT
- 调用 AI 分析提取知识点
- 解析 JSON 结果
- 写入数据库
提供稳定可靠的知识点分析能力。
"""
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.exceptions import ExternalServiceError
from app.schemas.course import KnowledgePointCreate
from .ai_service import AIService, AIResponse
from .llm_json_parser import parse_with_fallback, clean_llm_output
from .prompts.knowledge_analysis_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
KNOWLEDGE_POINT_SCHEMA,
DEFAULT_KNOWLEDGE_TYPE,
)
logger = logging.getLogger(__name__)
# 配置常量
STATIC_UPLOADS_PREFIX = '/static/uploads/'
MAX_CONTENT_LENGTH = 100000 # 最大文档内容长度(字符)
MAX_KNOWLEDGE_POINTS = 20 # 最大知识点数量
class KnowledgeAnalysisServiceV2:
"""
知识点分析服务 V2
使用 Python 原生实现。
使用示例:
```python
service = KnowledgeAnalysisServiceV2()
result = await service.analyze_course_material(
db=db_session,
course_id=1,
material_id=10,
file_url="/static/uploads/courses/1/doc.pdf",
course_title="医美产品知识",
user_id=1
)
```
"""
def __init__(self):
"""初始化服务"""
self.ai_service = AIService(module_code="knowledge_analysis")
self.upload_path = getattr(settings, 'UPLOAD_PATH', 'uploads')
async def analyze_course_material(
self,
db: AsyncSession,
course_id: int,
material_id: int,
file_url: str,
course_title: str,
user_id: int
) -> Dict[str, Any]:
"""
分析课程资料并提取知识点
Args:
db: 数据库会话
course_id: 课程ID
material_id: 资料ID
file_url: 文件URL相对路径
course_title: 课程标题
user_id: 用户ID
Returns:
分析结果,包含 success、knowledge_points_count 等字段
"""
try:
logger.info(
f"开始知识点分析 V2 - course_id: {course_id}, material_id: {material_id}, "
f"file_url: {file_url}"
)
# 1. 解析文件路径
file_path = self._resolve_file_path(file_url)
if not file_path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}")
logger.info(f"文件路径解析成功: {file_path}")
# 2. 提取文档内容
content = await self._extract_document_content(file_path)
if not content or not content.strip():
raise ValueError("文档内容为空")
logger.info(f"文档内容提取成功,长度: {len(content)} 字符")
# 3. 调用 AI 分析
ai_response = await self._call_ai_analysis(content, course_title)
logger.info(
f"AI 分析完成 - provider: {ai_response.provider}, "
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
)
# 4. 解析 JSON 结果
knowledge_points = self._parse_knowledge_points(ai_response.content)
logger.info(f"知识点解析成功,数量: {len(knowledge_points)}")
# 5. 删除旧的知识点
await self._delete_old_knowledge_points(db, material_id)
# 6. 保存到数据库
saved_count = await self._save_knowledge_points(
db=db,
course_id=course_id,
material_id=material_id,
knowledge_points=knowledge_points,
user_id=user_id
)
logger.info(
f"知识点分析完成 - course_id: {course_id}, material_id: {material_id}, "
f"saved_count: {saved_count}"
)
return {
"success": True,
"status": "completed",
"knowledge_points_count": saved_count,
"ai_provider": ai_response.provider,
"ai_model": ai_response.model,
"ai_tokens": ai_response.total_tokens,
"ai_latency_ms": ai_response.latency_ms,
}
except FileNotFoundError as e:
logger.error(f"文件不存在: {e}")
raise ExternalServiceError(f"分析文件不存在: {e}")
except ValueError as e:
logger.error(f"参数错误: {e}")
raise ExternalServiceError(f"分析参数错误: {e}")
except Exception as e:
logger.error(
f"知识点分析失败 - course_id: {course_id}, material_id: {material_id}, "
f"error: {e}",
exc_info=True
)
raise ExternalServiceError(f"知识点分析失败: {e}")
def _resolve_file_path(self, file_url: str) -> Path:
"""解析文件 URL 为本地路径"""
if file_url.startswith(STATIC_UPLOADS_PREFIX):
relative_path = file_url.replace(STATIC_UPLOADS_PREFIX, '')
return Path(self.upload_path) / relative_path
elif file_url.startswith('/'):
# 绝对路径
return Path(file_url)
else:
# 相对路径
return Path(self.upload_path) / file_url
async def _extract_document_content(self, file_path: Path) -> str:
"""
提取文档内容
支持PDF、Worddocx、文本文件
"""
suffix = file_path.suffix.lower()
try:
if suffix == '.pdf':
return await self._extract_pdf_content(file_path)
elif suffix in ['.docx', '.doc']:
return await self._extract_docx_content(file_path)
elif suffix in ['.txt', '.md', '.text']:
return await self._extract_text_content(file_path)
else:
# 尝试作为文本读取
return await self._extract_text_content(file_path)
except Exception as e:
logger.error(f"文档内容提取失败: {file_path}, error: {e}")
raise ValueError(f"无法读取文档内容: {e}")
async def _extract_pdf_content(self, file_path: Path) -> str:
"""提取 PDF 内容"""
try:
from PyPDF2 import PdfReader
reader = PdfReader(str(file_path))
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
content = '\n'.join(text_parts)
# 清理和截断
content = self._clean_content(content)
return content
except ImportError:
logger.error("PyPDF2 未安装,无法读取 PDF")
raise ValueError("服务器未安装 PDF 读取组件")
except Exception as e:
logger.error(f"PDF 读取失败: {e}")
raise ValueError(f"PDF 读取失败: {e}")
async def _extract_docx_content(self, file_path: Path) -> str:
"""提取 Word 文档内容"""
try:
from docx import Document
doc = Document(str(file_path))
text_parts = []
for para in doc.paragraphs:
if para.text.strip():
text_parts.append(para.text)
# 也提取表格内容
for table in doc.tables:
for row in table.rows:
for cell in row.cells:
if cell.text.strip():
text_parts.append(cell.text)
content = '\n'.join(text_parts)
content = self._clean_content(content)
return content
except ImportError:
logger.error("python-docx 未安装,无法读取 Word 文档")
raise ValueError("服务器未安装 Word 读取组件")
except Exception as e:
logger.error(f"Word 文档读取失败: {e}")
raise ValueError(f"Word 文档读取失败: {e}")
async def _extract_text_content(self, file_path: Path) -> str:
"""提取文本文件内容"""
try:
# 尝试多种编码
encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1']
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
content = f.read()
return self._clean_content(content)
except UnicodeDecodeError:
continue
raise ValueError("无法识别文件编码")
except Exception as e:
logger.error(f"文本文件读取失败: {e}")
raise ValueError(f"文本文件读取失败: {e}")
def _clean_content(self, content: str) -> str:
"""清理和截断内容"""
# 移除多余空白
import re
content = re.sub(r'\n{3,}', '\n\n', content)
content = re.sub(r' {2,}', ' ', content)
# 截断过长内容
if len(content) > MAX_CONTENT_LENGTH:
logger.warning(f"文档内容过长,截断至 {MAX_CONTENT_LENGTH} 字符")
content = content[:MAX_CONTENT_LENGTH] + "\n\n[内容已截断...]"
return content.strip()
async def _call_ai_analysis(
self,
content: str,
course_title: str
) -> AIResponse:
"""调用 AI 进行知识点分析"""
# 构建消息
user_message = USER_PROMPT.format(
course_name=course_title,
content=content
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_message}
]
# 调用 AI
response = await self.ai_service.chat(
messages=messages,
temperature=0.1, # 低温度,保持输出稳定
prompt_name="knowledge_analysis"
)
return response
def _parse_knowledge_points(self, ai_output: str) -> List[Dict[str, Any]]:
"""
解析 AI 输出的知识点 JSON
使用 LLM JSON Parser 进行多层兜底解析
"""
# 先清洗输出
cleaned_output, rules = clean_llm_output(ai_output)
if rules:
logger.debug(f"AI 输出已清洗: {rules}")
# 使用带 Schema 校验的解析
knowledge_points = parse_with_fallback(
cleaned_output,
schema=KNOWLEDGE_POINT_SCHEMA,
default=[],
validate_schema=True,
on_error="default"
)
# 后处理:确保每个知识点有必要字段
processed_points = []
for i, kp in enumerate(knowledge_points):
if i >= MAX_KNOWLEDGE_POINTS:
logger.warning(f"知识点数量超过限制 {MAX_KNOWLEDGE_POINTS},截断")
break
if isinstance(kp, dict):
# 提取字段(兼容多种字段名)
title = (
kp.get('title') or
kp.get('name') or
kp.get('知识点名称') or
f"知识点 {i + 1}"
)
content = (
kp.get('content') or
kp.get('description') or
kp.get('知识点描述') or
''
)
kp_type = (
kp.get('type') or
kp.get('知识点类型') or
DEFAULT_KNOWLEDGE_TYPE
)
topic_relation = (
kp.get('topic_relation') or
kp.get('关系描述') or
''
)
if title and (content or topic_relation):
processed_points.append({
'title': title[:200], # 限制长度
'content': content,
'type': kp_type,
'topic_relation': topic_relation,
})
if not processed_points:
logger.warning("未能解析出有效的知识点")
return processed_points
async def _delete_old_knowledge_points(
self,
db: AsyncSession,
material_id: int
) -> int:
"""删除资料关联的旧知识点"""
try:
from sqlalchemy import text
result = await db.execute(
text("DELETE FROM knowledge_points WHERE material_id = :material_id"),
{"material_id": material_id}
)
await db.commit()
deleted_count = result.rowcount
if deleted_count > 0:
logger.info(f"已删除旧知识点: material_id={material_id}, count={deleted_count}")
return deleted_count
except Exception as e:
logger.error(f"删除旧知识点失败: {e}")
await db.rollback()
raise
async def _save_knowledge_points(
self,
db: AsyncSession,
course_id: int,
material_id: int,
knowledge_points: List[Dict[str, Any]],
user_id: int
) -> int:
"""保存知识点到数据库"""
from app.services.course_service import knowledge_point_service
saved_count = 0
for kp_data in knowledge_points:
try:
kp_create = KnowledgePointCreate(
name=kp_data['title'],
description=kp_data.get('content', ''),
type=kp_data.get('type', DEFAULT_KNOWLEDGE_TYPE),
source=1, # AI 分析来源
topic_relation=kp_data.get('topic_relation'),
material_id=material_id
)
await knowledge_point_service.create_knowledge_point(
db=db,
course_id=course_id,
point_in=kp_create,
created_by=user_id
)
saved_count += 1
except Exception as e:
logger.warning(
f"保存单个知识点失败: title={kp_data.get('title')}, error={e}"
)
continue
return saved_count
async def reanalyze_course_materials(
self,
db: AsyncSession,
course_id: int,
course_title: str,
user_id: int
) -> Dict[str, Any]:
"""
重新分析课程的所有资料
Args:
db: 数据库会话
course_id: 课程ID
course_title: 课程标题
user_id: 用户ID
Returns:
分析结果汇总
"""
try:
from app.services.course_service import course_service
# 获取课程的所有资料
materials = await course_service.get_course_materials(db, course_id=course_id)
if not materials:
return {
"success": True,
"message": "该课程暂无资料需要分析",
"materials_count": 0,
"knowledge_points_count": 0
}
total_knowledge_points = 0
analysis_results = []
for material in materials:
try:
result = await self.analyze_course_material(
db=db,
course_id=course_id,
material_id=material.id,
file_url=material.file_url,
course_title=course_title,
user_id=user_id
)
kp_count = result.get('knowledge_points_count', 0)
total_knowledge_points += kp_count
analysis_results.append({
"material_id": material.id,
"material_name": material.name,
"success": True,
"knowledge_points_count": kp_count
})
except Exception as e:
logger.error(
f"资料分析失败: material_id={material.id}, error={e}"
)
analysis_results.append({
"material_id": material.id,
"material_name": material.name,
"success": False,
"error": str(e)
})
success_count = sum(1 for r in analysis_results if r['success'])
logger.info(
f"课程资料重新分析完成 - course_id: {course_id}, "
f"materials: {len(materials)}, success: {success_count}, "
f"total_knowledge_points: {total_knowledge_points}"
)
return {
"success": True,
"materials_count": len(materials),
"success_count": success_count,
"knowledge_points_count": total_knowledge_points,
"analysis_results": analysis_results
}
except Exception as e:
logger.error(
f"课程资料重新分析失败 - course_id: {course_id}, error: {e}",
exc_info=True
)
raise ExternalServiceError(f"重新分析失败: {e}")
# 创建全局实例
knowledge_analysis_service_v2 = KnowledgeAnalysisServiceV2()

View File

@@ -0,0 +1,707 @@
"""
LLM JSON Parser - 大模型 JSON 输出解析器
功能:
- 使用 json-repair 库修复 AI 输出的 JSON
- 处理中文标点、尾部逗号、Python 风格等问题
- Schema 校验确保数据完整性
使用示例:
```python
from app.services.ai.llm_json_parser import parse_llm_json, parse_with_fallback
# 简单解析
result = parse_llm_json(ai_response)
# 带 Schema 校验和默认值
result = parse_with_fallback(
ai_response,
schema=MY_SCHEMA,
default=[]
)
```
"""
import json
import re
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# 尝试导入 json-repair
try:
from json_repair import loads as json_repair_loads
from json_repair import repair_json
HAS_JSON_REPAIR = True
except ImportError:
HAS_JSON_REPAIR = False
logger.warning("json-repair 未安装,将使用内置修复逻辑")
# 尝试导入 jsonschema
try:
from jsonschema import validate, ValidationError, Draft7Validator
HAS_JSONSCHEMA = True
except ImportError:
HAS_JSONSCHEMA = False
logger.warning("jsonschema 未安装,将跳过 Schema 校验")
# ==================== 异常类 ====================
class JSONParseError(Exception):
"""JSON 解析错误基类"""
def __init__(self, message: str, raw_text: str = "", issues: List[dict] = None):
super().__init__(message)
self.raw_text = raw_text
self.issues = issues or []
class JSONUnrecoverableError(JSONParseError):
"""不可恢复的 JSON 错误"""
pass
# ==================== 解析结果 ====================
@dataclass
class ParseResult:
"""解析结果"""
success: bool
data: Any = None
method: str = "" # direct / json_repair / preprocessed / fixed / completed / default
issues: List[dict] = field(default_factory=list)
raw_text: str = ""
error: str = ""
# ==================== 核心解析函数 ====================
def parse_llm_json(
text: str,
*,
strict: bool = False,
return_result: bool = False
) -> Union[Any, ParseResult]:
"""
智能解析 LLM 输出的 JSON
Args:
text: 原始文本
strict: 严格模式,不进行自动修复
return_result: 返回 ParseResult 对象而非直接数据
Returns:
解析后的 JSON 对象,或 ParseResult如果 return_result=True
Raises:
JSONUnrecoverableError: 所有修复尝试都失败
"""
if not text or not text.strip():
if return_result:
return ParseResult(success=False, error="Empty input")
raise JSONUnrecoverableError("Empty input", text)
text = text.strip()
issues = []
# 第一层:直接解析
try:
data = json.loads(text)
result = ParseResult(success=True, data=data, method="direct", raw_text=text)
return result if return_result else data
except json.JSONDecodeError:
pass
if strict:
if return_result:
return ParseResult(success=False, error="Strict mode: direct parse failed", raw_text=text)
raise JSONUnrecoverableError("Strict mode: direct parse failed", text)
# 第二层:使用 json-repair推荐
if HAS_JSON_REPAIR:
try:
data = json_repair_loads(text)
issues.append({"type": "json_repair", "action": "Auto-repaired by json-repair library"})
result = ParseResult(success=True, data=data, method="json_repair", issues=issues, raw_text=text)
return result if return_result else data
except Exception as e:
logger.debug(f"json-repair 修复失败: {e}")
# 第三层:预处理(提取代码块、清理文字)
preprocessed = _preprocess_text(text)
if preprocessed != text:
try:
data = json.loads(preprocessed)
issues.append({"type": "preprocessed", "action": "Extracted JSON from text"})
result = ParseResult(success=True, data=data, method="preprocessed", issues=issues, raw_text=text)
return result if return_result else data
except json.JSONDecodeError:
pass
# 再次尝试 json-repair
if HAS_JSON_REPAIR:
try:
data = json_repair_loads(preprocessed)
issues.append({"type": "json_repair_preprocessed", "action": "Repaired after preprocessing"})
result = ParseResult(success=True, data=data, method="json_repair", issues=issues, raw_text=text)
return result if return_result else data
except Exception:
pass
# 第四层:自动修复
fixed, fix_issues = _fix_json_format(preprocessed)
issues.extend(fix_issues)
if fixed != preprocessed:
try:
data = json.loads(fixed)
result = ParseResult(success=True, data=data, method="fixed", issues=issues, raw_text=text)
return result if return_result else data
except json.JSONDecodeError:
pass
# 第五层:尝试补全截断的 JSON
completed = _try_complete_json(fixed)
if completed:
try:
data = json.loads(completed)
issues.append({"type": "completed", "action": "Auto-completed truncated JSON"})
result = ParseResult(success=True, data=data, method="completed", issues=issues, raw_text=text)
return result if return_result else data
except json.JSONDecodeError:
pass
# 所有尝试都失败
diagnosis = diagnose_json_error(fixed)
if return_result:
return ParseResult(
success=False,
method="failed",
issues=issues + diagnosis.get("issues", []),
raw_text=text,
error=f"All parse attempts failed. Issues: {diagnosis}"
)
raise JSONUnrecoverableError(f"All parse attempts failed: {diagnosis}", text, issues)
def parse_with_fallback(
raw_text: str,
schema: dict = None,
default: Any = None,
*,
validate_schema: bool = True,
on_error: str = "default" # "default" / "raise" / "none"
) -> Any:
"""
带兜底的 JSON 解析
Args:
raw_text: 原始文本
schema: JSON Schema可选
default: 默认值
validate_schema: 是否进行 Schema 校验
on_error: 错误处理方式
Returns:
解析后的数据或默认值
"""
try:
result = parse_llm_json(raw_text, return_result=True)
if not result.success:
logger.warning(f"JSON 解析失败: {result.error}")
if on_error == "raise":
raise JSONUnrecoverableError(result.error, raw_text, result.issues)
elif on_error == "none":
return None
return default
data = result.data
# Schema 校验
if validate_schema and schema and HAS_JSONSCHEMA:
is_valid, errors = validate_json_schema(data, schema)
if not is_valid:
logger.warning(f"Schema 校验失败: {errors}")
if on_error == "raise":
raise JSONUnrecoverableError(f"Schema validation failed: {errors}", raw_text)
elif on_error == "none":
return None
return default
# 记录解析方法
if result.method != "direct":
logger.info(f"JSON 解析成功: method={result.method}, issues={result.issues}")
return data
except Exception as e:
logger.error(f"JSON 解析异常: {e}")
if on_error == "raise":
raise
elif on_error == "none":
return None
return default
# ==================== 预处理函数 ====================
def _preprocess_text(text: str) -> str:
"""预处理文本:提取代码块、清理前后文字"""
# 移除 BOM
text = text.lstrip('\ufeff')
# 移除零宽字符
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
# 提取 Markdown 代码块
patterns = [
r'```json\s*([\s\S]*?)\s*```',
r'```\s*([\s\S]*?)\s*```',
r'`([^`]+)`',
]
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
extracted = match.group(1).strip()
if extracted.startswith(('{', '[')):
text = extracted
break
# 找到 JSON 边界
text = _find_json_boundaries(text)
return text.strip()
def _find_json_boundaries(text: str) -> str:
"""找到 JSON 的起止位置"""
# 找第一个 { 或 [
start = -1
for i, c in enumerate(text):
if c in '{[':
start = i
break
if start == -1:
return text
# 找最后一个匹配的 } 或 ]
depth = 0
end = -1
in_string = False
escape = False
for i in range(start, len(text)):
c = text[i]
if escape:
escape = False
continue
if c == '\\':
escape = True
continue
if c == '"':
in_string = not in_string
continue
if in_string:
continue
if c in '{[':
depth += 1
elif c in '}]':
depth -= 1
if depth == 0:
end = i + 1
break
if end == -1:
# 找最后一个 } 或 ]
for i in range(len(text) - 1, start, -1):
if text[i] in '}]':
end = i + 1
break
if end > start:
return text[start:end]
return text[start:]
# ==================== 修复函数 ====================
def _fix_json_format(text: str) -> Tuple[str, List[dict]]:
"""修复常见 JSON 格式问题"""
issues = []
# 1. 中文标点转英文
cn_punctuation = {
'': ',', '': '.', '': ':', '': ';',
'"': '"', '"': '"', ''': "'", ''': "'",
'': '[', '': ']', '': '(', '': ')',
'': '{', '': '}',
}
for cn, en in cn_punctuation.items():
if cn in text:
text = text.replace(cn, en)
issues.append({"type": "chinese_punctuation", "from": cn, "to": en})
# 2. 移除注释
if '//' in text:
text = re.sub(r'//[^\n]*', '', text)
issues.append({"type": "removed_comments", "style": "single-line"})
if '/*' in text:
text = re.sub(r'/\*[\s\S]*?\*/', '', text)
issues.append({"type": "removed_comments", "style": "multi-line"})
# 3. Python 风格转 JSON
python_replacements = [
(r'\bTrue\b', 'true'),
(r'\bFalse\b', 'false'),
(r'\bNone\b', 'null'),
]
for pattern, replacement in python_replacements:
if re.search(pattern, text):
text = re.sub(pattern, replacement, text)
issues.append({"type": "python_style", "from": pattern, "to": replacement})
# 4. 移除尾部逗号
trailing_comma_patterns = [
(r',(\s*})', r'\1'),
(r',(\s*\])', r'\1'),
]
for pattern, replacement in trailing_comma_patterns:
if re.search(pattern, text):
text = re.sub(pattern, replacement, text)
issues.append({"type": "trailing_comma", "action": "removed"})
# 5. 修复单引号(谨慎处理)
if text.count("'") > text.count('"') and re.match(r"^\s*\{?\s*'", text):
text = re.sub(r"'([^']*)'(\s*:)", r'"\1"\2', text)
text = re.sub(r":\s*'([^']*)'", r': "\1"', text)
issues.append({"type": "single_quotes", "action": "replaced"})
return text, issues
def _try_complete_json(text: str) -> Optional[str]:
"""尝试补全截断的 JSON"""
if not text:
return None
# 统计括号
stack = []
in_string = False
escape = False
for c in text:
if escape:
escape = False
continue
if c == '\\':
escape = True
continue
if c == '"':
in_string = not in_string
continue
if in_string:
continue
if c in '{[':
stack.append(c)
elif c == '}':
if stack and stack[-1] == '{':
stack.pop()
elif c == ']':
if stack and stack[-1] == '[':
stack.pop()
if not stack:
return None # 已经平衡了
# 如果在字符串中,先闭合字符串
if in_string:
text += '"'
# 补全括号
completion = ""
for bracket in reversed(stack):
if bracket == '{':
completion += '}'
elif bracket == '[':
completion += ']'
return text + completion
# ==================== Schema 校验 ====================
def validate_json_schema(data: Any, schema: dict) -> Tuple[bool, List[dict]]:
"""
校验 JSON 是否符合 Schema
Returns:
(is_valid, errors)
"""
if not HAS_JSONSCHEMA:
logger.warning("jsonschema 未安装,跳过校验")
return True, []
try:
validator = Draft7Validator(schema)
errors = list(validator.iter_errors(data))
if errors:
error_messages = [
{
"path": list(e.absolute_path),
"message": e.message,
"validator": e.validator
}
for e in errors
]
return False, error_messages
return True, []
except Exception as e:
return False, [{"message": str(e)}]
# ==================== 诊断函数 ====================
def diagnose_json_error(text: str) -> dict:
"""诊断 JSON 错误"""
issues = []
# 检查是否为空
if not text or not text.strip():
issues.append({
"type": "empty_input",
"severity": "critical",
"suggestion": "输入为空"
})
return {"issues": issues, "fixable": False}
# 检查中文标点
cn_punctuation = ['', '', '', '', '"', '"', ''', ''']
for p in cn_punctuation:
if p in text:
issues.append({
"type": "chinese_punctuation",
"char": p,
"severity": "low",
"suggestion": f"{p} 替换为对应英文标点"
})
# 检查代码块包裹
if '```' in text:
issues.append({
"type": "markdown_wrapped",
"severity": "low",
"suggestion": "需要提取代码块内容"
})
# 检查注释
if '//' in text or '/*' in text:
issues.append({
"type": "has_comments",
"severity": "low",
"suggestion": "需要移除注释"
})
# 检查 Python 风格
if re.search(r'\b(True|False|None)\b', text):
issues.append({
"type": "python_style",
"severity": "low",
"suggestion": "将 True/False/None 转为 true/false/null"
})
# 检查尾部逗号
if re.search(r',\s*[}\]]', text):
issues.append({
"type": "trailing_comma",
"severity": "low",
"suggestion": "移除 } 或 ] 前的逗号"
})
# 检查括号平衡
open_braces = text.count('{') - text.count('}')
open_brackets = text.count('[') - text.count(']')
if open_braces > 0:
issues.append({
"type": "unclosed_brace",
"count": open_braces,
"severity": "medium",
"suggestion": f"缺少 {open_braces}}}"
})
elif open_braces < 0:
issues.append({
"type": "extra_brace",
"count": -open_braces,
"severity": "medium",
"suggestion": f"多余 {-open_braces}}}"
})
if open_brackets > 0:
issues.append({
"type": "unclosed_bracket",
"count": open_brackets,
"severity": "medium",
"suggestion": f"缺少 {open_brackets} 个 ]"
})
elif open_brackets < 0:
issues.append({
"type": "extra_bracket",
"count": -open_brackets,
"severity": "medium",
"suggestion": f"多余 {-open_brackets} 个 ]"
})
# 检查引号平衡
quote_count = text.count('"')
if quote_count % 2 != 0:
issues.append({
"type": "unbalanced_quotes",
"severity": "high",
"suggestion": "引号数量不平衡,可能有未闭合的字符串"
})
# 判断是否可修复
fixable_types = {
"chinese_punctuation", "markdown_wrapped", "has_comments",
"python_style", "trailing_comma", "unclosed_brace", "unclosed_bracket"
}
fixable = all(i["type"] in fixable_types for i in issues)
return {
"issues": issues,
"issue_count": len(issues),
"fixable": fixable,
"severity": max(
(i.get("severity", "low") for i in issues),
key=lambda x: {"low": 1, "medium": 2, "high": 3, "critical": 4}.get(x, 0),
default="low"
)
}
# ==================== 便捷函数 ====================
def safe_json_loads(text: str, default: Any = None) -> Any:
"""安全的 json.loads失败返回默认值"""
try:
return parse_llm_json(text)
except Exception:
return default
def extract_json_from_text(text: str) -> Optional[str]:
"""从文本中提取 JSON 字符串"""
preprocessed = _preprocess_text(text)
fixed, _ = _fix_json_format(preprocessed)
try:
json.loads(fixed)
return fixed
except Exception:
completed = _try_complete_json(fixed)
if completed:
try:
json.loads(completed)
return completed
except Exception:
pass
return None
def clean_llm_output(text: str) -> Tuple[str, List[str]]:
"""
清洗大模型输出,返回清洗后的文本和应用的清洗规则
Args:
text: 原始输出文本
Returns:
(cleaned_text, applied_rules)
"""
if not text:
return "", ["empty_input"]
applied_rules = []
original = text
# 1. 去除 BOM 头
if text.startswith('\ufeff'):
text = text.lstrip('\ufeff')
applied_rules.append("removed_bom")
# 2. 去除 ANSI 转义序列
ansi_pattern = re.compile(r'\x1b\[[0-9;]*m')
if ansi_pattern.search(text):
text = ansi_pattern.sub('', text)
applied_rules.append("removed_ansi")
# 3. 去除首尾空白
text = text.strip()
# 4. 去除开头的客套话
polite_patterns = [
r'^好的[,。.]?\s*',
r'^当然[,。.]?\s*',
r'^没问题[,。.]?\s*',
r'^根据您的要求[,。.]?\s*',
r'^以下是.*?[:]\s*',
r'^分析结果如下[:]\s*',
r'^我来为您.*?[:]\s*',
r'^这是.*?结果[:]\s*',
]
for pattern in polite_patterns:
if re.match(pattern, text, re.IGNORECASE):
text = re.sub(pattern, '', text, flags=re.IGNORECASE)
applied_rules.append("removed_polite_prefix")
break
# 5. 提取 Markdown JSON 代码块
json_block_patterns = [
r'```json\s*([\s\S]*?)\s*```',
r'```\s*([\s\S]*?)\s*```',
]
for pattern in json_block_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
extracted = match.group(1).strip()
if extracted.startswith(('{', '[')):
text = extracted
applied_rules.append("extracted_code_block")
break
# 6. 处理零宽字符
zero_width = re.compile(r'[\u200b\u200c\u200d\ufeff]')
if zero_width.search(text):
text = zero_width.sub('', text)
applied_rules.append("removed_zero_width")
return text.strip(), applied_rules

View File

@@ -0,0 +1,377 @@
"""
陪练分析报告服务 - Python 原生实现
功能:
- 分析陪练对话历史
- 生成综合评分、能力维度评估
- 提供对话标注和改进建议
提供稳定可靠的陪练分析报告生成能力。
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from .ai_service import AIService, AIResponse
from .llm_json_parser import parse_with_fallback, clean_llm_output
from .prompts.practice_analysis_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
PRACTICE_ANALYSIS_SCHEMA,
SCORE_BREAKDOWN_ITEMS,
ABILITY_DIMENSIONS,
)
logger = logging.getLogger(__name__)
# ==================== 数据结构 ====================
@dataclass
class ScoreBreakdownItem:
"""分数细分项"""
name: str
score: float
description: str
@dataclass
class AbilityDimensionItem:
"""能力维度项"""
name: str
score: float
feedback: str
@dataclass
class DialogueAnnotation:
"""对话标注"""
sequence: int
tags: List[str]
comment: str
@dataclass
class Suggestion:
"""改进建议"""
title: str
content: str
example: str
@dataclass
class PracticeAnalysisResult:
"""陪练分析结果"""
success: bool
total_score: float = 0.0
score_breakdown: List[ScoreBreakdownItem] = field(default_factory=list)
ability_dimensions: List[AbilityDimensionItem] = field(default_factory=list)
dialogue_annotations: List[DialogueAnnotation] = field(default_factory=list)
suggestions: List[Suggestion] = field(default_factory=list)
ai_provider: str = ""
ai_model: str = ""
ai_tokens: int = 0
ai_latency_ms: int = 0
error: str = ""
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(兼容原有数据格式)"""
return {
"analysis": {
"total_score": self.total_score,
"score_breakdown": [
{"name": s.name, "score": s.score, "description": s.description}
for s in self.score_breakdown
],
"ability_dimensions": [
{"name": d.name, "score": d.score, "feedback": d.feedback}
for d in self.ability_dimensions
],
"dialogue_annotations": [
{"sequence": a.sequence, "tags": a.tags, "comment": a.comment}
for a in self.dialogue_annotations
],
"suggestions": [
{"title": s.title, "content": s.content, "example": s.example}
for s in self.suggestions
],
},
"ai_provider": self.ai_provider,
"ai_model": self.ai_model,
"ai_tokens": self.ai_tokens,
"ai_latency_ms": self.ai_latency_ms,
}
def to_db_format(self) -> Dict[str, Any]:
"""转换为数据库存储格式(兼容 PracticeReport 模型)"""
return {
"total_score": int(self.total_score),
"score_breakdown": [
{"name": s.name, "score": s.score, "description": s.description}
for s in self.score_breakdown
],
"ability_dimensions": [
{"name": d.name, "score": d.score, "feedback": d.feedback}
for d in self.ability_dimensions
],
"dialogue_review": [
{"sequence": a.sequence, "tags": a.tags, "comment": a.comment}
for a in self.dialogue_annotations
],
"suggestions": [
{"title": s.title, "content": s.content, "example": s.example}
for s in self.suggestions
],
}
# ==================== 服务类 ====================
class PracticeAnalysisService:
"""
陪练分析报告服务
使用 Python 原生实现。
使用示例:
```python
service = PracticeAnalysisService()
result = await service.analyze(
db=db_session, # 传入 db_session 用于记录调用日志
dialogue_history=[
{"speaker": "user", "content": "您好,我想咨询一下..."},
{"speaker": "ai", "content": "您好!很高兴为您服务..."}
]
)
print(result.total_score)
print(result.suggestions)
```
"""
MODULE_CODE = "practice_analysis"
async def analyze(
self,
dialogue_history: List[Dict[str, Any]],
db: Any = None # 数据库会话,用于记录 AI 调用日志
) -> PracticeAnalysisResult:
"""
分析陪练对话
Args:
dialogue_history: 对话历史列表,每项包含 speaker, content, timestamp 等字段
db: 数据库会话,用于记录调用日志(符合 AI 接入规范)
Returns:
PracticeAnalysisResult 分析结果
"""
try:
logger.info(f"开始分析陪练对话 - 对话轮次: {len(dialogue_history)}")
# 1. 验证输入
if not dialogue_history or len(dialogue_history) < 2:
return PracticeAnalysisResult(
success=False,
error="对话记录太少无法生成分析报告至少需要2轮对话"
)
# 2. 格式化对话历史
dialogue_text = self._format_dialogue_history(dialogue_history)
# 3. 创建 AIService 实例(传入 db_session 用于记录调用日志)
self._ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
# 4. 调用 AI 分析
ai_response = await self._call_ai_analysis(dialogue_text)
logger.info(
f"AI 分析完成 - provider: {ai_response.provider}, "
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
)
# 4. 解析 JSON 结果
analysis_data = self._parse_analysis_result(ai_response.content)
# 5. 构建返回结果
result = PracticeAnalysisResult(
success=True,
total_score=analysis_data.get("total_score", 0),
score_breakdown=[
ScoreBreakdownItem(
name=s.get("name", ""),
score=s.get("score", 0),
description=s.get("description", "")
)
for s in analysis_data.get("score_breakdown", [])
],
ability_dimensions=[
AbilityDimensionItem(
name=d.get("name", ""),
score=d.get("score", 0),
feedback=d.get("feedback", "")
)
for d in analysis_data.get("ability_dimensions", [])
],
dialogue_annotations=[
DialogueAnnotation(
sequence=a.get("sequence", 0),
tags=a.get("tags", []),
comment=a.get("comment", "")
)
for a in analysis_data.get("dialogue_annotations", [])
],
suggestions=[
Suggestion(
title=s.get("title", ""),
content=s.get("content", ""),
example=s.get("example", "")
)
for s in analysis_data.get("suggestions", [])
],
ai_provider=ai_response.provider,
ai_model=ai_response.model,
ai_tokens=ai_response.total_tokens,
ai_latency_ms=ai_response.latency_ms,
)
logger.info(
f"陪练分析完成 - total_score: {result.total_score}, "
f"annotations: {len(result.dialogue_annotations)}, "
f"suggestions: {len(result.suggestions)}"
)
return result
except Exception as e:
logger.error(f"陪练分析失败: {e}", exc_info=True)
return PracticeAnalysisResult(
success=False,
error=str(e)
)
def _format_dialogue_history(self, dialogue_history: List[Dict[str, Any]]) -> str:
"""
格式化对话历史为文本
Args:
dialogue_history: 对话历史列表
Returns:
格式化后的对话文本
"""
lines = []
for i, d in enumerate(dialogue_history, 1):
speaker = d.get('speaker', 'unknown')
content = d.get('content', '')
# 统一说话者标识
if speaker in ['user', 'employee', 'consultant', '员工', '用户']:
speaker_label = '员工'
elif speaker in ['ai', 'customer', 'client', '顾客', '客户', 'AI']:
speaker_label = '顾客'
else:
speaker_label = speaker
lines.append(f"[{i}] {speaker_label}: {content}")
return '\n'.join(lines)
async def _call_ai_analysis(self, dialogue_text: str) -> AIResponse:
"""调用 AI 进行分析"""
user_message = USER_PROMPT.format(dialogue_history=dialogue_text)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_message}
]
response = await self._ai_service.chat(
messages=messages,
temperature=0.7,
prompt_name="practice_analysis"
)
return response
def _parse_analysis_result(self, ai_output: str) -> Dict[str, Any]:
"""
解析 AI 输出的分析结果 JSON
使用 LLM JSON Parser 进行多层兜底解析
"""
# 先清洗输出
cleaned_output, rules = clean_llm_output(ai_output)
if rules:
logger.debug(f"AI 输出已清洗: {rules}")
# 使用带 Schema 校验的解析
parsed = parse_with_fallback(
cleaned_output,
schema=PRACTICE_ANALYSIS_SCHEMA,
default={"analysis": {}},
validate_schema=True,
on_error="default"
)
# 提取 analysis 部分
analysis = parsed.get("analysis", {})
# 确保 score_breakdown 完整
existing_breakdown = {s.get("name") for s in analysis.get("score_breakdown", [])}
for item_name in SCORE_BREAKDOWN_ITEMS:
if item_name not in existing_breakdown:
logger.warning(f"缺少分数维度: {item_name},使用默认值")
analysis.setdefault("score_breakdown", []).append({
"name": item_name,
"score": 75,
"description": "暂无详细评价"
})
# 确保 ability_dimensions 完整
existing_dims = {d.get("name") for d in analysis.get("ability_dimensions", [])}
for dim_name in ABILITY_DIMENSIONS:
if dim_name not in existing_dims:
logger.warning(f"缺少能力维度: {dim_name},使用默认值")
analysis.setdefault("ability_dimensions", []).append({
"name": dim_name,
"score": 75,
"feedback": "暂无详细评价"
})
# 确保有建议
if not analysis.get("suggestions"):
analysis["suggestions"] = [
{
"title": "持续练习",
"content": "建议继续进行陪练练习,提升整体表现",
"example": "每周进行2-3次陪练针对薄弱环节重点练习"
}
]
return analysis
# ==================== 全局实例 ====================
practice_analysis_service = PracticeAnalysisService()
# ==================== 便捷函数 ====================
async def analyze_practice_session(
dialogue_history: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
便捷函数:分析陪练会话
Args:
dialogue_history: 对话历史列表
Returns:
分析结果字典(兼容原有格式)
"""
result = await practice_analysis_service.analyze(dialogue_history)
return result.to_dict()

View File

@@ -0,0 +1,379 @@
"""
陪练场景准备服务 - Python 原生实现
功能:
- 根据课程ID获取知识点
- 调用 AI 生成陪练场景配置
- 解析并返回结构化场景数据
提供稳定可靠的陪练场景提取能力。
"""
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ExternalServiceError
from .ai_service import AIService, AIResponse
from .llm_json_parser import parse_with_fallback, clean_llm_output
from .prompts.practice_scene_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
PRACTICE_SCENE_SCHEMA,
DEFAULT_SCENE_TYPE,
DEFAULT_DIFFICULTY,
)
logger = logging.getLogger(__name__)
# ==================== 数据结构 ====================
@dataclass
class PracticeScene:
"""陪练场景数据结构"""
name: str
description: str
background: str
ai_role: str
objectives: List[str]
keywords: List[str]
type: str = DEFAULT_SCENE_TYPE
difficulty: str = DEFAULT_DIFFICULTY
@dataclass
class PracticeSceneResult:
"""陪练场景生成结果"""
success: bool
scene: Optional[PracticeScene] = None
raw_response: Dict[str, Any] = field(default_factory=dict)
ai_provider: str = ""
ai_model: str = ""
ai_tokens: int = 0
ai_latency_ms: int = 0
knowledge_points_count: int = 0
error: str = ""
# ==================== 服务类 ====================
class PracticeSceneService:
"""
陪练场景准备服务
使用 Python 原生实现。
使用示例:
```python
service = PracticeSceneService()
result = await service.prepare_practice_knowledge(
db=db_session,
course_id=1
)
if result.success:
print(result.scene.name)
print(result.scene.objectives)
```
"""
def __init__(self):
"""初始化服务"""
self.ai_service = AIService(module_code="practice_scene")
async def prepare_practice_knowledge(
self,
db: AsyncSession,
course_id: int
) -> PracticeSceneResult:
"""
准备陪练所需的知识内容并生成场景
陪练知识准备的 Python 实现。
Args:
db: 数据库会话(支持多租户,由调用方传入对应租户的数据库连接)
course_id: 课程ID
Returns:
PracticeSceneResult: 包含场景配置和元信息的结果对象
"""
try:
logger.info(f"开始陪练知识准备 - course_id: {course_id}")
# 1. 查询知识点
knowledge_points = await self._fetch_knowledge_points(db, course_id)
if not knowledge_points:
logger.warning(f"课程没有知识点 - course_id: {course_id}")
return PracticeSceneResult(
success=False,
error=f"课程 {course_id} 没有可用的知识点"
)
logger.info(f"获取到 {len(knowledge_points)} 个知识点 - course_id: {course_id}")
# 2. 格式化知识点为文本
knowledge_text = self._format_knowledge_points(knowledge_points)
# 3. 调用 AI 生成场景
ai_response = await self._call_ai_generation(knowledge_text)
logger.info(
f"AI 生成完成 - provider: {ai_response.provider}, "
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
)
# 4. 解析 JSON 结果
scene_data = self._parse_scene_response(ai_response.content)
if not scene_data:
logger.error(f"场景解析失败 - course_id: {course_id}")
return PracticeSceneResult(
success=False,
raw_response={"ai_output": ai_response.content},
ai_provider=ai_response.provider,
ai_model=ai_response.model,
ai_tokens=ai_response.total_tokens,
ai_latency_ms=ai_response.latency_ms,
knowledge_points_count=len(knowledge_points),
error="AI 输出解析失败"
)
# 5. 构建场景对象
scene = self._build_scene_object(scene_data)
logger.info(
f"陪练场景生成成功 - course_id: {course_id}, "
f"scene_name: {scene.name}, type: {scene.type}"
)
return PracticeSceneResult(
success=True,
scene=scene,
raw_response=scene_data,
ai_provider=ai_response.provider,
ai_model=ai_response.model,
ai_tokens=ai_response.total_tokens,
ai_latency_ms=ai_response.latency_ms,
knowledge_points_count=len(knowledge_points)
)
except Exception as e:
logger.error(
f"陪练知识准备失败 - course_id: {course_id}, error: {e}",
exc_info=True
)
return PracticeSceneResult(
success=False,
error=str(e)
)
async def _fetch_knowledge_points(
self,
db: AsyncSession,
course_id: int
) -> List[Dict[str, Any]]:
"""
从数据库获取课程知识点
获取课程知识点
"""
# 知识点查询 SQL
# SELECT kp.name, kp.description
# FROM knowledge_points kp
# INNER JOIN course_materials cm ON kp.material_id = cm.id
# WHERE kp.course_id = {course_id}
# AND kp.is_deleted = 0
# AND cm.is_deleted = 0
# ORDER BY kp.id;
sql = text("""
SELECT kp.name, kp.description
FROM knowledge_points kp
INNER JOIN course_materials cm ON kp.material_id = cm.id
WHERE kp.course_id = :course_id
AND kp.is_deleted = 0
AND cm.is_deleted = 0
ORDER BY kp.id
""")
try:
result = await db.execute(sql, {"course_id": course_id})
rows = result.fetchall()
knowledge_points = []
for row in rows:
knowledge_points.append({
"name": row[0],
"description": row[1] or ""
})
return knowledge_points
except Exception as e:
logger.error(f"查询知识点失败: {e}")
raise ExternalServiceError(f"数据库查询失败: {e}")
def _format_knowledge_points(self, knowledge_points: List[Dict[str, Any]]) -> str:
"""
将知识点列表格式化为文本
Args:
knowledge_points: 知识点列表
Returns:
格式化后的文本
"""
lines = []
for i, kp in enumerate(knowledge_points, 1):
name = kp.get("name", "")
description = kp.get("description", "")
if description:
lines.append(f"{i}. {name}\n {description}")
else:
lines.append(f"{i}. {name}")
return "\n\n".join(lines)
async def _call_ai_generation(self, knowledge_text: str) -> AIResponse:
"""
调用 AI 生成陪练场景
Args:
knowledge_text: 格式化后的知识点文本
Returns:
AI 响应对象
"""
# 构建用户消息
user_message = USER_PROMPT.format(knowledge_points=knowledge_text)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_message}
]
# 调用 AI自动降级4sapi.com → OpenRouter
response = await self.ai_service.chat(
messages=messages,
temperature=0.7, # 适中的创意性
prompt_name="practice_scene_generation"
)
return response
def _parse_scene_response(self, ai_output: str) -> Optional[Dict[str, Any]]:
"""
解析 AI 输出的场景 JSON
使用 LLM JSON Parser 进行多层兜底解析
Args:
ai_output: AI 原始输出
Returns:
解析后的字典,失败返回 None
"""
# 先清洗输出
cleaned_output, rules = clean_llm_output(ai_output)
if rules:
logger.debug(f"AI 输出已清洗: {rules}")
# 使用带 Schema 校验的解析
result = parse_with_fallback(
cleaned_output,
schema=PRACTICE_SCENE_SCHEMA,
default=None,
validate_schema=True,
on_error="none"
)
return result
def _build_scene_object(self, scene_data: Dict[str, Any]) -> PracticeScene:
"""
从解析的字典构建场景对象
Args:
scene_data: 解析后的场景数据
Returns:
PracticeScene 对象
"""
# 提取 scene 字段JSON 格式为 {"scene": {...}}
scene = scene_data.get("scene", scene_data)
return PracticeScene(
name=scene.get("name", "陪练场景"),
description=scene.get("description", ""),
background=scene.get("background", ""),
ai_role=scene.get("ai_role", "AI扮演客户"),
objectives=scene.get("objectives", []),
keywords=scene.get("keywords", []),
type=scene.get("type", DEFAULT_SCENE_TYPE),
difficulty=scene.get("difficulty", DEFAULT_DIFFICULTY)
)
def scene_to_dict(self, scene: PracticeScene) -> Dict[str, Any]:
"""
将场景对象转换为字典
便于 API 响应序列化
Args:
scene: PracticeScene 对象
Returns:
字典格式的场景数据
"""
return {
"scene": {
"name": scene.name,
"description": scene.description,
"background": scene.background,
"ai_role": scene.ai_role,
"objectives": scene.objectives,
"keywords": scene.keywords,
"type": scene.type,
"difficulty": scene.difficulty
}
}
# ==================== 全局实例 ====================
practice_scene_service = PracticeSceneService()
# ==================== 便捷函数 ====================
async def prepare_practice_knowledge(
db: AsyncSession,
course_id: int
) -> PracticeSceneResult:
"""
准备陪练所需的知识内容(便捷函数)
Args:
db: 数据库会话
course_id: 课程ID
Returns:
PracticeSceneResult 结果对象
"""
return await practice_scene_service.prepare_practice_knowledge(db, course_id)

View File

@@ -0,0 +1,57 @@
"""
提示词模板模块
遵循瑞小美提示词规范
"""
from .knowledge_analysis_prompts import (
PROMPT_META as KNOWLEDGE_ANALYSIS_PROMPT_META,
SYSTEM_PROMPT as KNOWLEDGE_ANALYSIS_SYSTEM_PROMPT,
USER_PROMPT as KNOWLEDGE_ANALYSIS_USER_PROMPT,
KNOWLEDGE_POINT_SCHEMA,
)
from .exam_generator_prompts import (
PROMPT_META as EXAM_GENERATOR_PROMPT_META,
SYSTEM_PROMPT as EXAM_GENERATOR_SYSTEM_PROMPT,
USER_PROMPT as EXAM_GENERATOR_USER_PROMPT,
MISTAKE_REGEN_SYSTEM_PROMPT,
MISTAKE_REGEN_USER_PROMPT,
QUESTION_SCHEMA,
QUESTION_TYPES,
DEFAULT_QUESTION_COUNTS,
DEFAULT_DIFFICULTY_LEVEL,
)
from .ability_analysis_prompts import (
PROMPT_META as ABILITY_ANALYSIS_PROMPT_META,
SYSTEM_PROMPT as ABILITY_ANALYSIS_SYSTEM_PROMPT,
USER_PROMPT as ABILITY_ANALYSIS_USER_PROMPT,
ABILITY_ANALYSIS_SCHEMA,
ABILITY_DIMENSIONS,
)
__all__ = [
# Knowledge Analysis Prompts
"KNOWLEDGE_ANALYSIS_PROMPT_META",
"KNOWLEDGE_ANALYSIS_SYSTEM_PROMPT",
"KNOWLEDGE_ANALYSIS_USER_PROMPT",
"KNOWLEDGE_POINT_SCHEMA",
# Exam Generator Prompts
"EXAM_GENERATOR_PROMPT_META",
"EXAM_GENERATOR_SYSTEM_PROMPT",
"EXAM_GENERATOR_USER_PROMPT",
"MISTAKE_REGEN_SYSTEM_PROMPT",
"MISTAKE_REGEN_USER_PROMPT",
"QUESTION_SCHEMA",
"QUESTION_TYPES",
"DEFAULT_QUESTION_COUNTS",
"DEFAULT_DIFFICULTY_LEVEL",
# Ability Analysis Prompts
"ABILITY_ANALYSIS_PROMPT_META",
"ABILITY_ANALYSIS_SYSTEM_PROMPT",
"ABILITY_ANALYSIS_USER_PROMPT",
"ABILITY_ANALYSIS_SCHEMA",
"ABILITY_DIMENSIONS",
]

View File

@@ -0,0 +1,215 @@
"""
智能工牌能力分析与课程推荐提示词模板
功能:分析员工与顾客的对话记录,评估能力维度得分,并推荐适合的课程
"""
# ==================== 元数据 ====================
PROMPT_META = {
"name": "ability_analysis",
"display_name": "智能工牌能力分析",
"description": "分析员工与顾客对话,评估多维度能力得分,推荐个性化课程",
"module": "kaopeilian",
"variables": ["dialogue_history", "user_info", "courses"],
"version": "1.0.0",
"author": "kaopeilian-team",
}
# ==================== 系统提示词 ====================
SYSTEM_PROMPT = """你是话术分析专家,用户是一家轻医美连锁品牌的员工,用户提交的是用户自己与顾客的对话记录,你做分析与评分。并严格按照以下格式输出。并根据课程列表,为该用户提供选课建议。
输出标准:
{
"analysis": {
"total_score": 82,
"ability_dimensions": [
{
"name": "专业知识",
"score": 88,
"feedback": "产品知识扎实,能准确回答客户问题。建议:继续深化对新产品的了解。"
},
{
"name": "沟通技巧",
"score": 92,
"feedback": "语言表达清晰流畅,善于倾听客户需求。建议:可以多使用开放式问题引导。"
},
{
"name": "操作技能",
"score": 85,
"feedback": "基本操作熟练,流程规范。建议:提升复杂场景的应对速度。"
},
{
"name": "客户服务",
"score": 90,
"feedback": "服务态度优秀,客户体验良好。建议:进一步提升个性化服务能力。"
},
{
"name": "安全意识",
"score": 79,
"feedback": "基本安全规范掌握,但在细节提醒上还可加强。"
},
{
"name": "应变能力",
"score": 76,
"feedback": "面对突发情况反应较快,但处理方式可以更灵活多样。"
}
],
"course_recommendations": [
{
"course_id": 5,
"course_name": "应变能力提升训练营",
"recommendation_reason": "该课程专注于提升应变能力包含大量实战案例分析和模拟演练针对您当前的薄弱环节应变能力76分设计。通过学习可提升15分左右。",
"priority": "high",
"match_score": 95
},
{
"course_id": 3,
"course_name": "安全规范与操作标准",
"recommendation_reason": "系统讲解安全规范和操作标准通过案例教学帮助建立安全意识。当前您的安全意识得分为79分通过本课程学习预计可提升12分。",
"priority": "high",
"match_score": 88
},
{
"course_id": 7,
"course_name": "高级销售技巧",
"recommendation_reason": "进阶课程帮助您将已有的沟通优势92分转化为更高级的销售技能进一步巩固客户服务能力90分",
"priority": "medium",
"match_score": 82
}
]
}
}
## 输出要求(严格执行)
1. 直接输出纯净的 JSON不要包含 Markdown 标记(如 ```json
2. 不要包含任何解释性文字
3. 能力维度必须包含:专业知识、沟通技巧、操作技能、客户服务、安全意识、应变能力
4. 课程推荐必须来自提供的课程列表,使用真实的 course_id
5. 推荐课程数量1-5个优先推荐能补齐短板的课程
6. priority 取值high得分<80的薄弱项、medium得分80-85、low锦上添花
## 评分标准
- 90-100优秀
- 80-89良好
- 70-79一般
- 60-69需改进
- <60亟需提升"""
# ==================== 用户提示词模板 ====================
USER_PROMPT = """对话记录:{dialogue_history}
---
用户的信息和岗位:{user_info}
---
所有可选课程:{courses}"""
# ==================== JSON Schema ====================
ABILITY_ANALYSIS_SCHEMA = {
"type": "object",
"required": ["analysis"],
"properties": {
"analysis": {
"type": "object",
"required": ["total_score", "ability_dimensions", "course_recommendations"],
"properties": {
"total_score": {
"type": "number",
"description": "总体评分0-100",
"minimum": 0,
"maximum": 100
},
"ability_dimensions": {
"type": "array",
"description": "能力维度评分列表",
"items": {
"type": "object",
"required": ["name", "score", "feedback"],
"properties": {
"name": {
"type": "string",
"description": "能力维度名称"
},
"score": {
"type": "number",
"description": "该维度得分0-100",
"minimum": 0,
"maximum": 100
},
"feedback": {
"type": "string",
"description": "该维度的反馈和建议"
}
}
},
"minItems": 1
},
"course_recommendations": {
"type": "array",
"description": "课程推荐列表",
"items": {
"type": "object",
"required": ["course_id", "course_name", "recommendation_reason", "priority", "match_score"],
"properties": {
"course_id": {
"type": "integer",
"description": "课程ID"
},
"course_name": {
"type": "string",
"description": "课程名称"
},
"recommendation_reason": {
"type": "string",
"description": "推荐理由"
},
"priority": {
"type": "string",
"description": "推荐优先级",
"enum": ["high", "medium", "low"]
},
"match_score": {
"type": "number",
"description": "匹配度得分0-100",
"minimum": 0,
"maximum": 100
}
}
}
}
}
}
}
}
# ==================== 能力维度常量 ====================
ABILITY_DIMENSIONS = [
"专业知识",
"沟通技巧",
"操作技能",
"客户服务",
"安全意识",
"应变能力",
]
PRIORITY_LEVELS = ["high", "medium", "low"]

View File

@@ -0,0 +1,48 @@
"""
答案判断器提示词模板
功能:判断填空题与问答题是否回答正确
"""
# ==================== 元数据 ====================
PROMPT_META = {
"name": "answer_judge",
"display_name": "答案判断器",
"description": "判断填空题与问答题的答案是否正确",
"module": "kaopeilian",
"variables": ["question", "correct_answer", "user_answer", "analysis"],
"version": "1.0.0",
"author": "kaopeilian-team",
}
# ==================== 系统提示词 ====================
SYSTEM_PROMPT = """你是一个答案判断器,根据用户提交的答案,比对题目、答案、解析。给出正确或错误的判断。
注意:仅输出"正确""错误",无需更多字符和说明。"""
# ==================== 用户提示词模板 ====================
USER_PROMPT = """题目:{question}
正确答案:{correct_answer}
解析:{analysis}
考生的回答:{user_answer}"""
# ==================== 判断结果常量 ====================
CORRECT_KEYWORDS = ["正确", "correct", "true", "yes", "", ""]
INCORRECT_KEYWORDS = ["错误", "incorrect", "false", "no", "wrong", "不正确", ""]

View File

@@ -0,0 +1,74 @@
"""
课程对话提示词模板
功能:基于课程知识点进行智能问答
"""
# ==================== 元数据 ====================
PROMPT_META = {
"name": "course_chat",
"display_name": "与课程对话",
"description": "基于课程知识点内容,为用户提供智能问答服务",
"module": "kaopeilian",
"variables": ["knowledge_base", "query"],
"version": "2.0.0",
"author": "kaopeilian-team",
}
# ==================== 系统提示词 ====================
SYSTEM_PROMPT = """你是知识拆解专家,精通以下知识库(课程)内容。请根据用户的问题,从知识库中找到最相关的信息,进行深入分析后,用简洁清晰的语言回答用户。为用户提供与课程对话的服务。
回答要求:
1. 直接针对问题核心,避免冗长铺垫
2. 使用通俗易懂的语言,必要时举例说明
3. 突出关键要点,帮助用户快速理解
4. 如果知识库中没有相关内容,请如实告知
知识库:
{knowledge_base}"""
# ==================== 用户提示词模板 ====================
USER_PROMPT = """{query}"""
# ==================== 知识库格式模板 ====================
KNOWLEDGE_ITEM_TEMPLATE = """{name}
{description}
"""
# ==================== 配置常量 ====================
# 会话历史窗口大小(保留最近 N 轮对话)
CONVERSATION_WINDOW_SIZE = 10
# 会话 TTL- 30 分钟
CONVERSATION_TTL = 1800
# 最大知识点数量
MAX_KNOWLEDGE_POINTS = 50
# 知识库最大字符数
MAX_KNOWLEDGE_BASE_LENGTH = 50000
# 默认模型
DEFAULT_CHAT_MODEL = "gemini-3-flash-preview"
# 温度参数(对话场景使用较高温度)
DEFAULT_TEMPERATURE = 0.7

View File

@@ -0,0 +1,300 @@
"""
试题生成器提示词模板
功能:根据岗位和知识点动态生成考试题目
"""
# ==================== 元数据 ====================
PROMPT_META = {
"name": "exam_generator",
"display_name": "试题生成器",
"description": "根据课程知识点和岗位特征,动态生成考试题目(单选、多选、判断、填空、问答)",
"module": "kaopeilian",
"variables": [
"total_count",
"single_choice_count",
"multiple_choice_count",
"true_false_count",
"fill_blank_count",
"essay_count",
"difficulty_level",
"position_info",
"knowledge_points",
],
"version": "2.0.0",
"author": "kaopeilian-team",
}
# ==================== 系统提示词(第一轮出题) ====================
SYSTEM_PROMPT = """## 角色
你是一位经验丰富的考试出题专家,能够依据用户提供的知识内容,结合用户的岗位特征,随机地生成{total_count}题考题。你会以专业、严谨且清晰的方式出题。
## 输出{single_choice_count}道单选题
1、每道题目只能有 1 个正确答案。
2、干扰项要具有合理性和迷惑性且所有选项必须与主题相关。
3、答案解析要简明扼要说明选择理由。
4、为每道题记录出题来源的知识点 id。
5、请以 JSON 格式输出。
6、为每道题输出一个序号。
### 输出结构:
{{
"num": "题号",
"type": "single_choice",
"topic": {{
"title": "清晰完整的题目描述",
"options": {{
"opt1": "A符合语境的选项",
"opt2": "B符合语境的选项",
"opt3": "C符合语境的选项",
"opt4": "D符合语境的选项"
}}
}},
"knowledge_point_id": "出题来源知识点的id",
"correct": "其中一个选项的全部原文",
"analysis": "准确的答案解析,包含选择原因和知识点说明"
}}
- 严格按照以上格式输出
## 输出{multiple_choice_count}道多选题
1、每道题目有多个正确答案。
2、"type": "multiple_choice"
3、其它事项同单选题。
## 输出{true_false_count}道判断题
1、每道题目只有 "正确""错误" 两种答案。
2、题目表述应明确清晰避免歧义。
3、题目应直接陈述事实或观点便于做出是非判断。
4、其它事项同单选题。
### 输出结构:
{{
"num": "题号",
"type": "true_false",
"topic": {{
"title": "清晰完整的题目描述"
}},
"knowledge_point_id": " 出题来源知识点的id",
"correct": "正确",
"analysis": "准确的答案解析,包含判断原因和知识点说明"
}}
- 严格按照以上格式输出
## 输出{fill_blank_count}道填空题
1. 题干应明确完整,空缺处需用横线"___"标示,且只能有一处空缺
2. 答案应唯一且明确,避免开放性表述
3. 空缺长度应与答案长度大致匹配
4. 解析需说明答案依据及相关知识点
5. 其余要求与单选题一致
### 输出结构:
{{
"num": "题号",
"type": "fill_blank",
"topic": {{
"title": "包含___空缺的题目描述"
}},
"knowledge_point_id": "出题来源知识点的id",
"correct": "准确的填空答案",
"analysis": "解析答案的依据和相关知识点说明"
}}
- 严格按照以上格式输出
### 输出{essay_count}道问答题
1. 问题应具体明确,限定回答范围
2. 答案需条理清晰,突出核心要点
3. 解析可补充扩展说明或评分要点
4. 避免过于宽泛或需要主观发挥的问题
5. 其余要求同单选题
### 输出结构:
{{
"num": "题号",
"type": "essay",
"topic": {{
"title": "需要详细回答的问题描述"
}},
"knowledge_point_id": "出题来源知识点的id",
"correct": "完整准确的参考答案(分点或连贯表述)",
"analysis": "对答案的补充说明、评分要点或相关知识点扩展"
}}
## 特殊要求
1. 题目难度:{difficulty_level}5 级为最难)
2. 避免使用模棱两可的表述
3. 选项内容要互斥,不能有重叠
4. 每个选项长度尽量均衡
5. 正确答案A、B、C、D分布要合理避免规律性
6. 正确答案必须使用其中一个选项中的全部原文,严禁修改
7. knowledge_point_id 必须是唯一的,即每道题的知识点来源只允许填一个 id。
## 输出格式要求
请直接输出一个纯净的 JSON 数组Array不要包含 Markdown 标记(如 ```json也不要包含任何解释性文字。
请按以上要求生成题目,确保每道题目质量。"""
# ==================== 用户提示词模板(第一轮出题) ====================
USER_PROMPT = """# 请针对岗位特征、待出题的知识点内容进行出题。
## 岗位信息:
{position_info}
---
## 知识点:
{knowledge_points}"""
# ==================== 错题重出系统提示词 ====================
MISTAKE_REGEN_SYSTEM_PROMPT = """## 角色
你是一位经验丰富的考试出题专家,能够依据用户提供的错题记录,重新为用户出题。你会为每道错题重新出一题,你会以专业、严谨且清晰的方式出题。
## 输出单选题
1、每道题目只能有 1 个正确答案。
2、干扰项要具有合理性和迷惑性且所有选项必须与主题相关。
3、答案解析要简明扼要说明选择理由。
4、为每道题记录出题来源的知识点 id。
5、请以 JSON 格式输出。
6、为每道题输出一个序号。
### 输出结构:
{{
"num": "题号",
"type": "single_choice",
"topic": {{
"title": "清晰完整的题目描述",
"options": {{
"opt1": "A符合语境的选项",
"opt2": "B符合语境的选项",
"opt3": "C符合语境的选项",
"opt4": "D符合语境的选项"
}}
}},
"knowledge_point_id": "出题来源知识点的id",
"correct": "其中一个选项的全部原文",
"analysis": "准确的答案解析,包含选择原因和知识点说明"
}}
- 严格按照以上格式输出
## 特殊要求
1. 题目难度:{difficulty_level}5 级为最难)
2. 避免使用模棱两可的表述
3. 选项内容要互斥,不能有重叠
4. 每个选项长度尽量均衡
5. 正确答案A、B、C、D分布要合理避免规律性
6. 正确答案必须使用其中一个选项中的全部原文,严禁修改
7. knowledge_point_id 必须是唯一的,即每道题的知识点来源只允许填一个 id。
## 输出格式要求
请直接输出一个纯净的 JSON 数组Array不要包含 Markdown 标记(如 ```json也不要包含任何解释性文字。
请按以上要求生成题目,确保每道题目质量。"""
# ==================== 错题重出用户提示词 ====================
MISTAKE_REGEN_USER_PROMPT = """## 错题记录:
{mistake_records}"""
# ==================== JSON Schema ====================
QUESTION_SCHEMA = {
"type": "array",
"items": {
"type": "object",
"required": ["num", "type", "topic", "correct"],
"properties": {
"num": {
"oneOf": [
{"type": "integer"},
{"type": "string"}
],
"description": "题号"
},
"type": {
"type": "string",
"enum": ["single_choice", "multiple_choice", "true_false", "fill_blank", "essay"],
"description": "题目类型"
},
"topic": {
"type": "object",
"required": ["title"],
"properties": {
"title": {
"type": "string",
"description": "题目标题"
},
"options": {
"type": "object",
"description": "选项(选择题必填)"
}
}
},
"knowledge_point_id": {
"oneOf": [
{"type": "integer"},
{"type": "string"},
{"type": "null"}
],
"description": "知识点ID"
},
"correct": {
"type": "string",
"description": "正确答案"
},
"analysis": {
"type": "string",
"description": "答案解析"
}
}
},
"minItems": 1,
"maxItems": 50
}
# ==================== 题目类型常量 ====================
QUESTION_TYPES = {
"single_choice": "单选题",
"multiple_choice": "多选题",
"true_false": "判断题",
"fill_blank": "填空题",
"essay": "问答题",
}
# 默认题目数量配置
DEFAULT_QUESTION_COUNTS = {
"single_choice_count": 4,
"multiple_choice_count": 2,
"true_false_count": 1,
"fill_blank_count": 2,
"essay_count": 1,
}
DEFAULT_DIFFICULTY_LEVEL = 3
MAX_DIFFICULTY_LEVEL = 5

View File

@@ -0,0 +1,148 @@
"""
知识点分析提示词模板
功能:从课程资料中提取知识点
"""
# ==================== 元数据 ====================
PROMPT_META = {
"name": "knowledge_analysis",
"display_name": "知识点分析",
"description": "从课程资料中提取和分析知识点支持PDF/Word/文本等格式",
"module": "kaopeilian",
"variables": ["course_name", "content"],
"version": "2.0.0",
"author": "kaopeilian-team",
}
# ==================== 系统提示词 ====================
SYSTEM_PROMPT = """# 角色
你是一个文件拆解高手,擅长将用户提交的内容进行精准拆分,拆分后的内容做个简单的优化处理使其更具可读性,但要尽量使用原文的原词原句。
## 技能
### 技能 1: 内容拆分
1. 当用户提交内容后,拆分为多段。
2. 对拆分后的内容做简单优化,使其更具可读性,比如去掉奇怪符号(如换行符、乱码),若语句不通顺,或格式原因导致错位,则重新表达。用户可能会提交录音转文字的内容,因此可能是有错字的,注意修复这些小瑕疵。
3. 优化过程中,尽量使用原文的原词原句,特别是话术类,必须保持原有的句式、保持原词原句,而不是重构。
4. 注意是拆分而不是重写,不需要润色,尽量不做任何处理。
5. 输出到 content。
### 技能 2: 为每一个选段概括一个标题
1. 为每个拆分出来的选段概括一个标题,并输出到 title。
### 技能 3: 为每一个选段说明与主题的关联
1. 详细说明这一段与全文核心主题的关联,并输出到 topic_relation。
### 技能 4: 为每一个选段打上一个类型标签
1. 用户提交的内容很有可能是一个课程、一篇讲义、一个产品的说明书,通常是用户希望他公司的员工或高管学习的知识。
2. 用户通常是医疗美容机构或轻医美、生活美容连锁品牌。
3. 你要为每个选段打上一个知识类型的标签,最好是这几个类型中的一个:"理论知识", "诊断设计", "操作步骤", "沟通话术", "案例分析", "注意事项", "技巧方法", "客诉处理"。当然你也可以为这个选段匹配一个更适合的。
## 输出要求(严格按要求输出)
请直接输出一个纯净的 JSON 数组Array不要包含 Markdown 标记(如 ```json也不要包含任何解释性文字。格式如下
[
{
"title": "知识点标题",
"content": "知识点内容",
"topic_relation": "知识点与主题的关系",
"type": "知识点类型"
},
{
"title": "第二个知识点标题",
"content": "第二个知识点内容...",
"topic_relation": "...",
"type": "..."
}
]
## 限制
- 仅围绕用户提交的内容进行拆分和关联标注,不涉及其他无关内容。
- 拆分后的内容必须最大程度保持与原文一致。
- 关联说明需清晰合理。
- 不论如何,不要拆分超过 20 段!"""
# ==================== 用户提示词模板 ====================
USER_PROMPT = """课程主题:{course_name}
## 用户提交的内容:
{content}
## 注意
- 以json的格式输出
- 不论如何不要拆分超过20 段!"""
# ==================== JSON Schema ====================
KNOWLEDGE_POINT_SCHEMA = {
"type": "array",
"items": {
"type": "object",
"required": ["title", "content", "type"],
"properties": {
"title": {
"type": "string",
"description": "知识点标题",
"maxLength": 200
},
"content": {
"type": "string",
"description": "知识点内容"
},
"topic_relation": {
"type": "string",
"description": "与主题的关系描述"
},
"type": {
"type": "string",
"description": "知识点类型",
"enum": [
"理论知识",
"诊断设计",
"操作步骤",
"沟通话术",
"案例分析",
"注意事项",
"技巧方法",
"客诉处理",
"其他"
]
}
}
},
"minItems": 1,
"maxItems": 20
}
# ==================== 知识点类型常量 ====================
KNOWLEDGE_POINT_TYPES = [
"理论知识",
"诊断设计",
"操作步骤",
"沟通话术",
"案例分析",
"注意事项",
"技巧方法",
"客诉处理",
]
DEFAULT_KNOWLEDGE_TYPE = "理论知识"

View File

@@ -0,0 +1,193 @@
"""
陪练分析报告提示词模板
功能:分析陪练对话,生成综合评分和改进建议
"""
# ==================== 元数据 ====================
PROMPT_META = {
"name": "practice_analysis",
"display_name": "陪练分析报告",
"description": "分析陪练对话,生成综合评分、能力维度评估、对话标注和改进建议",
"module": "kaopeilian",
"variables": ["dialogue_history"],
"version": "1.0.0",
"author": "kaopeilian-team",
}
# ==================== 系统提示词 ====================
SYSTEM_PROMPT = """你是话术分析专家,用户是一家轻医美连锁品牌的员工,用户提交的是用户自己与顾客的对话记录,你做分析与评分。并严格按照以下格式输出。
输出标准:
{
"analysis": {
"total_score": 88,
"score_breakdown": [
{"name": "开场技巧", "score": 92, "description": "开场自然,快速建立信任"},
{"name": "需求挖掘", "score": 90, "description": "能够有效识别客户需求"},
{"name": "产品介绍", "score": 88, "description": "产品介绍清晰,重点突出"},
{"name": "异议处理", "score": 85, "description": "处理客户异议还需加强"},
{"name": "成交技巧", "score": 86, "description": "成交话术运用良好"}
],
"ability_dimensions": [
{"name": "沟通表达", "score": 90, "feedback": "语言流畅,表达清晰,语调富有亲和力"},
{"name": "倾听理解", "score": 92, "feedback": "能够准确理解客户意图,给予恰当回应"},
{"name": "情绪控制", "score": 88, "feedback": "整体情绪稳定,面对异议时保持专业"},
{"name": "专业知识", "score": 93, "feedback": "对医美项目知识掌握扎实"},
{"name": "销售技巧", "score": 87, "feedback": "销售流程把控良好"},
{"name": "应变能力", "score": 85, "feedback": "面对突发问题能够快速反应"}
],
"dialogue_annotations": [
{"sequence": 1, "tags": ["亮点话术"], "comment": "开场专业,身份介绍清晰"},
{"sequence": 3, "tags": ["金牌话术"], "comment": "巧妙引导,从客户角度出发"},
{"sequence": 5, "tags": ["亮点话术"], "comment": "类比生动,让客户容易理解"},
{"sequence": 7, "tags": ["金牌话术"], "comment": "专业解答,打消客户疑虑"}
],
"suggestions": [
{"title": "控制语速", "content": "您的语速偏快,建议适当放慢,给客户更多思考时间", "example": "说完产品优势后停顿2-3秒观察客户反应"},
{"title": "多用开放式问题", "content": "增加开放式问题的使用,更深入了解客户需求", "example": "您对未来的保障有什么期望?而不是您需要保险吗?"},
{"title": "强化成交信号识别", "content": "客户已经表现出兴趣时,要及时推进成交", "example": "当客户问费用多少时,这是购买信号,应该立即报价并促成"}
]
}
}
## 输出要求(严格执行)
1. 直接输出纯净的 JSON不要包含 Markdown 标记(如 ```json
2. 不要包含任何解释性文字
3. score_breakdown 必须包含 5 项:开场技巧、需求挖掘、产品介绍、异议处理、成交技巧
4. ability_dimensions 必须包含 6 项:沟通表达、倾听理解、情绪控制、专业知识、销售技巧、应变能力
5. dialogue_annotations 标注有亮点或问题的对话轮次tags 可选:亮点话术、金牌话术、待改进、问题话术
6. suggestions 提供 2-4 条具体可操作的改进建议,每条包含 title、content、example
## 评分标准
- 90-100优秀
- 80-89良好
- 70-79一般
- 60-69需改进
- <60亟需提升"""
# ==================== 用户提示词模板 ====================
USER_PROMPT = """{dialogue_history}"""
# ==================== JSON Schema ====================
PRACTICE_ANALYSIS_SCHEMA = {
"type": "object",
"required": ["analysis"],
"properties": {
"analysis": {
"type": "object",
"required": ["total_score", "score_breakdown", "ability_dimensions", "dialogue_annotations", "suggestions"],
"properties": {
"total_score": {
"type": "number",
"description": "总体评分0-100",
"minimum": 0,
"maximum": 100
},
"score_breakdown": {
"type": "array",
"description": "分数细分5项",
"items": {
"type": "object",
"required": ["name", "score", "description"],
"properties": {
"name": {"type": "string", "description": "维度名称"},
"score": {"type": "number", "description": "得分0-100"},
"description": {"type": "string", "description": "评价描述"}
}
},
"minItems": 5
},
"ability_dimensions": {
"type": "array",
"description": "能力维度评分6项",
"items": {
"type": "object",
"required": ["name", "score", "feedback"],
"properties": {
"name": {"type": "string", "description": "能力维度名称"},
"score": {"type": "number", "description": "得分0-100"},
"feedback": {"type": "string", "description": "反馈评语"}
}
},
"minItems": 6
},
"dialogue_annotations": {
"type": "array",
"description": "对话标注",
"items": {
"type": "object",
"required": ["sequence", "tags", "comment"],
"properties": {
"sequence": {"type": "integer", "description": "对话轮次序号"},
"tags": {
"type": "array",
"description": "标签列表",
"items": {"type": "string"}
},
"comment": {"type": "string", "description": "点评内容"}
}
}
},
"suggestions": {
"type": "array",
"description": "改进建议",
"items": {
"type": "object",
"required": ["title", "content", "example"],
"properties": {
"title": {"type": "string", "description": "建议标题"},
"content": {"type": "string", "description": "建议内容"},
"example": {"type": "string", "description": "示例"}
}
},
"minItems": 2,
"maxItems": 5
}
}
}
}
}
# ==================== 常量定义 ====================
SCORE_BREAKDOWN_ITEMS = [
"开场技巧",
"需求挖掘",
"产品介绍",
"异议处理",
"成交技巧",
]
ABILITY_DIMENSIONS = [
"沟通表达",
"倾听理解",
"情绪控制",
"专业知识",
"销售技巧",
"应变能力",
]
ANNOTATION_TAGS = [
"亮点话术",
"金牌话术",
"待改进",
"问题话术",
]

View File

@@ -0,0 +1,192 @@
"""
陪练场景生成提示词模板
功能:根据课程知识点生成陪练场景配置
"""
# ==================== 元数据 ====================
PROMPT_META = {
"name": "practice_scene_generation",
"display_name": "陪练场景生成",
"description": "根据课程知识点生成 AI 陪练场景配置包含场景名称、背景、AI 角色、练习目标等",
"module": "kaopeilian",
"variables": ["knowledge_points"],
"version": "1.0.0",
"author": "kaopeilian-team",
}
# ==================== 系统提示词 ====================
SYSTEM_PROMPT = """你是一个训练场景研究专家,能将用户提交的知识点,转变为一个模拟陪练的场景,并严格按照以下格式输出。
输出标准:
{
"scene": {
"name": "轻医美产品咨询陪练",
"description": "模拟客户咨询轻医美产品的场景",
"background": "客户对脸部抗衰项目感兴趣。",
"ai_role": "AI扮演一位30岁女性客户",
"objectives": ["了解客户需求", "介绍产品优势", "处理价格异议"],
"keywords": ["抗衰", "玻尿酸", "价格"],
"type": "product-intro",
"difficulty": "intermediate"
}
}
## 字段说明
- **name**: 场景名称,简洁明了,体现陪练主题
- **description**: 场景描述,说明这是什么样的模拟场景
- **background**: 场景背景设定,描述客户的情况和需求
- **ai_role**: AI 角色描述,说明 AI 扮演什么角色(通常是客户)
- **objectives**: 练习目标数组,列出学员需要达成的目标
- **keywords**: 关键词数组,从知识点中提取的核心关键词
- **type**: 场景类型,可选值:
- phone: 电话销售
- face: 面对面销售
- complaint: 客户投诉
- after-sales: 售后服务
- product-intro: 产品介绍
- **difficulty**: 难度等级,可选值:
- beginner: 入门
- junior: 初级
- intermediate: 中级
- senior: 高级
- expert: 专家
## 输出要求
1. 直接输出纯净的 JSON 对象,不要包含 Markdown 标记(如 ```json
2. 不要包含任何解释性文字
3. 根据知识点内容合理设计场景,确保场景与知识点紧密相关
4. objectives 至少包含 2-3 个具体可操作的目标
5. keywords 提取 3-5 个核心关键词
6. 根据知识点的复杂程度选择合适的 difficulty
7. 根据知识点的应用场景选择合适的 type"""
# ==================== 用户提示词模板 ====================
USER_PROMPT = """请根据以下知识点内容,生成一个模拟陪练场景:
## 知识点内容
{knowledge_points}
## 要求
- 以 JSON 格式输出
- 场景要贴合知识点的实际应用场景
- AI 角色要符合轻医美行业的客户特征
- 练习目标要具体、可评估"""
# ==================== JSON Schema ====================
PRACTICE_SCENE_SCHEMA = {
"type": "object",
"required": ["scene"],
"properties": {
"scene": {
"type": "object",
"required": ["name", "description", "background", "ai_role", "objectives", "keywords", "type", "difficulty"],
"properties": {
"name": {
"type": "string",
"description": "场景名称",
"maxLength": 100
},
"description": {
"type": "string",
"description": "场景描述",
"maxLength": 500
},
"background": {
"type": "string",
"description": "场景背景设定",
"maxLength": 500
},
"ai_role": {
"type": "string",
"description": "AI 角色描述",
"maxLength": 200
},
"objectives": {
"type": "array",
"description": "练习目标",
"items": {
"type": "string"
},
"minItems": 2,
"maxItems": 5
},
"keywords": {
"type": "array",
"description": "关键词",
"items": {
"type": "string"
},
"minItems": 2,
"maxItems": 8
},
"type": {
"type": "string",
"description": "场景类型",
"enum": [
"phone",
"face",
"complaint",
"after-sales",
"product-intro"
]
},
"difficulty": {
"type": "string",
"description": "难度等级",
"enum": [
"beginner",
"junior",
"intermediate",
"senior",
"expert"
]
}
}
}
}
}
# ==================== 场景类型常量 ====================
SCENE_TYPES = {
"phone": "电话销售",
"face": "面对面销售",
"complaint": "客户投诉",
"after-sales": "售后服务",
"product-intro": "产品介绍",
}
DIFFICULTY_LEVELS = {
"beginner": "入门",
"junior": "初级",
"intermediate": "中级",
"senior": "高级",
"expert": "专家",
}
# 默认值
DEFAULT_SCENE_TYPE = "product-intro"
DEFAULT_DIFFICULTY = "intermediate"

View File

@@ -0,0 +1,141 @@
"""
认证服务
"""
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import UnauthorizedError
from app.core.logger import logger
from app.core.security import create_access_token, create_refresh_token, decode_token
from app.models.user import User
from app.schemas.auth import Token
from app.services.user_service import UserService
class AuthService:
"""认证服务"""
def __init__(self, db: AsyncSession):
self.db = db
self.user_service = UserService(db)
async def login(self, username: str, password: str) -> tuple[User, Token]:
"""
用户登录
Args:
username: 用户名/邮箱/手机号
password: 密码
Returns:
用户对象和令牌
"""
# 验证用户
user = await self.user_service.authenticate(
username=username, password=password
)
if not user:
logger.warning(
"登录失败:用户名或密码错误",
username=username,
)
raise UnauthorizedError("用户名或密码错误")
if not user.is_active:
logger.warning(
"登录失败:用户已被禁用",
user_id=user.id,
username=user.username,
)
raise UnauthorizedError("用户已被禁用")
# 生成令牌
access_token = create_access_token(subject=user.id)
refresh_token = create_refresh_token(subject=user.id)
# 更新最后登录时间
await self.user_service.update_last_login(user.id)
# 记录日志
logger.info(
"用户登录成功",
user_id=user.id,
username=user.username,
role=user.role,
)
return user, Token(
access_token=access_token,
refresh_token=refresh_token,
)
async def refresh_token(self, refresh_token: str) -> Token:
"""
刷新访问令牌
Args:
refresh_token: 刷新令牌
Returns:
新的令牌
"""
try:
# 解码刷新令牌
payload = decode_token(refresh_token)
# 验证令牌类型
if payload.get("type") != "refresh":
raise UnauthorizedError("无效的刷新令牌")
# 获取用户ID
user_id = int(payload.get("sub"))
# 获取用户
user = await self.user_service.get_by_id(user_id)
if not user:
raise UnauthorizedError("用户不存在")
if not user.is_active:
raise UnauthorizedError("用户已被禁用")
# 生成新的访问令牌
access_token = create_access_token(subject=user.id)
logger.info(
"令牌刷新成功",
user_id=user.id,
username=user.username,
)
return Token(
access_token=access_token,
refresh_token=refresh_token, # 保持原刷新令牌
)
except Exception as e:
logger.error(
"令牌刷新失败",
error=str(e),
)
raise UnauthorizedError("无效的刷新令牌")
async def logout(self, user_id: int) -> None:
"""
用户登出
注意JWT是无状态的实际的登出需要在客户端删除令牌
这里只是记录日志,如果需要可以将令牌加入黑名单
Args:
user_id: 用户ID
"""
user = await self.user_service.get_by_id(user_id)
if user:
logger.info(
"用户登出",
user_id=user.id,
username=user.username,
)

View File

@@ -0,0 +1,112 @@
"""基础服务类"""
from typing import TypeVar, Generic, Type, Optional, List, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from sqlalchemy.sql import Select
from app.models.base import BaseModel
ModelType = TypeVar("ModelType", bound=BaseModel)
class BaseService(Generic[ModelType]):
"""
基础服务类提供通用的CRUD操作
"""
def __init__(self, model: Type[ModelType]):
self.model = model
async def get(self, db: AsyncSession, id: int) -> Optional[ModelType]:
"""根据ID获取单个对象"""
result = await db.execute(select(self.model).where(self.model.id == id))
return result.scalar_one_or_none()
async def get_by_id(self, db: AsyncSession, id: int) -> Optional[ModelType]:
"""别名按ID获取对象兼容旧代码"""
return await self.get(db, id)
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
query: Optional[Select] = None,
) -> List[ModelType]:
"""获取多个对象"""
if query is None:
query = select(self.model)
result = await db.execute(query.offset(skip).limit(limit))
return result.scalars().all()
async def count(self, db: AsyncSession, *, query: Optional[Select] = None) -> int:
"""统计数量"""
if query is None:
query = select(func.count()).select_from(self.model)
else:
query = select(func.count()).select_from(query.subquery())
result = await db.execute(query)
return result.scalar_one()
async def create(self, db: AsyncSession, *, obj_in: Any, **kwargs) -> ModelType:
"""创建对象"""
if hasattr(obj_in, "model_dump"):
create_data = obj_in.model_dump()
else:
create_data = obj_in
# 合并额外参数
create_data.update(kwargs)
db_obj = self.model(**create_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def update(
self, db: AsyncSession, *, db_obj: ModelType, obj_in: Any, **kwargs
) -> ModelType:
"""更新对象"""
if hasattr(obj_in, "model_dump"):
update_data = obj_in.model_dump(exclude_unset=True)
else:
update_data = obj_in
# 合并额外参数(如 updated_by 等审计字段)
if kwargs:
update_data.update(kwargs)
for field, value in update_data.items():
setattr(db_obj, field, value)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def delete(self, db: AsyncSession, *, id: int) -> bool:
"""删除对象"""
obj = await self.get(db, id)
if obj:
await db.delete(obj)
await db.commit()
return True
return False
async def soft_delete(self, db: AsyncSession, *, id: int) -> bool:
"""软删除对象"""
from datetime import datetime
obj = await self.get(db, id)
if obj and hasattr(obj, "is_deleted"):
obj.is_deleted = True
if hasattr(obj, "deleted_at"):
obj.deleted_at = datetime.now()
db.add(obj)
await db.commit()
return True
return False

View File

@@ -0,0 +1,137 @@
"""
课程考试设置服务
"""
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logger import get_logger
from app.models.course_exam_settings import CourseExamSettings
from app.schemas.course import CourseExamSettingsCreate, CourseExamSettingsUpdate
from app.services.base_service import BaseService
logger = get_logger(__name__)
class CourseExamService(BaseService[CourseExamSettings]):
"""课程考试设置服务"""
def __init__(self):
super().__init__(CourseExamSettings)
async def get_by_course_id(self, db: AsyncSession, course_id: int) -> Optional[CourseExamSettings]:
"""
根据课程ID获取考试设置
Args:
db: 数据库会话
course_id: 课程ID
Returns:
考试设置实例或None
"""
stmt = select(CourseExamSettings).where(
CourseExamSettings.course_id == course_id,
CourseExamSettings.is_deleted == False
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def create_or_update(
self,
db: AsyncSession,
course_id: int,
settings_in: CourseExamSettingsCreate,
user_id: int
) -> CourseExamSettings:
"""
创建或更新课程考试设置
Args:
db: 数据库会话
course_id: 课程ID
settings_in: 考试设置数据
user_id: 操作用户ID
Returns:
考试设置实例
"""
# 检查是否已存在设置
existing_settings = await self.get_by_course_id(db, course_id)
if existing_settings:
# 更新现有设置
update_data = settings_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(existing_settings, field, value)
existing_settings.updated_by = user_id
await db.commit()
await db.refresh(existing_settings)
logger.info(f"更新课程考试设置成功", course_id=course_id, user_id=user_id)
return existing_settings
else:
# 创建新设置
create_data = settings_in.model_dump()
create_data.update({
"course_id": course_id,
"created_by": user_id,
"updated_by": user_id
})
new_settings = CourseExamSettings(**create_data)
db.add(new_settings)
await db.commit()
await db.refresh(new_settings)
logger.info(f"创建课程考试设置成功", course_id=course_id, user_id=user_id)
return new_settings
async def update(
self,
db: AsyncSession,
course_id: int,
settings_in: CourseExamSettingsUpdate,
user_id: int
) -> CourseExamSettings:
"""
更新课程考试设置
Args:
db: 数据库会话
course_id: 课程ID
settings_in: 更新的考试设置数据
user_id: 操作用户ID
Returns:
更新后的考试设置实例
"""
# 获取现有设置
settings = await self.get_by_course_id(db, course_id)
if not settings:
# 如果不存在,创建新的
create_data = settings_in.model_dump(exclude_unset=True)
return await self.create_or_update(
db,
course_id=course_id,
settings_in=CourseExamSettingsCreate(**create_data),
user_id=user_id
)
# 更新设置
update_data = settings_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(settings, field, value)
settings.updated_by = user_id
await db.commit()
await db.refresh(settings)
logger.info(f"更新课程考试设置成功", course_id=course_id, user_id=user_id)
return settings
# 创建服务实例
course_exam_service = CourseExamService()

View File

@@ -0,0 +1,194 @@
"""
课程岗位分配服务
"""
from typing import List, Optional
from sqlalchemy import select, and_, delete, func
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logger import get_logger
from app.models.position_course import PositionCourse
from app.models.position import Position
from app.models.position_member import PositionMember
from app.schemas.course import CoursePositionAssignment, CoursePositionAssignmentInDB
from app.services.base_service import BaseService
logger = get_logger(__name__)
class CoursePositionService(BaseService[PositionCourse]):
"""课程岗位分配服务"""
def __init__(self):
super().__init__(PositionCourse)
async def get_course_positions(
self,
db: AsyncSession,
course_id: int,
course_type: Optional[str] = None
) -> List[CoursePositionAssignmentInDB]:
"""
获取课程的岗位分配列表
Args:
db: 数据库会话
course_id: 课程ID
course_type: 课程类型筛选
Returns:
岗位分配列表
"""
# 构建查询
conditions = [
PositionCourse.course_id == course_id,
PositionCourse.is_deleted == False
]
if course_type:
conditions.append(PositionCourse.course_type == course_type)
stmt = (
select(PositionCourse)
.options(selectinload(PositionCourse.position))
.where(and_(*conditions))
.order_by(PositionCourse.priority, PositionCourse.id)
)
result = await db.execute(stmt)
assignments = result.scalars().all()
# 转换为返回格式,并查询每个岗位的成员数量
result_list = []
for assignment in assignments:
# 查询岗位成员数量
member_count = 0
if assignment.position_id:
member_count_result = await db.execute(
select(func.count(PositionMember.id)).where(
and_(
PositionMember.position_id == assignment.position_id,
PositionMember.is_deleted == False
)
)
)
member_count = member_count_result.scalar() or 0
result_list.append(
CoursePositionAssignmentInDB(
id=assignment.id,
course_id=assignment.course_id,
position_id=assignment.position_id,
course_type=assignment.course_type,
priority=assignment.priority,
position_name=assignment.position.name if assignment.position else None,
position_description=assignment.position.description if assignment.position else None,
member_count=member_count
)
)
return result_list
async def batch_assign_positions(
self,
db: AsyncSession,
course_id: int,
assignments: List[CoursePositionAssignment],
user_id: int
) -> List[CoursePositionAssignmentInDB]:
"""
批量分配课程到岗位
Args:
db: 数据库会话
course_id: 课程ID
assignments: 岗位分配列表
user_id: 操作用户ID
Returns:
分配结果列表
"""
created_assignments = []
for assignment in assignments:
# 检查是否已存在注意Result 只能消费一次,需保存结果)
result = await db.execute(
select(PositionCourse).where(
PositionCourse.course_id == course_id,
PositionCourse.position_id == assignment.position_id,
PositionCourse.is_deleted == False,
)
)
existing_assignment = result.scalar_one_or_none()
if existing_assignment:
# 已存在则更新类型与优先级
existing_assignment.course_type = assignment.course_type
existing_assignment.priority = assignment.priority
# PositionCourse 未继承 AuditMixin不强制写入审计字段
created_assignments.append(existing_assignment)
else:
# 新建分配关系
new_assignment = PositionCourse(
course_id=course_id,
position_id=assignment.position_id,
course_type=assignment.course_type,
priority=assignment.priority,
)
db.add(new_assignment)
created_assignments.append(new_assignment)
await db.commit()
# 重新加载关联数据
for obj in created_assignments:
await db.refresh(obj)
logger.info("批量分配课程到岗位成功", course_id=course_id, count=len(assignments), user_id=user_id)
# 返回分配结果
return await self.get_course_positions(db, course_id)
async def remove_position_assignment(
self,
db: AsyncSession,
course_id: int,
position_id: int,
user_id: int
) -> bool:
"""
移除课程的岗位分配
Args:
db: 数据库会话
course_id: 课程ID
position_id: 岗位ID
user_id: 操作用户ID
Returns:
是否成功
"""
# 查找分配记录
stmt = select(PositionCourse).where(
PositionCourse.course_id == course_id,
PositionCourse.position_id == position_id,
PositionCourse.is_deleted == False
)
result = await db.execute(stmt)
assignment = result.scalar_one_or_none()
if assignment:
# 软删除
assignment.is_deleted = True
assignment.deleted_by = user_id
await db.commit()
logger.info(f"移除课程岗位分配成功", course_id=course_id, position_id=position_id, user_id=user_id)
return True
return False
# 创建服务实例
course_position_service = CoursePositionService()

View File

@@ -0,0 +1,837 @@
"""
课程服务层
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from sqlalchemy import select, or_, and_, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.logger import get_logger
from app.core.exceptions import NotFoundError, BadRequestError, ConflictError
from app.models.course import (
Course,
CourseStatus,
CourseMaterial,
KnowledgePoint,
GrowthPath,
)
from app.models.course_exam_settings import CourseExamSettings
from app.models.position_member import PositionMember
from app.models.position_course import PositionCourse
from app.schemas.course import (
CourseCreate,
CourseUpdate,
CourseList,
CourseInDB,
CourseMaterialCreate,
KnowledgePointCreate,
KnowledgePointUpdate,
KnowledgePointInDB,
GrowthPathCreate,
)
from app.schemas.base import PaginationParams, PaginatedResponse
from app.services.base_service import BaseService
logger = get_logger(__name__)
class CourseService(BaseService[Course]):
"""
课程服务类
"""
def __init__(self):
super().__init__(Course)
async def get_course_list(
self,
db: AsyncSession,
*,
page_params: PaginationParams,
filters: CourseList,
user_id: Optional[int] = None,
) -> PaginatedResponse[CourseInDB]:
"""
获取课程列表(支持筛选)
Args:
db: 数据库会话
page_params: 分页参数
filters: 筛选条件
user_id: 用户ID用于记录访问日志
Returns:
分页的课程列表
"""
# 构建筛选条件
filter_conditions = []
# 状态筛选(默认只显示已发布的课程)
if filters.status is not None:
filter_conditions.append(Course.status == filters.status)
else:
# 如果没有指定状态,默认只返回已发布的课程
filter_conditions.append(Course.status == CourseStatus.PUBLISHED)
# 分类筛选
if filters.category is not None:
filter_conditions.append(Course.category == filters.category)
# 是否推荐筛选
if filters.is_featured is not None:
filter_conditions.append(Course.is_featured == filters.is_featured)
# 关键词搜索
if filters.keyword:
keyword = f"%{filters.keyword}%"
filter_conditions.append(
or_(Course.name.like(keyword), Course.description.like(keyword))
)
# 记录查询日志
logger.info(
"查询课程列表",
user_id=user_id,
filters=filters.model_dump(exclude_none=True),
page=page_params.page,
size=page_params.page_size,
)
# 执行分页查询
query = select(Course).where(Course.is_deleted == False)
# 添加筛选条件
if filter_conditions:
query = query.where(and_(*filter_conditions))
# 添加排序优先按sort_order升序其次按创建时间降序新课程优先
query = query.order_by(Course.sort_order.asc(), Course.created_at.desc())
# 获取总数
count_query = (
select(func.count()).select_from(Course).where(Course.is_deleted == False)
)
if filter_conditions:
count_query = count_query.where(and_(*filter_conditions))
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# 分页
query = query.offset(page_params.offset).limit(page_params.limit)
query = query.options(selectinload(Course.materials))
# 执行查询
result = await db.execute(query)
courses = result.scalars().all()
# 获取用户所属的岗位ID列表
user_position_ids = []
if user_id:
position_result = await db.execute(
select(PositionMember.position_id).where(
PositionMember.user_id == user_id,
PositionMember.is_deleted == False
)
)
user_position_ids = [row[0] for row in position_result.fetchall()]
# 批量查询课程的岗位分配信息
course_ids = [c.id for c in courses]
course_type_map = {}
if course_ids and user_position_ids:
position_course_result = await db.execute(
select(PositionCourse.course_id, PositionCourse.course_type).where(
PositionCourse.course_id.in_(course_ids),
PositionCourse.position_id.in_(user_position_ids),
PositionCourse.is_deleted == False
)
)
# 构建课程类型映射如果有多个岗位优先取required
for course_id, course_type in position_course_result.fetchall():
if course_id not in course_type_map:
course_type_map[course_id] = course_type
elif course_type == 'required':
course_type_map[course_id] = 'required'
# 转换为 Pydantic 模型,并附加课程类型
course_list = []
for course in courses:
course_data = CourseInDB.model_validate(course)
# 设置课程类型如果用户有岗位分配则使用分配类型否则为None
course_data.course_type = course_type_map.get(course.id)
course_list.append(course_data)
# 计算总页数
pages = (total + page_params.page_size - 1) // page_params.page_size
return PaginatedResponse(
items=course_list,
total=total,
page=page_params.page,
page_size=page_params.page_size,
pages=pages,
)
async def create_course(
self, db: AsyncSession, *, course_in: CourseCreate, created_by: int
) -> Course:
"""
创建课程
Args:
db: 数据库会话
course_in: 课程创建数据
created_by: 创建人ID
Returns:
创建的课程
"""
# 检查名称是否重复
existing = await db.execute(
select(Course).where(
and_(Course.name == course_in.name, Course.is_deleted == False)
)
)
if existing.scalar_one_or_none():
raise ConflictError(f"课程名称 '{course_in.name}' 已存在")
# 创建课程
course_data = course_in.model_dump()
course = await self.create(db, obj_in=course_data, created_by=created_by)
# 自动创建默认考试设置
default_exam_settings = CourseExamSettings(
course_id=course.id,
created_by=created_by,
updated_by=created_by
# 其他字段使用模型定义的默认值:
# single_choice_count=4, multiple_choice_count=2, true_false_count=1,
# fill_blank_count=2, essay_count=1, duration_minutes=10, 等
)
db.add(default_exam_settings)
await db.commit()
await db.refresh(course)
logger.info(
"创建课程", course_id=course.id, course_name=course.name, created_by=created_by
)
logger.info(
"自动创建默认考试设置", course_id=course.id, exam_settings_id=default_exam_settings.id
)
return course
async def update_course(
self,
db: AsyncSession,
*,
course_id: int,
course_in: CourseUpdate,
updated_by: int,
) -> Course:
"""
更新课程
Args:
db: 数据库会话
course_id: 课程ID
course_in: 课程更新数据
updated_by: 更新人ID
Returns:
更新后的课程
"""
# 获取课程
course = await self.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 检查名称是否重复(如果修改了名称)
if course_in.name and course_in.name != course.name:
existing = await db.execute(
select(Course).where(
and_(
Course.name == course_in.name,
Course.id != course_id,
Course.is_deleted == False,
)
)
)
if existing.scalar_one_or_none():
raise ConflictError(f"课程名称 '{course_in.name}' 已存在")
# 记录状态变更
old_status = course.status
# 更新课程
update_data = course_in.model_dump(exclude_unset=True)
# 如果状态变为已发布,记录发布时间
if (
update_data.get("status") == CourseStatus.PUBLISHED
and old_status != CourseStatus.PUBLISHED
):
update_data["published_at"] = datetime.now()
update_data["publisher_id"] = updated_by
course = await self.update(
db, db_obj=course, obj_in=update_data, updated_by=updated_by
)
logger.info(
"更新课程",
course_id=course.id,
course_name=course.name,
old_status=old_status,
new_status=course.status,
updated_by=updated_by,
)
return course
async def delete_course(
self, db: AsyncSession, *, course_id: int, deleted_by: int
) -> bool:
"""
删除课程(软删除 + 删除相关文件)
Args:
db: 数据库会话
course_id: 课程ID
deleted_by: 删除人ID
Returns:
是否删除成功
"""
import shutil
from pathlib import Path
from app.core.config import settings
course = await self.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 放开删除限制:任意状态均可软删除,由业务方自行控制
# 执行软删除(标记 is_deleted记录删除时间由审计日志记录操作者
success = await self.soft_delete(db, id=course_id)
if success:
# 删除课程文件夹及其所有内容
course_folder = Path(settings.UPLOAD_PATH) / "courses" / str(course_id)
if course_folder.exists() and course_folder.is_dir():
try:
shutil.rmtree(course_folder)
logger.info(
"删除课程文件夹成功",
course_id=course_id,
folder_path=str(course_folder),
)
except Exception as e:
# 文件夹删除失败不影响业务流程,仅记录日志
logger.error(
"删除课程文件夹失败",
course_id=course_id,
folder_path=str(course_folder),
error=str(e),
)
logger.warning(
"删除课程",
course_id=course_id,
course_name=course.name,
deleted_by=deleted_by,
folder_deleted=course_folder.exists(),
)
return success
async def add_course_material(
self,
db: AsyncSession,
*,
course_id: int,
material_in: CourseMaterialCreate,
created_by: int,
) -> CourseMaterial:
"""
添加课程资料
Args:
db: 数据库会话
course_id: 课程ID
material_in: 资料创建数据
created_by: 创建人ID
Returns:
创建的课程资料
"""
# 检查课程是否存在
course = await self.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 创建资料
material_data = material_in.model_dump()
material_data.update({
"course_id": course_id,
"created_by": created_by,
"updated_by": created_by
})
material = CourseMaterial(**material_data)
db.add(material)
await db.commit()
await db.refresh(material)
logger.info(
"添加课程资料",
course_id=course_id,
material_id=material.id,
material_name=material.name,
file_type=material.file_type,
file_size=material.file_size,
created_by=created_by,
)
return material
async def get_course_materials(
self,
db: AsyncSession,
*,
course_id: int,
) -> List[CourseMaterial]:
"""
获取课程资料列表
Args:
db: 数据库会话
course_id: 课程ID
Returns:
课程资料列表
"""
# 确认课程存在
course = await self.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
stmt = (
select(CourseMaterial)
.where(
CourseMaterial.course_id == course_id,
CourseMaterial.is_deleted == False,
)
.order_by(CourseMaterial.sort_order.asc(), CourseMaterial.id.asc())
)
result = await db.execute(stmt)
materials = result.scalars().all()
logger.info(
"查询课程资料列表", course_id=course_id, count=len(materials)
)
return materials
async def delete_course_material(
self,
db: AsyncSession,
*,
course_id: int,
material_id: int,
deleted_by: int,
) -> bool:
"""
删除课程资料(软删除 + 删除物理文件)
Args:
db: 数据库会话
course_id: 课程ID
material_id: 资料ID
deleted_by: 删除人ID
Returns:
是否删除成功
"""
import os
from pathlib import Path
from app.core.config import settings
# 先确认课程存在
course = await self.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 查找资料并校验归属
material_stmt = select(CourseMaterial).where(
CourseMaterial.id == material_id,
CourseMaterial.course_id == course_id,
CourseMaterial.is_deleted == False,
)
result = await db.execute(material_stmt)
material = result.scalar_one_or_none()
if not material:
raise NotFoundError(f"课程资料ID {material_id} 不存在或已删除")
# 获取文件路径信息用于删除物理文件
file_url = material.file_url
# 软删除数据库记录
material.is_deleted = True
material.deleted_at = datetime.now()
if hasattr(material, "deleted_by"):
# 兼容存在该字段的表
setattr(material, "deleted_by", deleted_by)
db.add(material)
await db.commit()
# 删除物理文件
if file_url and file_url.startswith("/static/uploads/"):
try:
# 从URL中提取相对路径
relative_path = file_url.replace("/static/uploads/", "")
file_path = Path(settings.UPLOAD_PATH) / relative_path
# 检查文件是否存在并删除
if file_path.exists() and file_path.is_file():
os.remove(file_path)
logger.info(
"删除物理文件成功",
file_path=str(file_path),
material_id=material_id,
)
except Exception as e:
# 物理文件删除失败不影响业务流程,仅记录日志
logger.error(
"删除物理文件失败",
file_url=file_url,
material_id=material_id,
error=str(e),
)
logger.warning(
"删除课程资料",
course_id=course_id,
material_id=material_id,
deleted_by=deleted_by,
file_deleted=file_url is not None,
)
return True
async def get_material_knowledge_points(
self, db: AsyncSession, material_id: int
) -> List[KnowledgePointInDB]:
"""获取资料关联的知识点列表"""
# 获取资料信息
result = await db.execute(
select(CourseMaterial).where(
CourseMaterial.id == material_id,
CourseMaterial.is_deleted == False
)
)
material = result.scalar_one_or_none()
if not material:
raise NotFoundError(f"资料ID {material_id} 不存在")
# 直接查询关联到该资料的知识点
query = select(KnowledgePoint).where(
KnowledgePoint.material_id == material_id,
KnowledgePoint.is_deleted == False
).order_by(KnowledgePoint.created_at.desc())
result = await db.execute(query)
knowledge_points = result.scalars().all()
from app.schemas.course import KnowledgePointInDB
return [KnowledgePointInDB.model_validate(kp) for kp in knowledge_points]
async def add_material_knowledge_points(
self, db: AsyncSession, material_id: int, knowledge_point_ids: List[int]
) -> List[KnowledgePointInDB]:
"""
为资料添加知识点关联
注意自2025-09-27起知识点直接通过material_id关联到资料
material_knowledge_points中间表已废弃。此方法将更新知识点的material_id字段。
"""
# 验证资料是否存在
result = await db.execute(
select(CourseMaterial).where(
CourseMaterial.id == material_id,
CourseMaterial.is_deleted == False
)
)
material = result.scalar_one_or_none()
if not material:
raise NotFoundError(f"资料ID {material_id} 不存在")
# 验证知识点是否存在且属于同一课程
result = await db.execute(
select(KnowledgePoint).where(
KnowledgePoint.id.in_(knowledge_point_ids),
KnowledgePoint.course_id == material.course_id,
KnowledgePoint.is_deleted == False
)
)
valid_knowledge_points = result.scalars().all()
if len(valid_knowledge_points) != len(knowledge_point_ids):
raise BadRequestError("部分知识点不存在或不属于同一课程")
# 更新知识点的material_id字段
added_knowledge_points = []
for kp in valid_knowledge_points:
# 更新知识点的资料关联
kp.material_id = material_id
added_knowledge_points.append(kp)
await db.commit()
# 刷新对象以获取更新后的数据
for kp in added_knowledge_points:
await db.refresh(kp)
from app.schemas.course import KnowledgePointInDB
return [KnowledgePointInDB.model_validate(kp) for kp in added_knowledge_points]
async def remove_material_knowledge_point(
self, db: AsyncSession, material_id: int, knowledge_point_id: int
) -> bool:
"""
移除资料的知识点关联(软删除知识点)
注意自2025-09-27起知识点直接通过material_id关联到资料
material_knowledge_points中间表已废弃。此方法将软删除知识点。
"""
# 查找知识点并验证归属
result = await db.execute(
select(KnowledgePoint).where(
KnowledgePoint.id == knowledge_point_id,
KnowledgePoint.material_id == material_id,
KnowledgePoint.is_deleted == False
)
)
knowledge_point = result.scalar_one_or_none()
if not knowledge_point:
raise NotFoundError(f"知识点ID {knowledge_point_id} 不存在或不属于该资料")
# 软删除知识点
knowledge_point.is_deleted = True
knowledge_point.deleted_at = datetime.now()
await db.commit()
logger.info(
"移除资料知识点关联",
material_id=material_id,
knowledge_point_id=knowledge_point_id,
)
return True
class KnowledgePointService(BaseService[KnowledgePoint]):
"""
知识点服务类
"""
def __init__(self):
super().__init__(KnowledgePoint)
async def get_knowledge_points_by_course(
self, db: AsyncSession, *, course_id: int, material_id: Optional[int] = None
) -> List[KnowledgePoint]:
"""
获取课程的知识点列表
Args:
db: 数据库会话
course_id: 课程ID
material_id: 资料ID可选用于筛选特定资料的知识点
Returns:
知识点列表
"""
query = select(KnowledgePoint).where(
and_(
KnowledgePoint.course_id == course_id,
KnowledgePoint.is_deleted == False,
)
)
if material_id is not None:
query = query.where(KnowledgePoint.material_id == material_id)
query = query.order_by(KnowledgePoint.created_at.desc())
result = await db.execute(query)
return result.scalars().all()
async def create_knowledge_point(
self,
db: AsyncSession,
*,
course_id: int,
point_in: KnowledgePointCreate,
created_by: int,
) -> KnowledgePoint:
"""
创建知识点
Args:
db: 数据库会话
course_id: 课程ID
point_in: 知识点创建数据
created_by: 创建人ID
Returns:
创建的知识点
"""
# 检查课程是否存在
course_service = CourseService()
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 创建知识点
point_data = point_in.model_dump()
point_data.update({"course_id": course_id})
knowledge_point = await self.create(
db, obj_in=point_data, created_by=created_by
)
logger.info(
"创建知识点",
course_id=course_id,
knowledge_point_id=knowledge_point.id,
knowledge_point_name=knowledge_point.name,
created_by=created_by,
)
return knowledge_point
async def update_knowledge_point(
self,
db: AsyncSession,
*,
point_id: int,
point_in: KnowledgePointUpdate,
updated_by: int,
) -> KnowledgePoint:
"""
更新知识点
Args:
db: 数据库会话
point_id: 知识点ID
point_in: 知识点更新数据
updated_by: 更新人ID
Returns:
更新后的知识点
"""
knowledge_point = await self.get_by_id(db, point_id)
if not knowledge_point:
raise NotFoundError(f"知识点ID {point_id} 不存在")
# 验证关联资料是否存在
if hasattr(point_in, 'material_id') and point_in.material_id:
result = await db.execute(
select(CourseMaterial).where(
CourseMaterial.id == point_in.material_id,
CourseMaterial.is_deleted == False
)
)
material = result.scalar_one_or_none()
if not material:
raise NotFoundError(f"资料ID {point_in.material_id} 不存在")
# 更新知识点
update_data = point_in.model_dump(exclude_unset=True)
knowledge_point = await self.update(
db, db_obj=knowledge_point, obj_in=update_data, updated_by=updated_by
)
logger.info(
"更新知识点",
knowledge_point_id=knowledge_point.id,
knowledge_point_name=knowledge_point.name,
updated_by=updated_by,
)
return knowledge_point
class GrowthPathService(BaseService[GrowthPath]):
"""
成长路径服务类
"""
def __init__(self):
super().__init__(GrowthPath)
async def create_growth_path(
self, db: AsyncSession, *, path_in: GrowthPathCreate, created_by: int
) -> GrowthPath:
"""
创建成长路径
Args:
db: 数据库会话
path_in: 成长路径创建数据
created_by: 创建人ID
Returns:
创建的成长路径
"""
# 检查名称是否重复
existing = await db.execute(
select(GrowthPath).where(
and_(GrowthPath.name == path_in.name, GrowthPath.is_deleted == False)
)
)
if existing.scalar_one_or_none():
raise ConflictError(f"成长路径名称 '{path_in.name}' 已存在")
# 验证课程是否存在
if path_in.courses:
course_ids = [c.course_id for c in path_in.courses]
course_service = CourseService()
for course_id in course_ids:
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 创建成长路径
path_data = path_in.model_dump()
# 转换课程列表为JSON格式
if path_data.get("courses"):
path_data["courses"] = [c.model_dump() for c in path_in.courses]
growth_path = await self.create(db, obj_in=path_data, created_by=created_by)
logger.info(
"创建成长路径",
growth_path_id=growth_path.id,
growth_path_name=growth_path.name,
course_count=len(path_in.courses) if path_in.courses else 0,
created_by=created_by,
)
return growth_path
# 创建服务实例
course_service = CourseService()
knowledge_point_service = KnowledgePointService()
growth_path_service = GrowthPathService()

View File

@@ -0,0 +1,65 @@
"""
课程统计服务
"""
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.exam import Exam
from app.models.course import Course
from app.core.logger import get_logger
logger = get_logger(__name__)
class CourseStatisticsService:
"""课程统计服务类"""
async def update_course_student_count(
self,
db: AsyncSession,
course_id: int
) -> int:
"""
更新课程学员数统计
Args:
db: 数据库会话
course_id: 课程ID
Returns:
更新后的学员数
"""
try:
# 统计该课程的不同学员数(基于考试记录)
stmt = select(func.count(func.distinct(Exam.user_id))).where(
Exam.course_id == course_id,
Exam.is_deleted == False
)
result = await db.execute(stmt)
student_count = result.scalar_one() or 0
# 更新课程表
course_stmt = select(Course).where(
Course.id == course_id,
Course.is_deleted == False
)
course_result = await db.execute(course_stmt)
course = course_result.scalar_one_or_none()
if course:
course.student_count = student_count
await db.commit()
logger.info(f"更新课程 {course_id} 学员数: {student_count}")
return student_count
else:
logger.warning(f"课程 {course_id} 不存在")
return 0
except Exception as e:
logger.error(f"更新课程学员数失败: {str(e)}", exc_info=True)
await db.rollback()
raise
# 创建全局实例
course_statistics_service = CourseStatisticsService()

View File

@@ -0,0 +1,97 @@
"""
Coze 播课服务
"""
import logging
from typing import Optional
from cozepy.exception import CozeError
from app.core.config import settings
from app.services.ai.coze.client import get_coze_client
logger = logging.getLogger(__name__)
class CozeBroadcastService:
"""Coze 播课服务"""
def __init__(self):
"""初始化配置"""
self.workflow_id = settings.COZE_BROADCAST_WORKFLOW_ID
self.space_id = settings.COZE_BROADCAST_SPACE_ID
def _get_client(self):
"""获取新的 Coze 客户端每次调用都创建新认证避免token过期"""
return get_coze_client(force_new=True)
async def trigger_workflow(self, course_id: int) -> None:
"""
触发播课生成工作流(不等待结果)
Coze工作流会
1. 生成播课音频
2. 直接将结果写入数据库
Args:
course_id: 课程ID
Raises:
CozeError: Coze API 调用失败
"""
logger.info(
f"触发播课生成工作流",
extra={
"course_id": course_id,
"workflow_id": self.workflow_id,
"bot_id": settings.COZE_BROADCAST_BOT_ID or settings.COZE_PRACTICE_BOT_ID # 关联到同一工作空间的Bot
}
)
try:
# 每次调用都获取新客户端确保OAuth token有效
coze = self._get_client()
# 调用工作流(触发即返回,不等待结果)
# 关键添加bot_id参数关联到OAuth应用下的Bot
import asyncio
result = await asyncio.to_thread(
coze.workflows.runs.create,
workflow_id=self.workflow_id,
parameters={"course_id": str(course_id)},
bot_id=settings.COZE_BROADCAST_BOT_ID or settings.COZE_PRACTICE_BOT_ID # 关联Bot确保OAuth权限
)
logger.info(
f"播课生成工作流已触发",
extra={
"course_id": course_id,
"execute_id": getattr(result, 'execute_id', None),
"debug_url": getattr(result, 'debug_url', None)
}
)
except CozeError as e:
logger.error(
f"触发 Coze 工作流失败",
extra={
"course_id": course_id,
"error": str(e),
"error_code": getattr(e, 'code', None)
}
)
raise
except Exception as e:
logger.error(
f"触发播课生成工作流异常",
extra={
"course_id": course_id,
"error": str(e),
"error_type": type(e).__name__
}
)
raise
# 全局单例
broadcast_service = CozeBroadcastService()

View File

@@ -0,0 +1,199 @@
"""
Coze AI对话服务
"""
import logging
from typing import Optional
from cozepy import Coze, COZE_CN_BASE_URL, Message
from cozepy.exception import CozeError, CozeAPIError
from app.core.config import settings
from app.services.ai.coze.client import get_auth_manager
# 注意:不再直接使用 TokenAuth统一通过 get_auth_manager() 管理认证
logger = logging.getLogger(__name__)
class CozeService:
"""Coze对话服务"""
def __init__(self):
"""初始化Coze客户端"""
if not settings.COZE_PRACTICE_BOT_ID:
raise ValueError("COZE_PRACTICE_BOT_ID 未配置")
self.bot_id = settings.COZE_PRACTICE_BOT_ID
self._auth_manager = get_auth_manager()
logger.info(
f"CozeService初始化成功Bot ID={self.bot_id}, "
f"Base URL={COZE_CN_BASE_URL}"
)
@property
def client(self) -> Coze:
"""获取Coze客户端每次获取确保OAuth token有效"""
return self._auth_manager.get_client(force_new=True)
def build_scene_prompt(
self,
scene_name: str,
scene_background: str,
scene_ai_role: str,
scene_objectives: list,
scene_keywords: Optional[list] = None,
scene_description: Optional[str] = None,
user_message: str = ""
) -> str:
"""
构建场景提示词Markdown格式
参数:
scene_name: 场景名称
scene_background: 场景背景
scene_ai_role: AI角色描述
scene_objectives: 练习目标列表
scene_keywords: 关键词列表
scene_description: 场景描述(可选)
user_message: 用户第一句话
返回:
完整的场景提示词Markdown格式
"""
# 构建练习目标
objectives_text = "\n".join(
f"{i+1}. {obj}" for i, obj in enumerate(scene_objectives)
)
# 构建关键词
keywords_text = ", ".join(scene_keywords) if scene_keywords else ""
# 构建完整提示词
prompt = f"""# 陪练场景设定
## 场景名称
{scene_name}
"""
# 添加场景描述(如果有)
if scene_description:
prompt += f"""
## 场景描述
{scene_description}
"""
prompt += f"""
## 场景背景
{scene_background}
## AI角色要求
{scene_ai_role}
## 练习目标
{objectives_text}
"""
# 添加关键词(如果有)
if keywords_text:
prompt += f"""
## 关键词
{keywords_text}
"""
prompt += f"""
---
现在开始陪练对话。请你严格按照上述场景设定扮演角色,与学员进行实战对话练习。
不要提及"场景设定""角色扮演"等元信息,直接进入角色开始对话。
学员的第一句话:{user_message}
"""
return prompt
def create_stream_chat(
self,
user_id: str,
message: str,
conversation_id: Optional[str] = None
):
"""
创建流式对话
参数:
user_id: 用户ID
message: 消息内容
conversation_id: 对话ID续接对话时使用
返回:
Coze流式对话迭代器
"""
try:
logger.info(
f"创建Coze流式对话user_id={user_id}, "
f"conversation_id={conversation_id}, "
f"message_length={len(message)}"
)
stream = self.client.chat.stream(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=[Message.build_user_question_text(message)],
conversation_id=conversation_id
)
# 记录LogID用于排查问题
if hasattr(stream, 'response') and hasattr(stream.response, 'logid'):
logger.info(f"Coze对话创建成功logid={stream.response.logid}")
return stream
except (CozeError, CozeAPIError) as e:
logger.error(f"Coze API调用失败: {e}")
raise
except Exception as e:
logger.error(f"创建Coze对话失败: {e}")
raise
def cancel_chat(self, conversation_id: str, chat_id: str):
"""
中断对话
参数:
conversation_id: 对话ID
chat_id: 聊天ID
"""
try:
logger.info(f"中断Coze对话conversation_id={conversation_id}, chat_id={chat_id}")
result = self.client.chat.cancel(
conversation_id=conversation_id,
chat_id=chat_id
)
logger.info(f"对话中断成功")
return result
except (CozeError, CozeAPIError) as e:
logger.error(f"中断对话失败: {e}")
raise
except Exception as e:
logger.error(f"中断对话异常: {e}")
raise
# 单例模式
_coze_service: Optional[CozeService] = None
def get_coze_service() -> CozeService:
"""
获取CozeService单例
用于FastAPI依赖注入
"""
global _coze_service
if _coze_service is None:
_coze_service = CozeService()
return _coze_service

View File

@@ -0,0 +1,305 @@
"""
文档转换服务
使用 LibreOffice 将 Office 文档转换为 PDF
"""
import os
import logging
import subprocess
from pathlib import Path
from typing import Optional
from datetime import datetime
from app.core.config import settings
logger = logging.getLogger(__name__)
class DocumentConverterService:
"""文档转换服务类"""
# 支持转换的文件格式
SUPPORTED_FORMATS = {'.docx', '.doc', '.pptx', '.ppt', '.xlsx', '.xls'}
# Excel文件格式需要特殊处理页面布局
EXCEL_FORMATS = {'.xlsx', '.xls'}
def __init__(self):
"""初始化转换服务"""
self.converted_path = Path(settings.UPLOAD_PATH) / "converted"
self.converted_path.mkdir(parents=True, exist_ok=True)
def get_converted_file_path(self, course_id: int, material_id: int) -> Path:
"""
获取转换后的文件路径
Args:
course_id: 课程ID
material_id: 资料ID
Returns:
转换后的PDF文件路径
"""
course_dir = self.converted_path / str(course_id)
course_dir.mkdir(parents=True, exist_ok=True)
return course_dir / f"{material_id}.pdf"
def need_convert(self, source_file: Path, converted_file: Path) -> bool:
"""
判断是否需要重新转换
Args:
source_file: 源文件路径
converted_file: 转换后的文件路径
Returns:
是否需要转换
"""
# 如果转换文件不存在,需要转换
if not converted_file.exists():
return True
# 如果源文件不存在,不需要转换
if not source_file.exists():
return False
# 如果源文件修改时间晚于转换文件,需要重新转换
source_mtime = source_file.stat().st_mtime
converted_mtime = converted_file.stat().st_mtime
return source_mtime > converted_mtime
def convert_excel_to_html(
self,
source_file: str,
course_id: int,
material_id: int
) -> Optional[str]:
"""
将Excel文件转换为HTML避免PDF分页问题
Args:
source_file: 源文件路径
course_id: 课程ID
material_id: 资料ID
Returns:
转换后的HTML文件URL失败返回None
"""
try:
try:
import openpyxl
from openpyxl.utils import get_column_letter
except ImportError as ie:
logger.error(f"Excel转换依赖缺失: openpyxl 未安装。请运行 pip install openpyxl 或重建Docker镜像。错误: {str(ie)}")
return None
source_path = Path(source_file)
logger.info(f"开始Excel转HTML: source={source_file}, course_id={course_id}, material_id={material_id}")
# 获取HTML输出路径
course_dir = self.converted_path / str(course_id)
course_dir.mkdir(parents=True, exist_ok=True)
html_file = course_dir / f"{material_id}.html"
# 检查缓存
if html_file.exists():
source_mtime = source_path.stat().st_mtime
html_mtime = html_file.stat().st_mtime
if source_mtime <= html_mtime:
logger.info(f"使用缓存的HTML文件: {html_file}")
return f"/static/uploads/converted/{course_id}/{material_id}.html"
# 读取Excel文件
wb = openpyxl.load_workbook(source_file, data_only=True)
# 构建HTML
html_content = '''<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: Arial, sans-serif; padding: 20px; background: #f5f5f5; }
.sheet-tabs { display: flex; gap: 10px; margin-bottom: 20px; flex-wrap: wrap; }
.sheet-tab { padding: 8px 16px; background: #fff; border: 1px solid #ddd; border-radius: 4px; cursor: pointer; }
.sheet-tab.active { background: #409eff; color: white; border-color: #409eff; }
.sheet-content { display: none; }
.sheet-content.active { display: block; }
table { border-collapse: collapse; width: 100%; background: white; box-shadow: 0 1px 3px rgba(0,0,0,0.1); }
th, td { border: 1px solid #e4e7ed; padding: 8px 12px; text-align: left; white-space: nowrap; }
th { background: #f5f7fa; font-weight: 600; position: sticky; top: 0; }
tr:nth-child(even) { background: #fafafa; }
tr:hover { background: #ecf5ff; }
.table-wrapper { overflow-x: auto; max-height: 80vh; }
</style>
</head>
<body>
'''
# 生成sheet选项卡
sheet_names = wb.sheetnames
html_content += '<div class="sheet-tabs">\n'
for i, name in enumerate(sheet_names):
active = 'active' if i == 0 else ''
html_content += f'<div class="sheet-tab {active}" onclick="showSheet({i})">{name}</div>\n'
html_content += '</div>\n'
# 生成每个sheet的表格
for i, sheet_name in enumerate(sheet_names):
ws = wb[sheet_name]
active = 'active' if i == 0 else ''
html_content += f'<div class="sheet-content {active}" id="sheet-{i}">\n'
html_content += '<div class="table-wrapper"><table>\n'
# 获取有效数据范围
max_row = ws.max_row or 1
max_col = ws.max_column or 1
for row_idx in range(1, min(max_row + 1, 1001)): # 限制最多1000行
html_content += '<tr>'
for col_idx in range(1, min(max_col + 1, 51)): # 限制最多50列
cell = ws.cell(row=row_idx, column=col_idx)
value = cell.value if cell.value is not None else ''
tag = 'th' if row_idx == 1 else 'td'
# 转义HTML特殊字符
if isinstance(value, str):
value = value.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;')
html_content += f'<{tag}>{value}</{tag}>'
html_content += '</tr>\n'
html_content += '</table></div></div>\n'
# 添加JavaScript
html_content += '''
<script>
function showSheet(index) {
document.querySelectorAll('.sheet-tab').forEach((tab, i) => {
tab.classList.toggle('active', i === index);
});
document.querySelectorAll('.sheet-content').forEach((content, i) => {
content.classList.toggle('active', i === index);
});
}
</script>
</body>
</html>'''
# 写入HTML文件
with open(html_file, 'w', encoding='utf-8') as f:
f.write(html_content)
logger.info(f"Excel转HTML成功: {html_file}")
return f"/static/uploads/converted/{course_id}/{material_id}.html"
except Exception as e:
logger.error(f"Excel转HTML失败: {source_file}, 错误: {str(e)}", exc_info=True)
return None
def convert_to_pdf(
self,
source_file: str,
course_id: int,
material_id: int
) -> Optional[str]:
"""
将Office文档转换为PDF
Args:
source_file: 源文件路径(绝对路径或相对路径)
course_id: 课程ID
material_id: 资料ID
Returns:
转换后的PDF文件URL失败返回None
"""
try:
source_path = Path(source_file)
# 检查源文件是否存在
if not source_path.exists():
logger.error(f"源文件不存在: {source_file}")
return None
# 检查文件格式是否支持
file_ext = source_path.suffix.lower()
if file_ext not in self.SUPPORTED_FORMATS:
logger.error(f"不支持的文件格式: {file_ext}")
return None
# Excel文件使用HTML预览避免分页问题
if file_ext in self.EXCEL_FORMATS:
return self.convert_excel_to_html(source_file, course_id, material_id)
# 获取转换后的文件路径
converted_file = self.get_converted_file_path(course_id, material_id)
# 检查是否需要转换
if not self.need_convert(source_path, converted_file):
logger.info(f"使用缓存的转换文件: {converted_file}")
return f"/static/uploads/converted/{course_id}/{material_id}.pdf"
# 执行转换
logger.info(f"开始转换文档: {source_file} -> {converted_file}")
# 使用 LibreOffice 转换
# --headless: 无界面模式
# --convert-to pdf: 转换为PDF
# --outdir: 输出目录
output_dir = converted_file.parent
cmd = [
'libreoffice',
'--headless',
'--convert-to', 'pdf',
'--outdir', str(output_dir),
str(source_path)
]
# 执行转换命令设置超时时间为60秒
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=60,
check=True
)
# LibreOffice 转换后的文件名是源文件名.pdf
# 需要重命名为 material_id.pdf
temp_converted = output_dir / f"{source_path.stem}.pdf"
if temp_converted.exists() and temp_converted != converted_file:
temp_converted.rename(converted_file)
# 检查转换结果
if converted_file.exists():
logger.info(f"文档转换成功: {converted_file}")
return f"/static/uploads/converted/{course_id}/{material_id}.pdf"
else:
logger.error(f"文档转换失败,输出文件不存在: {converted_file}")
return None
except subprocess.TimeoutExpired:
logger.error(f"文档转换超时: {source_file}")
return None
except subprocess.CalledProcessError as e:
logger.error(f"文档转换失败: {source_file}, 错误: {e.stderr}")
return None
except Exception as e:
logger.error(f"文档转换异常: {source_file}, 错误: {str(e)}", exc_info=True)
return None
def is_convertible(self, file_ext: str) -> bool:
"""
判断文件格式是否可转换
Args:
file_ext: 文件扩展名(带点,如 .docx
Returns:
是否可转换
"""
return file_ext.lower() in self.SUPPORTED_FORMATS
# 创建全局实例
document_converter = DocumentConverterService()

View File

@@ -0,0 +1,739 @@
"""
员工同步服务
从外部钉钉员工表同步员工数据到考培练系统
"""
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import selectinload
import asyncio
from app.core.logger import get_logger
from app.core.security import get_password_hash
from app.models.user import User, Team
from app.models.position import Position
from app.models.position_member import PositionMember
from app.schemas.user import UserCreate
logger = get_logger(__name__)
class EmployeeSyncService:
"""员工同步服务"""
# 外部数据库连接配置
EXTERNAL_DB_URL = "mysql+aiomysql://neuron_new:NWxGM6CQoMLKyEszXhfuLBIIo1QbeK@120.77.144.233:29613/neuron_new?charset=utf8mb4"
def __init__(self, db: AsyncSession):
self.db = db
self.external_engine = None
async def __aenter__(self):
"""异步上下文管理器入口"""
self.external_engine = create_async_engine(
self.EXTERNAL_DB_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
if self.external_engine:
await self.external_engine.dispose()
async def fetch_employees_from_dingtalk(self) -> List[Dict[str, Any]]:
"""
从钉钉员工表获取在职员工数据
Returns:
员工数据列表
"""
logger.info("开始从钉钉员工表获取数据...")
query = """
SELECT
员工姓名,
手机号,
邮箱,
所属部门,
职位,
工号,
是否领导,
是否在职,
钉钉用户ID,
入职日期,
工作地点
FROM v_钉钉员工表
WHERE 是否在职 = 1
ORDER BY 员工姓名
"""
async with self.external_engine.connect() as conn:
result = await conn.execute(text(query))
rows = result.fetchall()
employees = []
for row in rows:
employees.append({
'full_name': row[0],
'phone': row[1],
'email': row[2],
'department': row[3],
'position': row[4],
'employee_no': row[5],
'is_leader': bool(row[6]),
'is_active': bool(row[7]),
'dingtalk_id': row[8],
'join_date': row[9],
'work_location': row[10]
})
logger.info(f"获取到 {len(employees)} 条在职员工数据")
return employees
def generate_email(self, phone: str, original_email: Optional[str]) -> Optional[str]:
"""
生成邮箱地址
如果原始邮箱为空或格式无效,生成 {手机号}@rxm.com
Args:
phone: 手机号
original_email: 原始邮箱
Returns:
邮箱地址
"""
if original_email and original_email.strip():
email = original_email.strip()
# 验证邮箱格式:检查是否有@后直接跟点号等无效格式
if '@' in email and not email.startswith('@') and not email.endswith('@'):
# 检查@后面是否直接是点号
at_index = email.index('@')
if at_index + 1 < len(email) and email[at_index + 1] != '.':
# 检查是否有域名部分
domain = email[at_index + 1:]
if '.' in domain and len(domain) > 2:
return email
# 如果邮箱无效或为空,使用手机号生成
if phone:
return f"{phone}@rxm.com"
return None
def determine_role(self, is_leader: bool) -> str:
"""
确定用户角色
Args:
is_leader: 是否领导
Returns:
角色: manager 或 trainee
"""
return 'manager' if is_leader else 'trainee'
async def create_or_get_team(self, department_name: str, leader_id: Optional[int] = None) -> Team:
"""
创建或获取部门团队
Args:
department_name: 部门名称
leader_id: 负责人ID
Returns:
团队对象
"""
if not department_name or department_name.strip() == '':
return None
department_name = department_name.strip()
# 检查团队是否已存在
stmt = select(Team).where(
Team.name == department_name,
Team.is_deleted == False
)
result = await self.db.execute(stmt)
team = result.scalar_one_or_none()
if team:
# 更新负责人
if leader_id and not team.leader_id:
team.leader_id = leader_id
logger.info(f"更新团队 {department_name} 的负责人")
return team
# 创建新团队
# 生成团队代码:使用部门名称的拼音首字母或简化处理
team_code = f"DEPT_{hash(department_name) % 100000:05d}"
team = Team(
name=department_name,
code=team_code,
description=f"{department_name}",
team_type='department',
is_active=True,
leader_id=leader_id
)
self.db.add(team)
await self.db.flush() # 获取ID但不提交
logger.info(f"创建团队: {department_name} (ID: {team.id})")
return team
async def create_or_get_position(self, position_name: str) -> Optional[Position]:
"""
创建或获取岗位
Args:
position_name: 岗位名称
Returns:
岗位对象
"""
if not position_name or position_name.strip() == '':
return None
position_name = position_name.strip()
# 检查岗位是否已存在
stmt = select(Position).where(
Position.name == position_name,
Position.is_deleted == False
)
result = await self.db.execute(stmt)
position = result.scalar_one_or_none()
if position:
return position
# 创建新岗位
position_code = f"POS_{hash(position_name) % 100000:05d}"
position = Position(
name=position_name,
code=position_code,
description=f"{position_name}",
status='active'
)
self.db.add(position)
await self.db.flush()
logger.info(f"创建岗位: {position_name} (ID: {position.id})")
return position
async def create_user(self, employee_data: Dict[str, Any]) -> Optional[User]:
"""
创建用户
Args:
employee_data: 员工数据
Returns:
用户对象或None如果创建失败
"""
phone = employee_data.get('phone')
full_name = employee_data.get('full_name')
if not phone:
logger.warning(f"员工 {full_name} 没有手机号,跳过")
return None
# 检查用户是否已存在(通过手机号)
stmt = select(User).where(
User.phone == phone,
User.is_deleted == False
)
result = await self.db.execute(stmt)
existing_user = result.scalar_one_or_none()
if existing_user:
logger.info(f"用户已存在: {phone} ({full_name})")
return existing_user
# 生成邮箱
email = self.generate_email(phone, employee_data.get('email'))
# 检查邮箱是否已被其他用户使用(避免唯一索引冲突)
if email:
email_check_stmt = select(User).where(
User.email == email,
User.is_deleted == False
)
email_result = await self.db.execute(email_check_stmt)
if email_result.scalar_one_or_none():
# 邮箱已存在,使用手机号生成唯一邮箱
email = f"{phone}@rxm.com"
logger.warning(f"邮箱 {employee_data.get('email')} 已被使用,为员工 {full_name} 生成新邮箱: {email}")
# 确定角色
role = self.determine_role(employee_data.get('is_leader', False))
# 创建用户
user = User(
username=phone, # 使用手机号作为用户名
email=email,
phone=phone,
hashed_password=get_password_hash('123456'), # 初始密码
full_name=full_name,
role=role,
is_active=True,
is_verified=True
)
self.db.add(user)
await self.db.flush()
logger.info(f"创建用户: {phone} ({full_name}) - 角色: {role}")
return user
async def sync_employees(self) -> Dict[str, Any]:
"""
执行完整的员工同步流程
Returns:
同步结果统计
"""
logger.info("=" * 60)
logger.info("开始员工同步")
logger.info("=" * 60)
stats = {
'total_employees': 0,
'users_created': 0,
'users_skipped': 0,
'teams_created': 0,
'positions_created': 0,
'errors': [],
'start_time': datetime.now()
}
try:
# 1. 获取员工数据
employees = await self.fetch_employees_from_dingtalk()
stats['total_employees'] = len(employees)
if not employees:
logger.warning("没有获取到员工数据")
return stats
# 2. 创建用户和相关数据
for employee in employees:
try:
# 创建用户
user = await self.create_user(employee)
if not user:
stats['users_skipped'] += 1
continue
stats['users_created'] += 1
# 创建部门团队
department = employee.get('department')
if department:
team = await self.create_or_get_team(
department,
leader_id=user.id if employee.get('is_leader') else None
)
if team:
# 用SQL直接插入user_teams关联表避免懒加载问题
await self._add_user_to_team(user.id, team.id)
logger.info(f"关联用户 {user.full_name} 到团队 {team.name}")
# 创建岗位
position_name = employee.get('position')
if position_name:
position = await self.create_or_get_position(position_name)
if position:
# 检查是否已经关联
stmt = select(PositionMember).where(
PositionMember.position_id == position.id,
PositionMember.user_id == user.id,
PositionMember.is_deleted == False
)
result = await self.db.execute(stmt)
existing_member = result.scalar_one_or_none()
if not existing_member:
# 创建岗位成员关联
position_member = PositionMember(
position_id=position.id,
user_id=user.id,
role='member'
)
self.db.add(position_member)
logger.info(f"关联用户 {user.full_name} 到岗位 {position.name}")
except Exception as e:
error_msg = f"处理员工 {employee.get('full_name')} 时出错: {str(e)}"
logger.error(error_msg)
stats['errors'].append(error_msg)
continue
# 3. 提交所有更改
await self.db.commit()
logger.info("✅ 数据库事务已提交")
except Exception as e:
logger.error(f"员工同步失败: {str(e)}")
await self.db.rollback()
stats['errors'].append(str(e))
raise
finally:
stats['end_time'] = datetime.now()
stats['duration'] = (stats['end_time'] - stats['start_time']).total_seconds()
# 4. 输出统计信息
logger.info("=" * 60)
logger.info("同步完成统计")
logger.info("=" * 60)
logger.info(f"总员工数: {stats['total_employees']}")
logger.info(f"创建用户: {stats['users_created']}")
logger.info(f"跳过用户: {stats['users_skipped']}")
logger.info(f"耗时: {stats['duration']:.2f}")
if stats['errors']:
logger.warning(f"错误数量: {len(stats['errors'])}")
for error in stats['errors']:
logger.warning(f" - {error}")
return stats
async def preview_sync_data(self) -> Dict[str, Any]:
"""
预览待同步的员工数据(不执行实际同步)
Returns:
预览数据
"""
logger.info("预览待同步员工数据...")
employees = await self.fetch_employees_from_dingtalk()
preview = {
'total_count': len(employees),
'employees': [],
'departments': set(),
'positions': set(),
'leaders_count': 0,
'trainees_count': 0
}
for emp in employees:
role = self.determine_role(emp.get('is_leader', False))
email = self.generate_email(emp.get('phone'), emp.get('email'))
preview['employees'].append({
'full_name': emp.get('full_name'),
'phone': emp.get('phone'),
'email': email,
'department': emp.get('department'),
'position': emp.get('position'),
'role': role,
'is_leader': emp.get('is_leader')
})
if emp.get('department'):
preview['departments'].add(emp.get('department'))
if emp.get('position'):
preview['positions'].add(emp.get('position'))
if role == 'manager':
preview['leaders_count'] += 1
else:
preview['trainees_count'] += 1
preview['departments'] = list(preview['departments'])
preview['positions'] = list(preview['positions'])
return preview
async def _add_user_to_team(self, user_id: int, team_id: int) -> None:
"""
将用户添加到团队直接SQL操作避免懒加载问题
Args:
user_id: 用户ID
team_id: 团队ID
"""
# 先检查是否已存在关联
check_result = await self.db.execute(
text("SELECT 1 FROM user_teams WHERE user_id = :user_id AND team_id = :team_id"),
{"user_id": user_id, "team_id": team_id}
)
if check_result.scalar() is None:
# 不存在则插入
await self.db.execute(
text("INSERT INTO user_teams (user_id, team_id, role) VALUES (:user_id, :team_id, 'member')"),
{"user_id": user_id, "team_id": team_id}
)
async def _cleanup_user_related_data(self, user_id: int) -> None:
"""
清理用户关联数据(用于删除用户前)
Args:
user_id: 要清理的用户ID
"""
logger.info(f"清理用户 {user_id} 的关联数据...")
# 删除用户的考试记录
await self.db.execute(
text("DELETE FROM exam_results WHERE exam_id IN (SELECT id FROM exams WHERE user_id = :user_id)"),
{"user_id": user_id}
)
await self.db.execute(
text("DELETE FROM exams WHERE user_id = :user_id"),
{"user_id": user_id}
)
# 删除用户的错题记录
await self.db.execute(
text("DELETE FROM exam_mistakes WHERE user_id = :user_id"),
{"user_id": user_id}
)
# 删除用户的能力评估记录
await self.db.execute(
text("DELETE FROM ability_assessments WHERE user_id = :user_id"),
{"user_id": user_id}
)
# 删除用户的岗位关联
await self.db.execute(
text("DELETE FROM position_members WHERE user_id = :user_id"),
{"user_id": user_id}
)
# 删除用户的团队关联
await self.db.execute(
text("DELETE FROM user_teams WHERE user_id = :user_id"),
{"user_id": user_id}
)
# 删除用户的陪练会话
await self.db.execute(
text("DELETE FROM practice_sessions WHERE user_id = :user_id"),
{"user_id": user_id}
)
# 删除用户的任务分配
await self.db.execute(
text("DELETE FROM task_assignments WHERE user_id = :user_id"),
{"user_id": user_id}
)
# 删除用户创建的任务的分配记录
await self.db.execute(
text("DELETE FROM task_assignments WHERE task_id IN (SELECT id FROM tasks WHERE creator_id = :user_id)"),
{"user_id": user_id}
)
# 删除用户创建的任务
await self.db.execute(
text("DELETE FROM tasks WHERE creator_id = :user_id"),
{"user_id": user_id}
)
# 将用户作为负责人的团队的leader_id设为NULL
await self.db.execute(
text("UPDATE teams SET leader_id = NULL WHERE leader_id = :user_id"),
{"user_id": user_id}
)
logger.info(f"用户 {user_id} 的关联数据清理完成")
async def incremental_sync_employees(self) -> Dict[str, Any]:
"""
增量同步员工数据
- 新增钉钉有但系统没有的员工
- 删除系统有但钉钉没有的员工(物理删除)
- 跳过两边都存在的员工(不做任何修改)
Returns:
同步结果统计
"""
logger.info("=" * 60)
logger.info("开始增量员工同步")
logger.info("=" * 60)
stats = {
'added_count': 0,
'deleted_count': 0,
'skipped_count': 0,
'added_users': [],
'deleted_users': [],
'errors': [],
'start_time': datetime.now()
}
try:
# 1. 获取钉钉在职员工数据
dingtalk_employees = await self.fetch_employees_from_dingtalk()
dingtalk_phones = {emp.get('phone') for emp in dingtalk_employees if emp.get('phone')}
logger.info(f"钉钉在职员工数量: {len(dingtalk_phones)}")
# 2. 获取系统现有用户排除admin和已软删除的
stmt = select(User).where(
User.is_deleted == False,
User.username != 'admin'
)
result = await self.db.execute(stmt)
system_users = result.scalars().all()
system_phones = {user.phone for user in system_users if user.phone}
logger.info(f"系统现有员工数量排除admin: {len(system_phones)}")
# 3. 计算需要新增、删除、跳过的员工
phones_to_add = dingtalk_phones - system_phones
phones_to_delete = system_phones - dingtalk_phones
phones_to_skip = dingtalk_phones & system_phones
logger.info(f"待新增: {len(phones_to_add)}, 待删除: {len(phones_to_delete)}, 跳过: {len(phones_to_skip)}")
stats['skipped_count'] = len(phones_to_skip)
# 4. 新增员工
for employee in dingtalk_employees:
phone = employee.get('phone')
if not phone or phone not in phones_to_add:
continue
try:
# 创建用户
user = await self.create_user(employee)
if not user:
continue
stats['added_count'] += 1
stats['added_users'].append({
'full_name': user.full_name,
'phone': user.phone,
'role': user.role
})
# 创建部门团队
department = employee.get('department')
if department:
team = await self.create_or_get_team(
department,
leader_id=user.id if employee.get('is_leader') else None
)
if team:
# 用SQL直接插入user_teams关联表避免懒加载问题
await self._add_user_to_team(user.id, team.id)
logger.info(f"关联用户 {user.full_name} 到团队 {team.name}")
# 创建岗位
position_name = employee.get('position')
if position_name:
position = await self.create_or_get_position(position_name)
if position:
# 检查是否已经关联
stmt = select(PositionMember).where(
PositionMember.position_id == position.id,
PositionMember.user_id == user.id,
PositionMember.is_deleted == False
)
result = await self.db.execute(stmt)
existing_member = result.scalar_one_or_none()
if not existing_member:
position_member = PositionMember(
position_id=position.id,
user_id=user.id,
role='member'
)
self.db.add(position_member)
logger.info(f"关联用户 {user.full_name} 到岗位 {position.name}")
logger.info(f"✅ 新增员工: {user.full_name} ({user.phone})")
except Exception as e:
error_msg = f"新增员工 {employee.get('full_name')} 失败: {str(e)}"
logger.error(error_msg)
stats['errors'].append(error_msg)
continue
# 5. 删除离职员工(物理删除)
# 先flush之前的新增操作避免与删除操作冲突
await self.db.flush()
# 收集需要删除的用户ID
users_to_delete = []
for user in system_users:
if user.phone and user.phone in phones_to_delete:
# 双重保护确保不删除admin
if user.username == 'admin' or user.role == 'admin':
logger.warning(f"⚠️ 跳过删除管理员账户: {user.username}")
continue
users_to_delete.append({
'id': user.id,
'full_name': user.full_name,
'phone': user.phone,
'username': user.username
})
# 批量删除用户及其关联数据
for user_info in users_to_delete:
try:
user_id = user_info['id']
# 先清理关联数据(外键约束)
await self._cleanup_user_related_data(user_id)
# 用SQL直接删除用户避免ORM的级联操作冲突
await self.db.execute(
text("DELETE FROM users WHERE id = :user_id"),
{"user_id": user_id}
)
stats['deleted_users'].append({
'full_name': user_info['full_name'],
'phone': user_info['phone'],
'username': user_info['username']
})
stats['deleted_count'] += 1
logger.info(f"🗑️ 删除离职员工: {user_info['full_name']} ({user_info['phone']})")
except Exception as e:
error_msg = f"删除员工 {user_info['full_name']} 失败: {str(e)}"
logger.error(error_msg)
stats['errors'].append(error_msg)
continue
# 6. 提交所有更改
await self.db.commit()
logger.info("✅ 数据库事务已提交")
except Exception as e:
logger.error(f"增量同步失败: {str(e)}")
await self.db.rollback()
stats['errors'].append(str(e))
raise
finally:
stats['end_time'] = datetime.now()
stats['duration'] = (stats['end_time'] - stats['start_time']).total_seconds()
# 7. 输出统计信息
logger.info("=" * 60)
logger.info("增量同步完成统计")
logger.info("=" * 60)
logger.info(f"新增员工: {stats['added_count']}")
logger.info(f"删除员工: {stats['deleted_count']}")
logger.info(f"跳过员工: {stats['skipped_count']}")
logger.info(f"耗时: {stats['duration']:.2f}")
if stats['errors']:
logger.warning(f"错误数量: {len(stats['errors'])}")
return stats

View File

@@ -0,0 +1,486 @@
"""
考试报告和错题统计服务
"""
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_, or_, desc, case, text
from app.models.exam import Exam
from app.models.exam_mistake import ExamMistake
from app.models.course import Course, KnowledgePoint
from app.core.logger import get_logger
logger = get_logger(__name__)
class ExamReportService:
"""考试报告服务类"""
@staticmethod
async def get_exam_report(
db: AsyncSession,
user_id: int,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> Dict[str, Any]:
"""
获取成绩报告汇总数据
Args:
db: 数据库会话
user_id: 用户ID
start_date: 开始日期(YYYY-MM-DD)
end_date: 结束日期(YYYY-MM-DD)
Returns:
Dict: 包含overview、trends、subjects、recent_exams的完整报告数据
"""
logger.info(f"获取成绩报告 - user_id: {user_id}, start_date: {start_date}, end_date: {end_date}")
# 构建基础查询条件
conditions = [Exam.user_id == user_id]
# 添加时间范围条件
if start_date:
conditions.append(Exam.start_time >= start_date)
if end_date:
# 结束日期包含当天全部
end_datetime = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
conditions.append(Exam.start_time < end_datetime)
# 1. 获取概览数据
overview = await ExamReportService._get_overview(db, conditions)
# 2. 获取趋势数据最近30天
trends = await ExamReportService._get_trends(db, user_id, conditions)
# 3. 获取科目分析
subjects = await ExamReportService._get_subjects(db, conditions)
# 4. 获取最近考试记录
recent_exams = await ExamReportService._get_recent_exams(db, conditions)
return {
"overview": overview,
"trends": trends,
"subjects": subjects,
"recent_exams": recent_exams
}
@staticmethod
async def _get_overview(db: AsyncSession, conditions: List) -> Dict[str, Any]:
"""获取概览数据"""
# 查询统计数据使用round1_score作为主要成绩
stmt = select(
func.count(Exam.id).label("total_exams"),
func.avg(Exam.round1_score).label("avg_score"),
func.sum(Exam.question_count).label("total_questions"),
func.count(case((Exam.is_passed == True, 1))).label("passed_exams")
).where(
and_(*conditions, Exam.round1_score.isnot(None))
)
result = await db.execute(stmt)
stats = result.one()
total_exams = stats.total_exams or 0
passed_exams = stats.passed_exams or 0
return {
"avg_score": round(float(stats.avg_score or 0), 1),
"total_exams": total_exams,
"pass_rate": round((passed_exams / total_exams * 100) if total_exams > 0 else 0, 1),
"total_questions": stats.total_questions or 0
}
@staticmethod
async def _get_trends(
db: AsyncSession,
user_id: int,
base_conditions: List
) -> List[Dict[str, Any]]:
"""获取成绩趋势最近30天"""
# 计算30天前的日期
thirty_days_ago = datetime.now() - timedelta(days=30)
# 查询最近30天的考试数据按日期分组
stmt = select(
func.date(Exam.start_time).label("exam_date"),
func.avg(Exam.round1_score).label("avg_score")
).where(
and_(
Exam.user_id == user_id,
Exam.start_time >= thirty_days_ago,
Exam.round1_score.isnot(None)
)
).group_by(
func.date(Exam.start_time)
).order_by(
func.date(Exam.start_time)
)
result = await db.execute(stmt)
trend_data = result.all()
# 转换为前端需要的格式
trends = []
for row in trend_data:
trends.append({
"date": row.exam_date.strftime("%Y-%m-%d") if row.exam_date else "",
"avg_score": round(float(row.avg_score or 0), 1)
})
return trends
@staticmethod
async def _get_subjects(db: AsyncSession, conditions: List) -> List[Dict[str, Any]]:
"""获取科目分析"""
# 关联course表按课程分组统计
stmt = select(
Exam.course_id,
Course.name.label("course_name"),
func.avg(Exam.round1_score).label("avg_score"),
func.count(Exam.id).label("exam_count"),
func.max(Exam.round1_score).label("max_score"),
func.min(Exam.round1_score).label("min_score"),
func.count(case((Exam.is_passed == True, 1))).label("passed_count")
).join(
Course, Exam.course_id == Course.id
).where(
and_(*conditions, Exam.round1_score.isnot(None))
).group_by(
Exam.course_id, Course.name
).order_by(
desc(func.count(Exam.id))
)
result = await db.execute(stmt)
subject_data = result.all()
subjects = []
for row in subject_data:
exam_count = row.exam_count or 0
passed_count = row.passed_count or 0
subjects.append({
"course_id": row.course_id,
"course_name": row.course_name,
"avg_score": round(float(row.avg_score or 0), 1),
"exam_count": exam_count,
"max_score": round(float(row.max_score or 0), 1),
"min_score": round(float(row.min_score or 0), 1),
"pass_rate": round((passed_count / exam_count * 100) if exam_count > 0 else 0, 1)
})
return subjects
@staticmethod
async def _get_recent_exams(db: AsyncSession, conditions: List) -> List[Dict[str, Any]]:
"""获取最近10次考试记录"""
# 查询最近10次考试包含三轮得分
stmt = select(
Exam.id,
Exam.course_id,
Course.name.label("course_name"),
Exam.score,
Exam.total_score,
Exam.is_passed,
Exam.start_time,
Exam.end_time,
Exam.round1_score,
Exam.round2_score,
Exam.round3_score
).join(
Course, Exam.course_id == Course.id
).where(
and_(*conditions)
).order_by(
desc(Exam.created_at) # 改为按创建时间排序避免start_time为NULL的问题
).limit(10)
result = await db.execute(stmt)
exam_data = result.all()
recent_exams = []
for row in exam_data:
# 计算考试用时
duration_seconds = None
if row.start_time and row.end_time:
duration_seconds = int((row.end_time - row.start_time).total_seconds())
recent_exams.append({
"id": row.id,
"course_id": row.course_id,
"course_name": row.course_name,
"score": round(float(row.score), 1) if row.score else None,
"total_score": round(float(row.total_score or 100), 1),
"is_passed": row.is_passed,
"duration_seconds": duration_seconds,
"start_time": row.start_time.isoformat() if row.start_time else None,
"end_time": row.end_time.isoformat() if row.end_time else None,
"round_scores": {
"round1": round(float(row.round1_score), 1) if row.round1_score else None,
"round2": round(float(row.round2_score), 1) if row.round2_score else None,
"round3": round(float(row.round3_score), 1) if row.round3_score else None
}
})
return recent_exams
class MistakeService:
"""错题服务类"""
@staticmethod
async def get_mistakes_list(
db: AsyncSession,
user_id: int,
exam_id: Optional[int] = None,
course_id: Optional[int] = None,
question_type: Optional[str] = None,
search: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
page: int = 1,
size: int = 10
) -> Dict[str, Any]:
"""
获取错题列表(支持多维度筛选)
Args:
db: 数据库会话
user_id: 用户ID
exam_id: 考试ID可选
course_id: 课程ID可选
question_type: 题型(可选)
search: 关键词搜索(可选)
start_date: 开始日期(可选)
end_date: 结束日期(可选)
page: 页码
size: 每页数量
Returns:
Dict: 包含items、total、page、size、pages的分页数据
"""
logger.info(f"获取错题列表 - user_id: {user_id}, exam_id: {exam_id}, course_id: {course_id}")
# 构建查询条件
conditions = [ExamMistake.user_id == user_id]
if exam_id:
conditions.append(ExamMistake.exam_id == exam_id)
if question_type:
conditions.append(ExamMistake.question_type == question_type)
if search:
conditions.append(ExamMistake.question_content.like(f"%{search}%"))
if start_date:
conditions.append(ExamMistake.created_at >= start_date)
if end_date:
end_datetime = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
conditions.append(ExamMistake.created_at < end_datetime)
# 如果指定了course_id需要通过exam关联
if course_id:
conditions.append(Exam.course_id == course_id)
# 查询总数
count_stmt = select(func.count(ExamMistake.id)).select_from(ExamMistake).join(
Exam, ExamMistake.exam_id == Exam.id
).where(and_(*conditions))
total_result = await db.execute(count_stmt)
total = total_result.scalar() or 0
# 查询分页数据
offset = (page - 1) * size
stmt = select(
ExamMistake.id,
ExamMistake.exam_id,
Exam.course_id,
Course.name.label("course_name"),
ExamMistake.question_content,
ExamMistake.correct_answer,
ExamMistake.user_answer,
ExamMistake.question_type,
ExamMistake.knowledge_point_id,
KnowledgePoint.name.label("knowledge_point_name"),
ExamMistake.created_at
).select_from(ExamMistake).join(
Exam, ExamMistake.exam_id == Exam.id
).join(
Course, Exam.course_id == Course.id
).outerjoin(
KnowledgePoint, ExamMistake.knowledge_point_id == KnowledgePoint.id
).where(
and_(*conditions)
).order_by(
desc(ExamMistake.created_at)
).offset(offset).limit(size)
result = await db.execute(stmt)
mistakes = result.all()
# 构建返回数据
items = []
for row in mistakes:
items.append({
"id": row.id,
"exam_id": row.exam_id,
"course_id": row.course_id,
"course_name": row.course_name,
"question_content": row.question_content,
"correct_answer": row.correct_answer,
"user_answer": row.user_answer,
"question_type": row.question_type,
"knowledge_point_id": row.knowledge_point_id,
"knowledge_point_name": row.knowledge_point_name,
"created_at": row.created_at
})
pages = (total + size - 1) // size
return {
"items": items,
"total": total,
"page": page,
"size": size,
"pages": pages
}
@staticmethod
async def get_mistakes_statistics(
db: AsyncSession,
user_id: int,
course_id: Optional[int] = None
) -> Dict[str, Any]:
"""
获取错题统计数据
Args:
db: 数据库会话
user_id: 用户ID
course_id: 课程ID可选
Returns:
Dict: 包含total、by_course、by_type、by_time的统计数据
"""
logger.info(f"获取错题统计 - user_id: {user_id}, course_id: {course_id}")
# 基础条件
base_conditions = [ExamMistake.user_id == user_id]
if course_id:
base_conditions.append(Exam.course_id == course_id)
# 1. 总数统计
count_stmt = select(func.count(ExamMistake.id)).select_from(ExamMistake).join(
Exam, ExamMistake.exam_id == Exam.id
).where(and_(*base_conditions))
total_result = await db.execute(count_stmt)
total = total_result.scalar() or 0
# 2. 按课程统计
by_course_stmt = select(
Exam.course_id,
Course.name.label("course_name"),
func.count(ExamMistake.id).label("count")
).select_from(ExamMistake).join(
Exam, ExamMistake.exam_id == Exam.id
).join(
Course, Exam.course_id == Course.id
).where(
ExamMistake.user_id == user_id
).group_by(
Exam.course_id, Course.name
).order_by(
desc(func.count(ExamMistake.id))
)
by_course_result = await db.execute(by_course_stmt)
by_course_data = by_course_result.all()
by_course = [
{
"course_id": row.course_id,
"course_name": row.course_name,
"count": row.count
}
for row in by_course_data
]
# 3. 按题型统计
by_type_stmt = select(
ExamMistake.question_type,
func.count(ExamMistake.id).label("count")
).where(
and_(ExamMistake.user_id == user_id, ExamMistake.question_type.isnot(None))
).group_by(
ExamMistake.question_type
)
by_type_result = await db.execute(by_type_stmt)
by_type_data = by_type_result.all()
# 题型名称映射
type_names = {
"single": "单选题",
"multiple": "多选题",
"judge": "判断题",
"blank": "填空题",
"essay": "问答题"
}
by_type = [
{
"type": row.question_type,
"type_name": type_names.get(row.question_type, "未知类型"),
"count": row.count
}
for row in by_type_data
]
# 4. 按时间统计
now = datetime.now()
week_ago = now - timedelta(days=7)
month_ago = now - timedelta(days=30)
quarter_ago = now - timedelta(days=90)
# 最近一周
week_stmt = select(func.count(ExamMistake.id)).where(
and_(ExamMistake.user_id == user_id, ExamMistake.created_at >= week_ago)
)
week_result = await db.execute(week_stmt)
week_count = week_result.scalar() or 0
# 最近一月
month_stmt = select(func.count(ExamMistake.id)).where(
and_(ExamMistake.user_id == user_id, ExamMistake.created_at >= month_ago)
)
month_result = await db.execute(month_stmt)
month_count = month_result.scalar() or 0
# 最近三月
quarter_stmt = select(func.count(ExamMistake.id)).where(
and_(ExamMistake.user_id == user_id, ExamMistake.created_at >= quarter_ago)
)
quarter_result = await db.execute(quarter_stmt)
quarter_count = quarter_result.scalar() or 0
by_time = {
"week": week_count,
"month": month_count,
"quarter": quarter_count
}
return {
"total": total,
"by_course": by_course,
"by_type": by_type,
"by_time": by_time
}

View File

@@ -0,0 +1,439 @@
"""
考试服务层
"""
import json
import random
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_, or_, desc
from app.models.exam import Exam, Question, ExamResult
from app.models.exam_mistake import ExamMistake
from app.models.course import Course, KnowledgePoint
from app.core.exceptions import BusinessException, ErrorCode
class ExamService:
"""考试服务类"""
@staticmethod
async def start_exam(
db: AsyncSession, user_id: int, course_id: int, question_count: int = 10
) -> Exam:
"""
开始考试
Args:
db: 数据库会话
user_id: 用户ID
course_id: 课程ID
question_count: 题目数量
Returns:
Exam: 考试实例
"""
# 检查课程是否存在
course = await db.get(Course, course_id)
if not course:
raise BusinessException(error_code=ErrorCode.NOT_FOUND, message="课程不存在")
# 获取该课程的所有可用题目
stmt = select(Question).where(
and_(Question.course_id == course_id, Question.is_active == True)
)
result = await db.execute(stmt)
all_questions = result.scalars().all()
if not all_questions:
raise BusinessException(
error_code=ErrorCode.VALIDATION_ERROR, message="该课程暂无题目"
)
# 随机选择题目
selected_questions = random.sample(
all_questions, min(question_count, len(all_questions))
)
# 构建题目数据
questions_data = []
for q in selected_questions:
question_data = {
"id": str(q.id),
"type": q.question_type,
"title": q.title,
"content": q.content,
"options": q.options,
"score": q.score,
}
questions_data.append(question_data)
# 创建考试记录
exam = Exam(
user_id=user_id,
course_id=course_id,
exam_name=f"{course.name} - 随机测试",
question_count=len(selected_questions),
total_score=sum(q.score for q in selected_questions),
pass_score=sum(q.score for q in selected_questions) * 0.6,
duration_minutes=60,
status="started",
questions={"questions": questions_data},
)
db.add(exam)
await db.commit()
await db.refresh(exam)
return exam
@staticmethod
async def submit_exam(
db: AsyncSession, user_id: int, exam_id: int, answers: List[Dict[str, str]]
) -> Dict[str, Any]:
"""
提交考试答案
Args:
db: 数据库会话
user_id: 用户ID
exam_id: 考试ID
answers: 答案列表
Returns:
Dict: 考试结果
"""
# 获取考试记录
stmt = select(Exam).where(and_(Exam.id == exam_id, Exam.user_id == user_id))
result = await db.execute(stmt)
exam = result.scalar_one_or_none()
if not exam:
raise BusinessException(error_code=ErrorCode.NOT_FOUND, message="考试记录不存在")
if exam.status != "started":
raise BusinessException(
error_code=ErrorCode.VALIDATION_ERROR, message="考试已结束或已提交"
)
# 检查考试是否超时
if datetime.now() > exam.start_time + timedelta(
minutes=exam.duration_minutes
):
exam.status = "timeout"
await db.commit()
raise BusinessException(
error_code=ErrorCode.VALIDATION_ERROR, message="考试已超时"
)
# 处理答案
answers_dict = {ans["question_id"]: ans["answer"] for ans in answers}
total_score = 0.0
correct_count = 0
# 批量获取题目
question_ids = [int(ans["question_id"]) for ans in answers]
stmt = select(Question).where(Question.id.in_(question_ids))
result = await db.execute(stmt)
questions_map = {str(q.id): q for q in result.scalars().all()}
# 创建答题结果记录
for question_data in exam.questions["questions"]:
question_id = question_data["id"]
question = questions_map.get(question_id)
if not question:
continue
user_answer = answers_dict.get(question_id, "")
is_correct = user_answer == question.correct_answer
if is_correct:
total_score += question.score
correct_count += 1
# 创建答题结果记录
exam_result = ExamResult(
exam_id=exam_id,
question_id=int(question_id),
user_answer=user_answer,
is_correct=is_correct,
score=question.score if is_correct else 0.0,
)
db.add(exam_result)
# 更新题目使用统计
question.usage_count += 1
if is_correct:
question.correct_count += 1
# 更新考试记录
exam.end_time = datetime.now()
exam.score = total_score
exam.is_passed = total_score >= exam.pass_score
exam.status = "submitted"
exam.answers = {"answers": answers}
await db.commit()
return {
"exam_id": exam_id,
"total_score": total_score,
"pass_score": exam.pass_score,
"is_passed": exam.is_passed,
"correct_count": correct_count,
"total_count": exam.question_count,
"accuracy": correct_count / exam.question_count
if exam.question_count > 0
else 0,
}
@staticmethod
async def get_exam_detail(
db: AsyncSession, user_id: int, exam_id: int
) -> Dict[str, Any]:
"""
获取考试详情
Args:
db: 数据库会话
user_id: 用户ID
exam_id: 考试ID
Returns:
Dict: 考试详情
"""
# 获取考试记录
stmt = select(Exam).where(and_(Exam.id == exam_id, Exam.user_id == user_id))
result = await db.execute(stmt)
exam = result.scalar_one_or_none()
if not exam:
raise BusinessException(error_code=ErrorCode.NOT_FOUND, message="考试记录不存在")
# 构建返回数据
exam_data = {
"id": exam.id,
"course_id": exam.course_id,
"exam_name": exam.exam_name,
"question_count": exam.question_count,
"total_score": exam.total_score,
"pass_score": exam.pass_score,
"start_time": exam.start_time.isoformat() if exam.start_time else None,
"end_time": exam.end_time.isoformat() if exam.end_time else None,
"duration_minutes": exam.duration_minutes,
"status": exam.status,
"score": exam.score,
"is_passed": exam.is_passed,
"questions": exam.questions,
}
# 如果考试已提交,获取答题详情
if exam.status == "submitted" and exam.answers:
stmt = select(ExamResult).where(ExamResult.exam_id == exam_id)
result = await db.execute(stmt)
results = result.scalars().all()
results_data = []
for r in results:
results_data.append(
{
"question_id": r.question_id,
"user_answer": r.user_answer,
"is_correct": r.is_correct,
"score": r.score,
}
)
exam_data["results"] = results_data
exam_data["answers"] = exam.answers
return exam_data
@staticmethod
async def get_exam_records(
db: AsyncSession,
user_id: int,
page: int = 1,
size: int = 10,
course_id: Optional[int] = None,
) -> Dict[str, Any]:
"""
获取考试记录列表(包含统计数据)
Args:
db: 数据库会话
user_id: 用户ID
page: 页码
size: 每页数量
course_id: 课程ID可选
Returns:
Dict: 考试记录列表(包含统计信息)
"""
# 构建查询条件
conditions = [Exam.user_id == user_id]
if course_id:
conditions.append(Exam.course_id == course_id)
# 查询总数
count_stmt = select(func.count(Exam.id)).where(and_(*conditions))
total = await db.scalar(count_stmt)
# 查询考试数据JOIN courses表获取课程名称
offset = (page - 1) * size
stmt = (
select(Exam, Course.name.label("course_name"))
.join(Course, Exam.course_id == Course.id)
.where(and_(*conditions))
.order_by(Exam.created_at.desc())
.offset(offset)
.limit(size)
)
result = await db.execute(stmt)
rows = result.all()
# 构建返回数据
items = []
for exam, course_name in rows:
# 1. 计算用时
duration_seconds = None
if exam.start_time and exam.end_time:
duration_seconds = int((exam.end_time - exam.start_time).total_seconds())
# 2. 统计错题数
mistakes_stmt = select(func.count(ExamMistake.id)).where(
ExamMistake.exam_id == exam.id
)
wrong_count = await db.scalar(mistakes_stmt) or 0
# 3. 计算正确数和正确率
correct_count = exam.question_count - wrong_count if exam.question_count else 0
accuracy = None
if exam.question_count and exam.question_count > 0:
accuracy = round((correct_count / exam.question_count) * 100, 1)
# 4. 分题型统计
question_type_stats = []
if exam.questions:
try:
# 解析questions JSON统计每种题型的总数
questions_data = json.loads(exam.questions) if isinstance(exam.questions, str) else exam.questions
type_totals = {}
type_scores = {} # 存储每种题型的总分
for q in questions_data:
q_type = q.get("type", "unknown")
q_score = q.get("score", 0)
type_totals[q_type] = type_totals.get(q_type, 0) + 1
type_scores[q_type] = type_scores.get(q_type, 0) + q_score
# 查询错题按题型统计
mistakes_by_type_stmt = (
select(ExamMistake.question_type, func.count(ExamMistake.id))
.where(ExamMistake.exam_id == exam.id)
.group_by(ExamMistake.question_type)
)
mistakes_by_type_result = await db.execute(mistakes_by_type_stmt)
mistakes_by_type = dict(mistakes_by_type_result.all())
# 题型名称映射
type_name_map = {
"single": "单选题",
"multiple": "多选题",
"judge": "判断题",
"blank": "填空题",
"essay": "问答题"
}
# 组装分题型统计
for q_type, total in type_totals.items():
wrong = mistakes_by_type.get(q_type, 0)
correct = total - wrong
type_accuracy = round((correct / total) * 100, 1) if total > 0 else 0
question_type_stats.append({
"type": type_name_map.get(q_type, q_type),
"type_code": q_type,
"total": total,
"correct": correct,
"wrong": wrong,
"accuracy": type_accuracy,
"total_score": type_scores.get(q_type, 0)
})
except (json.JSONDecodeError, TypeError, KeyError) as e:
# 如果JSON解析失败返回空统计
question_type_stats = []
items.append(
{
"id": exam.id,
"course_id": exam.course_id,
"course_name": course_name,
"exam_name": exam.exam_name,
"question_count": exam.question_count,
"total_score": exam.total_score,
"score": exam.score,
"is_passed": exam.is_passed,
"status": exam.status,
"start_time": exam.start_time.isoformat() if exam.start_time else None,
"end_time": exam.end_time.isoformat() if exam.end_time else None,
"created_at": exam.created_at.isoformat(),
# 统计字段
"accuracy": accuracy,
"correct_count": correct_count,
"wrong_count": wrong_count,
"duration_seconds": duration_seconds,
"question_type_stats": question_type_stats,
}
)
return {
"items": items,
"total": total,
"page": page,
"size": size,
"pages": (total + size - 1) // size,
}
@staticmethod
async def get_exam_statistics(
db: AsyncSession, user_id: int, course_id: Optional[int] = None
) -> Dict[str, Any]:
"""
获取考试统计信息
Args:
db: 数据库会话
user_id: 用户ID
course_id: 课程ID可选
Returns:
Dict: 统计信息
"""
# 构建查询条件
conditions = [Exam.user_id == user_id, Exam.status == "submitted"]
if course_id:
conditions.append(Exam.course_id == course_id)
# 查询统计数据
stmt = select(
func.count(Exam.id).label("total_exams"),
func.count(func.nullif(Exam.is_passed, False)).label("passed_exams"),
func.avg(Exam.score).label("avg_score"),
func.max(Exam.score).label("max_score"),
func.min(Exam.score).label("min_score"),
).where(and_(*conditions))
result = await db.execute(stmt)
stats = result.one()
return {
"total_exams": stats.total_exams or 0,
"passed_exams": stats.passed_exams or 0,
"pass_rate": (stats.passed_exams / stats.total_exams * 100)
if stats.total_exams > 0
else 0,
"avg_score": float(stats.avg_score or 0),
"max_score": float(stats.max_score or 0),
"min_score": float(stats.min_score or 0),
}

View File

View File

@@ -0,0 +1,330 @@
"""
站内消息通知服务
提供通知的CRUD操作和业务逻辑
"""
from typing import List, Optional, Tuple
from sqlalchemy import select, and_, desc, func, update
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logger import get_logger
from app.models.notification import Notification
from app.models.user import User
from app.schemas.notification import (
NotificationCreate,
NotificationBatchCreate,
NotificationResponse,
NotificationType,
)
from app.services.base_service import BaseService
logger = get_logger(__name__)
class NotificationService(BaseService[Notification]):
"""
站内消息通知服务
提供通知的创建、查询、标记已读等功能
"""
def __init__(self):
super().__init__(Notification)
async def create_notification(
self,
db: AsyncSession,
notification_in: NotificationCreate
) -> Notification:
"""
创建单个通知
Args:
db: 数据库会话
notification_in: 通知创建数据
Returns:
创建的通知对象
"""
notification = Notification(
user_id=notification_in.user_id,
title=notification_in.title,
content=notification_in.content,
type=notification_in.type.value if isinstance(notification_in.type, NotificationType) else notification_in.type,
related_id=notification_in.related_id,
related_type=notification_in.related_type,
sender_id=notification_in.sender_id,
is_read=False
)
db.add(notification)
await db.commit()
await db.refresh(notification)
logger.info(
"创建通知成功",
notification_id=notification.id,
user_id=notification_in.user_id,
type=notification_in.type
)
return notification
async def batch_create_notifications(
self,
db: AsyncSession,
batch_in: NotificationBatchCreate
) -> List[Notification]:
"""
批量创建通知(发送给多个用户)
Args:
db: 数据库会话
batch_in: 批量通知创建数据
Returns:
创建的通知列表
"""
notifications = []
notification_type = batch_in.type.value if isinstance(batch_in.type, NotificationType) else batch_in.type
for user_id in batch_in.user_ids:
notification = Notification(
user_id=user_id,
title=batch_in.title,
content=batch_in.content,
type=notification_type,
related_id=batch_in.related_id,
related_type=batch_in.related_type,
sender_id=batch_in.sender_id,
is_read=False
)
notifications.append(notification)
db.add(notification)
await db.commit()
# 刷新所有对象
for notification in notifications:
await db.refresh(notification)
logger.info(
"批量创建通知成功",
count=len(notifications),
user_ids=batch_in.user_ids,
type=batch_in.type
)
return notifications
async def get_user_notifications(
self,
db: AsyncSession,
user_id: int,
skip: int = 0,
limit: int = 20,
is_read: Optional[bool] = None,
notification_type: Optional[str] = None
) -> Tuple[List[NotificationResponse], int, int]:
"""
获取用户的通知列表
Args:
db: 数据库会话
user_id: 用户ID
skip: 跳过数量
limit: 返回数量
is_read: 是否已读筛选
notification_type: 通知类型筛选
Returns:
(通知列表, 总数, 未读数)
"""
# 构建基础查询条件
conditions = [Notification.user_id == user_id]
if is_read is not None:
conditions.append(Notification.is_read == is_read)
if notification_type:
conditions.append(Notification.type == notification_type)
# 查询通知列表(带发送者信息)
stmt = (
select(Notification)
.where(and_(*conditions))
.order_by(desc(Notification.created_at))
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
notifications = result.scalars().all()
# 统计总数
count_stmt = select(func.count()).select_from(Notification).where(and_(*conditions))
total_result = await db.execute(count_stmt)
total = total_result.scalar_one()
# 统计未读数
unread_stmt = (
select(func.count())
.select_from(Notification)
.where(and_(Notification.user_id == user_id, Notification.is_read == False))
)
unread_result = await db.execute(unread_stmt)
unread_count = unread_result.scalar_one()
# 获取发送者信息
sender_ids = [n.sender_id for n in notifications if n.sender_id]
sender_names = {}
if sender_ids:
sender_stmt = select(User.id, User.full_name).where(User.id.in_(sender_ids))
sender_result = await db.execute(sender_stmt)
sender_names = {row[0]: row[1] for row in sender_result.fetchall()}
# 构建响应
responses = []
for notification in notifications:
response = NotificationResponse(
id=notification.id,
user_id=notification.user_id,
title=notification.title,
content=notification.content,
type=notification.type,
is_read=notification.is_read,
related_id=notification.related_id,
related_type=notification.related_type,
sender_id=notification.sender_id,
sender_name=sender_names.get(notification.sender_id) if notification.sender_id else None,
created_at=notification.created_at,
updated_at=notification.updated_at
)
responses.append(response)
return responses, total, unread_count
async def get_unread_count(
self,
db: AsyncSession,
user_id: int
) -> Tuple[int, int]:
"""
获取用户未读通知数量
Args:
db: 数据库会话
user_id: 用户ID
Returns:
(未读数, 总数)
"""
# 统计未读数
unread_stmt = (
select(func.count())
.select_from(Notification)
.where(and_(Notification.user_id == user_id, Notification.is_read == False))
)
unread_result = await db.execute(unread_stmt)
unread_count = unread_result.scalar_one()
# 统计总数
total_stmt = (
select(func.count())
.select_from(Notification)
.where(Notification.user_id == user_id)
)
total_result = await db.execute(total_stmt)
total = total_result.scalar_one()
return unread_count, total
async def mark_as_read(
self,
db: AsyncSession,
user_id: int,
notification_ids: Optional[List[int]] = None
) -> int:
"""
标记通知为已读
Args:
db: 数据库会话
user_id: 用户ID
notification_ids: 通知ID列表为空则标记全部
Returns:
更新的数量
"""
conditions = [
Notification.user_id == user_id,
Notification.is_read == False
]
if notification_ids:
conditions.append(Notification.id.in_(notification_ids))
stmt = (
update(Notification)
.where(and_(*conditions))
.values(is_read=True)
)
result = await db.execute(stmt)
await db.commit()
updated_count = result.rowcount
logger.info(
"标记通知已读",
user_id=user_id,
notification_ids=notification_ids,
updated_count=updated_count
)
return updated_count
async def delete_notification(
self,
db: AsyncSession,
user_id: int,
notification_id: int
) -> bool:
"""
删除通知
Args:
db: 数据库会话
user_id: 用户ID
notification_id: 通知ID
Returns:
是否删除成功
"""
stmt = select(Notification).where(
and_(
Notification.id == notification_id,
Notification.user_id == user_id
)
)
result = await db.execute(stmt)
notification = result.scalar_one_or_none()
if notification:
await db.delete(notification)
await db.commit()
logger.info(
"删除通知成功",
notification_id=notification_id,
user_id=user_id
)
return True
return False
# 创建服务实例
notification_service = NotificationService()

View File

@@ -0,0 +1,356 @@
"""
SCRM 系统对接服务
提供给 SCRM 系统调用的数据查询服务
"""
import logging
from typing import List, Optional, Dict, Any
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.user import User
from app.models.position import Position
from app.models.position_member import PositionMember
from app.models.position_course import PositionCourse
from app.models.course import Course, KnowledgePoint, CourseMaterial
logger = logging.getLogger(__name__)
class SCRMService:
"""SCRM 系统数据查询服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_employee_position(
self,
userid: Optional[str] = None,
name: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
根据企微 userid 或员工姓名获取员工岗位信息
Args:
userid: 企微员工 userid可选
name: 员工姓名(可选,支持模糊匹配)
Returns:
员工岗位信息字典,包含 employee_id, userid, name, positions
如果员工不存在返回 None
如果按姓名搜索有多个结果,返回列表
"""
query = (
select(User)
.options(selectinload(User.position_memberships).selectinload(PositionMember.position))
.where(User.is_deleted.is_(False))
)
# 优先按 wework_userid 精确匹配
if userid:
query = query.where(User.wework_userid == userid)
result = await self.db.execute(query)
user = result.scalar_one_or_none()
if user:
return self._build_employee_position_data(user)
# 其次按姓名匹配(支持精确匹配和模糊匹配)
if name:
# 先尝试精确匹配
exact_query = query.where(User.full_name == name)
result = await self.db.execute(exact_query)
users = result.scalars().all()
# 如果精确匹配没有结果,尝试模糊匹配
if not users:
fuzzy_query = query.where(User.full_name.ilike(f"%{name}%"))
result = await self.db.execute(fuzzy_query)
users = result.scalars().all()
if len(users) == 1:
return self._build_employee_position_data(users[0])
elif len(users) > 1:
# 多个匹配结果,返回列表供选择
return {
"multiple_matches": True,
"count": len(users),
"employees": [
{
"employee_id": u.id,
"userid": u.wework_userid,
"name": u.full_name or u.username,
"phone": u.phone[-4:] if u.phone else None # 只显示手机号后4位
}
for u in users
]
}
return None
def _build_employee_position_data(self, user: User) -> Dict[str, Any]:
"""构建员工岗位数据"""
positions = []
for i, pm in enumerate(user.position_memberships):
if pm.is_deleted or pm.position.is_deleted:
continue
positions.append({
"position_id": pm.position.id,
"position_name": pm.position.name,
"is_primary": i == 0, # 第一个为主岗位
"joined_at": pm.joined_at.strftime("%Y-%m-%d") if pm.joined_at else None
})
return {
"employee_id": user.id,
"userid": user.wework_userid,
"name": user.full_name or user.username,
"positions": positions
}
async def get_employee_position_by_id(self, employee_id: int) -> Optional[Dict[str, Any]]:
"""
根据员工ID获取岗位信息
Args:
employee_id: 员工IDusers表主键
Returns:
员工岗位信息字典
"""
result = await self.db.execute(
select(User)
.options(selectinload(User.position_memberships).selectinload(PositionMember.position))
.where(User.id == employee_id, User.is_deleted == False)
)
user = result.scalar_one_or_none()
if not user:
return None
return self._build_employee_position_data(user)
async def get_position_courses(
self,
position_id: int,
course_type: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
获取指定岗位的课程列表
Args:
position_id: 岗位ID
course_type: 课程类型筛选required/optional/all
Returns:
岗位课程信息字典,包含 position_id, position_name, courses
如果岗位不存在返回 None
"""
# 查询岗位
position_result = await self.db.execute(
select(Position).where(Position.id == position_id, Position.is_deleted.is_(False))
)
position = position_result.scalar_one_or_none()
if not position:
return None
# 查询岗位课程关联
query = (
select(PositionCourse, Course)
.join(Course, PositionCourse.course_id == Course.id)
.where(
PositionCourse.position_id == position_id,
PositionCourse.is_deleted.is_(False),
Course.is_deleted.is_(False),
Course.status == "published" # 只返回已发布的课程
)
.order_by(PositionCourse.priority.desc())
)
# 课程类型筛选
if course_type and course_type != "all":
query = query.where(PositionCourse.course_type == course_type)
result = await self.db.execute(query)
pc_courses = result.all()
# 构建课程列表,并统计知识点数量
courses = []
for pc, course in pc_courses:
# 统计该课程的知识点数量
kp_count_result = await self.db.execute(
select(func.count(KnowledgePoint.id))
.where(
KnowledgePoint.course_id == course.id,
KnowledgePoint.is_deleted.is_(False)
)
)
kp_count = kp_count_result.scalar() or 0
courses.append({
"course_id": course.id,
"course_name": course.name,
"course_type": pc.course_type,
"priority": pc.priority,
"knowledge_point_count": kp_count
})
return {
"position_id": position.id,
"position_name": position.name,
"courses": courses
}
async def search_knowledge_points(
self,
keywords: List[str],
position_id: Optional[int] = None,
course_ids: Optional[List[int]] = None,
knowledge_type: Optional[str] = None,
limit: int = 10
) -> Dict[str, Any]:
"""
搜索知识点
Args:
keywords: 搜索关键词列表
position_id: 岗位ID用于优先排序
course_ids: 限定课程范围
knowledge_type: 知识点类型筛选
limit: 返回数量限制
Returns:
搜索结果字典,包含 total 和 items
"""
# 基础查询
query = (
select(KnowledgePoint, Course)
.join(Course, KnowledgePoint.course_id == Course.id)
.where(
KnowledgePoint.is_deleted.is_(False),
Course.is_deleted.is_(False),
Course.status == "published"
)
)
# 关键词搜索条件(在名称和描述中搜索)
keyword_conditions = []
for keyword in keywords:
keyword_conditions.append(
or_(
KnowledgePoint.name.ilike(f"%{keyword}%"),
KnowledgePoint.description.ilike(f"%{keyword}%")
)
)
if keyword_conditions:
query = query.where(or_(*keyword_conditions))
# 课程范围筛选
if course_ids:
query = query.where(KnowledgePoint.course_id.in_(course_ids))
# 知识点类型筛选
if knowledge_type:
query = query.where(KnowledgePoint.type == knowledge_type)
# 如果指定了岗位,优先返回该岗位相关课程的知识点
if position_id:
# 获取该岗位的课程ID列表
pos_course_result = await self.db.execute(
select(PositionCourse.course_id)
.where(
PositionCourse.position_id == position_id,
PositionCourse.is_deleted.is_(False)
)
)
pos_course_ids = [row[0] for row in pos_course_result.all()]
if pos_course_ids:
# 使用 CASE WHEN 进行排序:岗位相关课程优先
from sqlalchemy import case
priority_order = case(
(KnowledgePoint.course_id.in_(pos_course_ids), 0),
else_=1
)
query = query.order_by(priority_order, KnowledgePoint.id.desc())
else:
query = query.order_by(KnowledgePoint.id.desc())
else:
query = query.order_by(KnowledgePoint.id.desc())
# 执行查询
result = await self.db.execute(query.limit(limit))
kp_courses = result.all()
# 计算相关度分数(简单实现:匹配的关键词越多分数越高)
def calc_relevance(kp: KnowledgePoint) -> float:
text = f"{kp.name} {kp.description or ''}"
matched = sum(1 for kw in keywords if kw.lower() in text.lower())
return round(matched / len(keywords), 2) if keywords else 1.0
# 构建结果
items = []
for kp, course in kp_courses:
items.append({
"knowledge_point_id": kp.id,
"name": kp.name,
"course_id": course.id,
"course_name": course.name,
"type": kp.type,
"relevance_score": calc_relevance(kp)
})
# 按相关度分数排序
items.sort(key=lambda x: x["relevance_score"], reverse=True)
return {
"total": len(items),
"items": items
}
async def get_knowledge_point_detail(self, kp_id: int) -> Optional[Dict[str, Any]]:
"""
获取知识点详情
Args:
kp_id: 知识点ID
Returns:
知识点详情字典
如果知识点不存在返回 None
"""
# 查询知识点及关联的课程和资料
result = await self.db.execute(
select(KnowledgePoint, Course, CourseMaterial)
.join(Course, KnowledgePoint.course_id == Course.id)
.outerjoin(CourseMaterial, KnowledgePoint.material_id == CourseMaterial.id)
.where(
KnowledgePoint.id == kp_id,
KnowledgePoint.is_deleted.is_(False)
)
)
row = result.one_or_none()
if not row:
return None
kp, course, material = row
return {
"knowledge_point_id": kp.id,
"name": kp.name,
"course_id": course.id,
"course_name": course.name,
"type": kp.type,
"content": kp.description or "", # description 作为知识点内容
"material_id": material.id if material else None,
"material_type": material.file_type if material else None,
"material_url": material.file_url if material else None,
"topic_relation": kp.topic_relation,
"source": kp.source,
"created_at": kp.created_at.strftime("%Y-%m-%d %H:%M:%S") if kp.created_at else None
}

View File

@@ -0,0 +1,708 @@
"""
统计分析服务
"""
from datetime import datetime, timedelta
from typing import List, Optional, Dict, Any, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_, or_, case, desc, distinct
from app.models.exam import Exam, Question
from app.models.exam_mistake import ExamMistake
from app.models.course import Course, KnowledgePoint
from app.models.practice import PracticeSession
from app.models.training import TrainingSession
from app.core.logger import get_logger
logger = get_logger(__name__)
class StatisticsService:
"""统计分析服务类"""
@staticmethod
def _get_date_range(period: str) -> Tuple[datetime, datetime]:
"""
根据period返回开始和结束日期
Args:
period: 时间范围 (week/month/quarter/halfYear/year)
Returns:
Tuple[datetime, datetime]: (开始日期, 结束日期)
"""
end_date = datetime.now()
if period == "week":
start_date = end_date - timedelta(days=7)
elif period == "month":
start_date = end_date - timedelta(days=30)
elif period == "quarter":
start_date = end_date - timedelta(days=90)
elif period == "halfYear":
start_date = end_date - timedelta(days=180)
elif period == "year":
start_date = end_date - timedelta(days=365)
else:
# 默认一个月
start_date = end_date - timedelta(days=30)
return start_date, end_date
@staticmethod
def _calculate_change_rate(current: float, previous: float) -> float:
"""
计算环比变化率
Args:
current: 当前值
previous: 上期值
Returns:
float: 变化率(百分比)
"""
if previous == 0:
return 0 if current == 0 else 100
return round(((current - previous) / previous) * 100, 1)
@staticmethod
async def get_key_metrics(
db: AsyncSession,
user_id: int,
course_id: Optional[int] = None,
period: str = "month"
) -> Dict[str, Any]:
"""
获取关键指标
Args:
db: 数据库会话
user_id: 用户ID
course_id: 课程ID可选
period: 时间范围
Returns:
Dict: 包含学习效率、知识覆盖率、平均用时、进步速度的指标数据
"""
logger.info(f"获取关键指标 - user_id: {user_id}, course_id: {course_id}, period: {period}")
start_date, end_date = StatisticsService._get_date_range(period)
# 构建基础查询条件
exam_conditions = [
Exam.user_id == user_id,
Exam.start_time >= start_date,
Exam.start_time <= end_date,
Exam.round1_score.isnot(None)
]
if course_id:
exam_conditions.append(Exam.course_id == course_id)
# 1. 学习效率 = (总题数 - 错题数) / 总题数
# 获取总题数
total_questions_stmt = select(
func.coalesce(func.sum(Exam.question_count), 0)
).where(and_(*exam_conditions))
total_questions = await db.scalar(total_questions_stmt) or 0
# 获取错题数
mistake_conditions = [ExamMistake.user_id == user_id]
if course_id:
mistake_conditions.append(
ExamMistake.exam_id.in_(
select(Exam.id).where(Exam.course_id == course_id)
)
)
mistake_stmt = select(func.count(ExamMistake.id)).where(
and_(*mistake_conditions)
)
mistake_count = await db.scalar(mistake_stmt) or 0
# 计算学习效率
learning_efficiency = 0.0
if total_questions > 0:
correct_questions = total_questions - mistake_count
learning_efficiency = round((correct_questions / total_questions) * 100, 1)
# 计算上期学习效率(用于环比)
prev_start_date = start_date - (end_date - start_date)
prev_exam_conditions = [
Exam.user_id == user_id,
Exam.start_time >= prev_start_date,
Exam.start_time < start_date,
Exam.round1_score.isnot(None)
]
if course_id:
prev_exam_conditions.append(Exam.course_id == course_id)
prev_total_questions = await db.scalar(
select(func.coalesce(func.sum(Exam.question_count), 0)).where(
and_(*prev_exam_conditions)
)
) or 0
prev_mistake_count = await db.scalar(
select(func.count(ExamMistake.id)).where(
and_(
ExamMistake.user_id == user_id,
ExamMistake.exam_id.in_(
select(Exam.id).where(and_(*prev_exam_conditions))
)
)
)
) or 0
prev_efficiency = 0.0
if prev_total_questions > 0:
prev_correct = prev_total_questions - prev_mistake_count
prev_efficiency = (prev_correct / prev_total_questions) * 100
efficiency_change = StatisticsService._calculate_change_rate(
learning_efficiency, prev_efficiency
)
# 2. 知识覆盖率 = 已掌握知识点数 / 总知识点数
# 获取总知识点数
kp_conditions = []
if course_id:
kp_conditions.append(KnowledgePoint.course_id == course_id)
total_kp_stmt = select(func.count(KnowledgePoint.id)).where(
and_(KnowledgePoint.is_deleted == False, *kp_conditions)
)
total_kp = await db.scalar(total_kp_stmt) or 0
# 获取错误的知识点数(至少错过一次的)
mistake_kp_stmt = select(
func.count(distinct(ExamMistake.knowledge_point_id))
).where(
and_(
ExamMistake.user_id == user_id,
ExamMistake.knowledge_point_id.isnot(None),
*([ExamMistake.exam_id.in_(
select(Exam.id).where(Exam.course_id == course_id)
)] if course_id else [])
)
)
mistake_kp = await db.scalar(mistake_kp_stmt) or 0
# 计算知识覆盖率(掌握的知识点 = 总知识点 - 错误知识点)
knowledge_coverage = 0.0
if total_kp > 0:
mastered_kp = max(0, total_kp - mistake_kp)
knowledge_coverage = round((mastered_kp / total_kp) * 100, 1)
# 上期知识覆盖率(简化:假设知识点总数不变)
prev_mistake_kp = await db.scalar(
select(func.count(distinct(ExamMistake.knowledge_point_id))).where(
and_(
ExamMistake.user_id == user_id,
ExamMistake.knowledge_point_id.isnot(None),
ExamMistake.exam_id.in_(
select(Exam.id).where(and_(*prev_exam_conditions))
)
)
)
) or 0
prev_coverage = 0.0
if total_kp > 0:
prev_mastered = max(0, total_kp - prev_mistake_kp)
prev_coverage = (prev_mastered / total_kp) * 100
coverage_change = StatisticsService._calculate_change_rate(
knowledge_coverage, prev_coverage
)
# 3. 平均用时 = 总考试时长 / 总题数
total_duration_stmt = select(
func.coalesce(func.sum(Exam.duration_minutes), 0)
).where(and_(*exam_conditions))
total_duration = await db.scalar(total_duration_stmt) or 0
avg_time_per_question = 0.0
if total_questions > 0:
avg_time_per_question = round((total_duration / total_questions), 1)
# 上期平均用时
prev_total_duration = await db.scalar(
select(func.coalesce(func.sum(Exam.duration_minutes), 0)).where(
and_(*prev_exam_conditions)
)
) or 0
prev_avg_time = 0.0
if prev_total_questions > 0:
prev_avg_time = prev_total_duration / prev_total_questions
# 平均用时的环比是负增长表示好(时间减少)
time_change = StatisticsService._calculate_change_rate(
avg_time_per_question, prev_avg_time
)
# 4. 进步速度 = (本期平均分 - 上期平均分) / 上期平均分
avg_score_stmt = select(func.avg(Exam.round1_score)).where(
and_(*exam_conditions)
)
avg_score = await db.scalar(avg_score_stmt) or 0
prev_avg_score_stmt = select(func.avg(Exam.round1_score)).where(
and_(*prev_exam_conditions)
)
prev_avg_score = await db.scalar(prev_avg_score_stmt) or 0
progress_speed = StatisticsService._calculate_change_rate(
float(avg_score), float(prev_avg_score)
)
return {
"learningEfficiency": {
"value": learning_efficiency,
"unit": "%",
"change": efficiency_change,
"description": "正确题数/总练习题数"
},
"knowledgeCoverage": {
"value": knowledge_coverage,
"unit": "%",
"change": coverage_change,
"description": "已掌握知识点/总知识点"
},
"avgTimePerQuestion": {
"value": avg_time_per_question,
"unit": "分/题",
"change": time_change,
"description": "平均每道题的答题时间"
},
"progressSpeed": {
"value": abs(progress_speed),
"unit": "%",
"change": progress_speed,
"description": "成绩提升速度"
}
}
@staticmethod
async def get_score_distribution(
db: AsyncSession,
user_id: int,
course_id: Optional[int] = None,
period: str = "month"
) -> Dict[str, Any]:
"""
获取成绩分布统计
Args:
db: 数据库会话
user_id: 用户ID
course_id: 课程ID可选
period: 时间范围
Returns:
Dict: 成绩分布数据(优秀、良好、中等、及格、不及格)
"""
logger.info(f"获取成绩分布 - user_id: {user_id}, course_id: {course_id}, period: {period}")
start_date, end_date = StatisticsService._get_date_range(period)
# 构建查询条件
conditions = [
Exam.user_id == user_id,
Exam.start_time >= start_date,
Exam.start_time <= end_date,
Exam.round1_score.isnot(None)
]
if course_id:
conditions.append(Exam.course_id == course_id)
# 使用case when统计各分数段的数量
stmt = select(
func.count(case((Exam.round1_score >= 90, 1))).label("excellent"),
func.count(case((and_(Exam.round1_score >= 80, Exam.round1_score < 90), 1))).label("good"),
func.count(case((and_(Exam.round1_score >= 70, Exam.round1_score < 80), 1))).label("medium"),
func.count(case((and_(Exam.round1_score >= 60, Exam.round1_score < 70), 1))).label("pass_count"),
func.count(case((Exam.round1_score < 60, 1))).label("fail")
).where(and_(*conditions))
result = await db.execute(stmt)
row = result.one()
return {
"excellent": row.excellent or 0, # 优秀(90-100)
"good": row.good or 0, # 良好(80-89)
"medium": row.medium or 0, # 中等(70-79)
"pass": row.pass_count or 0, # 及格(60-69)
"fail": row.fail or 0 # 不及格(<60)
}
@staticmethod
async def get_difficulty_analysis(
db: AsyncSession,
user_id: int,
course_id: Optional[int] = None,
period: str = "month"
) -> Dict[str, Any]:
"""
获取题目难度分析
Args:
db: 数据库会话
user_id: 用户ID
course_id: 课程ID可选
period: 时间范围
Returns:
Dict: 各难度题目的正确率统计
"""
logger.info(f"获取难度分析 - user_id: {user_id}, course_id: {course_id}, period: {period}")
start_date, end_date = StatisticsService._get_date_range(period)
# 获取用户在时间范围内的考试
exam_conditions = [
Exam.user_id == user_id,
Exam.start_time >= start_date,
Exam.start_time <= end_date
]
if course_id:
exam_conditions.append(Exam.course_id == course_id)
exam_ids_stmt = select(Exam.id).where(and_(*exam_conditions))
result = await db.execute(exam_ids_stmt)
exam_ids = [row[0] for row in result.all()]
if not exam_ids:
# 没有考试数据,返回默认值
return {
"easy": 100.0,
"medium": 100.0,
"hard": 100.0,
"综合题": 100.0,
"应用题": 100.0
}
# 统计各难度的总题数和错题数
difficulty_stats = {}
for difficulty in ["easy", "medium", "hard"]:
# 总题数从exams的questions字段中统计这里简化处理
# 由于questions字段是JSON我们通过question_count估算
# 实际应用中可以解析JSON或通过exam_results表统计
# 错题数从exam_mistakes通过question_id关联查询
mistake_stmt = select(func.count(ExamMistake.id)).select_from(
ExamMistake
).join(
Question, ExamMistake.question_id == Question.id
).where(
and_(
ExamMistake.user_id == user_id,
ExamMistake.exam_id.in_(exam_ids),
Question.difficulty == difficulty
)
)
mistake_count = await db.scalar(mistake_stmt) or 0
# 总题数:该难度的题目在用户考试中出现的次数
# 简化处理:假设每次考试平均包含该难度题目的比例
total_questions_stmt = select(
func.coalesce(func.sum(Exam.question_count), 0)
).where(and_(*exam_conditions))
total_count = await db.scalar(total_questions_stmt) or 0
total_count = int(total_count) # 转换为int避免Decimal类型问题
# 简化算法假设easy:medium:hard = 3:2:1
if difficulty == "easy":
estimated_count = int(total_count * 0.5)
elif difficulty == "medium":
estimated_count = int(total_count * 0.3)
else: # hard
estimated_count = int(total_count * 0.2)
# 计算正确率
if estimated_count > 0:
correct_count = max(0, estimated_count - mistake_count)
accuracy = round((correct_count / estimated_count) * 100, 1)
else:
accuracy = 100.0
difficulty_stats[difficulty] = accuracy
# 综合题和应用题使用中等和困难题的平均值
difficulty_stats["综合题"] = round((difficulty_stats["medium"] + difficulty_stats["hard"]) / 2, 1)
difficulty_stats["应用题"] = round((difficulty_stats["medium"] + difficulty_stats["hard"]) / 2, 1)
return {
"简单题": difficulty_stats["easy"],
"中等题": difficulty_stats["medium"],
"困难题": difficulty_stats["hard"],
"综合题": difficulty_stats["综合题"],
"应用题": difficulty_stats["应用题"]
}
@staticmethod
async def get_knowledge_mastery(
db: AsyncSession,
user_id: int,
course_id: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
获取知识点掌握度
Args:
db: 数据库会话
user_id: 用户ID
course_id: 课程ID可选
Returns:
List[Dict]: 知识点掌握度列表
"""
logger.info(f"获取知识点掌握度 - user_id: {user_id}, course_id: {course_id}")
# 获取知识点列表
kp_conditions = [KnowledgePoint.is_deleted == False]
if course_id:
kp_conditions.append(KnowledgePoint.course_id == course_id)
kp_stmt = select(KnowledgePoint).where(and_(*kp_conditions)).limit(10)
result = await db.execute(kp_stmt)
knowledge_points = result.scalars().all()
if not knowledge_points:
# 没有知识点数据,返回默认数据
return [
{"name": "基础概念", "mastery": 85.0},
{"name": "核心知识", "mastery": 72.0},
{"name": "实践应用", "mastery": 68.0},
{"name": "综合运用", "mastery": 58.0},
{"name": "高级技巧", "mastery": 75.0},
{"name": "案例分析", "mastery": 62.0}
]
mastery_list = []
for kp in knowledge_points:
# 统计该知识点的错误次数
mistake_stmt = select(func.count(ExamMistake.id)).where(
and_(
ExamMistake.user_id == user_id,
ExamMistake.knowledge_point_id == kp.id
)
)
mistake_count = await db.scalar(mistake_stmt) or 0
# 假设每个知识点平均被考查10次简化处理
estimated_total = 10
# 计算掌握度
if estimated_total > 0:
correct_count = max(0, estimated_total - mistake_count)
mastery = round((correct_count / estimated_total) * 100, 1)
else:
mastery = 100.0
mastery_list.append({
"name": kp.name[:10], # 限制名称长度
"mastery": mastery
})
return mastery_list[:6] # 最多返回6个知识点
@staticmethod
async def get_study_time_stats(
db: AsyncSession,
user_id: int,
course_id: Optional[int] = None,
period: str = "month"
) -> Dict[str, Any]:
"""
获取学习时长统计
Args:
db: 数据库会话
user_id: 用户ID
course_id: 课程ID可选
period: 时间范围
Returns:
Dict: 学习时长和练习时长的日期分布数据
"""
logger.info(f"获取学习时长统计 - user_id: {user_id}, course_id: {course_id}, period: {period}")
start_date, end_date = StatisticsService._get_date_range(period)
# 获取天数
days = (end_date - start_date).days
if days > 30:
days = 30 # 最多显示30天
# 生成日期列表
date_list = []
for i in range(days):
date = end_date - timedelta(days=days - i - 1)
date_list.append(date.date())
# 初始化数据
study_time_data = {str(d): 0.0 for d in date_list}
practice_time_data = {str(d): 0.0 for d in date_list}
# 统计考试时长(学习时长)
exam_conditions = [
Exam.user_id == user_id,
Exam.start_time >= start_date,
Exam.start_time <= end_date
]
if course_id:
exam_conditions.append(Exam.course_id == course_id)
exam_stmt = select(
func.date(Exam.start_time).label("date"),
func.sum(Exam.duration_minutes).label("total_minutes")
).where(
and_(*exam_conditions)
).group_by(
func.date(Exam.start_time)
)
exam_result = await db.execute(exam_stmt)
for row in exam_result.all():
date_str = str(row.date)
if date_str in study_time_data:
study_time_data[date_str] = round(float(row.total_minutes) / 60, 1)
# 统计陪练时长(练习时长)
practice_conditions = [
PracticeSession.user_id == user_id,
PracticeSession.start_time >= start_date,
PracticeSession.start_time <= end_date,
PracticeSession.status == "completed"
]
practice_stmt = select(
func.date(PracticeSession.start_time).label("date"),
func.sum(PracticeSession.duration_seconds).label("total_seconds")
).where(
and_(*practice_conditions)
).group_by(
func.date(PracticeSession.start_time)
)
practice_result = await db.execute(practice_stmt)
for row in practice_result.all():
date_str = str(row.date)
if date_str in practice_time_data:
practice_time_data[date_str] = round(float(row.total_seconds) / 3600, 1)
# 如果period是week返回星期几标签
if period == "week":
weekday_labels = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
labels = weekday_labels[:len(date_list)]
else:
# 其他情况返回日期
labels = [d.strftime("%m-%d") for d in date_list]
study_values = [study_time_data[str(d)] for d in date_list]
practice_values = [practice_time_data[str(d)] for d in date_list]
return {
"labels": labels,
"studyTime": study_values,
"practiceTime": practice_values
}
@staticmethod
async def get_detail_data(
db: AsyncSession,
user_id: int,
course_id: Optional[int] = None,
period: str = "month"
) -> List[Dict[str, Any]]:
"""
获取详细统计数据(按日期)
Args:
db: 数据库会话
user_id: 用户ID
course_id: 课程ID可选
period: 时间范围
Returns:
List[Dict]: 每日详细统计数据
"""
logger.info(f"获取详细数据 - user_id: {user_id}, course_id: {course_id}, period: {period}")
start_date, end_date = StatisticsService._get_date_range(period)
# 构建查询条件
exam_conditions = [
Exam.user_id == user_id,
Exam.start_time >= start_date,
Exam.start_time <= end_date,
Exam.round1_score.isnot(None)
]
if course_id:
exam_conditions.append(Exam.course_id == course_id)
# 按日期分组统计
stmt = select(
func.date(Exam.start_time).label("date"),
func.count(Exam.id).label("exam_count"),
func.avg(Exam.round1_score).label("avg_score"),
func.sum(Exam.duration_minutes).label("total_duration"),
func.sum(Exam.question_count).label("total_questions")
).where(
and_(*exam_conditions)
).group_by(
func.date(Exam.start_time)
).order_by(
desc(func.date(Exam.start_time))
).limit(10) # 最多返回10条
result = await db.execute(stmt)
rows = result.all()
detail_list = []
for row in rows:
date_str = row.date.strftime("%Y-%m-%d")
exam_count = row.exam_count or 0
avg_score = round(float(row.avg_score or 0), 1)
study_time = round(float(row.total_duration or 0) / 60, 1)
question_count = row.total_questions or 0
# 统计当天的错题数
mistake_stmt = select(func.count(ExamMistake.id)).where(
and_(
ExamMistake.user_id == user_id,
ExamMistake.exam_id.in_(
select(Exam.id).where(
and_(
Exam.user_id == user_id,
func.date(Exam.start_time) == row.date
)
)
)
)
)
mistake_count = await db.scalar(mistake_stmt) or 0
# 计算正确率
accuracy = 0.0
if question_count > 0:
correct_count = question_count - mistake_count
accuracy = round((correct_count / question_count) * 100, 1)
# 计算进步指数(基于平均分)
improvement = min(100, max(0, int(avg_score)))
detail_list.append({
"date": date_str,
"examCount": exam_count,
"avgScore": avg_score,
"studyTime": study_time,
"questionCount": question_count,
"accuracy": accuracy,
"improvement": improvement
})
return detail_list

View File

@@ -0,0 +1,170 @@
"""
系统日志服务
"""
import logging
from typing import Optional
from datetime import datetime
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.system_log import SystemLog
from app.schemas.system_log import SystemLogCreate, SystemLogQuery
logger = logging.getLogger(__name__)
class SystemLogService:
"""系统日志服务类"""
async def create_log(
self,
db: AsyncSession,
log_data: SystemLogCreate
) -> SystemLog:
"""
创建系统日志
Args:
db: 数据库会话
log_data: 日志数据
Returns:
创建的日志对象
"""
try:
log = SystemLog(**log_data.model_dump())
db.add(log)
await db.commit()
await db.refresh(log)
return log
except Exception as e:
await db.rollback()
logger.error(f"创建系统日志失败: {str(e)}")
raise
async def get_logs(
self,
db: AsyncSession,
query_params: SystemLogQuery
) -> tuple[list[SystemLog], int]:
"""
查询系统日志列表
Args:
db: 数据库会话
query_params: 查询参数
Returns:
(日志列表, 总数)
"""
try:
# 构建基础查询
stmt = select(SystemLog)
count_stmt = select(func.count(SystemLog.id))
# 应用筛选条件
filters = []
if query_params.level:
filters.append(SystemLog.level == query_params.level)
if query_params.type:
filters.append(SystemLog.type == query_params.type)
if query_params.user:
filters.append(SystemLog.user == query_params.user)
if query_params.keyword:
filters.append(SystemLog.message.like(f"%{query_params.keyword}%"))
if query_params.start_date:
filters.append(SystemLog.created_at >= query_params.start_date)
if query_params.end_date:
filters.append(SystemLog.created_at <= query_params.end_date)
# 应用所有筛选条件
if filters:
stmt = stmt.where(*filters)
count_stmt = count_stmt.where(*filters)
# 获取总数
result = await db.execute(count_stmt)
total = result.scalar_one()
# 应用排序和分页
stmt = stmt.order_by(SystemLog.created_at.desc())
stmt = stmt.offset((query_params.page - 1) * query_params.page_size)
stmt = stmt.limit(query_params.page_size)
# 执行查询
result = await db.execute(stmt)
logs = result.scalars().all()
return list(logs), total
except Exception as e:
logger.error(f"查询系统日志失败: {str(e)}")
raise
async def get_log_by_id(
self,
db: AsyncSession,
log_id: int
) -> Optional[SystemLog]:
"""
根据ID获取日志详情
Args:
db: 数据库会话
log_id: 日志ID
Returns:
日志对象或None
"""
try:
stmt = select(SystemLog).where(SystemLog.id == log_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()
except Exception as e:
logger.error(f"获取日志详情失败: {str(e)}")
raise
async def delete_logs_before_date(
self,
db: AsyncSession,
before_date: datetime
) -> int:
"""
删除指定日期之前的日志(用于日志清理)
Args:
db: 数据库会话
before_date: 截止日期
Returns:
删除的日志数量
"""
try:
stmt = select(SystemLog).where(SystemLog.created_at < before_date)
result = await db.execute(stmt)
logs = result.scalars().all()
count = len(logs)
for log in logs:
await db.delete(log)
await db.commit()
logger.info(f"已删除 {count} 条日志记录")
return count
except Exception as e:
await db.rollback()
logger.error(f"删除日志失败: {str(e)}")
raise
# 创建全局服务实例
system_log_service = SystemLogService()

View File

@@ -0,0 +1,214 @@
"""
任务服务
"""
from typing import List, Optional
from datetime import datetime
from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.models.task import Task, TaskCourse, TaskAssignment, TaskStatus, AssignmentStatus
from app.models.course import Course
from app.schemas.task import TaskCreate, TaskUpdate, TaskStatsResponse
from app.services.base_service import BaseService
class TaskService(BaseService[Task]):
"""任务服务"""
def __init__(self):
super().__init__(Task)
async def create_task(self, db: AsyncSession, task_in: TaskCreate, creator_id: int) -> Task:
"""创建任务"""
# 创建任务
task = Task(
title=task_in.title,
description=task_in.description,
priority=task_in.priority,
deadline=task_in.deadline,
requirements=task_in.requirements,
creator_id=creator_id,
status=TaskStatus.PENDING
)
db.add(task)
await db.flush()
# 关联课程
for course_id in task_in.course_ids:
task_course = TaskCourse(task_id=task.id, course_id=course_id)
db.add(task_course)
# 分配用户
for user_id in task_in.user_ids:
assignment = TaskAssignment(task_id=task.id, user_id=user_id)
db.add(assignment)
await db.commit()
await db.refresh(task)
return task
async def get_tasks(
self,
db: AsyncSession,
status: Optional[str] = None,
page: int = 1,
page_size: int = 20
) -> (List[Task], int):
"""获取任务列表"""
stmt = select(Task).where(Task.is_deleted == False)
if status:
stmt = stmt.where(Task.status == status)
stmt = stmt.order_by(Task.created_at.desc())
# 获取总数
count_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False)
if status:
count_stmt = count_stmt.where(Task.status == status)
total = (await db.execute(count_stmt)).scalar_one()
# 分页
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await db.execute(stmt)
tasks = result.scalars().all()
return tasks, total
async def get_task_detail(self, db: AsyncSession, task_id: int) -> Optional[Task]:
"""获取任务详情"""
stmt = select(Task).where(
and_(Task.id == task_id, Task.is_deleted == False)
).options(
joinedload(Task.course_links).joinedload(TaskCourse.course),
joinedload(Task.assignments)
)
result = await db.execute(stmt)
return result.unique().scalar_one_or_none()
async def update_task(self, db: AsyncSession, task_id: int, task_in: TaskUpdate) -> Optional[Task]:
"""更新任务"""
stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False))
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if not task:
return None
update_data = task_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(task, field, value)
await db.commit()
await db.refresh(task)
return task
async def delete_task(self, db: AsyncSession, task_id: int) -> bool:
"""删除任务(软删除)"""
stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False))
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if not task:
return False
task.is_deleted = True
await db.commit()
return True
async def get_task_stats(self, db: AsyncSession) -> TaskStatsResponse:
"""获取任务统计"""
# 总任务数
total_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False)
total = (await db.execute(total_stmt)).scalar_one()
# 各状态任务数
status_stmt = select(
Task.status,
func.count(Task.id)
).where(Task.is_deleted == False).group_by(Task.status)
status_result = await db.execute(status_stmt)
status_counts = dict(status_result.all())
# 平均完成率
avg_stmt = select(func.avg(Task.progress)).where(
and_(Task.is_deleted == False, Task.status != TaskStatus.EXPIRED)
)
avg_completion = (await db.execute(avg_stmt)).scalar_one() or 0.0
return TaskStatsResponse(
total=total,
ongoing=status_counts.get(TaskStatus.ONGOING.value, 0),
completed=status_counts.get(TaskStatus.COMPLETED.value, 0),
expired=status_counts.get(TaskStatus.EXPIRED.value, 0),
avg_completion_rate=round(avg_completion, 1)
)
async def update_task_progress(self, db: AsyncSession, task_id: int) -> int:
"""
更新任务进度
计算已完成的分配数占总分配数的百分比
"""
# 统计总分配数和完成数
stmt = select(
func.count(TaskAssignment.id).label('total'),
func.sum(
func.case(
(TaskAssignment.status == AssignmentStatus.COMPLETED, 1),
else_=0
)
).label('completed')
).where(TaskAssignment.task_id == task_id)
result = (await db.execute(stmt)).first()
total = result.total or 0
completed = result.completed or 0
if total == 0:
progress = 0
else:
progress = int((completed / total) * 100)
# 更新任务进度
task_stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False))
task_result = await db.execute(task_stmt)
task = task_result.scalar_one_or_none()
if task:
task.progress = progress
await db.commit()
return progress
async def update_task_status(self, db: AsyncSession, task_id: int):
"""
更新任务状态
根据进度和截止时间自动更新任务状态
"""
task = await self.get_task_detail(db, task_id)
if not task:
return
# 计算并更新进度
progress = await self.update_task_progress(db, task_id)
# 自动更新状态
now = datetime.now()
if progress == 100:
# 完全完成
task.status = TaskStatus.COMPLETED
elif task.deadline and now > task.deadline and task.status != TaskStatus.COMPLETED:
# 已过期且未完成
task.status = TaskStatus.EXPIRED
elif progress > 0 and task.status == TaskStatus.PENDING:
# 已开始但未完成
task.status = TaskStatus.ONGOING
await db.commit()
task_service = TaskService()

View File

@@ -0,0 +1,372 @@
"""陪练服务层"""
import logging
from typing import List, Optional, Dict, Any
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func
from fastapi import HTTPException, status
from app.models.training import (
TrainingScene,
TrainingSession,
TrainingMessage,
TrainingReport,
TrainingSceneStatus,
TrainingSessionStatus,
MessageRole,
MessageType,
)
from app.schemas.training import (
TrainingSceneCreate,
TrainingSceneUpdate,
TrainingSessionCreate,
TrainingSessionUpdate,
TrainingMessageCreate,
TrainingReportCreate,
StartTrainingRequest,
StartTrainingResponse,
EndTrainingRequest,
EndTrainingResponse,
)
from app.services.base_service import BaseService
# from app.services.ai.coze.client import CozeClient
from app.core.config import get_settings
logger = logging.getLogger(__name__)
settings = get_settings()
class TrainingSceneService(BaseService[TrainingScene]):
"""陪练场景服务"""
def __init__(self):
super().__init__(TrainingScene)
async def get_active_scenes(
self,
db: AsyncSession,
*,
category: Optional[str] = None,
is_public: Optional[bool] = None,
user_level: Optional[int] = None,
skip: int = 0,
limit: int = 20,
) -> List[TrainingScene]:
"""获取激活的陪练场景列表"""
query = select(self.model).where(
and_(
self.model.status == TrainingSceneStatus.ACTIVE,
self.model.is_deleted == False,
)
)
if category:
query = query.where(self.model.category == category)
if is_public is not None:
query = query.where(self.model.is_public == is_public)
if user_level is not None:
query = query.where(
or_(
self.model.required_level == None,
self.model.required_level <= user_level,
)
)
return await self.get_multi(db, skip=skip, limit=limit, query=query)
async def create_scene(
self, db: AsyncSession, *, scene_in: TrainingSceneCreate, created_by: int
) -> TrainingScene:
"""创建陪练场景"""
return await self.create(
db, obj_in=scene_in, created_by=created_by, updated_by=created_by
)
async def update_scene(
self,
db: AsyncSession,
*,
scene_id: int,
scene_in: TrainingSceneUpdate,
updated_by: int,
) -> Optional[TrainingScene]:
"""更新陪练场景"""
scene = await self.get(db, scene_id)
if not scene or scene.is_deleted:
return None
scene.updated_by = updated_by
return await self.update(db, db_obj=scene, obj_in=scene_in)
class TrainingSessionService(BaseService[TrainingSession]):
"""陪练会话服务"""
def __init__(self):
super().__init__(TrainingSession)
self.scene_service = TrainingSceneService()
self.message_service = TrainingMessageService()
self.report_service = TrainingReportService()
# TODO: 等Coze网关模块实现后替换
self._coze_client = None
@property
def coze_client(self):
"""延迟初始化Coze客户端"""
if self._coze_client is None:
try:
# from app.services.ai.coze.client import CozeClient
# self._coze_client = CozeClient()
logger.warning("Coze客户端暂未实现使用模拟模式")
self._coze_client = None
except ImportError:
logger.warning("Coze客户端未实现使用模拟模式")
return self._coze_client
async def start_training(
self, db: AsyncSession, *, request: StartTrainingRequest, user_id: int
) -> StartTrainingResponse:
"""开始陪练会话"""
# 验证场景
scene = await self.scene_service.get(db, request.scene_id)
if not scene or scene.is_deleted or scene.status != TrainingSceneStatus.ACTIVE:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在或未激活"
)
# 检查用户等级
# TODO: 从User服务获取用户等级
user_level = 1 # 临时模拟
if scene.required_level and user_level < scene.required_level:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="用户等级不足")
# 创建会话
session_data = TrainingSessionCreate(
scene_id=request.scene_id, session_config=request.config
)
session = await self.create(
db, obj_in=session_data, user_id=user_id, created_by=user_id
)
# 初始化Coze会话
coze_conversation_id = None
if self.coze_client and scene.ai_config:
try:
bot_id = scene.ai_config.get("bot_id", settings.coze_training_bot_id)
if bot_id:
# 创建Coze会话
coze_result = await self.coze_client.create_conversation(
bot_id=bot_id,
user_id=str(user_id),
meta_data={
"scene_id": scene.id,
"scene_name": scene.name,
"session_id": session.id,
},
)
coze_conversation_id = coze_result.get("conversation_id")
# 更新会话的Coze ID
session.coze_conversation_id = coze_conversation_id
await db.commit()
except Exception as e:
logger.error(f"创建Coze会话失败: {e}")
# 加载场景信息
await db.refresh(session, ["scene"])
# 构造WebSocket URL如果需要
websocket_url = None
if coze_conversation_id:
websocket_url = f"ws://localhost:8000/ws/v1/training/{session.id}"
return StartTrainingResponse(
session_id=session.id,
coze_conversation_id=coze_conversation_id,
scene=scene,
websocket_url=websocket_url,
)
async def end_training(
self,
db: AsyncSession,
*,
session_id: int,
request: EndTrainingRequest,
user_id: int,
) -> EndTrainingResponse:
"""结束陪练会话"""
# 获取会话
session = await self.get(db, session_id)
if not session or session.user_id != user_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="会话不存在")
if session.status in [
TrainingSessionStatus.COMPLETED,
TrainingSessionStatus.CANCELLED,
]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="会话已结束")
# 计算持续时间
end_time = datetime.now()
duration_seconds = int((end_time - session.start_time).total_seconds())
# 更新会话状态
update_data = TrainingSessionUpdate(
status=TrainingSessionStatus.COMPLETED,
end_time=end_time,
duration_seconds=duration_seconds,
)
session = await self.update(db, db_obj=session, obj_in=update_data)
# 生成报告
report = None
if request.generate_report:
report = await self._generate_report(
db, session_id=session_id, user_id=user_id
)
# 加载关联数据
await db.refresh(session, ["scene"])
if report:
await db.refresh(report, ["session"])
return EndTrainingResponse(session=session, report=report)
async def get_user_sessions(
self,
db: AsyncSession,
*,
user_id: int,
scene_id: Optional[int] = None,
status: Optional[TrainingSessionStatus] = None,
skip: int = 0,
limit: int = 20,
) -> List[TrainingSession]:
"""获取用户的陪练会话列表"""
query = select(self.model).where(self.model.user_id == user_id)
if scene_id:
query = query.where(self.model.scene_id == scene_id)
if status:
query = query.where(self.model.status == status)
query = query.order_by(self.model.created_at.desc())
return await self.get_multi(db, skip=skip, limit=limit, query=query)
async def _generate_report(
self, db: AsyncSession, *, session_id: int, user_id: int
) -> Optional[TrainingReport]:
"""生成陪练报告(内部方法)"""
# 获取会话消息
messages = await self.message_service.get_session_messages(
db, session_id=session_id
)
# TODO: 调用AI分析服务生成报告
# 这里先生成模拟报告
report_data = TrainingReportCreate(
session_id=session_id,
user_id=user_id,
overall_score=85.5,
dimension_scores={"表达能力": 88.0, "逻辑思维": 85.0, "专业知识": 82.0, "应变能力": 87.0},
strengths=["表达清晰,语言流畅", "能够快速理解问题并作出回应", "展现了良好的专业素养"],
weaknesses=["部分专业术语使用不够准确", "回答有时过于冗长,需要更加精炼"],
suggestions=["加强专业知识的学习,特别是术语的准确使用", "练习更加简洁有力的表达方式", "增加实际案例的积累,丰富回答内容"],
detailed_analysis="整体表现良好,展现了扎实的基础知识和良好的沟通能力...",
statistics={
"total_messages": len(messages),
"user_messages": len(
[m for m in messages if m.role == MessageRole.USER]
),
"avg_response_time": 2.5,
"total_words": 1500,
},
)
return await self.report_service.create(
db, obj_in=report_data, created_by=user_id
)
class TrainingMessageService(BaseService[TrainingMessage]):
"""陪练消息服务"""
def __init__(self):
super().__init__(TrainingMessage)
async def create_message(
self, db: AsyncSession, *, message_in: TrainingMessageCreate
) -> TrainingMessage:
"""创建消息"""
return await self.create(db, obj_in=message_in)
async def get_session_messages(
self, db: AsyncSession, *, session_id: int, skip: int = 0, limit: int = 100
) -> List[TrainingMessage]:
"""获取会话的所有消息"""
query = (
select(self.model)
.where(self.model.session_id == session_id)
.order_by(self.model.created_at)
)
return await self.get_multi(db, skip=skip, limit=limit, query=query)
async def save_voice_message(
self,
db: AsyncSession,
*,
session_id: int,
role: MessageRole,
content: str,
voice_url: str,
voice_duration: float,
metadata: Optional[Dict[str, Any]] = None,
) -> TrainingMessage:
"""保存语音消息"""
message_data = TrainingMessageCreate(
session_id=session_id,
role=role,
type=MessageType.VOICE,
content=content,
voice_url=voice_url,
voice_duration=voice_duration,
metadata=metadata,
)
return await self.create(db, obj_in=message_data)
class TrainingReportService(BaseService[TrainingReport]):
"""陪练报告服务"""
def __init__(self):
super().__init__(TrainingReport)
async def get_by_session(
self, db: AsyncSession, *, session_id: int
) -> Optional[TrainingReport]:
"""根据会话ID获取报告"""
result = await db.execute(
select(self.model).where(self.model.session_id == session_id)
)
return result.scalar_one_or_none()
async def get_user_reports(
self, db: AsyncSession, *, user_id: int, skip: int = 0, limit: int = 20
) -> List[TrainingReport]:
"""获取用户的所有报告"""
query = (
select(self.model)
.where(self.model.user_id == user_id)
.order_by(self.model.created_at.desc())
)
return await self.get_multi(db, skip=skip, limit=limit, query=query)

View File

@@ -0,0 +1,423 @@
"""
用户服务
"""
from datetime import datetime
from typing import Any, Dict, List, Optional
from sqlalchemy import and_, or_, select, func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.exceptions import ConflictError, NotFoundError
from app.core.logger import logger
from app.core.security import get_password_hash, verify_password
from app.models.user import Team, User, user_teams
from app.schemas.user import UserCreate, UserFilter, UserUpdate
from app.services.base_service import BaseService
class UserService(BaseService[User]):
"""用户服务"""
def __init__(self, db: AsyncSession):
super().__init__(User)
self.db = db
async def get_by_id(self, user_id: int) -> Optional[User]:
"""根据ID获取用户"""
result = await self.db.execute(
select(User).where(User.id == user_id, User.is_deleted == False)
)
return result.scalar_one_or_none()
async def get_by_username(self, username: str) -> Optional[User]:
"""根据用户名获取用户"""
result = await self.db.execute(
select(User).where(
User.username == username,
User.is_deleted == False,
)
)
return result.scalar_one_or_none()
async def get_by_email(self, email: str) -> Optional[User]:
"""根据邮箱获取用户"""
result = await self.db.execute(
select(User).where(
User.email == email,
User.is_deleted == False,
)
)
return result.scalar_one_or_none()
async def get_by_phone(self, phone: str) -> Optional[User]:
"""根据手机号获取用户"""
result = await self.db.execute(
select(User).where(
User.phone == phone,
User.is_deleted == False,
)
)
return result.scalar_one_or_none()
async def _check_username_exists_all(self, username: str) -> Optional[User]:
"""
检查用户名是否已存在(包括已删除的用户)
用于创建用户时检查唯一性约束
"""
result = await self.db.execute(
select(User).where(User.username == username)
)
return result.scalar_one_or_none()
async def _check_email_exists_all(self, email: str) -> Optional[User]:
"""
检查邮箱是否已存在(包括已删除的用户)
用于创建用户时检查唯一性约束
"""
result = await self.db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
async def _check_phone_exists_all(self, phone: str) -> Optional[User]:
"""
检查手机号是否已存在(包括已删除的用户)
用于创建用户时检查唯一性约束
"""
result = await self.db.execute(
select(User).where(User.phone == phone)
)
return result.scalar_one_or_none()
async def create_user(
self,
*,
obj_in: UserCreate,
created_by: Optional[int] = None,
) -> User:
"""创建用户"""
# 检查用户名是否已存在(包括已删除的用户,防止唯一键冲突)
existing_user = await self._check_username_exists_all(obj_in.username)
if existing_user:
if existing_user.is_deleted:
raise ConflictError(f"用户名 {obj_in.username} 已被使用(历史用户),请更换其他用户名")
else:
raise ConflictError(f"用户名 {obj_in.username} 已存在")
# 检查邮箱是否已存在(包括已删除的用户)
if obj_in.email:
existing_email = await self._check_email_exists_all(obj_in.email)
if existing_email:
if existing_email.is_deleted:
raise ConflictError(f"邮箱 {obj_in.email} 已被使用(历史用户),请更换其他邮箱")
else:
raise ConflictError(f"邮箱 {obj_in.email} 已存在")
# 检查手机号是否已存在(包括已删除的用户)
if obj_in.phone:
existing_phone = await self._check_phone_exists_all(obj_in.phone)
if existing_phone:
if existing_phone.is_deleted:
raise ConflictError(f"手机号 {obj_in.phone} 已被使用(历史用户),请更换其他手机号")
else:
raise ConflictError(f"手机号 {obj_in.phone} 已存在")
# 创建用户数据
user_data = obj_in.model_dump(exclude={"password"})
user_data["hashed_password"] = get_password_hash(obj_in.password)
# 注意User模型不包含created_by字段该信息记录在日志中
# user_data["created_by"] = created_by
try:
# 创建用户
user = await self.create(db=self.db, obj_in=user_data)
except IntegrityError as e:
# 捕获数据库唯一键冲突异常,返回友好错误信息
await self.db.rollback()
error_msg = str(e.orig) if e.orig else str(e)
logger.warning(
"创建用户时发生唯一键冲突",
username=obj_in.username,
email=obj_in.email,
error=error_msg,
)
if "username" in error_msg.lower():
raise ConflictError(f"用户名 {obj_in.username} 已被占用,请更换其他用户名")
elif "email" in error_msg.lower():
raise ConflictError(f"邮箱 {obj_in.email} 已被占用,请更换其他邮箱")
elif "phone" in error_msg.lower():
raise ConflictError(f"手机号 {obj_in.phone} 已被占用,请更换其他手机号")
else:
raise ConflictError(f"创建用户失败:数据冲突,请检查用户名、邮箱或手机号是否重复")
# 记录日志
logger.info(
"用户创建成功",
user_id=user.id,
username=user.username,
role=user.role,
created_by=created_by,
)
return user
async def update_user(
self,
*,
user_id: int,
obj_in: UserUpdate,
updated_by: Optional[int] = None,
) -> User:
"""更新用户"""
user = await self.get_by_id(user_id)
if not user:
raise NotFoundError("用户不存在")
# 如果更新邮箱,检查是否已存在
if obj_in.email and obj_in.email != user.email:
if await self.get_by_email(obj_in.email):
raise ConflictError(f"邮箱 {obj_in.email} 已存在")
# 如果更新手机号,检查是否已存在
if obj_in.phone and obj_in.phone != user.phone:
if await self.get_by_phone(obj_in.phone):
raise ConflictError(f"手机号 {obj_in.phone} 已存在")
# 更新用户数据
update_data = obj_in.model_dump(exclude_unset=True)
update_data["updated_by"] = updated_by
user = await self.update(db=self.db, db_obj=user, obj_in=update_data)
# 记录日志
logger.info(
"用户更新成功",
user_id=user.id,
username=user.username,
updated_fields=list(update_data.keys()),
updated_by=updated_by,
)
return user
async def update_password(
self,
*,
user_id: int,
old_password: str,
new_password: str,
) -> User:
"""更新密码"""
user = await self.get_by_id(user_id)
if not user:
raise NotFoundError("用户不存在")
# 验证旧密码
if not verify_password(old_password, user.hashed_password):
raise ConflictError("旧密码错误")
# 更新密码
update_data = {
"hashed_password": get_password_hash(new_password),
"password_changed_at": datetime.now(),
}
user = await self.update(db=self.db, db_obj=user, obj_in=update_data)
# 记录日志
logger.info(
"用户密码更新成功",
user_id=user.id,
username=user.username,
)
return user
async def update_last_login(self, user_id: int) -> None:
"""更新最后登录时间"""
user = await self.get_by_id(user_id)
if user:
await self.update(
db=self.db,
db_obj=user,
obj_in={"last_login_at": datetime.now()},
)
async def get_users_with_filter(
self,
*,
skip: int = 0,
limit: int = 100,
filter_params: UserFilter,
) -> tuple[List[User], int]:
"""根据筛选条件获取用户列表"""
# 构建筛选条件
filters = [User.is_deleted == False]
if filter_params.role:
filters.append(User.role == filter_params.role)
if filter_params.is_active is not None:
filters.append(User.is_active == filter_params.is_active)
if filter_params.keyword:
keyword = f"%{filter_params.keyword}%"
filters.append(
or_(
User.username.like(keyword),
User.email.like(keyword),
User.full_name.like(keyword),
)
)
if filter_params.team_id:
# 通过团队ID筛选用户
subquery = select(user_teams.c.user_id).where(
user_teams.c.team_id == filter_params.team_id
)
filters.append(User.id.in_(subquery))
# 构建查询
query = select(User).where(and_(*filters))
# 获取用户列表
users = await self.get_multi(self.db, skip=skip, limit=limit, query=query)
# 获取总数
count_query = select(func.count(User.id)).where(and_(*filters))
count_result = await self.db.execute(count_query)
total = count_result.scalar()
return users, total
async def add_user_to_team(
self,
*,
user_id: int,
team_id: int,
role: str = "member",
) -> None:
"""将用户添加到团队"""
# 检查用户是否存在
user = await self.get_by_id(user_id)
if not user:
raise NotFoundError("用户不存在")
# 检查团队是否存在
team_result = await self.db.execute(
select(Team).where(Team.id == team_id, Team.is_deleted == False)
)
team = team_result.scalar_one_or_none()
if not team:
raise NotFoundError("团队不存在")
# 检查是否已在团队中
existing = await self.db.execute(
select(user_teams).where(
user_teams.c.user_id == user_id,
user_teams.c.team_id == team_id,
)
)
if existing.first():
raise ConflictError("用户已在该团队中")
# 添加到团队
await self.db.execute(
user_teams.insert().values(
user_id=user_id,
team_id=team_id,
role=role,
joined_at=datetime.now(),
)
)
await self.db.commit()
# 记录日志
logger.info(
"用户加入团队",
user_id=user_id,
username=user.username,
team_id=team_id,
team_name=team.name,
role=role,
)
async def remove_user_from_team(
self,
*,
user_id: int,
team_id: int,
) -> None:
"""从团队中移除用户"""
# 删除关联
result = await self.db.execute(
user_teams.delete().where(
user_teams.c.user_id == user_id,
user_teams.c.team_id == team_id,
)
)
if result.rowcount == 0:
raise NotFoundError("用户不在该团队中")
await self.db.commit()
# 记录日志
logger.info(
"用户离开团队",
user_id=user_id,
team_id=team_id,
)
async def soft_delete(self, *, db_obj: User) -> User:
"""
软删除用户
Args:
db_obj: 用户对象
Returns:
软删除后的用户对象
"""
db_obj.is_deleted = True
db_obj.deleted_at = datetime.now()
self.db.add(db_obj)
await self.db.commit()
await self.db.refresh(db_obj)
logger.info(
"用户软删除成功",
user_id=db_obj.id,
username=db_obj.username,
)
return db_obj
async def authenticate(
self,
*,
username: str,
password: str,
) -> Optional[User]:
"""用户认证"""
# 尝试用户名登录
user = await self.get_by_username(username)
# 尝试邮箱登录
if not user:
user = await self.get_by_email(username)
# 尝试手机号登录
if not user:
user = await self.get_by_phone(username)
if not user:
return None
# 验证密码
if not verify_password(password, user.hashed_password):
return None
return user

View File

@@ -0,0 +1,510 @@
"""
言迹智能工牌API服务
"""
import logging
import random
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
import httpx
from app.core.config import settings
logger = logging.getLogger(__name__)
class YanjiService:
"""言迹智能工牌API服务类"""
def __init__(self):
self.base_url = settings.YANJI_API_BASE
self.client_id = settings.YANJI_CLIENT_ID
self.client_secret = settings.YANJI_CLIENT_SECRET
self.tenant_id = settings.YANJI_TENANT_ID
self.estate_id = int(settings.YANJI_ESTATE_ID)
# Token缓存
self._access_token: Optional[str] = None
self._token_expires_at: Optional[datetime] = None
async def get_access_token(self) -> str:
"""
获取或刷新access_token
Returns:
access_token字符串
"""
# 检查缓存的token是否仍然有效提前5分钟刷新
if self._access_token and self._token_expires_at:
if datetime.now() < self._token_expires_at - timedelta(minutes=5):
return self._access_token
# 获取新的token
url = f"{self.base_url}/oauth/token"
params = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
}
async with httpx.AsyncClient() as client:
response = await client.get(url, params=params, timeout=30.0)
response.raise_for_status()
data = response.json()
self._access_token = data["access_token"]
expires_in = data.get("expires_in", 3600) # 默认1小时
self._token_expires_at = datetime.now() + timedelta(seconds=expires_in)
logger.info(f"言迹API token获取成功有效期至: {self._token_expires_at}")
return self._access_token
async def _request(
self,
method: str,
path: str,
params: Optional[Dict] = None,
json_data: Optional[Dict] = None,
) -> Dict:
"""
统一的HTTP请求方法
Args:
method: HTTP方法GET/POST等
path: API路径
params: Query参数
json_data: Body参数JSON
Returns:
响应数据data字段
Raises:
Exception: API调用失败
"""
token = await self.get_access_token()
url = f"{self.base_url}{path}"
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient() as client:
response = await client.request(
method=method,
url=url,
params=params,
json=json_data,
headers=headers,
timeout=60.0,
)
response.raise_for_status()
result = response.json()
# 言迹API: code='0'或code=0表示成功
code = result.get("code")
if str(code) != '0':
error_msg = result.get("msg", "Unknown error")
logger.error(f"言迹API调用失败: {error_msg}, result={result}")
raise Exception(f"言迹API错误: {error_msg}")
# data可能为None返回空字典或空列表由调用方判断
return result.get("data")
async def get_visit_audios(
self, external_visit_ids: List[str]
) -> List[Dict]:
"""
根据来访单ID获取录音信息
Args:
external_visit_ids: 三方来访ID列表最多10个
Returns:
录音信息列表
"""
if not external_visit_ids:
return []
if len(external_visit_ids) > 10:
logger.warning(f"来访单ID数量超过限制截取前10个")
external_visit_ids = external_visit_ids[:10]
data = await self._request(
method="POST",
path="/api/beauty/v1/visit/audios",
json_data={
"estateId": self.estate_id,
"externalVisitIds": external_visit_ids,
},
)
if data is None:
logger.info(f"获取来访录音信息: 无数据")
return []
records = data.get("records", [])
logger.info(f"获取来访录音信息成功: {len(records)}")
return records
async def get_audio_asr_result(self, audio_id: int) -> Dict:
"""
获取录音的ASR分析结果对话文本
Args:
audio_id: 录音ID
Returns:
ASR分析结果包含对话文本数组
"""
data = await self._request(
method="GET",
path="/api/beauty/v1/audio/asr-analysed",
params={"estateId": self.estate_id, "audioId": audio_id},
)
# 检查data是否为None
if data is None:
logger.warning(f"录音ASR结果为None: audio_id={audio_id}")
return {}
# data是一个数组取第一个元素
if isinstance(data, list) and len(data) > 0:
result = data[0]
logger.info(
f"获取录音ASR结果成功: audio_id={audio_id}, "
f"对话数={len(result.get('result', []))}"
)
return result
else:
logger.warning(f"录音ASR结果为空: audio_id={audio_id}")
return {}
async def get_recent_conversations(
self, consultant_phone: str, limit: int = 10
) -> List[Dict]:
"""
获取员工最近N条对话记录
业务逻辑:
1. 通过员工手机号获取录音列表(目前使用模拟数据)
2. 对每个录音获取ASR分析结果
3. 组合返回完整的对话记录
Args:
consultant_phone: 员工手机号
limit: 获取数量默认10条
Returns:
对话记录列表,格式:
[{
"audio_id": 123,
"visit_id": "xxx",
"start_time": "2025-01-15 10:30:00",
"duration": 120000,
"consultant_name": "张三",
"consultant_phone": "13800138000",
"conversation": [
{"role": "consultant", "text": "您好..."},
{"role": "customer", "text": "你好..."}
]
}]
"""
# TODO: 目前言迹API没有直接通过手机号查询录音的接口
# 需要先获取来访单列表,再获取录音
# 这里暂时返回空列表,后续根据实际业务需求补充
logger.warning(
f"获取员工对话记录功能需要额外的业务逻辑支持 "
f"(consultant_phone={consultant_phone}, limit={limit})"
)
# 返回空列表,表示暂未实现
return []
async def get_conversations_by_visit_ids(
self, external_visit_ids: List[str]
) -> List[Dict]:
"""
根据来访单ID列表获取对话记录
Args:
external_visit_ids: 三方来访ID列表
Returns:
对话记录列表
"""
if not external_visit_ids:
return []
# 1. 获取录音信息
audio_records = await self.get_visit_audios(external_visit_ids)
if not audio_records:
logger.info("没有找到录音记录")
return []
# 2. 对每个录音获取ASR分析结果
conversations = []
for audio in audio_records:
audio_id = audio.get("id")
if not audio_id:
continue
try:
asr_result = await self.get_audio_asr_result(audio_id)
# 解析对话文本
conversation_messages = []
for item in asr_result.get("result", []):
role = "consultant" if item.get("role") == -1 else "customer"
conversation_messages.append({
"role": role,
"text": item.get("text", ""),
"begin_time": item.get("beginTime"),
"end_time": item.get("endTime"),
})
# 组合完整对话记录
conversations.append({
"audio_id": audio_id,
"visit_id": audio.get("externalVisitId", ""),
"start_time": audio.get("startTime", ""),
"duration": audio.get("duration", 0),
"consultant_name": audio.get("consultantName", ""),
"consultant_phone": audio.get("consultantPhone", ""),
"conversation": conversation_messages,
})
except Exception as e:
logger.error(f"获取录音ASR结果失败: audio_id={audio_id}, error={e}")
continue
logger.info(f"成功获取{len(conversations)}条对话记录")
return conversations
async def get_audio_list(self, phone: str) -> List[Dict]:
"""
获取员工的录音列表(模拟)
注意言迹API暂时没有提供通过手机号直接查询录音列表的接口
这里使用模拟数据,返回假想的录音列表
Args:
phone: 员工手机号
Returns:
录音信息列表
"""
logger.info(f"获取员工录音列表(模拟): phone={phone}")
# 模拟返回10条录音记录
mock_audios = []
base_time = datetime.now()
for i in range(10):
# 模拟不同时长的录音
durations = [25000, 45000, 180000, 240000, 120000, 90000, 60000, 300000, 420000, 150000]
mock_audios.append({
"id": f"mock_audio_{i+1}",
"externalVisitId": f"visit_{i+1}",
"startTime": (base_time - timedelta(days=i)).strftime("%Y-%m-%d %H:%M:%S"),
"duration": durations[i], # 毫秒
"consultantName": "模拟员工",
"consultantPhone": phone
})
return mock_audios
async def get_employee_conversations_for_analysis(
self,
phone: str,
limit: int = 10
) -> List[Dict[str, Any]]:
"""
获取员工最近N条录音的模拟对话数据用于能力分析
Args:
phone: 员工手机号
limit: 获取数量默认10条
Returns:
对话数据列表,格式:
[{
"audio_id": "mock_audio_1",
"duration_seconds": 25,
"start_time": "2025-10-15 10:30:00",
"dialogue_history": [
{"speaker": "consultant", "content": "您好..."},
{"speaker": "customer", "content": "你好..."}
]
}]
"""
# 1. 获取录音列表
audios = await self.get_audio_list(phone)
if not audios:
logger.warning(f"未找到员工的录音记录: phone={phone}")
return []
# 2. 筛选前limit条
selected_audios = audios[:limit]
# 3. 为每条录音生成模拟对话
conversations = []
for audio in selected_audios:
conversation = self._generate_mock_conversation(audio)
conversations.append(conversation)
logger.info(f"生成模拟对话数据: phone={phone}, count={len(conversations)}")
return conversations
def _generate_mock_conversation(self, audio: Dict) -> Dict:
"""
为录音生成模拟对话数据
根据录音时长选择不同复杂度的对话模板:
- <30秒: 短对话4-6轮
- 30秒-5分钟: 中等对话8-12轮
- >5分钟: 长对话15-20轮完整销售流程
Args:
audio: 录音信息字典
Returns:
对话数据字典
"""
duration = int(audio.get('duration', 60000)) // 1000 # 转换为秒
# 根据时长选择对话模板
if duration < 30:
dialogue = self._short_conversation_template()
elif duration < 300:
dialogue = self._medium_conversation_template()
else:
dialogue = self._long_conversation_template()
return {
"audio_id": audio.get('id'),
"duration_seconds": duration,
"start_time": audio.get('startTime'),
"dialogue_history": dialogue
}
def _short_conversation_template(self) -> List[Dict]:
"""短对话模板(<30秒- 4-6轮对话"""
templates = [
[
{"speaker": "consultant", "content": "您好,欢迎光临曼尼斐绮,请问有什么可以帮到您?"},
{"speaker": "customer", "content": "你好,我想了解一下面部护理项目"},
{"speaker": "consultant", "content": "好的,我们有多种面部护理方案,请问您主要关注哪方面呢?"},
{"speaker": "customer", "content": "主要是想改善皮肤暗沉"},
{"speaker": "consultant", "content": "明白了,针对皮肤暗沉,我推荐我们的美白焕肤套餐"}
],
[
{"speaker": "consultant", "content": "您好,请问需要什么帮助吗?"},
{"speaker": "customer", "content": "我想咨询一下祛斑项目"},
{"speaker": "consultant", "content": "好的,请问您主要是哪种类型的斑点呢?"},
{"speaker": "customer", "content": "脸颊两侧有些黄褐斑"},
{"speaker": "consultant", "content": "了解,我们有专门针对黄褐斑的光子嫩肤项目,效果很不错"}
],
[
{"speaker": "consultant", "content": "欢迎光临,有什么可以帮您的吗?"},
{"speaker": "customer", "content": "我想预约做个面部护理"},
{"speaker": "consultant", "content": "好的,请问您之前做过我们的项目吗?"},
{"speaker": "customer", "content": "没有,第一次来"},
{"speaker": "consultant", "content": "那我建议您先做个免费的皮肤检测,帮您制定个性化方案"},
{"speaker": "customer", "content": "好的,那现在可以吗?"}
]
]
return random.choice(templates)
def _medium_conversation_template(self) -> List[Dict]:
"""中等对话模板30秒-5分钟- 8-12轮对话"""
templates = [
[
{"speaker": "consultant", "content": "您好,欢迎光临曼尼斐绮,我是美容顾问小王,请问怎么称呼您?"},
{"speaker": "customer", "content": "你好,我姓李"},
{"speaker": "consultant", "content": "李女士您好,请问今天是第一次了解我们的项目吗?"},
{"speaker": "customer", "content": "是的,之前在网上看到你们的介绍"},
{"speaker": "consultant", "content": "好的,您对哪方面的美容项目比较感兴趣呢?"},
{"speaker": "customer", "content": "我想改善面部松弛的问题,最近感觉皮肤没有以前紧致了"},
{"speaker": "consultant", "content": "我理解您的困扰。请问您多大年龄?平时有做面部护理吗?"},
{"speaker": "customer", "content": "我35岁平时就是用护肤品没做过专业护理"},
{"speaker": "consultant", "content": "明白了。35岁开始注重抗衰是很及时的。我们有几种方案比如射频紧肤、超声刀提拉还有胶原蛋白再生项目"},
{"speaker": "customer", "content": "这几种有什么区别吗?"},
{"speaker": "consultant", "content": "射频主要是刺激胶原蛋白增生,效果温和持久。超声刀作用更深层,提拉效果更明显但价格稍高。我建议您先做个皮肤检测,看具体适合哪种"},
{"speaker": "customer", "content": "好的,那先做个检测吧"}
],
[
{"speaker": "consultant", "content": "您好,欢迎光临,我是美容顾问晓雯,请问您是第一次来吗?"},
{"speaker": "customer", "content": "是的,朋友推荐过来看看"},
{"speaker": "consultant", "content": "太好了,请问您朋友是做的什么项目呢?"},
{"speaker": "customer", "content": "她做的好像是什么水光针"},
{"speaker": "consultant", "content": "水光针确实是我们很受欢迎的项目。请问您今天主要想了解哪方面呢?"},
{"speaker": "customer", "content": "我主要是皮肤有点粗糙,毛孔也大"},
{"speaker": "consultant", "content": "嗯,针对毛孔粗大和皮肤粗糙,水光针确实有不错的效果。不过我建议先看看您的具体情况"},
{"speaker": "customer", "content": "需要检查吗?"},
{"speaker": "consultant", "content": "是的,我们有专业的皮肤检测仪,可以看到肉眼看不到的皮肤问题,这样制定方案更精准"},
{"speaker": "customer", "content": "好的,那检查一下吧"},
{"speaker": "consultant", "content": "好的请这边来检查大概需要5分钟"}
]
]
return random.choice(templates)
def _long_conversation_template(self) -> List[Dict]:
"""长对话模板(>5分钟- 15-20轮对话完整销售流程"""
templates = [
[
{"speaker": "consultant", "content": "您好,欢迎光临曼尼斐绮,我是资深美容顾问晓雯,请问怎么称呼您?"},
{"speaker": "customer", "content": "你好,我姓陈"},
{"speaker": "consultant", "content": "陈女士您好,看您气色很好,平时应该很注重保养吧?"},
{"speaker": "customer", "content": "还好吧,基本的护肤品会用"},
{"speaker": "consultant", "content": "这样啊。那今天是专程过来了解我们的项目,还是朋友推荐的呢?"},
{"speaker": "customer", "content": "我闺蜜在你们这做过,说效果不错,所以想来看看"},
{"speaker": "consultant", "content": "太好了,请问您闺蜜做的是什么项目呢?"},
{"speaker": "customer", "content": "好像是什么光子嫩肤"},
{"speaker": "consultant", "content": "明白了,光子嫩肤确实是我们的明星项目。不过每个人的皮肤状况不同,我先帮您做个详细的皮肤检测,看看最适合您的方案好吗?"},
{"speaker": "customer", "content": "好的"},
{"speaker": "consultant", "content": "陈女士通过检测我看到您的皮肤主要有三个问题一是T区毛孔粗大二是两颊有轻微色斑三是皮肤缺水。您平时有感觉到这些问题吗"},
{"speaker": "customer", "content": "对,毛孔确实有点大,色斑是最近才发现的"},
{"speaker": "consultant", "content": "嗯,这些问题如果不及时处理会越来越明显。针对您的情况,我建议做一个综合性的美白嫩肤方案"},
{"speaker": "customer", "content": "具体是怎么做的?"},
{"speaker": "consultant", "content": "我们采用光子嫩肤配合水光针的组合疗程。光子嫩肤主要解决色斑和毛孔问题,水光针补水锁水,效果相辅相成"},
{"speaker": "customer", "content": "听起来不错,大概需要多少钱?"},
{"speaker": "consultant", "content": "我们现在正好有活动光子嫩肤单次原价3800水光针单次2600组合套餐优惠后只要5800相当于打了九折"},
{"speaker": "customer", "content": "嗯...还是有点贵"},
{"speaker": "consultant", "content": "我理解您的顾虑。但是陈女士您想想这个价格是一次性投入效果却能维持3-6个月。平均下来每天不到30块钱换来的是皮肤的明显改善"},
{"speaker": "customer", "content": "这倒也是..."},
{"speaker": "consultant", "content": "而且这个活动就到本月底下个月恢复原价的话就要6400了。您今天如果确定的话我还可以帮您申请赠送一次基础补水护理"},
{"speaker": "customer", "content": "那行吧,今天就定了"},
{"speaker": "consultant", "content": "太好了!陈女士您做了个很明智的决定。我现在帮您预约最近的时间,您看周三下午方便吗?"}
],
[
{"speaker": "consultant", "content": "您好,欢迎光临,我是美容顾问小张,请问您贵姓?"},
{"speaker": "customer", "content": "我姓王"},
{"speaker": "consultant", "content": "王女士您好,请坐。今天想了解什么项目呢?"},
{"speaker": "customer", "content": "我想做个面部提升,感觉脸有点下垂了"},
{"speaker": "consultant", "content": "嗯,我看得出来您平时很注重保养。请问您今年多大年龄?"},
{"speaker": "customer", "content": "我42了"},
{"speaker": "consultant", "content": "42岁这个年龄段确实容易出现轻微松弛。您之前有做过抗衰项目吗"},
{"speaker": "customer", "content": "做过几次普通的面部护理,但感觉效果不明显"},
{"speaker": "consultant", "content": "普通护理主要是表层保养,对于松弛问题作用有限。您的情况需要更深层的治疗"},
{"speaker": "customer", "content": "那有什么好的方案吗?"},
{"speaker": "consultant", "content": "针对您的情况,我推荐热玛吉或者超声刀。这两种都是通过热能刺激深层胶原蛋白重组,达到紧致提升的效果"},
{"speaker": "customer", "content": "这两种有什么区别?"},
{"speaker": "consultant", "content": "热玛吉作用在真皮层,效果更自然持久,适合轻中度松弛。超声刀能到达筋膜层,提拉力度更强,适合松弛比较明显的情况"},
{"speaker": "customer", "content": "我的情况适合哪种?"},
{"speaker": "consultant", "content": "从您的面部状况来看,我建议选择热玛吉。您的松弛程度属于轻度,热玛吉的效果会更自然,恢复期也更短"},
{"speaker": "customer", "content": "费用大概多少?"},
{"speaker": "consultant", "content": "热玛吉全脸的话我们的价格是28800元。不过您今天来的时机很好我们正在做周年庆活动可以优惠到23800"},
{"speaker": "customer", "content": "还是挺贵的啊"},
{"speaker": "consultant", "content": "王女士我理解您的感受。但是热玛吉一次治疗效果可以维持2-3年平均每天只要20多块钱。而且这是一次性投入不需要反复做"},
{"speaker": "customer", "content": "效果真的能维持那么久吗?"},
{"speaker": "consultant", "content": "这是有科学依据的。热玛吉刺激的是您自身的胶原蛋白再生,不是外来填充,所以效果持久自然。我们有很多客户都做过,反馈都很好"},
{"speaker": "customer", "content": "那我考虑一下吧"},
{"speaker": "consultant", "content": "可以的。不过这个活动优惠就到本周日,下周就恢复原价了。而且名额有限,您要是确定的话最好尽快预约"},
{"speaker": "customer", "content": "好吧,那我今天就定下来吧"}
]
]
return random.choice(templates)