feat: 初始化考培练系统项目

- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
This commit is contained in:
111
2026-01-24 19:33:28 +08:00
commit 998211c483
1197 changed files with 228429 additions and 0 deletions

1
backend/app/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""考培练系统后端应用包"""

View File

@@ -0,0 +1 @@
# API 路由模块

View File

@@ -0,0 +1,497 @@
openapi: 3.0.0
info:
title: 课程管理模块API契约
description: 定义课程管理模块对外提供的所有API接口
version: 1.0.0
servers:
- url: http://localhost:8000/api/v1
description: 本地开发服务器
paths:
/courses:
get:
summary: 获取课程列表
description: 支持分页和多条件筛选
operationId: getCourses
security:
- bearerAuth: []
parameters:
- name: page
in: query
schema:
type: integer
minimum: 1
default: 1
- name: size
in: query
schema:
type: integer
minimum: 1
maximum: 100
default: 20
- name: status
in: query
schema:
type: string
enum: [draft, published, archived]
- name: category
in: query
schema:
type: string
enum: [technology, management, business, general]
- name: is_featured
in: query
schema:
type: boolean
- name: keyword
in: query
schema:
type: string
responses:
"200":
description: 成功获取课程列表
content:
application/json:
schema:
$ref: "#/components/schemas/CoursePageResponse"
"401":
$ref: "#/components/responses/UnauthorizedError"
post:
summary: 创建课程
description: 创建新课程(需要管理员权限)
operationId: createCourse
security:
- bearerAuth: []
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CourseCreate"
responses:
"201":
description: 成功创建课程
content:
application/json:
schema:
$ref: "#/components/schemas/CourseResponse"
"401":
$ref: "#/components/responses/UnauthorizedError"
"403":
$ref: "#/components/responses/ForbiddenError"
"409":
$ref: "#/components/responses/ConflictError"
/courses/{courseId}:
get:
summary: 获取课程详情
operationId: getCourse
security:
- bearerAuth: []
parameters:
- name: courseId
in: path
required: true
schema:
type: integer
responses:
"200":
description: 成功获取课程详情
content:
application/json:
schema:
$ref: "#/components/schemas/CourseResponse"
"401":
$ref: "#/components/responses/UnauthorizedError"
"404":
$ref: "#/components/responses/NotFoundError"
put:
summary: 更新课程
description: 更新课程信息(需要管理员权限)
operationId: updateCourse
security:
- bearerAuth: []
parameters:
- name: courseId
in: path
required: true
schema:
type: integer
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/CourseUpdate"
responses:
"200":
description: 成功更新课程
content:
application/json:
schema:
$ref: "#/components/schemas/CourseResponse"
"401":
$ref: "#/components/responses/UnauthorizedError"
"403":
$ref: "#/components/responses/ForbiddenError"
"404":
$ref: "#/components/responses/NotFoundError"
delete:
summary: 删除课程
description: 软删除课程(需要管理员权限)
operationId: deleteCourse
security:
- bearerAuth: []
parameters:
- name: courseId
in: path
required: true
schema:
type: integer
responses:
"200":
description: 成功删除课程
content:
application/json:
schema:
$ref: "#/components/schemas/DeleteResponse"
"400":
$ref: "#/components/responses/BadRequestError"
"401":
$ref: "#/components/responses/UnauthorizedError"
"403":
$ref: "#/components/responses/ForbiddenError"
"404":
$ref: "#/components/responses/NotFoundError"
/courses/{courseId}/knowledge-points:
get:
summary: 获取课程知识点列表
operationId: getCourseKnowledgePoints
security:
- bearerAuth: []
parameters:
- name: courseId
in: path
required: true
schema:
type: integer
- name: parent_id
in: query
schema:
type: integer
nullable: true
responses:
"200":
description: 成功获取知识点列表
content:
application/json:
schema:
$ref: "#/components/schemas/KnowledgePointListResponse"
"401":
$ref: "#/components/responses/UnauthorizedError"
"404":
$ref: "#/components/responses/NotFoundError"
components:
securitySchemes:
bearerAuth:
type: http
scheme: bearer
bearerFormat: JWT
schemas:
ResponseBase:
type: object
required:
- code
- message
properties:
code:
type: integer
default: 200
message:
type: string
request_id:
type: string
timestamp:
type: string
format: date-time
CourseBase:
type: object
properties:
name:
type: string
minLength: 1
maxLength: 200
description:
type: string
category:
type: string
enum: [technology, management, business, general]
default: general
cover_image:
type: string
maxLength: 500
duration_hours:
type: number
format: float
minimum: 0
difficulty_level:
type: integer
minimum: 1
maximum: 5
tags:
type: array
items:
type: string
sort_order:
type: integer
default: 0
is_featured:
type: boolean
default: false
CourseCreate:
allOf:
- $ref: "#/components/schemas/CourseBase"
- type: object
required:
- name
properties:
status:
type: string
enum: [draft, published, archived]
default: draft
CourseUpdate:
allOf:
- $ref: "#/components/schemas/CourseBase"
- type: object
properties:
status:
type: string
enum: [draft, published, archived]
Course:
allOf:
- $ref: "#/components/schemas/CourseBase"
- type: object
required:
- id
- status
- created_at
- updated_at
properties:
id:
type: integer
status:
type: string
enum: [draft, published, archived]
created_at:
type: string
format: date-time
updated_at:
type: string
format: date-time
published_at:
type: string
format: date-time
nullable: true
publisher_id:
type: integer
nullable: true
created_by:
type: integer
nullable: true
updated_by:
type: integer
nullable: true
CourseResponse:
allOf:
- $ref: "#/components/schemas/ResponseBase"
- type: object
properties:
data:
$ref: "#/components/schemas/Course"
CoursePageResponse:
allOf:
- $ref: "#/components/schemas/ResponseBase"
- type: object
properties:
data:
type: object
required:
- items
- total
- page
- size
- pages
properties:
items:
type: array
items:
$ref: "#/components/schemas/Course"
total:
type: integer
page:
type: integer
size:
type: integer
pages:
type: integer
KnowledgePoint:
type: object
required:
- id
- course_id
- name
- level
- created_at
- updated_at
properties:
id:
type: integer
course_id:
type: integer
name:
type: string
maxLength: 200
description:
type: string
parent_id:
type: integer
nullable: true
level:
type: integer
path:
type: string
nullable: true
sort_order:
type: integer
weight:
type: number
format: float
is_required:
type: boolean
estimated_hours:
type: number
format: float
nullable: true
created_at:
type: string
format: date-time
updated_at:
type: string
format: date-time
KnowledgePointListResponse:
allOf:
- $ref: "#/components/schemas/ResponseBase"
- type: object
properties:
data:
type: array
items:
$ref: "#/components/schemas/KnowledgePoint"
DeleteResponse:
allOf:
- $ref: "#/components/schemas/ResponseBase"
- type: object
properties:
data:
type: boolean
ErrorDetail:
type: object
required:
- message
properties:
message:
type: string
error_code:
type: string
field:
type: string
details:
type: object
responses:
BadRequestError:
description: 请求参数错误
content:
application/json:
schema:
allOf:
- $ref: "#/components/schemas/ResponseBase"
- type: object
properties:
code:
example: 400
detail:
$ref: "#/components/schemas/ErrorDetail"
UnauthorizedError:
description: 未认证
content:
application/json:
schema:
allOf:
- $ref: "#/components/schemas/ResponseBase"
- type: object
properties:
code:
example: 401
detail:
$ref: "#/components/schemas/ErrorDetail"
ForbiddenError:
description: 权限不足
content:
application/json:
schema:
allOf:
- $ref: "#/components/schemas/ResponseBase"
- type: object
properties:
code:
example: 403
detail:
$ref: "#/components/schemas/ErrorDetail"
NotFoundError:
description: 资源不存在
content:
application/json:
schema:
allOf:
- $ref: "#/components/schemas/ResponseBase"
- type: object
properties:
code:
example: 404
detail:
$ref: "#/components/schemas/ErrorDetail"
ConflictError:
description: 资源冲突
content:
application/json:
schema:
allOf:
- $ref: "#/components/schemas/ResponseBase"
- type: object
properties:
code:
example: 409
detail:
$ref: "#/components/schemas/ErrorDetail"

View File

@@ -0,0 +1,105 @@
"""
API v1 版本模块
整合所有 v1 版本的路由
"""
from fastapi import APIRouter
# 先只导入必要的路由
from .coze_gateway import router as coze_router
# 创建 v1 版本的主路由
api_router = APIRouter()
# 包含各个子路由
api_router.include_router(coze_router, tags=["coze"])
# TODO: 逐步添加其他路由
from .auth import router as auth_router
from .courses import router as courses_router
from .users import router as users_router
from .training import router as training_router
from .admin import router as admin_router
from .positions import router as positions_router
from .upload import router as upload_router
from .teams import router as teams_router
from .knowledge_analysis import router as knowledge_analysis_router
from .system import router as system_router
from .sql_executor import router as sql_executor_router
from .exam import router as exam_router
from .practice import router as practice_router
from .course_chat import router as course_chat_router
from .broadcast import router as broadcast_router
from .preview import router as preview_router
from .yanji import router as yanji_router
from .ability import router as ability_router
from .statistics import router as statistics_router
from .team_dashboard import router as team_dashboard_router
from .team_management import router as team_management_router
# Manager 模块路由
from .manager import student_scores_router, student_practice_router
from .system_logs import router as system_logs_router
from .tasks import router as tasks_router
from .endpoints.employee_sync import router as employee_sync_router
from .notifications import router as notifications_router
from .scrm import router as scrm_router
# 管理后台路由
from .admin_portal import router as admin_portal_router
api_router.include_router(auth_router, prefix="/auth", tags=["auth"])
# courses_router 已在内部定义了 prefix="/courses",此处不再额外添加前缀
api_router.include_router(courses_router, tags=["courses"])
api_router.include_router(users_router, prefix="/users", tags=["users"])
# training_router 已在内部定义了 prefix="/training",此处不再额外添加前缀
api_router.include_router(training_router, tags=["training"])
# admin_router 已在内部定义了 prefix="/admin",此处不再额外添加前缀
api_router.include_router(admin_router, tags=["admin"])
api_router.include_router(positions_router, tags=["positions"])
# upload_router 已在内部定义了 prefix="/upload",此处不再额外添加前缀
api_router.include_router(upload_router, tags=["upload"])
api_router.include_router(teams_router, tags=["teams"])
# knowledge_analysis_router 不需要额外前缀,路径已在路由中定义
api_router.include_router(knowledge_analysis_router, tags=["knowledge-analysis"])
# system_router 已在内部定义了 prefix="/system",此处不再额外添加前缀
api_router.include_router(system_router, tags=["system"])
# sql_executor_router SQL 执行器
api_router.include_router(sql_executor_router, prefix="/sql", tags=["sql-executor"])
# exam_router 已在内部定义了 prefix="/exams",此处不再额外添加前缀
api_router.include_router(exam_router, tags=["exams"])
# practice_router 陪练功能路由
api_router.include_router(practice_router, prefix="/practice", tags=["practice"])
# course_chat_router 与课程对话路由
api_router.include_router(course_chat_router, prefix="/course", tags=["course-chat"])
# broadcast_router 播课功能路由不添加prefix路径在router内部定义
api_router.include_router(broadcast_router, tags=["broadcast"])
# preview_router 文件预览路由
api_router.include_router(preview_router, prefix="/preview", tags=["preview"])
# yanji_router 言迹智能工牌路由
api_router.include_router(yanji_router, prefix="/yanji", tags=["yanji"])
# ability_router 能力评估路由
api_router.include_router(ability_router, prefix="/ability", tags=["ability"])
# statistics_router 统计分析路由不添加prefix路径在router内部定义
api_router.include_router(statistics_router, tags=["statistics"])
# team_dashboard_router 团队看板路由不添加prefix路径在router内部定义为/team/dashboard
api_router.include_router(team_dashboard_router, tags=["team-dashboard"])
# team_management_router 团队成员管理路由不添加prefix路径在router内部定义为/team/management
api_router.include_router(team_management_router, tags=["team-management"])
# student_scores_router 学员考试成绩管理路由不添加prefix路径在router内部定义为/manager/student-scores
api_router.include_router(student_scores_router, tags=["manager-student-scores"])
# student_practice_router 学员陪练记录管理路由不添加prefix路径在router内部定义为/manager/student-practice
api_router.include_router(student_practice_router, tags=["manager-student-practice"])
# system_logs_router 系统日志路由不添加prefix路径在router内部定义为/admin/logs
api_router.include_router(system_logs_router, tags=["system-logs"])
# tasks_router 任务管理路由不添加prefix路径在router内部定义为/manager/tasks
api_router.include_router(tasks_router, tags=["tasks"])
# employee_sync_router 员工同步路由
api_router.include_router(employee_sync_router, prefix="/employee-sync", tags=["employee-sync"])
# notifications_router 站内消息通知路由不添加prefix路径在router内部定义为/notifications
api_router.include_router(notifications_router, tags=["notifications"])
# scrm_router SCRM系统对接路由prefix在router内部定义为/scrm
api_router.include_router(scrm_router, tags=["scrm"])
# admin_portal_router SaaS超级管理后台路由prefix在router内部定义为/admin
api_router.include_router(admin_portal_router, tags=["admin-portal"])
__all__ = ["api_router"]

View File

@@ -0,0 +1,187 @@
"""
能力评估API接口
用于智能工牌数据分析、能力评估报告生成等
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.core.deps import get_current_user, get_db
from app.models.user import User
from app.schemas.base import ResponseModel
from app.schemas.ability import AbilityAssessmentResponse, AbilityAssessmentHistory
from app.services.yanji_service import YanjiService
from app.services.ability_assessment_service import get_ability_assessment_service
import logging
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/analyze-yanji", response_model=ResponseModel)
async def analyze_yanji_badge_data(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
分析智能工牌数据生成能力评估和课程推荐
使用 Python 原生 AI 服务实现。
功能说明:
1. 从言迹智能工牌获取员工的最近10条录音记录
2. 分析对话数据进行能力评估6个维度
3. 基于能力短板生成课程推荐3-5门
4. 保存评估记录到数据库
要求:
- 用户必须已绑定手机号(用于匹配言迹数据)
返回:
- assessment_id: 评估记录ID
- total_score: 综合评分0-100
- dimensions: 能力维度列表6个维度
- recommended_courses: 推荐课程列表3-5门
- conversation_count: 分析的对话数量
"""
# 检查用户是否绑定手机号
if not current_user.phone:
logger.warning(f"用户未绑定手机号: user_id={current_user.id}")
raise HTTPException(
status_code=400,
detail="用户未绑定手机号,无法匹配言迹数据"
)
# 获取服务实例
yanji_service = YanjiService()
assessment_service = get_ability_assessment_service()
try:
logger.info(
f"开始分析智能工牌数据: user_id={current_user.id}, "
f"phone={current_user.phone}"
)
# 调用能力评估服务(使用 Python 原生实现)
result = await assessment_service.analyze_yanji_conversations(
user_id=current_user.id,
phone=current_user.phone,
db=db,
yanji_service=yanji_service,
engine="v2" # 固定使用 V2
)
logger.info(
f"智能工牌数据分析完成: user_id={current_user.id}, "
f"assessment_id={result['assessment_id']}, "
f"total_score={result['total_score']}"
)
return ResponseModel(
code=200,
message="智能工牌数据分析完成",
data=result
)
except ValueError as e:
# 业务逻辑错误(如未找到录音记录)
logger.warning(f"智能工牌数据分析失败: {e}")
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
# 系统错误
logger.error(f"分析智能工牌数据失败: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"分析失败: {str(e)}"
)
@router.get("/history", response_model=ResponseModel)
async def get_assessment_history(
limit: int = Query(default=10, ge=1, le=50, description="返回记录数量"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取用户的能力评估历史记录
参数:
- limit: 返回记录数量默认10最大50
返回:
- 评估历史记录列表
"""
assessment_service = get_ability_assessment_service()
try:
history = await assessment_service.get_user_assessment_history(
user_id=current_user.id,
db=db,
limit=limit
)
return ResponseModel(
code=200,
message=f"获取评估历史成功,共{len(history)}",
data={"history": history, "total": len(history)}
)
except Exception as e:
logger.error(f"获取评估历史失败: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"获取评估历史失败: {str(e)}"
)
@router.get("/{assessment_id}", response_model=ResponseModel)
async def get_assessment_detail(
assessment_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取单个评估记录的详细信息
参数:
- assessment_id: 评估记录ID
返回:
- 评估详细信息
"""
assessment_service = get_ability_assessment_service()
try:
detail = await assessment_service.get_assessment_detail(
assessment_id=assessment_id,
db=db
)
# 权限检查:只能查看自己的评估记录
if detail['user_id'] != current_user.id:
raise HTTPException(
status_code=403,
detail="无权访问该评估记录"
)
return ResponseModel(
code=200,
message="获取评估详情成功",
data=detail
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except HTTPException:
raise
except Exception as e:
logger.error(f"获取评估详情失败: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"获取评估详情失败: {str(e)}"
)

509
backend/app/api/v1/admin.py Normal file
View File

@@ -0,0 +1,509 @@
"""
管理员相关API路由
"""
from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from app.core.deps import get_current_active_user as get_current_user, get_db
from app.models.user import User
from app.models.course import Course, CourseStatus
from app.schemas.base import ResponseModel
router = APIRouter(prefix="/admin")
@router.get("/dashboard/stats")
async def get_dashboard_stats(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
获取管理员仪表盘统计数据
需要管理员权限
"""
# 权限检查
if current_user.role != "admin":
return ResponseModel(
code=403,
message="权限不足,需要管理员权限"
)
# 用户统计
total_users = await db.scalar(select(func.count(User.id)))
# 计算最近30天的新增用户
thirty_days_ago = datetime.now() - timedelta(days=30)
new_users_count = await db.scalar(
select(func.count(User.id))
.where(User.created_at >= thirty_days_ago)
)
# 计算增长率假设上个月也是30天
sixty_days_ago = datetime.now() - timedelta(days=60)
last_month_users = await db.scalar(
select(func.count(User.id))
.where(User.created_at >= sixty_days_ago)
.where(User.created_at < thirty_days_ago)
)
growth_rate = 0.0
if last_month_users > 0:
growth_rate = ((new_users_count - last_month_users) / last_month_users) * 100
# 课程统计
total_courses = await db.scalar(
select(func.count(Course.id))
.where(Course.status == CourseStatus.PUBLISHED)
)
# TODO: 完成的课程数需要根据用户课程进度表计算
completed_courses = 0 # 暂时设为0
# 考试统计(如果有考试表的话)
total_exams = 0
avg_score = 0.0
pass_rate = "0%"
# 学习时长统计(如果有学习记录表的话)
total_learning_hours = 0
avg_learning_hours = 0.0
active_rate = "0%"
# 构建响应数据
stats = {
"users": {
"total": total_users,
"growth": new_users_count,
"growthRate": f"{growth_rate:.1f}%"
},
"courses": {
"total": total_courses,
"completed": completed_courses,
"completionRate": f"{(completed_courses / total_courses * 100) if total_courses > 0 else 0:.1f}%"
},
"exams": {
"total": total_exams,
"avgScore": avg_score,
"passRate": pass_rate
},
"learning": {
"totalHours": total_learning_hours,
"avgHours": avg_learning_hours,
"activeRate": active_rate
}
}
return ResponseModel(
code=200,
message="获取仪表盘统计数据成功",
data=stats
)
@router.get("/dashboard/user-growth")
async def get_user_growth_data(
days: int = Query(30, description="统计天数", ge=7, le=90),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
获取用户增长数据
Args:
days: 统计天数默认30天
需要管理员权限
"""
# 权限检查
if current_user.role != "admin":
return ResponseModel(
code=403,
message="权限不足,需要管理员权限"
)
# 准备日期列表
dates = []
new_users = []
active_users = []
end_date = datetime.now().date()
for i in range(days):
current_date = end_date - timedelta(days=days-1-i)
dates.append(current_date.strftime("%Y-%m-%d"))
# 统计当天新增用户
next_date = current_date + timedelta(days=1)
new_count = await db.scalar(
select(func.count(User.id))
.where(func.date(User.created_at) == current_date)
)
new_users.append(new_count or 0)
# 统计当天活跃用户(有登录记录)
active_count = await db.scalar(
select(func.count(User.id))
.where(func.date(User.last_login_at) == current_date)
)
active_users.append(active_count or 0)
return ResponseModel(
code=200,
message="获取用户增长数据成功",
data={
"dates": dates,
"newUsers": new_users,
"activeUsers": active_users
}
)
@router.get("/dashboard/course-completion")
async def get_course_completion_data(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
获取课程完成率数据
需要管理员权限
"""
# 权限检查
if current_user.role != "admin":
return ResponseModel(
code=403,
message="权限不足,需要管理员权限"
)
# 获取所有已发布的课程
courses_result = await db.execute(
select(Course.name, Course.id)
.where(Course.status == CourseStatus.PUBLISHED)
.order_by(Course.sort_order, Course.id)
.limit(10) # 限制显示前10个课程
)
courses = courses_result.all()
course_names = []
completion_rates = []
for course_name, course_id in courses:
course_names.append(course_name)
# TODO: 根据用户课程进度表计算完成率
# 这里暂时生成模拟数据
import random
completion_rate = random.randint(60, 95)
completion_rates.append(completion_rate)
return ResponseModel(
code=200,
message="获取课程完成率数据成功",
data={
"courses": course_names,
"completionRates": completion_rates
}
)
# ===== 岗位管理(最小可用 stub 版本)=====
def _ensure_admin(user: User) -> Optional[ResponseModel]:
if user.role != "admin":
return ResponseModel(code=403, message="权限不足,需要管理员权限")
return None
# 注意positions相关路由已移至positions.py
# _sample_positions函数和所有positions路由已删除避免与positions.py冲突
# ===== 用户批量操作 =====
from pydantic import BaseModel
from app.models.position_member import PositionMember
class BatchUserOperation(BaseModel):
"""批量用户操作请求模型"""
ids: List[int]
action: str # delete, activate, deactivate, change_role, assign_position, assign_team
value: Optional[Any] = None # 角色值、岗位ID、团队ID等
@router.post("/users/batch", response_model=ResponseModel)
async def batch_user_operation(
operation: BatchUserOperation,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
批量用户操作
支持的操作类型:
- delete: 批量删除用户(软删除)
- activate: 批量启用用户
- deactivate: 批量禁用用户
- change_role: 批量修改角色(需要 value 参数)
- assign_position: 批量分配岗位(需要 value 参数为岗位ID
- assign_team: 批量分配团队(需要 value 参数为团队ID
权限:需要管理员权限
"""
# 权限检查
if current_user.role != "admin":
return ResponseModel(
code=403,
message="权限不足,需要管理员权限"
)
if not operation.ids:
return ResponseModel(
code=400,
message="请选择要操作的用户"
)
# 不能操作自己
if current_user.id in operation.ids:
return ResponseModel(
code=400,
message="不能对自己执行批量操作"
)
# 获取要操作的用户
result = await db.execute(
select(User).where(User.id.in_(operation.ids), User.is_deleted == False)
)
users = result.scalars().all()
if not users:
return ResponseModel(
code=404,
message="未找到要操作的用户"
)
success_count = 0
failed_count = 0
errors = []
try:
if operation.action == "delete":
# 批量软删除
for user in users:
try:
user.is_deleted = True
user.deleted_at = datetime.now()
success_count += 1
except Exception as e:
failed_count += 1
errors.append(f"删除用户 {user.username} 失败: {str(e)}")
await db.commit()
elif operation.action == "activate":
# 批量启用
for user in users:
try:
user.is_active = True
success_count += 1
except Exception as e:
failed_count += 1
errors.append(f"启用用户 {user.username} 失败: {str(e)}")
await db.commit()
elif operation.action == "deactivate":
# 批量禁用
for user in users:
try:
user.is_active = False
success_count += 1
except Exception as e:
failed_count += 1
errors.append(f"禁用用户 {user.username} 失败: {str(e)}")
await db.commit()
elif operation.action == "change_role":
# 批量修改角色
if not operation.value:
return ResponseModel(
code=400,
message="请指定要修改的角色"
)
valid_roles = ["trainee", "manager", "admin"]
if operation.value not in valid_roles:
return ResponseModel(
code=400,
message=f"无效的角色,可选值: {', '.join(valid_roles)}"
)
for user in users:
try:
user.role = operation.value
success_count += 1
except Exception as e:
failed_count += 1
errors.append(f"修改用户 {user.username} 角色失败: {str(e)}")
await db.commit()
elif operation.action == "assign_position":
# 批量分配岗位
if not operation.value:
return ResponseModel(
code=400,
message="请指定要分配的岗位ID"
)
position_id = int(operation.value)
# 获取岗位信息用于通知
from app.models.position import Position
position_result = await db.execute(
select(Position).where(Position.id == position_id)
)
position = position_result.scalar_one_or_none()
position_name = position.name if position else "未知岗位"
# 记录新分配成功的用户ID用于发送通知
newly_assigned_user_ids = []
for user in users:
try:
# 检查是否已有该岗位
existing = await db.execute(
select(PositionMember).where(
PositionMember.user_id == user.id,
PositionMember.position_id == position_id,
PositionMember.is_deleted == False
)
)
if existing.scalar_one_or_none():
# 已有该岗位,跳过
success_count += 1
continue
# 添加岗位关联PositionMember模型没有created_by字段
member = PositionMember(
position_id=position_id,
user_id=user.id,
joined_at=datetime.now()
)
db.add(member)
newly_assigned_user_ids.append(user.id)
success_count += 1
except Exception as e:
failed_count += 1
errors.append(f"为用户 {user.username} 分配岗位失败: {str(e)}")
await db.commit()
# 发送岗位分配通知给新分配的用户
if newly_assigned_user_ids:
try:
from app.services.notification_service import notification_service
from app.schemas.notification import NotificationBatchCreate, NotificationType
notification_batch = NotificationBatchCreate(
user_ids=newly_assigned_user_ids,
title="岗位分配通知",
content=f"您已被分配到「{position_name}」岗位,请查看相关培训课程。",
type=NotificationType.POSITION_ASSIGN,
related_id=position_id,
related_type="position",
sender_id=current_user.id
)
await notification_service.batch_create_notifications(
db=db,
batch_in=notification_batch
)
except Exception as e:
# 通知发送失败不影响岗位分配结果
import logging
logging.getLogger(__name__).error(f"发送岗位分配通知失败: {str(e)}")
elif operation.action == "assign_team":
# 批量分配团队
if not operation.value:
return ResponseModel(
code=400,
message="请指定要分配的团队ID"
)
from app.models.user import user_teams
team_id = int(operation.value)
for user in users:
try:
# 检查是否已在该团队
existing = await db.execute(
select(user_teams).where(
user_teams.c.user_id == user.id,
user_teams.c.team_id == team_id
)
)
if existing.first():
# 已在该团队,跳过
success_count += 1
continue
# 添加团队关联
await db.execute(
user_teams.insert().values(
user_id=user.id,
team_id=team_id,
role="member",
joined_at=datetime.now()
)
)
success_count += 1
except Exception as e:
failed_count += 1
errors.append(f"为用户 {user.username} 分配团队失败: {str(e)}")
await db.commit()
else:
return ResponseModel(
code=400,
message=f"不支持的操作类型: {operation.action}"
)
# 返回结果
action_names = {
"delete": "删除",
"activate": "启用",
"deactivate": "禁用",
"change_role": "修改角色",
"assign_position": "分配岗位",
"assign_team": "分配团队"
}
action_name = action_names.get(operation.action, operation.action)
return ResponseModel(
code=200,
message=f"批量{action_name}完成:成功 {success_count} 个,失败 {failed_count}",
data={
"success_count": success_count,
"failed_count": failed_count,
"errors": errors
}
)
except Exception as e:
await db.rollback()
return ResponseModel(
code=500,
message=f"批量操作失败: {str(e)}"
)

View File

@@ -0,0 +1,24 @@
"""
SaaS 超级管理后台 API
提供租户管理、配置管理、提示词管理等功能
"""
from fastapi import APIRouter
from .auth import router as auth_router
from .tenants import router as tenants_router
from .configs import router as configs_router
from .prompts import router as prompts_router
from .features import router as features_router
# 创建管理后台主路由
router = APIRouter(prefix="/admin", tags=["管理后台"])
# 注册子路由
router.include_router(auth_router)
router.include_router(tenants_router)
router.include_router(configs_router)
router.include_router(prompts_router)
router.include_router(features_router)

View File

@@ -0,0 +1,277 @@
"""
管理员认证 API
"""
import os
from datetime import datetime, timedelta
from typing import Optional
import jwt
import pymysql
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from passlib.context import CryptContext
from .schemas import (
AdminLoginRequest,
AdminLoginResponse,
AdminUserInfo,
AdminChangePasswordRequest,
ResponseModel,
)
router = APIRouter(prefix="/auth", tags=["管理员认证"])
# 密码加密
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# JWT 配置
SECRET_KEY = os.getenv("ADMIN_JWT_SECRET", "admin-secret-key-kaopeilian-2026")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_HOURS = 24
# 安全认证
security = HTTPBearer()
# 管理库连接配置
ADMIN_DB_CONFIG = {
"host": os.getenv("ADMIN_DB_HOST", "prod-mysql"),
"port": int(os.getenv("ADMIN_DB_PORT", "3306")),
"user": os.getenv("ADMIN_DB_USER", "root"),
"password": os.getenv("ADMIN_DB_PASSWORD", "ProdMySQL2025!@#"),
"db": os.getenv("ADMIN_DB_NAME", "kaopeilian_admin"),
"charset": "utf8mb4",
}
def get_db_connection():
"""获取数据库连接"""
return pymysql.connect(**ADMIN_DB_CONFIG, cursorclass=pymysql.cursors.DictCursor)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""获取密码哈希"""
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def decode_access_token(token: str) -> dict:
"""解码访问令牌"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token已过期",
)
except jwt.InvalidTokenError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的Token",
)
async def get_current_admin(
credentials: HTTPAuthorizationCredentials = Depends(security)
) -> AdminUserInfo:
"""获取当前登录的管理员"""
token = credentials.credentials
payload = decode_access_token(token)
admin_id = payload.get("sub")
if not admin_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的Token",
)
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT id, username, email, full_name, role, is_active, last_login_at
FROM admin_users WHERE id = %s
""",
(admin_id,)
)
admin = cursor.fetchone()
if not admin:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="管理员不存在",
)
if not admin["is_active"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账户已被禁用",
)
return AdminUserInfo(
id=admin["id"],
username=admin["username"],
email=admin["email"],
full_name=admin["full_name"],
role=admin["role"],
last_login_at=admin["last_login_at"],
)
finally:
conn.close()
async def require_superadmin(
admin: AdminUserInfo = Depends(get_current_admin)
) -> AdminUserInfo:
"""要求超级管理员权限"""
if admin.role != "superadmin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要超级管理员权限",
)
return admin
@router.post("/login", response_model=AdminLoginResponse, summary="管理员登录")
async def admin_login(request: Request, login_data: AdminLoginRequest):
"""
管理员登录
- **username**: 用户名
- **password**: 密码
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 查询管理员
cursor.execute(
"""
SELECT id, username, email, full_name, role, password_hash, is_active, last_login_at
FROM admin_users WHERE username = %s
""",
(login_data.username,)
)
admin = cursor.fetchone()
if not admin:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
if not admin["is_active"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账户已被禁用",
)
# 验证密码
if not verify_password(login_data.password, admin["password_hash"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
# 更新最后登录时间和IP
client_ip = request.client.host if request.client else None
cursor.execute(
"""
UPDATE admin_users
SET last_login_at = NOW(), last_login_ip = %s
WHERE id = %s
""",
(client_ip, admin["id"])
)
conn.commit()
# 创建 Token
access_token = create_access_token(
data={"sub": str(admin["id"]), "username": admin["username"], "role": admin["role"]}
)
return AdminLoginResponse(
access_token=access_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_HOURS * 3600,
admin_user=AdminUserInfo(
id=admin["id"],
username=admin["username"],
email=admin["email"],
full_name=admin["full_name"],
role=admin["role"],
last_login_at=datetime.now(),
),
)
finally:
conn.close()
@router.get("/me", response_model=AdminUserInfo, summary="获取当前管理员信息")
async def get_me(admin: AdminUserInfo = Depends(get_current_admin)):
"""获取当前登录管理员的信息"""
return admin
@router.post("/change-password", response_model=ResponseModel, summary="修改密码")
async def change_password(
data: AdminChangePasswordRequest,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""
修改当前管理员密码
- **old_password**: 旧密码
- **new_password**: 新密码
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证旧密码
cursor.execute(
"SELECT password_hash FROM admin_users WHERE id = %s",
(admin.id,)
)
row = cursor.fetchone()
if not verify_password(data.old_password, row["password_hash"]):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="旧密码错误",
)
# 更新密码
new_hash = get_password_hash(data.new_password)
cursor.execute(
"UPDATE admin_users SET password_hash = %s WHERE id = %s",
(new_hash, admin.id)
)
conn.commit()
return ResponseModel(message="密码修改成功")
finally:
conn.close()
@router.post("/logout", response_model=ResponseModel, summary="退出登录")
async def admin_logout(admin: AdminUserInfo = Depends(get_current_admin)):
"""退出登录(客户端需清除 Token"""
return ResponseModel(message="退出成功")

View File

@@ -0,0 +1,480 @@
"""
配置管理 API
"""
import os
import json
from typing import Optional, List, Dict
import pymysql
from fastapi import APIRouter, Depends, HTTPException, status, Query
from .auth import get_current_admin, get_db_connection, AdminUserInfo
from .schemas import (
ConfigTemplateResponse,
TenantConfigResponse,
TenantConfigCreate,
TenantConfigUpdate,
TenantConfigGroupResponse,
ConfigBatchUpdate,
ResponseModel,
)
router = APIRouter(prefix="/configs", tags=["配置管理"])
# 配置分组显示名称
CONFIG_GROUP_NAMES = {
"database": "数据库配置",
"redis": "Redis配置",
"security": "安全配置",
"coze": "Coze配置",
"ai": "AI服务配置",
"yanji": "言迹工牌配置",
"storage": "文件存储配置",
"basic": "基础配置",
}
def log_operation(cursor, admin: AdminUserInfo, tenant_id: int, tenant_code: str,
operation_type: str, resource_type: str, resource_id: int,
resource_name: str, old_value: dict = None, new_value: dict = None):
"""记录操作日志"""
cursor.execute(
"""
INSERT INTO operation_logs
(admin_user_id, admin_username, tenant_id, tenant_code, operation_type,
resource_type, resource_id, resource_name, old_value, new_value)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(admin.id, admin.username, tenant_id, tenant_code, operation_type,
resource_type, resource_id, resource_name,
json.dumps(old_value, ensure_ascii=False) if old_value else None,
json.dumps(new_value, ensure_ascii=False) if new_value else None)
)
@router.get("/templates", response_model=List[ConfigTemplateResponse], summary="获取配置模板")
async def get_config_templates(
config_group: Optional[str] = Query(None, description="配置分组筛选"),
admin: AdminUserInfo = Depends(get_current_admin),
):
"""
获取配置模板列表
配置模板定义了所有可配置项的元数据
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
if config_group:
cursor.execute(
"""
SELECT * FROM config_templates
WHERE config_group = %s
ORDER BY sort_order, id
""",
(config_group,)
)
else:
cursor.execute(
"SELECT * FROM config_templates ORDER BY config_group, sort_order, id"
)
rows = cursor.fetchall()
result = []
for row in rows:
# 解析 options 字段
options = None
if row.get("options"):
try:
options = json.loads(row["options"])
except:
pass
result.append(ConfigTemplateResponse(
id=row["id"],
config_group=row["config_group"],
config_key=row["config_key"],
display_name=row["display_name"],
description=row["description"],
value_type=row["value_type"],
default_value=row["default_value"],
is_required=row["is_required"],
is_secret=row["is_secret"],
options=options,
sort_order=row["sort_order"],
))
return result
finally:
conn.close()
@router.get("/groups", response_model=List[Dict], summary="获取配置分组列表")
async def get_config_groups(admin: AdminUserInfo = Depends(get_current_admin)):
"""获取配置分组列表"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT config_group, COUNT(*) as count
FROM config_templates
GROUP BY config_group
ORDER BY config_group
"""
)
rows = cursor.fetchall()
return [
{
"group_name": row["config_group"],
"group_display_name": CONFIG_GROUP_NAMES.get(row["config_group"], row["config_group"]),
"config_count": row["count"],
}
for row in rows
]
finally:
conn.close()
@router.get("/tenants/{tenant_id}", response_model=List[TenantConfigGroupResponse], summary="获取租户配置")
async def get_tenant_configs(
tenant_id: int,
config_group: Optional[str] = Query(None, description="配置分组筛选"),
admin: AdminUserInfo = Depends(get_current_admin),
):
"""
获取租户的所有配置
返回按分组整理的配置列表,包含模板信息
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证租户存在
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 查询配置模板和租户配置
group_filter = "AND ct.config_group = %s" if config_group else ""
params = [tenant_id, config_group] if config_group else [tenant_id]
cursor.execute(
f"""
SELECT
ct.config_group,
ct.config_key,
ct.display_name,
ct.description,
ct.value_type,
ct.default_value,
ct.is_required,
ct.is_secret,
ct.sort_order,
tc.id as config_id,
tc.config_value,
tc.is_encrypted,
tc.created_at,
tc.updated_at
FROM config_templates ct
LEFT JOIN tenant_configs tc
ON tc.config_group = ct.config_group
AND tc.config_key = ct.config_key
AND tc.tenant_id = %s
WHERE 1=1 {group_filter}
ORDER BY ct.config_group, ct.sort_order, ct.id
""",
params
)
rows = cursor.fetchall()
# 按分组整理
groups: Dict[str, List] = {}
for row in rows:
group = row["config_group"]
if group not in groups:
groups[group] = []
# 如果是敏感信息且有值,隐藏部分内容
config_value = row["config_value"]
if row["is_secret"] and config_value:
if len(config_value) > 8:
config_value = config_value[:4] + "****" + config_value[-4:]
else:
config_value = "****"
groups[group].append(TenantConfigResponse(
id=row["config_id"] or 0,
config_group=row["config_group"],
config_key=row["config_key"],
config_value=config_value if not row["is_secret"] else row["config_value"],
value_type=row["value_type"],
is_encrypted=row["is_encrypted"] or False,
description=row["description"],
created_at=row["created_at"] or None,
updated_at=row["updated_at"] or None,
display_name=row["display_name"],
is_required=row["is_required"],
is_secret=row["is_secret"],
))
return [
TenantConfigGroupResponse(
group_name=group,
group_display_name=CONFIG_GROUP_NAMES.get(group, group),
configs=configs,
)
for group, configs in groups.items()
]
finally:
conn.close()
@router.put("/tenants/{tenant_id}/{config_group}/{config_key}", response_model=ResponseModel, summary="更新单个配置")
async def update_tenant_config(
tenant_id: int,
config_group: str,
config_key: str,
data: TenantConfigUpdate,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""更新租户的单个配置项"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证租户存在
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 验证配置模板存在
cursor.execute(
"""
SELECT value_type, is_secret FROM config_templates
WHERE config_group = %s AND config_key = %s
""",
(config_group, config_key)
)
template = cursor.fetchone()
if not template:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的配置项",
)
# 检查是否已有配置
cursor.execute(
"""
SELECT id, config_value FROM tenant_configs
WHERE tenant_id = %s AND config_group = %s AND config_key = %s
""",
(tenant_id, config_group, config_key)
)
existing = cursor.fetchone()
if existing:
# 更新
old_value = existing["config_value"]
cursor.execute(
"""
UPDATE tenant_configs
SET config_value = %s, is_encrypted = %s
WHERE id = %s
""",
(data.config_value, template["is_secret"], existing["id"])
)
else:
# 插入
old_value = None
cursor.execute(
"""
INSERT INTO tenant_configs
(tenant_id, config_group, config_key, config_value, value_type, is_encrypted)
VALUES (%s, %s, %s, %s, %s, %s)
""",
(tenant_id, config_group, config_key, data.config_value,
template["value_type"], template["is_secret"])
)
# 记录操作日志
log_operation(
cursor, admin, tenant_id, tenant["code"],
"update", "config", tenant_id, f"{config_group}.{config_key}",
old_value={"value": old_value} if old_value else None,
new_value={"value": data.config_value}
)
conn.commit()
return ResponseModel(message="配置已更新")
finally:
conn.close()
@router.put("/tenants/{tenant_id}/batch", response_model=ResponseModel, summary="批量更新配置")
async def batch_update_tenant_configs(
tenant_id: int,
data: ConfigBatchUpdate,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""批量更新租户配置"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证租户存在
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
updated_count = 0
for config in data.configs:
# 获取模板信息
cursor.execute(
"""
SELECT value_type, is_secret FROM config_templates
WHERE config_group = %s AND config_key = %s
""",
(config.config_group, config.config_key)
)
template = cursor.fetchone()
if not template:
continue
# 检查是否已有配置
cursor.execute(
"""
SELECT id FROM tenant_configs
WHERE tenant_id = %s AND config_group = %s AND config_key = %s
""",
(tenant_id, config.config_group, config.config_key)
)
existing = cursor.fetchone()
if existing:
cursor.execute(
"""
UPDATE tenant_configs
SET config_value = %s, is_encrypted = %s
WHERE id = %s
""",
(config.config_value, template["is_secret"], existing["id"])
)
else:
cursor.execute(
"""
INSERT INTO tenant_configs
(tenant_id, config_group, config_key, config_value, value_type, is_encrypted)
VALUES (%s, %s, %s, %s, %s, %s)
""",
(tenant_id, config.config_group, config.config_key, config.config_value,
template["value_type"], template["is_secret"])
)
updated_count += 1
# 记录操作日志
log_operation(
cursor, admin, tenant_id, tenant["code"],
"batch_update", "config", tenant_id, f"批量更新 {updated_count} 项配置"
)
conn.commit()
return ResponseModel(message=f"已更新 {updated_count} 项配置")
finally:
conn.close()
@router.delete("/tenants/{tenant_id}/{config_group}/{config_key}", response_model=ResponseModel, summary="删除配置")
async def delete_tenant_config(
tenant_id: int,
config_group: str,
config_key: str,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""删除租户的配置项(恢复为默认值)"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证租户存在
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 删除配置
cursor.execute(
"""
DELETE FROM tenant_configs
WHERE tenant_id = %s AND config_group = %s AND config_key = %s
""",
(tenant_id, config_group, config_key)
)
if cursor.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="配置不存在",
)
# 记录操作日志
log_operation(
cursor, admin, tenant_id, tenant["code"],
"delete", "config", tenant_id, f"{config_group}.{config_key}"
)
conn.commit()
return ResponseModel(message="配置已删除,将使用默认值")
finally:
conn.close()
@router.post("/tenants/{tenant_id}/refresh-cache", response_model=ResponseModel, summary="刷新配置缓存")
async def refresh_tenant_config_cache(
tenant_id: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""刷新租户的配置缓存"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 获取租户编码
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 刷新缓存
try:
from app.core.config import DynamicConfig
import asyncio
asyncio.create_task(DynamicConfig.refresh_cache(tenant["code"]))
except Exception as e:
pass # 缓存刷新失败不影响主流程
return ResponseModel(message="缓存刷新请求已发送")
finally:
conn.close()

View File

@@ -0,0 +1,424 @@
"""
功能开关管理 API
"""
import os
import json
from typing import Optional, List, Dict
import pymysql
from fastapi import APIRouter, Depends, HTTPException, status, Query
from .auth import get_current_admin, get_db_connection, AdminUserInfo
from .schemas import (
FeatureSwitchCreate,
FeatureSwitchUpdate,
FeatureSwitchResponse,
FeatureSwitchGroupResponse,
ResponseModel,
)
router = APIRouter(prefix="/features", tags=["功能开关"])
# 功能分组显示名称
FEATURE_GROUP_NAMES = {
"exam": "考试模块",
"practice": "陪练模块",
"broadcast": "播课模块",
"course": "课程模块",
"yanji": "智能工牌模块",
}
def log_operation(cursor, admin: AdminUserInfo, tenant_id: int, tenant_code: str,
operation_type: str, resource_type: str, resource_id: int,
resource_name: str, old_value: dict = None, new_value: dict = None):
"""记录操作日志"""
cursor.execute(
"""
INSERT INTO operation_logs
(admin_user_id, admin_username, tenant_id, tenant_code, operation_type,
resource_type, resource_id, resource_name, old_value, new_value)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(admin.id, admin.username, tenant_id, tenant_code, operation_type,
resource_type, resource_id, resource_name,
json.dumps(old_value, ensure_ascii=False) if old_value else None,
json.dumps(new_value, ensure_ascii=False) if new_value else None)
)
@router.get("/defaults", response_model=List[FeatureSwitchGroupResponse], summary="获取默认功能开关")
async def get_default_features(
admin: AdminUserInfo = Depends(get_current_admin),
):
"""获取全局默认的功能开关配置"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM feature_switches
WHERE tenant_id IS NULL
ORDER BY feature_group, id
"""
)
rows = cursor.fetchall()
# 按分组整理
groups: Dict[str, List] = {}
for row in rows:
group = row["feature_group"] or "other"
if group not in groups:
groups[group] = []
config = None
if row.get("config"):
try:
config = json.loads(row["config"])
except:
pass
groups[group].append(FeatureSwitchResponse(
id=row["id"],
tenant_id=row["tenant_id"],
feature_code=row["feature_code"],
feature_name=row["feature_name"],
feature_group=row["feature_group"],
is_enabled=row["is_enabled"],
config=config,
description=row["description"],
created_at=row["created_at"],
updated_at=row["updated_at"],
))
return [
FeatureSwitchGroupResponse(
group_name=group,
group_display_name=FEATURE_GROUP_NAMES.get(group, group),
features=features,
)
for group, features in groups.items()
]
finally:
conn.close()
@router.get("/tenants/{tenant_id}", response_model=List[FeatureSwitchGroupResponse], summary="获取租户功能开关")
async def get_tenant_features(
tenant_id: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""
获取租户的功能开关配置
返回租户自定义配置和默认配置的合并结果
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证租户存在
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 获取默认配置
cursor.execute(
"""
SELECT * FROM feature_switches
WHERE tenant_id IS NULL
ORDER BY feature_group, id
"""
)
default_rows = cursor.fetchall()
# 获取租户配置
cursor.execute(
"""
SELECT * FROM feature_switches
WHERE tenant_id = %s
""",
(tenant_id,)
)
tenant_rows = cursor.fetchall()
# 合并配置
tenant_features = {row["feature_code"]: row for row in tenant_rows}
groups: Dict[str, List] = {}
for row in default_rows:
group = row["feature_group"] or "other"
if group not in groups:
groups[group] = []
# 使用租户配置覆盖默认配置
effective_row = tenant_features.get(row["feature_code"], row)
config = None
if effective_row.get("config"):
try:
config = json.loads(effective_row["config"])
except:
pass
groups[group].append(FeatureSwitchResponse(
id=effective_row["id"],
tenant_id=effective_row["tenant_id"],
feature_code=effective_row["feature_code"],
feature_name=effective_row["feature_name"],
feature_group=effective_row["feature_group"],
is_enabled=effective_row["is_enabled"],
config=config,
description=effective_row["description"],
created_at=effective_row["created_at"],
updated_at=effective_row["updated_at"],
))
return [
FeatureSwitchGroupResponse(
group_name=group,
group_display_name=FEATURE_GROUP_NAMES.get(group, group),
features=features,
)
for group, features in groups.items()
]
finally:
conn.close()
@router.put("/tenants/{tenant_id}/{feature_code}", response_model=ResponseModel, summary="更新租户功能开关")
async def update_tenant_feature(
tenant_id: int,
feature_code: str,
data: FeatureSwitchUpdate,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""更新租户的功能开关"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证租户存在
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 获取默认配置
cursor.execute(
"""
SELECT * FROM feature_switches
WHERE tenant_id IS NULL AND feature_code = %s
""",
(feature_code,)
)
default_feature = cursor.fetchone()
if not default_feature:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的功能编码",
)
# 检查租户是否已有配置
cursor.execute(
"""
SELECT id, is_enabled FROM feature_switches
WHERE tenant_id = %s AND feature_code = %s
""",
(tenant_id, feature_code)
)
existing = cursor.fetchone()
if existing:
# 更新
old_enabled = existing["is_enabled"]
update_fields = []
update_values = []
if data.is_enabled is not None:
update_fields.append("is_enabled = %s")
update_values.append(data.is_enabled)
if data.config is not None:
update_fields.append("config = %s")
update_values.append(json.dumps(data.config))
if update_fields:
update_values.append(existing["id"])
cursor.execute(
f"UPDATE feature_switches SET {', '.join(update_fields)} WHERE id = %s",
update_values
)
else:
# 创建租户配置
old_enabled = default_feature["is_enabled"]
cursor.execute(
"""
INSERT INTO feature_switches
(tenant_id, feature_code, feature_name, feature_group, is_enabled, config, description)
VALUES (%s, %s, %s, %s, %s, %s, %s)
""",
(tenant_id, feature_code, default_feature["feature_name"],
default_feature["feature_group"],
data.is_enabled if data.is_enabled is not None else default_feature["is_enabled"],
json.dumps(data.config) if data.config else default_feature["config"],
default_feature["description"])
)
# 记录操作日志
log_operation(
cursor, admin, tenant_id, tenant["code"],
"update", "feature", tenant_id, feature_code,
old_value={"is_enabled": old_enabled},
new_value={"is_enabled": data.is_enabled, "config": data.config}
)
conn.commit()
status_text = "启用" if data.is_enabled else "禁用"
return ResponseModel(message=f"功能 {default_feature['feature_name']}{status_text}")
finally:
conn.close()
@router.delete("/tenants/{tenant_id}/{feature_code}", response_model=ResponseModel, summary="重置租户功能开关")
async def reset_tenant_feature(
tenant_id: int,
feature_code: str,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""重置租户的功能开关为默认值"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证租户存在
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 删除租户配置
cursor.execute(
"""
DELETE FROM feature_switches
WHERE tenant_id = %s AND feature_code = %s
""",
(tenant_id, feature_code)
)
if cursor.rowcount == 0:
return ResponseModel(message="功能配置已是默认值")
# 记录操作日志
log_operation(
cursor, admin, tenant_id, tenant["code"],
"reset", "feature", tenant_id, feature_code
)
conn.commit()
return ResponseModel(message="功能配置已重置为默认值")
finally:
conn.close()
@router.post("/tenants/{tenant_id}/batch", response_model=ResponseModel, summary="批量更新功能开关")
async def batch_update_tenant_features(
tenant_id: int,
features: List[Dict],
admin: AdminUserInfo = Depends(get_current_admin),
):
"""
批量更新租户的功能开关
请求体格式:
[
{"feature_code": "exam_module", "is_enabled": true},
{"feature_code": "practice_voice", "is_enabled": false}
]
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证租户存在
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
updated_count = 0
for feature in features:
feature_code = feature.get("feature_code")
is_enabled = feature.get("is_enabled")
if not feature_code or is_enabled is None:
continue
# 获取默认配置
cursor.execute(
"""
SELECT * FROM feature_switches
WHERE tenant_id IS NULL AND feature_code = %s
""",
(feature_code,)
)
default_feature = cursor.fetchone()
if not default_feature:
continue
# 检查租户是否已有配置
cursor.execute(
"""
SELECT id FROM feature_switches
WHERE tenant_id = %s AND feature_code = %s
""",
(tenant_id, feature_code)
)
existing = cursor.fetchone()
if existing:
cursor.execute(
"UPDATE feature_switches SET is_enabled = %s WHERE id = %s",
(is_enabled, existing["id"])
)
else:
cursor.execute(
"""
INSERT INTO feature_switches
(tenant_id, feature_code, feature_name, feature_group, is_enabled, description)
VALUES (%s, %s, %s, %s, %s, %s)
""",
(tenant_id, feature_code, default_feature["feature_name"],
default_feature["feature_group"], is_enabled, default_feature["description"])
)
updated_count += 1
# 记录操作日志
log_operation(
cursor, admin, tenant_id, tenant["code"],
"batch_update", "feature", tenant_id, f"批量更新 {updated_count} 项功能开关"
)
conn.commit()
return ResponseModel(message=f"已更新 {updated_count} 项功能开关")
finally:
conn.close()

View File

@@ -0,0 +1,637 @@
"""
AI 提示词管理 API
"""
import os
import json
from typing import Optional, List
import pymysql
from fastapi import APIRouter, Depends, HTTPException, status, Query
from .auth import get_current_admin, require_superadmin, get_db_connection, AdminUserInfo
from .schemas import (
AIPromptCreate,
AIPromptUpdate,
AIPromptResponse,
AIPromptVersionResponse,
TenantPromptResponse,
TenantPromptUpdate,
ResponseModel,
)
router = APIRouter(prefix="/prompts", tags=["提示词管理"])
def log_operation(cursor, admin: AdminUserInfo, tenant_id: int, tenant_code: str,
operation_type: str, resource_type: str, resource_id: int,
resource_name: str, old_value: dict = None, new_value: dict = None):
"""记录操作日志"""
cursor.execute(
"""
INSERT INTO operation_logs
(admin_user_id, admin_username, tenant_id, tenant_code, operation_type,
resource_type, resource_id, resource_name, old_value, new_value)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(admin.id, admin.username, tenant_id, tenant_code, operation_type,
resource_type, resource_id, resource_name,
json.dumps(old_value, ensure_ascii=False) if old_value else None,
json.dumps(new_value, ensure_ascii=False) if new_value else None)
)
@router.get("", response_model=List[AIPromptResponse], summary="获取提示词列表")
async def list_prompts(
module: Optional[str] = Query(None, description="模块筛选"),
is_active: Optional[bool] = Query(None, description="是否启用"),
admin: AdminUserInfo = Depends(get_current_admin),
):
"""
获取所有 AI 提示词模板
- **module**: 模块筛选course, exam, practice, ability
- **is_active**: 是否启用
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
conditions = []
params = []
if module:
conditions.append("module = %s")
params.append(module)
if is_active is not None:
conditions.append("is_active = %s")
params.append(is_active)
where_clause = " AND ".join(conditions) if conditions else "1=1"
cursor.execute(
f"""
SELECT * FROM ai_prompts
WHERE {where_clause}
ORDER BY module, id
""",
params
)
rows = cursor.fetchall()
result = []
for row in rows:
# 解析 JSON 字段
variables = None
if row.get("variables"):
try:
variables = json.loads(row["variables"])
except:
pass
output_schema = None
if row.get("output_schema"):
try:
output_schema = json.loads(row["output_schema"])
except:
pass
result.append(AIPromptResponse(
id=row["id"],
code=row["code"],
name=row["name"],
description=row["description"],
module=row["module"],
system_prompt=row["system_prompt"],
user_prompt_template=row["user_prompt_template"],
variables=variables,
output_schema=output_schema,
model_recommendation=row["model_recommendation"],
max_tokens=row["max_tokens"],
temperature=float(row["temperature"]) if row["temperature"] else 0.7,
is_system=row["is_system"],
is_active=row["is_active"],
version=row["version"],
created_at=row["created_at"],
updated_at=row["updated_at"],
))
return result
finally:
conn.close()
@router.get("/{prompt_id}", response_model=AIPromptResponse, summary="获取提示词详情")
async def get_prompt(
prompt_id: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""获取提示词详情"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute("SELECT * FROM ai_prompts WHERE id = %s", (prompt_id,))
row = cursor.fetchone()
if not row:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="提示词不存在",
)
# 解析 JSON 字段
variables = None
if row.get("variables"):
try:
variables = json.loads(row["variables"])
except:
pass
output_schema = None
if row.get("output_schema"):
try:
output_schema = json.loads(row["output_schema"])
except:
pass
return AIPromptResponse(
id=row["id"],
code=row["code"],
name=row["name"],
description=row["description"],
module=row["module"],
system_prompt=row["system_prompt"],
user_prompt_template=row["user_prompt_template"],
variables=variables,
output_schema=output_schema,
model_recommendation=row["model_recommendation"],
max_tokens=row["max_tokens"],
temperature=float(row["temperature"]) if row["temperature"] else 0.7,
is_system=row["is_system"],
is_active=row["is_active"],
version=row["version"],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
finally:
conn.close()
@router.post("", response_model=AIPromptResponse, summary="创建提示词")
async def create_prompt(
data: AIPromptCreate,
admin: AdminUserInfo = Depends(require_superadmin),
):
"""
创建新的提示词模板
需要超级管理员权限
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 检查编码是否已存在
cursor.execute("SELECT id FROM ai_prompts WHERE code = %s", (data.code,))
if cursor.fetchone():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="提示词编码已存在",
)
# 创建提示词
cursor.execute(
"""
INSERT INTO ai_prompts
(code, name, description, module, system_prompt, user_prompt_template,
variables, output_schema, model_recommendation, max_tokens, temperature,
is_system, created_by)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, FALSE, %s)
""",
(data.code, data.name, data.description, data.module,
data.system_prompt, data.user_prompt_template,
json.dumps(data.variables) if data.variables else None,
json.dumps(data.output_schema) if data.output_schema else None,
data.model_recommendation, data.max_tokens, data.temperature,
admin.id)
)
prompt_id = cursor.lastrowid
# 记录操作日志
log_operation(
cursor, admin, None, None,
"create", "prompt", prompt_id, data.name,
new_value=data.model_dump()
)
conn.commit()
return await get_prompt(prompt_id, admin)
finally:
conn.close()
@router.put("/{prompt_id}", response_model=AIPromptResponse, summary="更新提示词")
async def update_prompt(
prompt_id: int,
data: AIPromptUpdate,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""
更新提示词模板
更新会自动保存版本历史
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 获取原提示词
cursor.execute("SELECT * FROM ai_prompts WHERE id = %s", (prompt_id,))
old_prompt = cursor.fetchone()
if not old_prompt:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="提示词不存在",
)
# 保存版本历史(如果系统提示词或用户提示词有变化)
if data.system_prompt or data.user_prompt_template:
new_version = old_prompt["version"] + 1
cursor.execute(
"""
INSERT INTO ai_prompt_versions
(prompt_id, version, system_prompt, user_prompt_template, variables,
output_schema, change_summary, created_by)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""",
(prompt_id, old_prompt["version"],
old_prompt["system_prompt"], old_prompt["user_prompt_template"],
old_prompt["variables"], old_prompt["output_schema"],
f"版本 {old_prompt['version']} 备份",
admin.id)
)
else:
new_version = old_prompt["version"]
# 构建更新语句
update_fields = []
update_values = []
if data.name is not None:
update_fields.append("name = %s")
update_values.append(data.name)
if data.description is not None:
update_fields.append("description = %s")
update_values.append(data.description)
if data.system_prompt is not None:
update_fields.append("system_prompt = %s")
update_values.append(data.system_prompt)
if data.user_prompt_template is not None:
update_fields.append("user_prompt_template = %s")
update_values.append(data.user_prompt_template)
if data.variables is not None:
update_fields.append("variables = %s")
update_values.append(json.dumps(data.variables))
if data.output_schema is not None:
update_fields.append("output_schema = %s")
update_values.append(json.dumps(data.output_schema))
if data.model_recommendation is not None:
update_fields.append("model_recommendation = %s")
update_values.append(data.model_recommendation)
if data.max_tokens is not None:
update_fields.append("max_tokens = %s")
update_values.append(data.max_tokens)
if data.temperature is not None:
update_fields.append("temperature = %s")
update_values.append(data.temperature)
if data.is_active is not None:
update_fields.append("is_active = %s")
update_values.append(data.is_active)
if not update_fields:
return await get_prompt(prompt_id, admin)
# 更新版本号
if data.system_prompt or data.user_prompt_template:
update_fields.append("version = %s")
update_values.append(new_version)
update_fields.append("updated_by = %s")
update_values.append(admin.id)
update_values.append(prompt_id)
cursor.execute(
f"UPDATE ai_prompts SET {', '.join(update_fields)} WHERE id = %s",
update_values
)
# 记录操作日志
log_operation(
cursor, admin, None, None,
"update", "prompt", prompt_id, old_prompt["name"],
old_value={"version": old_prompt["version"]},
new_value=data.model_dump(exclude_unset=True)
)
conn.commit()
return await get_prompt(prompt_id, admin)
finally:
conn.close()
@router.get("/{prompt_id}/versions", response_model=List[AIPromptVersionResponse], summary="获取提示词版本历史")
async def get_prompt_versions(
prompt_id: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""获取提示词的版本历史"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM ai_prompt_versions
WHERE prompt_id = %s
ORDER BY version DESC
""",
(prompt_id,)
)
rows = cursor.fetchall()
result = []
for row in rows:
variables = None
if row.get("variables"):
try:
variables = json.loads(row["variables"])
except:
pass
result.append(AIPromptVersionResponse(
id=row["id"],
prompt_id=row["prompt_id"],
version=row["version"],
system_prompt=row["system_prompt"],
user_prompt_template=row["user_prompt_template"],
variables=variables,
change_summary=row["change_summary"],
created_at=row["created_at"],
))
return result
finally:
conn.close()
@router.post("/{prompt_id}/rollback/{version}", response_model=AIPromptResponse, summary="回滚提示词版本")
async def rollback_prompt_version(
prompt_id: int,
version: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""回滚到指定版本的提示词"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 获取指定版本
cursor.execute(
"""
SELECT * FROM ai_prompt_versions
WHERE prompt_id = %s AND version = %s
""",
(prompt_id, version)
)
version_row = cursor.fetchone()
if not version_row:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="版本不存在",
)
# 获取当前提示词
cursor.execute("SELECT * FROM ai_prompts WHERE id = %s", (prompt_id,))
current = cursor.fetchone()
if not current:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="提示词不存在",
)
# 保存当前版本到历史
new_version = current["version"] + 1
cursor.execute(
"""
INSERT INTO ai_prompt_versions
(prompt_id, version, system_prompt, user_prompt_template, variables,
output_schema, change_summary, created_by)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""",
(prompt_id, current["version"],
current["system_prompt"], current["user_prompt_template"],
current["variables"], current["output_schema"],
f"回滚前备份(版本 {current['version']}",
admin.id)
)
# 回滚
cursor.execute(
"""
UPDATE ai_prompts
SET system_prompt = %s, user_prompt_template = %s, variables = %s,
output_schema = %s, version = %s, updated_by = %s
WHERE id = %s
""",
(version_row["system_prompt"], version_row["user_prompt_template"],
version_row["variables"], version_row["output_schema"],
new_version, admin.id, prompt_id)
)
# 记录操作日志
log_operation(
cursor, admin, None, None,
"rollback", "prompt", prompt_id, current["name"],
old_value={"version": current["version"]},
new_value={"version": new_version, "rollback_from": version}
)
conn.commit()
return await get_prompt(prompt_id, admin)
finally:
conn.close()
@router.get("/tenants/{tenant_id}", response_model=List[TenantPromptResponse], summary="获取租户自定义提示词")
async def get_tenant_prompts(
tenant_id: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""获取租户的自定义提示词列表"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT tp.*, ap.code as prompt_code, ap.name as prompt_name
FROM tenant_prompts tp
JOIN ai_prompts ap ON tp.prompt_id = ap.id
WHERE tp.tenant_id = %s
ORDER BY ap.module, ap.id
""",
(tenant_id,)
)
rows = cursor.fetchall()
return [
TenantPromptResponse(
id=row["id"],
tenant_id=row["tenant_id"],
prompt_id=row["prompt_id"],
prompt_code=row["prompt_code"],
prompt_name=row["prompt_name"],
system_prompt=row["system_prompt"],
user_prompt_template=row["user_prompt_template"],
is_active=row["is_active"],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
for row in rows
]
finally:
conn.close()
@router.put("/tenants/{tenant_id}/{prompt_id}", response_model=ResponseModel, summary="更新租户自定义提示词")
async def update_tenant_prompt(
tenant_id: int,
prompt_id: int,
data: TenantPromptUpdate,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""创建或更新租户的自定义提示词"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 验证租户存在
cursor.execute("SELECT code FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 验证提示词存在
cursor.execute("SELECT name FROM ai_prompts WHERE id = %s", (prompt_id,))
prompt = cursor.fetchone()
if not prompt:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="提示词不存在",
)
# 检查是否已有自定义
cursor.execute(
"""
SELECT id FROM tenant_prompts
WHERE tenant_id = %s AND prompt_id = %s
""",
(tenant_id, prompt_id)
)
existing = cursor.fetchone()
if existing:
# 更新
update_fields = []
update_values = []
if data.system_prompt is not None:
update_fields.append("system_prompt = %s")
update_values.append(data.system_prompt)
if data.user_prompt_template is not None:
update_fields.append("user_prompt_template = %s")
update_values.append(data.user_prompt_template)
if data.is_active is not None:
update_fields.append("is_active = %s")
update_values.append(data.is_active)
if update_fields:
update_fields.append("updated_by = %s")
update_values.append(admin.id)
update_values.append(existing["id"])
cursor.execute(
f"UPDATE tenant_prompts SET {', '.join(update_fields)} WHERE id = %s",
update_values
)
else:
# 创建
cursor.execute(
"""
INSERT INTO tenant_prompts
(tenant_id, prompt_id, system_prompt, user_prompt_template, is_active, created_by)
VALUES (%s, %s, %s, %s, %s, %s)
""",
(tenant_id, prompt_id, data.system_prompt, data.user_prompt_template,
data.is_active if data.is_active is not None else True, admin.id)
)
# 记录操作日志
log_operation(
cursor, admin, tenant_id, tenant["code"],
"update", "tenant_prompt", prompt_id, prompt["name"],
new_value=data.model_dump(exclude_unset=True)
)
conn.commit()
return ResponseModel(message="自定义提示词已保存")
finally:
conn.close()
@router.delete("/tenants/{tenant_id}/{prompt_id}", response_model=ResponseModel, summary="删除租户自定义提示词")
async def delete_tenant_prompt(
tenant_id: int,
prompt_id: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""删除租户的自定义提示词(恢复使用默认)"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
DELETE FROM tenant_prompts
WHERE tenant_id = %s AND prompt_id = %s
""",
(tenant_id, prompt_id)
)
if cursor.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="自定义提示词不存在",
)
conn.commit()
return ResponseModel(message="自定义提示词已删除,将使用默认模板")
finally:
conn.close()

View File

@@ -0,0 +1,352 @@
"""
管理后台数据模型
"""
from datetime import datetime
from typing import Optional, List, Any, Dict
from pydantic import BaseModel, Field
# ============================================
# 通用模型
# ============================================
class ResponseModel(BaseModel):
"""通用响应模型"""
code: int = 0
message: str = "success"
data: Optional[Any] = None
class PaginationParams(BaseModel):
"""分页参数"""
page: int = Field(default=1, ge=1)
page_size: int = Field(default=20, ge=1, le=100)
class PaginatedResponse(BaseModel):
"""分页响应"""
items: List[Any]
total: int
page: int
page_size: int
total_pages: int
# ============================================
# 认证相关
# ============================================
class AdminLoginRequest(BaseModel):
"""管理员登录请求"""
username: str = Field(..., min_length=1, max_length=50)
password: str = Field(..., min_length=6)
class AdminLoginResponse(BaseModel):
"""管理员登录响应"""
access_token: str
token_type: str = "bearer"
expires_in: int
admin_user: "AdminUserInfo"
class AdminUserInfo(BaseModel):
"""管理员信息"""
id: int
username: str
email: Optional[str]
full_name: Optional[str]
role: str
last_login_at: Optional[datetime]
class AdminChangePasswordRequest(BaseModel):
"""修改密码请求"""
old_password: str = Field(..., min_length=6)
new_password: str = Field(..., min_length=6)
# ============================================
# 租户相关
# ============================================
class TenantBase(BaseModel):
"""租户基础信息"""
code: str = Field(..., min_length=2, max_length=20, pattern=r'^[a-z0-9_]+$')
name: str = Field(..., min_length=1, max_length=100)
display_name: Optional[str] = Field(None, max_length=200)
domain: str = Field(..., min_length=1, max_length=200)
logo_url: Optional[str] = None
favicon_url: Optional[str] = None
contact_name: Optional[str] = None
contact_phone: Optional[str] = None
contact_email: Optional[str] = None
industry: str = Field(default="medical_beauty")
remarks: Optional[str] = None
class TenantCreate(TenantBase):
"""创建租户请求"""
pass
class TenantUpdate(BaseModel):
"""更新租户请求"""
name: Optional[str] = Field(None, min_length=1, max_length=100)
display_name: Optional[str] = Field(None, max_length=200)
domain: Optional[str] = Field(None, min_length=1, max_length=200)
logo_url: Optional[str] = None
favicon_url: Optional[str] = None
contact_name: Optional[str] = None
contact_phone: Optional[str] = None
contact_email: Optional[str] = None
industry: Optional[str] = None
status: Optional[str] = None
expire_at: Optional[datetime] = None
remarks: Optional[str] = None
class TenantResponse(TenantBase):
"""租户响应"""
id: int
status: str
expire_at: Optional[datetime]
created_at: datetime
updated_at: datetime
config_count: int = 0 # 配置项数量
class Config:
from_attributes = True
class TenantListResponse(BaseModel):
"""租户列表响应"""
items: List[TenantResponse]
total: int
page: int
page_size: int
# ============================================
# 配置相关
# ============================================
class ConfigTemplateResponse(BaseModel):
"""配置模板响应"""
id: int
config_group: str
config_key: str
display_name: str
description: Optional[str]
value_type: str
default_value: Optional[str]
is_required: bool
is_secret: bool
options: Optional[List[str]]
sort_order: int
class TenantConfigBase(BaseModel):
"""租户配置基础"""
config_group: str
config_key: str
config_value: Optional[str] = None
class TenantConfigCreate(TenantConfigBase):
"""创建租户配置请求"""
pass
class TenantConfigUpdate(BaseModel):
"""更新租户配置请求"""
config_value: Optional[str] = None
class TenantConfigResponse(TenantConfigBase):
"""租户配置响应"""
id: int
value_type: str
is_encrypted: bool
description: Optional[str]
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
# 从模板获取的额外信息
display_name: Optional[str] = None
is_required: bool = False
is_secret: bool = False
class Config:
from_attributes = True
class TenantConfigGroupResponse(BaseModel):
"""租户配置分组响应"""
group_name: str
group_display_name: str
configs: List[TenantConfigResponse]
class ConfigBatchUpdate(BaseModel):
"""批量更新配置请求"""
configs: List[TenantConfigCreate]
# ============================================
# 提示词相关
# ============================================
class AIPromptBase(BaseModel):
"""AI提示词基础"""
code: str = Field(..., min_length=1, max_length=50)
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = None
module: str
system_prompt: str
user_prompt_template: Optional[str] = None
variables: Optional[List[str]] = None
output_schema: Optional[Dict] = None
model_recommendation: Optional[str] = None
max_tokens: int = 4096
temperature: float = 0.7
class AIPromptCreate(AIPromptBase):
"""创建提示词请求"""
pass
class AIPromptUpdate(BaseModel):
"""更新提示词请求"""
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
system_prompt: Optional[str] = None
user_prompt_template: Optional[str] = None
variables: Optional[List[str]] = None
output_schema: Optional[Dict] = None
model_recommendation: Optional[str] = None
max_tokens: Optional[int] = None
temperature: Optional[float] = None
is_active: Optional[bool] = None
class AIPromptResponse(AIPromptBase):
"""提示词响应"""
id: int
is_system: bool
is_active: bool
version: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class AIPromptVersionResponse(BaseModel):
"""提示词版本响应"""
id: int
prompt_id: int
version: int
system_prompt: str
user_prompt_template: Optional[str]
variables: Optional[List[str]]
change_summary: Optional[str]
created_at: datetime
class Config:
from_attributes = True
class TenantPromptResponse(BaseModel):
"""租户自定义提示词响应"""
id: int
tenant_id: int
prompt_id: int
prompt_code: str
prompt_name: str
system_prompt: Optional[str]
user_prompt_template: Optional[str]
is_active: bool
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class TenantPromptUpdate(BaseModel):
"""更新租户自定义提示词"""
system_prompt: Optional[str] = None
user_prompt_template: Optional[str] = None
is_active: Optional[bool] = None
# ============================================
# 功能开关相关
# ============================================
class FeatureSwitchBase(BaseModel):
"""功能开关基础"""
feature_code: str
feature_name: str
feature_group: Optional[str] = None
is_enabled: bool = True
config: Optional[Dict] = None
description: Optional[str] = None
class FeatureSwitchCreate(FeatureSwitchBase):
"""创建功能开关请求"""
pass
class FeatureSwitchUpdate(BaseModel):
"""更新功能开关请求"""
is_enabled: Optional[bool] = None
config: Optional[Dict] = None
class FeatureSwitchResponse(FeatureSwitchBase):
"""功能开关响应"""
id: int
tenant_id: Optional[int]
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class FeatureSwitchGroupResponse(BaseModel):
"""功能开关分组响应"""
group_name: str
group_display_name: str
features: List[FeatureSwitchResponse]
# ============================================
# 操作日志相关
# ============================================
class OperationLogResponse(BaseModel):
"""操作日志响应"""
id: int
admin_username: Optional[str]
tenant_code: Optional[str]
operation_type: str
resource_type: str
resource_name: Optional[str]
old_value: Optional[Dict]
new_value: Optional[Dict]
ip_address: Optional[str]
created_at: datetime
class Config:
from_attributes = True
# 更新前向引用
AdminLoginResponse.model_rebuild()

View File

@@ -0,0 +1,379 @@
"""
租户管理 API
"""
import os
import json
from datetime import datetime
from typing import Optional, List
import pymysql
from fastapi import APIRouter, Depends, HTTPException, status, Query
from .auth import get_current_admin, require_superadmin, get_db_connection, AdminUserInfo
from .schemas import (
TenantCreate,
TenantUpdate,
TenantResponse,
TenantListResponse,
ResponseModel,
)
router = APIRouter(prefix="/tenants", tags=["租户管理"])
def log_operation(cursor, admin: AdminUserInfo, tenant_id: int, tenant_code: str,
operation_type: str, resource_type: str, resource_id: int,
resource_name: str, old_value: dict = None, new_value: dict = None):
"""记录操作日志"""
cursor.execute(
"""
INSERT INTO operation_logs
(admin_user_id, admin_username, tenant_id, tenant_code, operation_type,
resource_type, resource_id, resource_name, old_value, new_value)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(admin.id, admin.username, tenant_id, tenant_code, operation_type,
resource_type, resource_id, resource_name,
json.dumps(old_value, ensure_ascii=False) if old_value else None,
json.dumps(new_value, ensure_ascii=False) if new_value else None)
)
@router.get("", response_model=TenantListResponse, summary="获取租户列表")
async def list_tenants(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
status: Optional[str] = Query(None, description="状态筛选"),
keyword: Optional[str] = Query(None, description="关键词搜索"),
admin: AdminUserInfo = Depends(get_current_admin),
):
"""
获取租户列表
- **page**: 页码
- **page_size**: 每页数量
- **status**: 状态筛选active, inactive, suspended
- **keyword**: 关键词搜索(匹配名称、编码、域名)
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 构建查询条件
conditions = []
params = []
if status:
conditions.append("t.status = %s")
params.append(status)
if keyword:
conditions.append("(t.name LIKE %s OR t.code LIKE %s OR t.domain LIKE %s)")
params.extend([f"%{keyword}%"] * 3)
where_clause = " AND ".join(conditions) if conditions else "1=1"
# 查询总数
cursor.execute(
f"SELECT COUNT(*) as total FROM tenants t WHERE {where_clause}",
params
)
total = cursor.fetchone()["total"]
# 查询列表
offset = (page - 1) * page_size
cursor.execute(
f"""
SELECT t.*,
(SELECT COUNT(*) FROM tenant_configs tc WHERE tc.tenant_id = t.id) as config_count
FROM tenants t
WHERE {where_clause}
ORDER BY t.id DESC
LIMIT %s OFFSET %s
""",
params + [page_size, offset]
)
rows = cursor.fetchall()
items = [TenantResponse(**row) for row in rows]
return TenantListResponse(
items=items,
total=total,
page=page,
page_size=page_size,
)
finally:
conn.close()
@router.get("/{tenant_id}", response_model=TenantResponse, summary="获取租户详情")
async def get_tenant(
tenant_id: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""获取租户详情"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT t.*,
(SELECT COUNT(*) FROM tenant_configs tc WHERE tc.tenant_id = t.id) as config_count
FROM tenants t
WHERE t.id = %s
""",
(tenant_id,)
)
row = cursor.fetchone()
if not row:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
return TenantResponse(**row)
finally:
conn.close()
@router.post("", response_model=TenantResponse, summary="创建租户")
async def create_tenant(
data: TenantCreate,
admin: AdminUserInfo = Depends(require_superadmin),
):
"""
创建新租户
需要超级管理员权限
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 检查编码是否已存在
cursor.execute("SELECT id FROM tenants WHERE code = %s", (data.code,))
if cursor.fetchone():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="租户编码已存在",
)
# 检查域名是否已存在
cursor.execute("SELECT id FROM tenants WHERE domain = %s", (data.domain,))
if cursor.fetchone():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="域名已被使用",
)
# 创建租户
cursor.execute(
"""
INSERT INTO tenants
(code, name, display_name, domain, logo_url, favicon_url,
contact_name, contact_phone, contact_email, industry, remarks, created_by)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(data.code, data.name, data.display_name, data.domain,
data.logo_url, data.favicon_url, data.contact_name,
data.contact_phone, data.contact_email, data.industry,
data.remarks, admin.id)
)
tenant_id = cursor.lastrowid
# 记录操作日志
log_operation(
cursor, admin, tenant_id, data.code,
"create", "tenant", tenant_id, data.name,
new_value=data.model_dump()
)
conn.commit()
# 返回创建的租户
return await get_tenant(tenant_id, admin)
finally:
conn.close()
@router.put("/{tenant_id}", response_model=TenantResponse, summary="更新租户")
async def update_tenant(
tenant_id: int,
data: TenantUpdate,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""更新租户信息"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 获取原租户信息
cursor.execute("SELECT * FROM tenants WHERE id = %s", (tenant_id,))
old_tenant = cursor.fetchone()
if not old_tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 如果更新域名,检查是否已被使用
if data.domain and data.domain != old_tenant["domain"]:
cursor.execute(
"SELECT id FROM tenants WHERE domain = %s AND id != %s",
(data.domain, tenant_id)
)
if cursor.fetchone():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="域名已被使用",
)
# 构建更新语句
update_fields = []
update_values = []
for field, value in data.model_dump(exclude_unset=True).items():
if value is not None:
update_fields.append(f"{field} = %s")
update_values.append(value)
if not update_fields:
return await get_tenant(tenant_id, admin)
update_fields.append("updated_by = %s")
update_values.append(admin.id)
update_values.append(tenant_id)
cursor.execute(
f"UPDATE tenants SET {', '.join(update_fields)} WHERE id = %s",
update_values
)
# 记录操作日志
log_operation(
cursor, admin, tenant_id, old_tenant["code"],
"update", "tenant", tenant_id, old_tenant["name"],
old_value=dict(old_tenant),
new_value=data.model_dump(exclude_unset=True)
)
conn.commit()
return await get_tenant(tenant_id, admin)
finally:
conn.close()
@router.delete("/{tenant_id}", response_model=ResponseModel, summary="删除租户")
async def delete_tenant(
tenant_id: int,
admin: AdminUserInfo = Depends(require_superadmin),
):
"""
删除租户
需要超级管理员权限
警告:此操作将删除租户及其所有配置
"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
# 获取租户信息
cursor.execute("SELECT * FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 记录操作日志
log_operation(
cursor, admin, tenant_id, tenant["code"],
"delete", "tenant", tenant_id, tenant["name"],
old_value=dict(tenant)
)
# 删除租户(级联删除配置)
cursor.execute("DELETE FROM tenants WHERE id = %s", (tenant_id,))
conn.commit()
return ResponseModel(message=f"租户 {tenant['name']} 已删除")
finally:
conn.close()
@router.post("/{tenant_id}/enable", response_model=ResponseModel, summary="启用租户")
async def enable_tenant(
tenant_id: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""启用租户"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"UPDATE tenants SET status = 'active', updated_by = %s WHERE id = %s",
(admin.id, tenant_id)
)
if cursor.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 获取租户信息并记录日志
cursor.execute("SELECT code, name FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
log_operation(
cursor, admin, tenant_id, tenant["code"],
"enable", "tenant", tenant_id, tenant["name"]
)
conn.commit()
return ResponseModel(message="租户已启用")
finally:
conn.close()
@router.post("/{tenant_id}/disable", response_model=ResponseModel, summary="禁用租户")
async def disable_tenant(
tenant_id: int,
admin: AdminUserInfo = Depends(get_current_admin),
):
"""禁用租户"""
conn = get_db_connection()
try:
with conn.cursor() as cursor:
cursor.execute(
"UPDATE tenants SET status = 'inactive', updated_by = %s WHERE id = %s",
(admin.id, tenant_id)
)
if cursor.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="租户不存在",
)
# 获取租户信息并记录日志
cursor.execute("SELECT code, name FROM tenants WHERE id = %s", (tenant_id,))
tenant = cursor.fetchone()
log_operation(
cursor, admin, tenant_id, tenant["code"],
"disable", "tenant", tenant_id, tenant["name"]
)
conn.commit()
return ResponseModel(message="租户已禁用")
finally:
conn.close()

View File

@@ -0,0 +1,158 @@
# 此文件备份了admin.py中的positions相关路由代码
# 这些路由已移至positions.py为避免冲突从admin.py中移除
@router.get("/positions")
async def list_positions(
keyword: Optional[str] = Query(None, description="关键词"),
page: int = Query(1, ge=1),
pageSize: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
_db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取岗位列表stub 数据)
返回结构兼容前端data.list/total/page/pageSize
"""
not_admin = _ensure_admin(current_user)
if not_admin:
return not_admin
try:
items = _sample_positions()
if keyword:
kw = keyword.lower()
items = [
p for p in items if kw in (p.get("name", "") + p.get("description", "")).lower()
]
total = len(items)
start = (page - 1) * pageSize
end = start + pageSize
page_items = items[start:end]
return ResponseModel(
code=200,
message="获取岗位列表成功",
data={
"list": page_items,
"total": total,
"page": page,
"pageSize": pageSize,
},
)
except Exception as exc:
# 记录错误堆栈由全局异常中间件处理;此处返回统一结构
return ResponseModel(code=500, message=f"服务器错误:{exc}")
@router.get("/positions/tree")
async def get_position_tree(
current_user: User = Depends(get_current_user),
_db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取岗位树stub 数据)
"""
not_admin = _ensure_admin(current_user)
if not_admin:
return not_admin
try:
items = _sample_positions()
id_to_node: Dict[int, Dict[str, Any]] = {}
for p in items:
node = {**p, "children": []}
id_to_node[p["id"]] = node
roots: List[Dict[str, Any]] = []
for p in items:
parent_id = p.get("parentId")
if parent_id and parent_id in id_to_node:
id_to_node[parent_id]["children"].append(id_to_node[p["id"]])
else:
roots.append(id_to_node[p["id"]])
return ResponseModel(code=200, message="获取岗位树成功", data=roots)
except Exception as exc:
return ResponseModel(code=500, message=f"服务器错误:{exc}")
@router.get("/positions/{position_id}")
async def get_position_detail(
position_id: int,
current_user: User = Depends(get_current_user),
_db: AsyncSession = Depends(get_db),
) -> ResponseModel:
not_admin = _ensure_admin(current_user)
if not_admin:
return not_admin
items = _sample_positions()
for p in items:
if p["id"] == position_id:
return ResponseModel(code=200, message="获取岗位详情成功", data=p)
return ResponseModel(code=404, message="岗位不存在")
@router.get("/positions/{position_id}/check-delete")
async def check_position_delete(
position_id: int,
current_user: User = Depends(get_current_user),
_db: AsyncSession = Depends(get_db),
) -> ResponseModel:
not_admin = _ensure_admin(current_user)
if not_admin:
return not_admin
# stub允许删除非根岗位
deletable = position_id != 1
reason = "根岗位不允许删除" if not deletable else ""
return ResponseModel(code=200, message="检查成功", data={"deletable": deletable, "reason": reason})
@router.post("/positions")
async def create_position(
payload: Dict[str, Any],
current_user: User = Depends(get_current_user),
_db: AsyncSession = Depends(get_db),
) -> ResponseModel:
not_admin = _ensure_admin(current_user)
if not_admin:
return not_admin
# stub直接回显并附带一个伪ID
payload = dict(payload)
payload.setdefault("id", 999)
payload.setdefault("createTime", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
return ResponseModel(code=200, message="创建岗位成功", data=payload)
@router.put("/positions/{position_id}")
async def update_position(
position_id: int,
payload: Dict[str, Any],
current_user: User = Depends(get_current_user),
_db: AsyncSession = Depends(get_db),
) -> ResponseModel:
not_admin = _ensure_admin(current_user)
if not_admin:
return not_admin
# stub直接回显
updated = {"id": position_id, **payload}
return ResponseModel(code=200, message="更新岗位成功", data=updated)
@router.delete("/positions/{position_id}")
async def delete_position(
position_id: int,
current_user: User = Depends(get_current_user),
_db: AsyncSession = Depends(get_db),
) -> ResponseModel:
not_admin = _ensure_admin(current_user)
if not_admin:
return not_admin
# stub直接返回成功
return ResponseModel(code=200, message="删除岗位成功", data={"id": position_id})

156
backend/app/api/v1/auth.py Normal file
View File

@@ -0,0 +1,156 @@
"""
认证 API
"""
from fastapi import APIRouter, Depends, status, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_current_active_user, get_db
from app.core.logger import logger
from app.models.user import User
from app.schemas.auth import LoginRequest, RefreshTokenRequest, Token
from app.schemas.base import ResponseModel
from app.schemas.user import User as UserSchema
from app.services.auth_service import AuthService
from app.services.system_log_service import system_log_service
from app.schemas.system_log import SystemLogCreate
from app.core.exceptions import UnauthorizedError
router = APIRouter()
@router.post("/login", response_model=ResponseModel)
async def login(
login_data: LoginRequest,
request: Request,
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
用户登录
支持使用用户名、邮箱或手机号登录
"""
auth_service = AuthService(db)
try:
user, token = await auth_service.login(
username=login_data.username,
password=login_data.password,
)
# 记录登录成功日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="security",
message=f"用户 {user.username} 登录成功",
user_id=user.id,
user=user.username,
ip=request.client.host if request.client else None,
path="/api/v1/auth/login",
method="POST",
user_agent=request.headers.get("user-agent")
)
)
return ResponseModel(
message="登录成功",
data={
"user": UserSchema.model_validate(user).model_dump(),
"token": token.model_dump(),
},
)
except UnauthorizedError as e:
# 记录登录失败日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="WARNING",
type="security",
message=f"用户 {login_data.username} 登录失败:密码错误",
user=login_data.username,
ip=request.client.host if request.client else None,
path="/api/v1/auth/login",
method="POST",
user_agent=request.headers.get("user-agent")
)
)
# 不返回 401统一返回 HTTP 200 + 业务失败码,便于前端友好提示
logger.warning("login_failed_wrong_credentials", username=login_data.username)
return ResponseModel(
code=400,
message=str(e) or "用户名或密码错误",
data=None,
)
except Exception as e:
logger.error("login_failed_unexpected", error=str(e))
return ResponseModel(
code=500,
message="登录失败,请稍后重试",
data=None,
)
@router.post("/refresh", response_model=ResponseModel)
async def refresh_token(
refresh_data: RefreshTokenRequest,
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
刷新访问令牌
使用刷新令牌获取新的访问令牌
"""
auth_service = AuthService(db)
token = await auth_service.refresh_token(refresh_data.refresh_token)
return ResponseModel(message="令牌刷新成功", data=token.model_dump())
@router.post("/logout", response_model=ResponseModel)
async def logout(
request: Request,
current_user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
用户登出
注意:客户端需要删除本地存储的令牌
"""
auth_service = AuthService(db)
await auth_service.logout(current_user.id)
# 记录登出日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="security",
message=f"用户 {current_user.username} 登出",
user_id=current_user.id,
user=current_user.username,
ip=request.client.host if request.client else None,
path="/api/v1/auth/logout",
method="POST",
user_agent=request.headers.get("user-agent")
)
)
return ResponseModel(message="登出成功")
@router.get("/verify", response_model=ResponseModel)
async def verify_token(
current_user: User = Depends(get_current_active_user),
) -> ResponseModel:
"""
验证令牌
用于检查当前令牌是否有效
"""
return ResponseModel(
message="令牌有效",
data={
"user": UserSchema.model_validate(current_user).model_dump(),
},
)

View File

@@ -0,0 +1,145 @@
"""
播课功能 API 接口
"""
import logging
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, get_current_user, require_admin_or_manager
from app.schemas.base import ResponseModel
from app.models.course import Course
from app.models.user import User
from app.services.coze_broadcast_service import broadcast_service
logger = logging.getLogger(__name__)
router = APIRouter()
# Schema 定义
class GenerateBroadcastResponse(BaseModel):
"""生成播课响应"""
message: str = Field(..., description="提示信息")
class BroadcastInfo(BaseModel):
"""播课信息"""
has_broadcast: bool = Field(..., description="是否有播课")
mp3_url: Optional[str] = Field(None, description="播课音频URL")
generated_at: Optional[datetime] = Field(None, description="生成时间")
@router.post("/courses/{course_id}/generate-broadcast", response_model=ResponseModel[GenerateBroadcastResponse])
async def generate_broadcast(
course_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_admin_or_manager)
):
"""
触发播课音频生成立即返回Coze工作流会直接写数据库
权限manager、admin
Args:
course_id: 课程ID
db: 数据库会话
current_user: 当前用户
Returns:
启动提示信息
Raises:
HTTPException 404: 课程不存在
"""
logger.info(
f"请求生成播课",
extra={"course_id": course_id, "user_id": current_user.id}
)
# 查询课程
result = await db.execute(
select(Course)
.where(Course.id == course_id)
.where(Course.is_deleted == False)
)
course = result.scalar_one_or_none()
if not course:
logger.warning(f"课程不存在", extra={"course_id": course_id})
raise HTTPException(status_code=404, detail="课程不存在")
# 调用 Coze 工作流(不等待结果,工作流会直接写数据库)
try:
await broadcast_service.trigger_workflow(course_id)
logger.info(
f"播课生成工作流已触发",
extra={"course_id": course_id, "user_id": current_user.id}
)
return ResponseModel(
code=200,
message="播课生成已启动",
data=GenerateBroadcastResponse(
message="播课生成工作流已启动,生成完成后将自动更新"
)
)
except Exception as e:
logger.error(
f"触发播课生成失败",
extra={"course_id": course_id, "error": str(e)}
)
raise HTTPException(status_code=500, detail=f"触发播课生成失败: {str(e)}")
@router.get("/courses/{course_id}/broadcast", response_model=ResponseModel[BroadcastInfo])
async def get_broadcast_info(
course_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取播课信息
权限:所有登录用户
Args:
course_id: 课程ID
db: 数据库会话
current_user: 当前用户
Returns:
播课信息
Raises:
HTTPException 404: 课程不存在
"""
# 查询课程
result = await db.execute(
select(Course)
.where(Course.id == course_id)
.where(Course.is_deleted == False)
)
course = result.scalar_one_or_none()
if not course:
raise HTTPException(status_code=404, detail="课程不存在")
# 构建播课信息
has_broadcast = bool(course.broadcast_audio_url)
return ResponseModel(
code=200,
message="success",
data=BroadcastInfo(
has_broadcast=has_broadcast,
mp3_url=course.broadcast_audio_url if has_broadcast else None,
generated_at=course.broadcast_generated_at if has_broadcast else None
)
)

View File

@@ -0,0 +1,190 @@
"""
与课程对话 API
使用 Python 原生 AI 服务实现
"""
import json
import logging
from typing import Optional, Any
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, get_current_user
from app.models.user import User
from app.services.ai.course_chat_service import course_chat_service_v2
router = APIRouter()
logger = logging.getLogger(__name__)
class CourseChatRequest(BaseModel):
"""课程对话请求"""
course_id: int = Field(..., description="课程ID")
query: str = Field(..., description="用户问题")
conversation_id: Optional[str] = Field(None, description="会话ID续接对话时传入")
class ResponseModel(BaseModel):
"""通用响应模型"""
code: int = 200
message: str = "success"
data: Optional[Any] = None
async def _chat_with_course(
request: CourseChatRequest,
current_user: User,
db: AsyncSession
):
"""
Python 原生实现的流式对话
"""
logger.info(
f"用户 {current_user.username} 与课程 {request.course_id} 对话: "
f"{request.query[:50]}..."
)
async def generate_stream():
"""生成 SSE 流"""
try:
async for event_type, data in course_chat_service_v2.chat_stream(
db=db,
course_id=request.course_id,
query=request.query,
user_id=current_user.id,
conversation_id=request.conversation_id
):
if event_type == "conversation_started":
yield f"data: {json.dumps({'event': 'conversation_started', 'conversation_id': data})}\n\n"
logger.info(f"会话已创建: {data}")
elif event_type == "chunk":
yield f"data: {json.dumps({'event': 'message_chunk', 'chunk': data})}\n\n"
elif event_type == "done":
yield f"data: {json.dumps({'event': 'message_end', 'message': data})}\n\n"
logger.info(f"对话完成,总长度: {len(data)}")
elif event_type == "error":
yield f"data: {json.dumps({'event': 'error', 'message': data})}\n\n"
logger.error(f"对话错误: {data}")
except Exception as e:
logger.error(f"流式对话异常: {e}", exc_info=True)
yield f"data: {json.dumps({'event': 'error', 'message': str(e)})}\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@router.post("/chat")
async def chat_with_course(
request: CourseChatRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
与课程对话(流式响应)
使用 Python 原生 AI 服务实现,支持多轮对话。
"""
return await _chat_with_course(request, current_user, db)
@router.get("/conversations")
async def get_conversations(
course_id: Optional[int] = None,
limit: int = 20,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取会话列表
返回当前用户的历史会话列表
"""
try:
conversations = await course_chat_service_v2.get_conversations(
user_id=current_user.id,
course_id=course_id,
limit=limit
)
return ResponseModel(
code=200,
message="获取会话列表成功",
data={
"conversations": conversations,
"total": len(conversations)
}
)
except Exception as e:
logger.error(f"获取会话列表失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取会话列表失败: {str(e)}")
@router.get("/messages")
async def get_messages(
conversation_id: str,
limit: int = 50,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取历史消息
返回指定会话的历史消息
"""
try:
messages = await course_chat_service_v2.get_messages(
conversation_id=conversation_id,
user_id=current_user.id,
limit=limit
)
return ResponseModel(
code=200,
message="获取历史消息成功",
data={
"messages": messages,
"total": len(messages)
}
)
except Exception as e:
logger.error(f"获取历史消息失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取历史消息失败: {str(e)}")
@router.get("/engines")
async def list_chat_engines():
"""
获取可用的对话引擎列表
"""
return ResponseModel(
code=200,
message="获取对话引擎列表成功",
data={
"engines": [
{
"id": "native",
"name": "Python 原生实现",
"description": "使用本地 AI 服务4sapi.com + OpenRouter支持流式输出和多轮对话",
"default": True
}
],
"default_engine": "native"
}
)

View File

@@ -0,0 +1,786 @@
"""
课程管理API路由
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, Query, status, BackgroundTasks, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, get_current_user, require_admin, require_admin_or_manager, User
from app.core.exceptions import NotFoundError, BadRequestError
from app.core.logger import get_logger
from app.services.system_log_service import system_log_service
from app.schemas.system_log import SystemLogCreate
from app.models.course import CourseStatus, CourseCategory
from app.schemas.base import ResponseModel, PaginationParams, PaginatedResponse
from app.schemas.course import (
CourseCreate,
CourseUpdate,
CourseInDB,
CourseList,
CourseMaterialCreate,
CourseMaterialInDB,
KnowledgePointCreate,
KnowledgePointUpdate,
KnowledgePointInDB,
GrowthPathCreate,
GrowthPathInDB,
CourseExamSettingsCreate,
CourseExamSettingsUpdate,
CourseExamSettingsInDB,
CoursePositionAssignment,
CoursePositionAssignmentInDB,
)
from app.services.course_service import (
course_service,
knowledge_point_service,
growth_path_service,
)
logger = get_logger(__name__)
router = APIRouter(prefix="/courses", tags=["courses"])
@router.get("", response_model=ResponseModel[PaginatedResponse[CourseInDB]])
async def get_courses(
page: int = Query(1, ge=1, description="页码"),
size: int = Query(20, ge=1, le=100, description="每页数量"),
status: Optional[CourseStatus] = Query(None, description="课程状态"),
category: Optional[CourseCategory] = Query(None, description="课程分类"),
is_featured: Optional[bool] = Query(None, description="是否推荐"),
keyword: Optional[str] = Query(None, description="搜索关键词"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取课程列表(支持分页和筛选)
- **page**: 页码
- **size**: 每页数量
- **status**: 课程状态筛选
- **category**: 课程分类筛选
- **is_featured**: 是否推荐筛选
- **keyword**: 关键词搜索(搜索名称和描述)
"""
page_params = PaginationParams(page=page, page_size=size)
filters = CourseList(
status=status, category=category, is_featured=is_featured, keyword=keyword
)
result = await course_service.get_course_list(
db, page_params=page_params, filters=filters, user_id=current_user.id
)
return ResponseModel(data=result, message="获取课程列表成功")
@router.post(
"", response_model=ResponseModel[CourseInDB], status_code=status.HTTP_201_CREATED
)
async def create_course(
course_in: CourseCreate,
request: Request,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
创建课程(需要管理员权限)
- **name**: 课程名称
- **description**: 课程描述
- **category**: 课程分类
- **status**: 课程状态(默认为草稿)
- **cover_image**: 封面图片URL
- **duration_hours**: 课程时长(小时)
- **difficulty_level**: 难度等级1-5
- **tags**: 标签列表
- **is_featured**: 是否推荐
"""
course = await course_service.create_course(
db, course_in=course_in, created_by=current_user.id
)
# 记录课程创建日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="api",
message=f"创建课程: {course.name}",
user_id=current_user.id,
user=current_user.username,
ip=request.client.host if request.client else None,
path="/api/v1/courses",
method="POST",
user_agent=request.headers.get("user-agent")
)
)
return ResponseModel(data=course, message="创建课程成功")
@router.get("/{course_id}", response_model=ResponseModel[CourseInDB])
async def get_course(
course_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取课程详情
- **course_id**: 课程ID
"""
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
logger.info(f"查看课程详情 - course_id: {course_id}, user_id: {current_user.id}")
return ResponseModel(data=course, message="获取课程详情成功")
@router.put("/{course_id}", response_model=ResponseModel[CourseInDB])
async def update_course(
course_id: int,
course_in: CourseUpdate,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
更新课程(需要管理员权限)
- **course_id**: 课程ID
- **course_in**: 更新的课程数据(所有字段都是可选的)
"""
course = await course_service.update_course(
db, course_id=course_id, course_in=course_in, updated_by=current_user.id
)
return ResponseModel(data=course, message="更新课程成功")
@router.delete("/{course_id}", response_model=ResponseModel[bool])
async def delete_course(
course_id: int,
request: Request,
current_user: User = Depends(require_admin_or_manager),
db: AsyncSession = Depends(get_db),
):
"""
删除课程(需要管理员权限)
- **course_id**: 课程ID
说明任意状态均可软删除is_deleted=1请谨慎操作
"""
# 先获取课程信息
course = await course_service.get_by_id(db, course_id)
course_name = course.name if course else f"ID:{course_id}"
success = await course_service.delete_course(
db, course_id=course_id, deleted_by=current_user.id
)
# 记录课程删除日志
if success:
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="api",
message=f"删除课程: {course_name}",
user_id=current_user.id,
user=current_user.username,
ip=request.client.host if request.client else None,
path=f"/api/v1/courses/{course_id}",
method="DELETE",
user_agent=request.headers.get("user-agent")
)
)
return ResponseModel(data=success, message="删除课程成功" if success else "删除课程失败")
# 课程资料相关API
@router.post(
"/{course_id}/materials",
response_model=ResponseModel[CourseMaterialInDB],
status_code=status.HTTP_201_CREATED,
)
async def add_course_material(
course_id: int,
material_in: CourseMaterialCreate,
background_tasks: BackgroundTasks,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
添加课程资料(需要管理员权限)
- **course_id**: 课程ID
- **name**: 资料名称
- **description**: 资料描述
- **file_url**: 文件URL
- **file_type**: 文件类型pdf, doc, docx, ppt, pptx, xls, xlsx, mp4, mp3, zip
- **file_size**: 文件大小(字节)
添加资料后会自动触发知识点分析
"""
material = await course_service.add_course_material(
db, course_id=course_id, material_in=material_in, created_by=current_user.id
)
# 获取课程信息用于知识点分析
course = await course_service.get_by_id(db, course_id)
if course:
# 异步触发知识点分析
from app.services.ai.knowledge_analysis_v2 import knowledge_analysis_service_v2
background_tasks.add_task(
_trigger_knowledge_analysis,
db,
course_id,
material.id,
material.file_url,
course.name,
current_user.id
)
logger.info(
f"资料添加成功,已触发知识点分析 - course_id: {course_id}, material_id: {material.id}, user_id: {current_user.id}"
)
return ResponseModel(data=material, message="添加课程资料成功")
@router.get(
"/{course_id}/materials",
response_model=ResponseModel[List[CourseMaterialInDB]],
)
async def list_course_materials(
course_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取课程资料列表
- **course_id**: 课程ID
"""
materials = await course_service.get_course_materials(db, course_id=course_id)
return ResponseModel(data=materials, message="获取课程资料列表成功")
@router.delete(
"/{course_id}/materials/{material_id}",
response_model=ResponseModel[bool],
)
async def delete_course_material(
course_id: int,
material_id: int,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
删除课程资料(需要管理员权限)
- **course_id**: 课程ID
- **material_id**: 资料ID
"""
success = await course_service.delete_course_material(
db, course_id=course_id, material_id=material_id, deleted_by=current_user.id
)
return ResponseModel(data=success, message="删除课程资料成功" if success else "删除课程资料失败")
# 知识点相关API
@router.get(
"/{course_id}/knowledge-points",
response_model=ResponseModel[List[KnowledgePointInDB]],
)
async def get_course_knowledge_points(
course_id: int,
material_id: Optional[int] = Query(None, description="资料ID"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取课程的知识点列表
- **course_id**: 课程ID
- **material_id**: 资料ID可选用于筛选特定资料的知识点
"""
# 先检查课程是否存在
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
knowledge_points = await knowledge_point_service.get_knowledge_points_by_course(
db, course_id=course_id, material_id=material_id
)
return ResponseModel(data=knowledge_points, message="获取知识点列表成功")
@router.post(
"/{course_id}/knowledge-points",
response_model=ResponseModel[KnowledgePointInDB],
status_code=status.HTTP_201_CREATED,
)
async def create_knowledge_point(
course_id: int,
point_in: KnowledgePointCreate,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
创建知识点(需要管理员权限)
- **course_id**: 课程ID
- **name**: 知识点名称
- **description**: 知识点描述
- **parent_id**: 父知识点ID
- **weight**: 权重0-10
- **is_required**: 是否必修
- **estimated_hours**: 预计学习时间(小时)
"""
knowledge_point = await knowledge_point_service.create_knowledge_point(
db, course_id=course_id, point_in=point_in, created_by=current_user.id
)
return ResponseModel(data=knowledge_point, message="创建知识点成功")
@router.put(
"/knowledge-points/{point_id}", response_model=ResponseModel[KnowledgePointInDB]
)
async def update_knowledge_point(
point_id: int,
point_in: KnowledgePointUpdate,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
更新知识点(需要管理员权限)
- **point_id**: 知识点ID
- **point_in**: 更新的知识点数据(所有字段都是可选的)
"""
knowledge_point = await knowledge_point_service.update_knowledge_point(
db, point_id=point_id, point_in=point_in, updated_by=current_user.id
)
return ResponseModel(data=knowledge_point, message="更新知识点成功")
@router.delete("/knowledge-points/{point_id}", response_model=ResponseModel[bool])
async def delete_knowledge_point(
point_id: int,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
删除知识点(需要管理员权限)
- **point_id**: 知识点ID
"""
success = await knowledge_point_service.delete(
db, id=point_id, soft=True, deleted_by=current_user.id
)
if success:
logger.warning("删除知识点", knowledge_point_id=point_id, deleted_by=current_user.id)
return ResponseModel(data=success, message="删除知识点成功" if success else "删除知识点失败")
# 资料知识点关联API
@router.get(
"/materials/{material_id}/knowledge-points",
response_model=ResponseModel[List[KnowledgePointInDB]],
)
async def get_material_knowledge_points(
material_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取资料关联的知识点列表
"""
knowledge_points = await course_service.get_material_knowledge_points(
db, material_id=material_id
)
return ResponseModel(data=knowledge_points, message="获取知识点列表成功")
@router.post(
"/materials/{material_id}/knowledge-points",
response_model=ResponseModel[List[KnowledgePointInDB]],
status_code=status.HTTP_201_CREATED,
)
async def add_material_knowledge_points(
material_id: int,
knowledge_point_ids: List[int],
current_user: User = Depends(require_admin_or_manager),
db: AsyncSession = Depends(get_db),
):
"""
为资料添加知识点关联(需要管理员或经理权限)
"""
knowledge_points = await course_service.add_material_knowledge_points(
db, material_id=material_id, knowledge_point_ids=knowledge_point_ids
)
return ResponseModel(data=knowledge_points, message="添加知识点成功")
@router.delete(
"/materials/{material_id}/knowledge-points/{knowledge_point_id}",
response_model=ResponseModel[bool],
)
async def remove_material_knowledge_point(
material_id: int,
knowledge_point_id: int,
current_user: User = Depends(require_admin_or_manager),
db: AsyncSession = Depends(get_db),
):
"""
移除资料的知识点关联(需要管理员或经理权限)
"""
success = await course_service.remove_material_knowledge_point(
db, material_id=material_id, knowledge_point_id=knowledge_point_id
)
return ResponseModel(data=success, message="移除知识点成功" if success else "移除失败")
# 成长路径相关API
@router.post(
"/growth-paths",
response_model=ResponseModel[GrowthPathInDB],
status_code=status.HTTP_201_CREATED,
)
async def create_growth_path(
path_in: GrowthPathCreate,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
创建成长路径(需要管理员权限)
- **name**: 路径名称
- **description**: 路径描述
- **target_role**: 目标角色
- **courses**: 课程列表包含course_id、order、is_required
- **estimated_duration_days**: 预计完成天数
- **is_active**: 是否启用
"""
growth_path = await growth_path_service.create_growth_path(
db, path_in=path_in, created_by=current_user.id
)
return ResponseModel(data=growth_path, message="创建成长路径成功")
@router.get(
"/growth-paths", response_model=ResponseModel[PaginatedResponse[GrowthPathInDB]]
)
async def get_growth_paths(
page: int = Query(1, ge=1, description="页码"),
size: int = Query(20, ge=1, le=100, description="每页数量"),
is_active: Optional[bool] = Query(None, description="是否启用"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取成长路径列表
- **page**: 页码
- **size**: 每页数量
- **is_active**: 是否启用筛选
"""
page_params = PaginationParams(page=page, page_size=size)
filters = []
if is_active is not None:
from app.models.course import GrowthPath
filters.append(GrowthPath.is_active == is_active)
result = await growth_path_service.get_page(
db, page_params=page_params, filters=filters
)
return ResponseModel(data=result, message="获取成长路径列表成功")
# 课程考试设置相关API
@router.get(
"/{course_id}/exam-settings",
response_model=ResponseModel[Optional[CourseExamSettingsInDB]],
)
async def get_course_exam_settings(
course_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取课程的考试设置
- **course_id**: 课程ID
"""
# 检查课程是否存在
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 获取考试设置
from app.services.course_exam_service import course_exam_service
settings = await course_exam_service.get_by_course_id(db, course_id)
# 添加调试日志
if settings:
logger.info(
f"📊 获取考试设置成功 - course_id: {course_id}, "
f"单选: {settings.single_choice_count}, 多选: {settings.multiple_choice_count}, "
f"判断: {settings.true_false_count}, 填空: {settings.fill_blank_count}, "
f"问答: {settings.essay_count}, 难度: {settings.difficulty_level}"
)
else:
logger.warning(f"⚠️ 课程 {course_id} 没有配置考试设置,将使用默认值")
return ResponseModel(data=settings, message="获取考试设置成功")
@router.post(
"/{course_id}/exam-settings",
response_model=ResponseModel[CourseExamSettingsInDB],
status_code=status.HTTP_201_CREATED,
)
async def create_course_exam_settings(
course_id: int,
settings_in: CourseExamSettingsCreate,
current_user: User = Depends(require_admin_or_manager),
db: AsyncSession = Depends(get_db),
):
"""
创建或更新课程的考试设置(需要管理员权限)
- **course_id**: 课程ID
- **settings_in**: 考试设置数据
"""
# 检查课程是否存在
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 创建或更新考试设置
from app.services.course_exam_service import course_exam_service
settings = await course_exam_service.create_or_update(
db, course_id=course_id, settings_in=settings_in, user_id=current_user.id
)
return ResponseModel(data=settings, message="保存考试设置成功")
@router.put(
"/{course_id}/exam-settings",
response_model=ResponseModel[CourseExamSettingsInDB],
)
async def update_course_exam_settings(
course_id: int,
settings_in: CourseExamSettingsUpdate,
current_user: User = Depends(require_admin_or_manager),
db: AsyncSession = Depends(get_db),
):
"""
更新课程的考试设置(需要管理员权限)
- **course_id**: 课程ID
- **settings_in**: 更新的考试设置数据
"""
# 检查课程是否存在
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 更新考试设置
from app.services.course_exam_service import course_exam_service
settings = await course_exam_service.update(
db, course_id=course_id, settings_in=settings_in, user_id=current_user.id
)
return ResponseModel(data=settings, message="更新考试设置成功")
# 课程岗位分配相关API
@router.get(
"/{course_id}/positions",
response_model=ResponseModel[List[CoursePositionAssignmentInDB]],
)
async def get_course_positions(
course_id: int,
course_type: Optional[str] = Query(None, pattern="^(required|optional)$", description="课程类型筛选"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取课程的岗位分配列表
- **course_id**: 课程ID
- **course_type**: 课程类型筛选required必修/optional选修
"""
# 检查课程是否存在
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 获取岗位分配列表
from app.services.course_position_service import course_position_service
assignments = await course_position_service.get_course_positions(
db, course_id=course_id, course_type=course_type
)
return ResponseModel(data=assignments, message="获取岗位分配列表成功")
@router.post(
"/{course_id}/positions",
response_model=ResponseModel[List[CoursePositionAssignmentInDB]],
status_code=status.HTTP_201_CREATED,
)
async def assign_course_positions(
course_id: int,
assignments: List[CoursePositionAssignment],
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
批量分配课程到岗位(需要管理员权限)
- **course_id**: 课程ID
- **assignments**: 岗位分配列表
"""
# 检查课程是否存在
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 批量分配岗位
from app.services.course_position_service import course_position_service
result = await course_position_service.batch_assign_positions(
db, course_id=course_id, assignments=assignments, user_id=current_user.id
)
# 发送课程分配通知给相关岗位的学员
try:
from app.models.position_member import PositionMember
from app.services.notification_service import notification_service
from app.schemas.notification import NotificationBatchCreate, NotificationType
# 获取所有分配岗位的学员ID
position_ids = [a.position_id for a in assignments]
if position_ids:
member_result = await db.execute(
select(PositionMember.user_id).where(
PositionMember.position_id.in_(position_ids),
PositionMember.is_deleted == False
).distinct()
)
user_ids = [row[0] for row in member_result.fetchall()]
if user_ids:
notification_batch = NotificationBatchCreate(
user_ids=user_ids,
title="新课程通知",
content=f"您所在岗位有新课程「{course.name}」已分配,请及时学习。",
type=NotificationType.COURSE_ASSIGN,
related_id=course_id,
related_type="course",
sender_id=current_user.id
)
await notification_service.batch_create_notifications(
db=db,
batch_in=notification_batch
)
except Exception as e:
# 通知发送失败不影响课程分配结果
import logging
logging.getLogger(__name__).error(f"发送课程分配通知失败: {str(e)}")
return ResponseModel(data=result, message="岗位分配成功")
@router.delete(
"/{course_id}/positions/{position_id}",
response_model=ResponseModel[bool],
)
async def remove_course_position(
course_id: int,
position_id: int,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
移除课程的岗位分配(需要管理员权限)
- **course_id**: 课程ID
- **position_id**: 岗位ID
"""
# 检查课程是否存在
course = await course_service.get_by_id(db, course_id)
if not course:
raise NotFoundError(f"课程ID {course_id} 不存在")
# 移除岗位分配
from app.services.course_position_service import course_position_service
success = await course_position_service.remove_position_assignment(
db, course_id=course_id, position_id=position_id, user_id=current_user.id
)
return ResponseModel(data=success, message="移除岗位分配成功" if success else "移除岗位分配失败")
async def _trigger_knowledge_analysis(
db: AsyncSession,
course_id: int,
material_id: int,
file_url: str,
course_title: str,
user_id: int
):
"""
后台触发知识点分析任务
注意:此函数在后台任务中执行,异常不会影响资料添加的成功响应
"""
try:
from app.services.ai.knowledge_analysis_v2 import knowledge_analysis_service_v2
logger.info(
f"后台知识点分析开始 - course_id: {course_id}, material_id: {material_id}, file_url: {file_url}, user_id: {user_id}"
)
result = await knowledge_analysis_service_v2.analyze_course_material(
db=db,
course_id=course_id,
material_id=material_id,
file_url=file_url,
course_title=course_title,
user_id=user_id
)
logger.info(
f"后台知识点分析完成 - course_id: {course_id}, material_id: {material_id}, knowledge_points_count: {result.get('knowledge_points_count', 0)}, user_id: {user_id}"
)
except FileNotFoundError as e:
# 文件不存在时记录警告,但不记录完整堆栈
logger.warning(
f"后台知识点分析失败(文件不存在) - course_id: {course_id}, material_id: {material_id}, "
f"file_url: {file_url}, error: {str(e)}, user_id: {user_id}"
)
except Exception as e:
# 其他异常记录详细信息
logger.error(
f"后台知识点分析失败 - course_id: {course_id}, material_id: {material_id}, error: {str(e)}",
exc_info=True
)

View File

@@ -0,0 +1,275 @@
"""
Coze 网关 API 路由
提供课程对话和陪练功能的统一接口
"""
import logging
from typing import Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse
from app.services.ai.coze import (
get_coze_service,
CreateSessionRequest,
SendMessageRequest,
EndSessionRequest,
SessionType,
CozeException,
StreamEventType,
)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["coze-gateway"])
# TODO: 依赖注入获取当前用户
async def get_current_user():
"""获取当前登录用户(临时实现)"""
# 实际应该从 Auth 模块获取
return {"user_id": "test-user-123", "username": "test_user"}
@router.post("/course-chat/sessions")
async def create_course_chat_session(course_id: str, user=Depends(get_current_user)):
"""
创建课程对话会话
- **course_id**: 课程ID
"""
try:
service = get_coze_service()
request = CreateSessionRequest(
session_type=SessionType.COURSE_CHAT,
user_id=user["user_id"],
course_id=course_id,
metadata={"username": user["username"], "course_id": course_id},
)
response = await service.create_session(request)
return {"code": 200, "message": "success", "data": response.dict()}
except CozeException as e:
logger.error(f"创建课程对话会话失败: {e}")
raise HTTPException(
status_code=e.status_code or 500,
detail={"code": e.code, "message": e.message, "details": e.details},
)
except Exception as e:
logger.error(f"未知错误: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)
@router.post("/training/sessions")
async def create_training_session(
training_topic: str = None, user=Depends(get_current_user)
):
"""
创建陪练会话
- **training_topic**: 陪练主题(可选)
"""
try:
service = get_coze_service()
request = CreateSessionRequest(
session_type=SessionType.TRAINING,
user_id=user["user_id"],
training_topic=training_topic,
metadata={"username": user["username"], "training_topic": training_topic},
)
response = await service.create_session(request)
return {"code": 200, "message": "success", "data": response.dict()}
except CozeException as e:
logger.error(f"创建陪练会话失败: {e}")
raise HTTPException(
status_code=e.status_code or 500,
detail={"code": e.code, "message": e.message, "details": e.details},
)
except Exception as e:
logger.error(f"未知错误: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)
@router.post("/training/sessions/{session_id}/end")
async def end_training_session(
session_id: str, request: EndSessionRequest, user=Depends(get_current_user)
):
"""
结束陪练会话
- **session_id**: 会话ID
"""
try:
service = get_coze_service()
response = await service.end_session(session_id, request)
return {"code": 200, "message": "success", "data": response.dict()}
except CozeException as e:
logger.error(f"结束会话失败: {e}")
raise HTTPException(
status_code=e.status_code or 500,
detail={"code": e.code, "message": e.message, "details": e.details},
)
except Exception as e:
logger.error(f"未知错误: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)
@router.post("/chat/messages")
async def send_message(request: SendMessageRequest, user=Depends(get_current_user)):
"""
发送消息(支持流式响应)
- **session_id**: 会话ID
- **content**: 消息内容
- **stream**: 是否流式响应默认True
"""
try:
service = get_coze_service()
if request.stream:
# 流式响应
async def event_generator():
async for event in service.send_message(request):
# 转换为 SSE 格式
if event.event == StreamEventType.MESSAGE_DELTA:
yield {
"event": "message",
"data": {
"type": "delta",
"content": event.content,
"content_type": event.content_type.value,
"message_id": event.message_id,
},
}
elif event.event == StreamEventType.MESSAGE_COMPLETED:
yield {
"event": "message",
"data": {
"type": "completed",
"content": event.content,
"content_type": event.content_type.value,
"message_id": event.message_id,
"usage": event.data.get("usage", {}),
},
}
elif event.event == StreamEventType.ERROR:
yield {"event": "error", "data": {"error": event.error}}
elif event.event == StreamEventType.DONE:
yield {
"event": "done",
"data": {"session_id": event.data.get("session_id")},
}
return EventSourceResponse(event_generator())
else:
# 非流式响应(收集完整响应)
full_content = ""
content_type = None
message_id = None
async for event in service.send_message(request):
if event.event == StreamEventType.MESSAGE_COMPLETED:
full_content = event.content
content_type = event.content_type
message_id = event.message_id
break
return {
"code": 200,
"message": "success",
"data": {
"message_id": message_id,
"content": full_content,
"content_type": content_type.value if content_type else "text",
"role": "assistant",
},
}
except CozeException as e:
logger.error(f"发送消息失败: {e}")
if request.stream:
# 流式响应的错误处理
async def error_generator():
yield {
"event": "error",
"data": {
"code": e.code,
"message": e.message,
"details": e.details,
},
}
return EventSourceResponse(error_generator())
else:
raise HTTPException(
status_code=e.status_code or 500,
detail={"code": e.code, "message": e.message, "details": e.details},
)
except Exception as e:
logger.error(f"未知错误: {e}", exc_info=True)
if request.stream:
async def error_generator():
yield {
"event": "error",
"data": {"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
}
return EventSourceResponse(error_generator())
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)
@router.get("/sessions/{session_id}/messages")
async def get_session_messages(
session_id: str, limit: int = 50, offset: int = 0, user=Depends(get_current_user)
):
"""
获取会话消息历史
- **session_id**: 会话ID
- **limit**: 返回消息数量限制
- **offset**: 偏移量
"""
try:
service = get_coze_service()
messages = await service.get_session_messages(session_id, limit, offset)
return {
"code": 200,
"message": "success",
"data": {
"messages": [msg.dict() for msg in messages],
"total": len(messages),
"limit": limit,
"offset": offset,
},
}
except Exception as e:
logger.error(f"获取消息历史失败: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"code": "INTERNAL_ERROR", "message": "服务器内部错误"},
)

View File

@@ -0,0 +1,236 @@
"""
员工同步API接口
提供从钉钉员工表同步员工数据的功能
"""
from typing import Any, Dict
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.logger import get_logger
from app.core.deps import get_current_user, get_db
from app.services.employee_sync_service import EmployeeSyncService
from app.models.user import User
logger = get_logger(__name__)
router = APIRouter()
@router.post("/sync", summary="执行员工同步")
async def sync_employees(
*,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Dict[str, Any]:
"""
从钉钉员工表同步在职员工数据到考培练系统
权限要求: 仅管理员可执行
同步内容:
- 创建用户账号(用户名=手机号,初始密码=123456
- 创建部门团队
- 创建岗位并关联用户
- 设置领导为团队负责人
Returns:
同步结果统计
"""
# 权限检查:仅管理员可执行
if current_user.role != 'admin':
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="只有管理员可以执行员工同步"
)
logger.info(f"管理员 {current_user.username} 开始执行员工同步")
try:
async with EmployeeSyncService(db) as sync_service:
stats = await sync_service.sync_employees()
return {
"success": True,
"message": "员工同步完成",
"data": stats
}
except Exception as e:
logger.error(f"员工同步失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"员工同步失败: {str(e)}"
)
@router.get("/preview", summary="预览待同步员工数据")
async def preview_sync_data(
*,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Dict[str, Any]:
"""
预览待同步的员工数据(不执行实际同步)
权限要求: 仅管理员可查看
Returns:
预览数据,包括员工列表、部门列表、岗位列表等
"""
# 权限检查:仅管理员可查看
if current_user.role != 'admin':
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="只有管理员可以预览员工数据"
)
logger.info(f"管理员 {current_user.username} 预览员工同步数据")
try:
async with EmployeeSyncService(db) as sync_service:
preview_data = await sync_service.preview_sync_data()
return {
"success": True,
"message": "预览数据获取成功",
"data": preview_data
}
except Exception as e:
logger.error(f"预览数据获取失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"预览数据获取失败: {str(e)}"
)
@router.post("/incremental-sync", summary="增量同步员工")
async def incremental_sync_employees(
*,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Dict[str, Any]:
"""
增量同步钉钉员工数据
功能说明:
- 新增:钉钉有但系统没有的员工
- 删除:系统有但钉钉没有的员工(物理删除)
- 跳过:两边都存在的员工(不修改任何信息)
权限要求: 管理员admin 或 manager可执行
Returns:
同步结果统计
"""
# 权限检查:管理员或经理可执行
if current_user.role not in ['admin', 'manager']:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="只有管理员可以执行员工同步"
)
logger.info(f"用户 {current_user.username} ({current_user.role}) 开始执行增量员工同步")
try:
async with EmployeeSyncService(db) as sync_service:
stats = await sync_service.incremental_sync_employees()
return {
"success": True,
"message": "增量同步完成",
"data": {
"added_count": stats['added_count'],
"deleted_count": stats['deleted_count'],
"skipped_count": stats['skipped_count'],
"added_users": stats['added_users'],
"deleted_users": stats['deleted_users'],
"errors": stats['errors'],
"duration": stats['duration']
}
}
except Exception as e:
logger.error(f"增量同步失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"增量同步失败: {str(e)}"
)
@router.get("/status", summary="查询同步状态")
async def get_sync_status(
*,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> Dict[str, Any]:
"""
查询当前系统的用户、团队、岗位统计信息
Returns:
统计信息
"""
from sqlalchemy import select, func
from app.models.user import User, Team
from app.models.position import Position
try:
# 统计用户数量
user_count_stmt = select(func.count(User.id)).where(User.is_deleted == False)
user_result = await db.execute(user_count_stmt)
total_users = user_result.scalar()
# 统计各角色用户数量
admin_count_stmt = select(func.count(User.id)).where(
User.is_deleted == False,
User.role == 'admin'
)
admin_result = await db.execute(admin_count_stmt)
admin_count = admin_result.scalar()
manager_count_stmt = select(func.count(User.id)).where(
User.is_deleted == False,
User.role == 'manager'
)
manager_result = await db.execute(manager_count_stmt)
manager_count = manager_result.scalar()
trainee_count_stmt = select(func.count(User.id)).where(
User.is_deleted == False,
User.role == 'trainee'
)
trainee_result = await db.execute(trainee_count_stmt)
trainee_count = trainee_result.scalar()
# 统计团队数量
team_count_stmt = select(func.count(Team.id)).where(Team.is_deleted == False)
team_result = await db.execute(team_count_stmt)
total_teams = team_result.scalar()
# 统计岗位数量
position_count_stmt = select(func.count(Position.id)).where(Position.is_deleted == False)
position_result = await db.execute(position_count_stmt)
total_positions = position_result.scalar()
return {
"success": True,
"data": {
"users": {
"total": total_users,
"admin": admin_count,
"manager": manager_count,
"trainee": trainee_count
},
"teams": total_teams,
"positions": total_positions
}
}
except Exception as e:
logger.error(f"查询统计信息失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询统计信息失败: {str(e)}"
)

761
backend/app/api/v1/exam.py Normal file
View File

@@ -0,0 +1,761 @@
"""
考试相关API路由
"""
from typing import List, Optional
import json
from datetime import datetime
from fastapi import APIRouter, Depends, Query, HTTPException, status, Request
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.deps import get_db, get_current_user
from app.core.config import get_settings
from app.core.logger import get_logger
from app.models.user import User
from app.models.exam import Exam
from app.models.exam_mistake import ExamMistake
from app.models.position_member import PositionMember
from app.models.position_course import PositionCourse
from app.schemas.base import ResponseModel
from app.schemas.exam import (
StartExamRequest,
StartExamResponse,
SubmitExamRequest,
SubmitExamResponse,
ExamDetailResponse,
ExamRecordResponse,
GenerateExamRequest,
GenerateExamResponse,
JudgeAnswerRequest,
JudgeAnswerResponse,
RecordMistakeRequest,
RecordMistakeResponse,
GetMistakesResponse,
MistakeRecordItem,
# 新增的Schema
ExamReportResponse,
MistakeListResponse,
MistakesStatisticsResponse,
UpdateRoundScoreRequest,
)
from app.services.exam_report_service import ExamReportService, MistakeService
from app.services.course_statistics_service import course_statistics_service
from app.services.system_log_service import system_log_service
from app.schemas.system_log import SystemLogCreate
# V2 原生服务Python 实现
from app.services.ai import exam_generator_service, ExamGeneratorConfig
from app.services.ai.answer_judge_service import answer_judge_service
from app.core.exceptions import ExternalServiceError
logger = get_logger(__name__)
settings = get_settings()
router = APIRouter(prefix="/exams", tags=["考试"])
@router.post("/start", response_model=ResponseModel[StartExamResponse])
async def start_exam(
request: StartExamRequest,
http_request: Request,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""开始考试"""
exam = await ExamService.start_exam(
db=db,
user_id=current_user.id,
course_id=request.course_id,
question_count=request.count,
)
# 异步更新课程学员数统计
try:
await course_statistics_service.update_course_student_count(db, request.course_id)
except Exception as e:
logger.warning(f"更新课程学员数失败: {str(e)}")
# 不影响主流程,只记录警告
# 记录考试开始日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="api",
message=f"用户 {current_user.username} 开始考试课程ID: {request.course_id}",
user_id=current_user.id,
user=current_user.username,
ip=http_request.client.host if http_request.client else None,
path="/api/v1/exams/start",
method="POST",
user_agent=http_request.headers.get("user-agent")
)
)
return ResponseModel(code=200, data=StartExamResponse(exam_id=exam.id), message="考试开始")
@router.post("/submit", response_model=ResponseModel[SubmitExamResponse])
async def submit_exam(
request: SubmitExamRequest,
http_request: Request,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""提交考试答案"""
result = await ExamService.submit_exam(
db=db, user_id=current_user.id, exam_id=request.exam_id, answers=request.answers
)
# 获取考试记录以获取course_id
exam_stmt = select(Exam).where(Exam.id == request.exam_id)
exam_result = await db.execute(exam_stmt)
exam = exam_result.scalar_one_or_none()
# 异步更新课程学员数统计
if exam and exam.course_id:
try:
await course_statistics_service.update_course_student_count(db, exam.course_id)
except Exception as e:
logger.warning(f"更新课程学员数失败: {str(e)}")
# 不影响主流程,只记录警告
# 记录考试提交日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="api",
message=f"用户 {current_user.username} 提交考试考试ID: {request.exam_id},得分: {result.get('score', 0)}",
user_id=current_user.id,
user=current_user.username,
ip=http_request.client.host if http_request.client else None,
path="/api/v1/exams/submit",
method="POST",
user_agent=http_request.headers.get("user-agent")
)
)
return ResponseModel(code=200, data=SubmitExamResponse(**result), message="考试提交成功")
@router.get("/mistakes", response_model=ResponseModel[GetMistakesResponse])
async def get_mistakes(
exam_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取错题记录
用于第二、三轮考试时获取上一轮的错题记录
返回的数据可直接序列化为JSON字符串作为mistake_records参数传给考试生成接口
"""
logger.info(f"📋 GET /mistakes 收到请求")
try:
logger.info(f"📋 获取错题记录 - exam_id: {exam_id}, user_id: {current_user.id}")
# 查询指定考试的错题记录
result = await db.execute(
select(ExamMistake).where(
ExamMistake.exam_id == exam_id,
ExamMistake.user_id == current_user.id,
).order_by(ExamMistake.id)
)
mistakes = result.scalars().all()
logger.info(f"✅ 查询到错题记录数量: {len(mistakes)}")
# 转换为响应格式
mistake_items = [
MistakeRecordItem(
id=m.id,
question_id=m.question_id,
knowledge_point_id=m.knowledge_point_id,
question_content=m.question_content,
correct_answer=m.correct_answer,
user_answer=m.user_answer,
created_at=m.created_at,
)
for m in mistakes
]
logger.info(
f"获取错题记录成功 - user_id: {current_user.id}, exam_id: {exam_id}, "
f"count: {len(mistake_items)}"
)
# 返回统一的ResponseModel格式让Pydantic自动处理序列化
return ResponseModel(
code=200,
message="获取错题记录成功",
data=GetMistakesResponse(
mistakes=mistake_items
)
)
except Exception as e:
logger.error(f"获取错题记录失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取错题记录失败: {str(e)}"
)
@router.get("/{exam_id}", response_model=ResponseModel[ExamDetailResponse])
async def get_exam_detail(
exam_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取考试详情"""
exam_data = await ExamService.get_exam_detail(
db=db, user_id=current_user.id, exam_id=exam_id
)
return ResponseModel(code=200, data=ExamDetailResponse(**exam_data), message="获取成功")
@router.get("/records", response_model=ResponseModel[dict])
async def get_exam_records(
page: int = Query(1, ge=1),
size: int = Query(10, ge=1, le=100),
course_id: Optional[int] = Query(None),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取考试记录列表"""
records = await ExamService.get_exam_records(
db=db, user_id=current_user.id, page=page, size=size, course_id=course_id
)
return ResponseModel(code=200, data=records, message="获取成功")
@router.get("/statistics/summary", response_model=ResponseModel[dict])
async def get_exam_statistics(
course_id: Optional[int] = Query(None),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取考试统计信息"""
stats = await ExamService.get_exam_statistics(
db=db, user_id=current_user.id, course_id=course_id
)
return ResponseModel(code=200, data=stats, message="获取成功")
# ==================== 试题生成接口 ====================
@router.post("/generate", response_model=ResponseModel[GenerateExamResponse])
async def generate_exam(
request: GenerateExamRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
生成考试试题
使用 Python 原生 AI 服务实现。
考试轮次说明:
- 第一轮考试mistake_records 传空或不传
- 第二、三轮错题重考mistake_records 传入上一轮错题记录的JSON字符串
"""
try:
# 从用户信息中自动获取岗位ID如果未提供
position_id = request.position_id
if not position_id:
# 1. 首先查询用户已分配的岗位
result = await db.execute(
select(PositionMember).where(
PositionMember.user_id == current_user.id,
PositionMember.is_deleted == False
).limit(1)
)
position_member = result.scalar_one_or_none()
if position_member:
position_id = position_member.position_id
else:
# 2. 如果用户没有岗位,从课程关联的岗位中获取第一个
result = await db.execute(
select(PositionCourse.position_id).where(
PositionCourse.course_id == request.course_id,
PositionCourse.is_deleted == False
).limit(1)
)
course_position = result.scalar_one_or_none()
if course_position:
position_id = course_position
logger.info(f"用户 {current_user.id} 没有分配岗位使用课程关联的岗位ID: {position_id}")
else:
# 3. 如果课程也没有关联岗位,抛出错误
logger.warning(f"用户 {current_user.id} 没有分配岗位,且课程 {request.course_id} 未关联任何岗位")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无法生成试题:用户未分配岗位,且课程未关联任何岗位"
)
# 记录详细的题型设置(用于调试)
logger.info(
f"考试题型设置 - 单选:{request.single_choice_count}, 多选:{request.multiple_choice_count}, "
f"判断:{request.true_false_count}, 填空:{request.fill_blank_count}, 问答:{request.essay_count}, "
f"难度:{request.difficulty_level}"
)
# 调用 Python 原生试题生成服务
logger.info(
f"调用原生试题生成服务 - user_id: {current_user.id}, "
f"course_id: {request.course_id}, position_id: {position_id}"
)
# 构建配置
config = ExamGeneratorConfig(
course_id=request.course_id,
position_id=position_id,
single_choice_count=request.single_choice_count or 0,
multiple_choice_count=request.multiple_choice_count or 0,
true_false_count=request.true_false_count or 0,
fill_blank_count=request.fill_blank_count or 0,
essay_count=request.essay_count or 0,
difficulty_level=request.difficulty_level or 3,
mistake_records=request.mistake_records or "",
)
# 调用原生服务
gen_result = await exam_generator_service.generate_exam(db, config)
if not gen_result.get("success"):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="试题生成服务返回失败"
)
# 将题目列表转为 JSON 字符串(兼容原有前端格式)
result_data = json.dumps(gen_result.get("questions", []), ensure_ascii=False)
logger.info(
f"试题生成完成 - questions: {gen_result.get('total_count')}, "
f"provider: {gen_result.get('ai_provider')}, latency: {gen_result.get('ai_latency_ms')}ms"
)
if result_data is None or result_data == "":
logger.error(f"试题生成未返回有效结果数据")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="试题生成失败: 未返回结果数据"
)
# 创建或复用考试记录
question_count = sum([
request.single_choice_count or 0,
request.multiple_choice_count or 0,
request.true_false_count or 0,
request.fill_blank_count or 0,
request.essay_count or 0
])
# 第一轮创建新的exam记录
if request.current_round == 1:
exam = Exam(
user_id=current_user.id,
course_id=request.course_id,
exam_name=f"课程{request.course_id}考试",
question_count=question_count,
total_score=100.0,
pass_score=60.0,
duration_minutes=60,
status="started",
start_time=datetime.now(),
questions=None,
answers=None,
)
db.add(exam)
await db.commit()
await db.refresh(exam)
logger.info(f"{request.current_round}轮:创建考试记录成功 - exam_id: {exam.id}")
else:
# 第二、三轮复用已有exam记录
if not request.exam_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{request.current_round}轮考试必须提供exam_id"
)
exam = await db.get(Exam, request.exam_id)
if not exam:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="考试记录不存在"
)
if exam.user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权访问此考试记录"
)
logger.info(f"{request.current_round}轮:复用考试记录 - exam_id: {exam.id}")
return ResponseModel(
code=200,
message="试题生成成功",
data=GenerateExamResponse(
result=result_data,
workflow_run_id=f"{gen_result.get('ai_provider')}_{gen_result.get('ai_latency_ms')}ms",
task_id=f"native_{request.course_id}",
exam_id=exam.id,
)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"生成考试试题失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"试题生成失败: {str(e)}"
)
@router.post("/judge-answer", response_model=ResponseModel[JudgeAnswerResponse])
async def judge_answer(
request: JudgeAnswerRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
判断主观题答案
适用于填空题和问答题的答案判断。
使用 Python 原生 AI 服务实现。
"""
try:
logger.info(
f"调用原生答案判断服务 - user_id: {current_user.id}, "
f"question: {request.question[:50]}..."
)
result = await answer_judge_service.judge(
question=request.question,
correct_answer=request.correct_answer,
user_answer=request.user_answer,
analysis=request.analysis,
db=db # 传入 db_session 用于记录调用日志
)
logger.info(
f"答案判断完成 - is_correct: {result.is_correct}, "
f"provider: {result.ai_provider}, latency: {result.ai_latency_ms}ms"
)
return ResponseModel(
code=200,
message="答案判断完成",
data=JudgeAnswerResponse(
is_correct=result.is_correct,
correct_answer=request.correct_answer,
feedback=result.raw_response if not result.is_correct else None,
)
)
except Exception as e:
logger.error(f"答案判断失败: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"答案判断失败: {str(e)}"
)
@router.post("/record-mistake", response_model=ResponseModel[RecordMistakeResponse])
async def record_mistake(
request: RecordMistakeRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
记录错题
当用户答错题目时,立即调用此接口记录到错题表
"""
try:
# 创建错题记录
# 注意knowledge_point_id暂时设置为None避免外键约束失败
mistake = ExamMistake(
user_id=current_user.id,
exam_id=request.exam_id,
question_id=request.question_id,
knowledge_point_id=None, # 暂时设为None避免外键约束
question_content=request.question_content,
correct_answer=request.correct_answer,
user_answer=request.user_answer,
question_type=request.question_type, # 新增:记录题型
)
if request.knowledge_point_id:
logger.info(f"原始knowledge_point_id={request.knowledge_point_id}已设置为NULL待同步生产数据")
db.add(mistake)
await db.commit()
await db.refresh(mistake)
logger.info(
f"记录错题成功 - user_id: {current_user.id}, exam_id: {request.exam_id}, "
f"mistake_id: {mistake.id}"
)
return ResponseModel(
code=200,
message="错题记录成功",
data=RecordMistakeResponse(
id=mistake.id,
created_at=mistake.created_at,
)
)
except Exception as e:
await db.rollback()
logger.error(f"记录错题失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"记录错题失败: {str(e)}"
)
@router.get("/mistakes-debug")
async def get_mistakes_debug(
exam_id: int,
):
"""调试endpoint - 不需要认证"""
logger.info(f"🔍 调试 - exam_id: {exam_id}, type: {type(exam_id)}")
return {"exam_id": exam_id, "type": str(type(exam_id))}
# ==================== 成绩报告和错题本相关接口 ====================
@router.get("/statistics/report", response_model=ResponseModel[ExamReportResponse])
async def get_exam_report(
start_date: Optional[str] = Query(None, description="开始日期(YYYY-MM-DD)"),
end_date: Optional[str] = Query(None, description="结束日期(YYYY-MM-DD)"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取成绩报告汇总
返回包含概览、趋势、科目分析、最近考试记录的完整报告
"""
try:
report_data = await ExamReportService.get_exam_report(
db=db,
user_id=current_user.id,
start_date=start_date,
end_date=end_date
)
return ResponseModel(code=200, data=report_data, message="获取成绩报告成功")
except Exception as e:
logger.error(f"获取成绩报告失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取成绩报告失败: {str(e)}"
)
@router.get("/mistakes/list", response_model=ResponseModel[MistakeListResponse])
async def get_mistakes_list(
exam_id: Optional[int] = Query(None, description="考试ID"),
course_id: Optional[int] = Query(None, description="课程ID"),
question_type: Optional[str] = Query(None, description="题型(single/multiple/judge/blank/essay)"),
search: Optional[str] = Query(None, description="关键词搜索"),
start_date: Optional[str] = Query(None, description="开始日期(YYYY-MM-DD)"),
end_date: Optional[str] = Query(None, description="结束日期(YYYY-MM-DD)"),
page: int = Query(1, ge=1, description="页码"),
size: int = Query(10, ge=1, le=100, description="每页数量"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取错题列表(支持多维度筛选)
- 不传exam_id时返回用户所有错题
- 支持按course_id、question_type、关键词、时间范围筛选
"""
try:
mistakes_data = await MistakeService.get_mistakes_list(
db=db,
user_id=current_user.id,
exam_id=exam_id,
course_id=course_id,
question_type=question_type,
search=search,
start_date=start_date,
end_date=end_date,
page=page,
size=size
)
return ResponseModel(code=200, data=mistakes_data, message="获取错题列表成功")
except Exception as e:
logger.error(f"获取错题列表失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取错题列表失败: {str(e)}"
)
@router.get("/mistakes/statistics", response_model=ResponseModel[MistakesStatisticsResponse])
async def get_mistakes_statistics(
course_id: Optional[int] = Query(None, description="课程ID"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取错题统计数据
返回按课程、题型、时间维度的统计数据
"""
try:
stats_data = await MistakeService.get_mistakes_statistics(
db=db,
user_id=current_user.id,
course_id=course_id
)
return ResponseModel(code=200, data=stats_data, message="获取错题统计成功")
except Exception as e:
logger.error(f"获取错题统计失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取错题统计失败: {str(e)}"
)
@router.put("/{exam_id}/round-score", response_model=ResponseModel[dict])
async def update_round_score(
exam_id: int,
request: UpdateRoundScoreRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
更新某轮的得分
用于前端每轮考试结束后更新对应轮次的得分
"""
try:
# 查询考试记录
exam = await db.get(Exam, exam_id)
if not exam:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="考试记录不存在"
)
# 验证权限
if exam.user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权修改此考试记录"
)
# 更新对应轮次的得分
if request.round == 1:
exam.round1_score = request.score
elif request.round == 2:
exam.round2_score = request.score
elif request.round == 3:
exam.round3_score = request.score
# 第三轮默认就是 final
request.is_final = True
# 如果是最终轮次可能是第1/2轮就全对了更新总分和状态
if request.is_final:
exam.score = request.score
exam.status = "submitted"
# 计算是否通过 (pass_score 为空默认 60)
exam.is_passed = request.score >= (exam.pass_score or 60)
# 更新结束时间
from datetime import datetime
exam.end_time = datetime.now()
await db.commit()
logger.info(f"更新轮次得分成功 - exam_id: {exam_id}, round: {request.round}, score: {request.score}")
return ResponseModel(code=200, data={"exam_id": exam_id}, message="更新得分成功")
except HTTPException:
raise
except Exception as e:
await db.rollback()
logger.error(f"更新轮次得分失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新轮次得分失败: {str(e)}"
)
@router.put("/mistakes/{mistake_id}/mastered", response_model=ResponseModel)
async def mark_mistake_mastered(
mistake_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
标记错题为已掌握
Args:
mistake_id: 错题记录ID
db: 数据库会话
current_user: 当前用户
Returns:
ResponseModel: 标记结果
"""
try:
# 查询错题记录
stmt = select(ExamMistake).where(
ExamMistake.id == mistake_id,
ExamMistake.user_id == current_user.id
)
result = await db.execute(stmt)
mistake = result.scalar_one_or_none()
if not mistake:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="错题记录不存在或无权访问"
)
# 更新掌握状态
from datetime import datetime as dt
mistake.mastery_status = 'mastered'
mistake.mastered_at = dt.utcnow()
await db.commit()
logger.info(f"标记错题已掌握成功 - mistake_id: {mistake_id}, user_id: {current_user.id}")
return ResponseModel(
code=200,
message="已标记为掌握",
data={"mistake_id": mistake_id, "mastery_status": "mastered"}
)
except HTTPException:
raise
except Exception as e:
await db.rollback()
logger.error(f"标记错题已掌握失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"标记失败: {str(e)}"
)

View File

@@ -0,0 +1,201 @@
"""
知识点分析 API
使用 Python 原生 AI 服务实现
"""
import logging
from typing import Dict, Any
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, get_current_user
from app.schemas.base import ResponseModel
from app.models.user import User
from app.services.course_service import course_service
from app.services.ai.knowledge_analysis_v2 import knowledge_analysis_service_v2
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/courses/{course_id}/materials/{material_id}/analyze", response_model=ResponseModel[Dict[str, Any]])
async def analyze_material_knowledge_points(
course_id: int,
material_id: int,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
分析单个资料的知识点
- **course_id**: 课程ID
- **material_id**: 资料ID
使用 Python 原生 AI 服务:
- 本地 AI 服务调用4sapi.com 首选OpenRouter 备选)
- 多层 JSON 解析兜底
- 无外部平台依赖,更稳定
"""
try:
# 验证课程是否存在
course = await course_service.get_by_id(db, course_id)
if not course:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"课程 {course_id} 不存在"
)
# 获取资料信息
materials = await course_service.get_course_materials(db, course_id=course_id)
material = next((m for m in materials if m.id == material_id), None)
if not material:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"资料 {material_id} 不存在"
)
logger.info(
f"准备启动知识点分析 - course_id: {course_id}, material_id: {material_id}, "
f"file_url: {material.file_url}, user_id: {current_user.id}"
)
# 调用 Python 原生知识点分析服务
result = await knowledge_analysis_service_v2.analyze_course_material(
db=db,
course_id=course_id,
material_id=material_id,
file_url=material.file_url,
course_title=course.name,
user_id=current_user.id
)
logger.info(
f"知识点分析完成 - course_id: {course_id}, material_id: {material_id}, "
f"knowledge_points: {result.get('knowledge_points_count', 0)}, "
f"provider: {result.get('ai_provider')}"
)
# 构建响应
response_data = {
"message": "知识点分析完成",
"course_id": course_id,
"material_id": material_id,
"status": result.get("status", "completed"),
"knowledge_points_count": result.get("knowledge_points_count", 0),
"ai_provider": result.get("ai_provider"),
"ai_model": result.get("ai_model"),
"ai_tokens": result.get("ai_tokens"),
"ai_latency_ms": result.get("ai_latency_ms"),
}
return ResponseModel(
data=response_data,
message="知识点分析完成"
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"知识点分析失败 - course_id: {course_id}, material_id: {material_id}, error: {e}",
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"知识点分析失败: {str(e)}"
)
@router.post("/courses/{course_id}/reanalyze", response_model=ResponseModel[Dict[str, Any]])
async def reanalyze_course_materials(
course_id: int,
background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
重新分析课程的所有资料
- **course_id**: 课程ID
该接口会重新分析课程下的所有资料,提取知识点
"""
try:
# 验证课程是否存在
course = await course_service.get_by_id(db, course_id)
if not course:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"课程 {course_id} 不存在"
)
# 获取课程资料信息
materials = await course_service.get_course_materials(db, course_id=course_id)
if not materials:
return ResponseModel(
data={
"message": "该课程暂无资料需要分析",
"course_id": course_id,
"status": "stopped",
"materials_count": 0
},
message="无资料需要分析"
)
# 调用 Python 原生知识点分析服务
result = await knowledge_analysis_service_v2.reanalyze_course_materials(
db=db,
course_id=course_id,
course_title=course.name,
user_id=current_user.id
)
return ResponseModel(
data={
"message": "课程资料重新分析完成",
"course_id": course_id,
"status": "completed",
"materials_count": result.get("materials_count", 0),
"success_count": result.get("success_count", 0),
"knowledge_points_count": result.get("knowledge_points_count", 0),
"analysis_results": result.get("analysis_results", [])
},
message="重新分析完成"
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"启动课程资料重新分析失败 - course_id: {course_id}, error: {e}",
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="启动重新分析失败"
)
@router.get("/engines", response_model=ResponseModel[Dict[str, Any]])
async def list_analysis_engines():
"""
获取可用的分析引擎列表
"""
return ResponseModel(
data={
"engines": [
{
"id": "native",
"name": "Python 原生实现",
"description": "使用本地 AI 服务4sapi.com + OpenRouter稳定可靠",
"default": True
}
],
"default_engine": "native"
},
message="获取分析引擎列表成功"
)

View File

@@ -0,0 +1,8 @@
"""
管理员相关API模块
"""
from .student_scores import router as student_scores_router
from .student_practice import router as student_practice_router
__all__ = ["student_scores_router", "student_practice_router"]

View File

@@ -0,0 +1,345 @@
"""
管理员查看学员陪练记录API
"""
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy import and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_current_user, get_db
from app.core.logger import logger
from app.models.position import Position
from app.models.position_member import PositionMember
from app.models.practice import PracticeReport, PracticeSession, PracticeDialogue
from app.models.user import User
from app.schemas.base import PaginatedResponse, ResponseModel
router = APIRouter(prefix="/manager/student-practice", tags=["manager-student-practice"])
@router.get("/", response_model=ResponseModel[PaginatedResponse])
async def get_student_practice_records(
page: int = Query(1, ge=1, description="页码"),
size: int = Query(20, ge=1, le=100, description="每页数量"),
student_name: Optional[str] = Query(None, description="学员姓名搜索"),
position: Optional[str] = Query(None, description="岗位筛选"),
scene_type: Optional[str] = Query(None, description="场景类型筛选"),
result: Optional[str] = Query(None, description="结果筛选: excellent/good/average/needs_improvement"),
start_date: Optional[str] = Query(None, description="开始日期 YYYY-MM-DD"),
end_date: Optional[str] = Query(None, description="结束日期 YYYY-MM-DD"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取所有用户的陪练记录列表管理员和manager可访问
包含所有角色trainee/admin/manager的陪练记录方便测试和全面管理
支持筛选:
- student_name: 按用户姓名模糊搜索
- position: 按岗位筛选
- scene_type: 按场景类型筛选
- result: 按结果筛选(优秀/良好/一般/需改进)
- start_date/end_date: 按日期范围筛选
"""
try:
# 权限检查
if current_user.role not in ['admin', 'manager']:
return ResponseModel(code=403, message="无权访问", data=None)
# 构建基础查询
# 关联User、PracticeReport来获取完整信息
query = (
select(
PracticeSession,
User.full_name.label('student_name'),
User.id.label('student_id'),
PracticeReport.total_score
)
.join(User, PracticeSession.user_id == User.id)
.outerjoin(
PracticeReport,
PracticeSession.session_id == PracticeReport.session_id
)
.where(
# 管理员可以查看所有人的陪练记录(包括其他管理员的),方便测试和全面管理
PracticeSession.status == 'completed', # 只查询已完成的陪练
PracticeSession.is_deleted == False
)
)
# 学员姓名筛选
if student_name:
query = query.where(User.full_name.contains(student_name))
# 岗位筛选
if position:
# 通过position_members关联查询
query = query.join(
PositionMember,
and_(
PositionMember.user_id == User.id,
PositionMember.is_deleted == False
)
).join(
Position,
Position.id == PositionMember.position_id
).where(
Position.name == position
)
# 场景类型筛选
if scene_type:
query = query.where(PracticeSession.scene_type == scene_type)
# 结果筛选(根据分数)
if result:
if result == 'excellent':
query = query.where(PracticeReport.total_score >= 90)
elif result == 'good':
query = query.where(and_(
PracticeReport.total_score >= 80,
PracticeReport.total_score < 90
))
elif result == 'average':
query = query.where(and_(
PracticeReport.total_score >= 70,
PracticeReport.total_score < 80
))
elif result == 'needs_improvement':
query = query.where(PracticeReport.total_score < 70)
# 日期范围筛选
if start_date:
try:
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
query = query.where(PracticeSession.start_time >= start_dt)
except ValueError:
pass
if end_date:
try:
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
end_dt = end_dt.replace(hour=23, minute=59, second=59)
query = query.where(PracticeSession.start_time <= end_dt)
except ValueError:
pass
# 按开始时间倒序
query = query.order_by(PracticeSession.start_time.desc())
# 计算总数
count_query = select(func.count()).select_from(query.subquery())
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# 分页查询
offset = (page - 1) * size
results = await db.execute(query.offset(offset).limit(size))
# 构建响应数据
items = []
for session, student_name, student_id, total_score in results:
# 查询该学员的所有岗位
position_query = (
select(Position.name)
.join(PositionMember, Position.id == PositionMember.position_id)
.where(
PositionMember.user_id == student_id,
PositionMember.is_deleted == False,
Position.is_deleted == False
)
)
position_result = await db.execute(position_query)
positions = position_result.scalars().all()
position_str = ', '.join(positions) if positions else None
# 根据分数计算结果等级
result_level = "needs_improvement"
if total_score:
if total_score >= 90:
result_level = "excellent"
elif total_score >= 80:
result_level = "good"
elif total_score >= 70:
result_level = "average"
items.append({
"id": session.id,
"student_id": student_id,
"student_name": student_name,
"position": position_str, # 所有岗位,逗号分隔
"session_id": session.session_id,
"scene_name": session.scene_name,
"scene_type": session.scene_type,
"duration_seconds": session.duration_seconds,
"round_count": session.turns, # turns字段表示对话轮数
"score": total_score,
"result": result_level,
"practice_time": session.start_time.strftime('%Y-%m-%d %H:%M:%S') if session.start_time else None
})
# 计算分页信息
pages = (total + size - 1) // size
return ResponseModel(
code=200,
message="success",
data=PaginatedResponse(
items=items,
total=total,
page=page,
page_size=size,
pages=pages
)
)
except Exception as e:
logger.error(f"获取学员陪练记录失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取学员陪练记录失败: {str(e)}", data=None)
@router.get("/statistics", response_model=ResponseModel)
async def get_student_practice_statistics(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取学员陪练统计数据
返回:
- total_count: 总陪练次数
- avg_score: 平均评分
- total_duration_hours: 总陪练时长(小时)
- excellent_rate: 优秀率
"""
try:
# 权限检查
if current_user.role not in ['admin', 'manager']:
return ResponseModel(code=403, message="无权访问", data=None)
# 查询所有已完成陪练(包括所有角色)
query = (
select(PracticeSession, PracticeReport.total_score)
.join(User, PracticeSession.user_id == User.id)
.outerjoin(
PracticeReport,
PracticeSession.session_id == PracticeReport.session_id
)
.where(
PracticeSession.status == 'completed',
PracticeSession.is_deleted == False
)
)
result = await db.execute(query)
records = result.all()
if not records:
return ResponseModel(
code=200,
message="success",
data={
"total_count": 0,
"avg_score": 0,
"total_duration_hours": 0,
"excellent_rate": 0
}
)
total_count = len(records)
# 计算总时长(秒转小时)
total_duration_seconds = sum(
session.duration_seconds for session, _ in records if session.duration_seconds
)
total_duration_hours = round(total_duration_seconds / 3600, 1)
# 计算平均分
scores = [score for _, score in records if score is not None]
avg_score = round(sum(scores) / len(scores), 1) if scores else 0
# 计算优秀率(>=90分
excellent = sum(1 for _, score in records if score and score >= 90)
excellent_rate = round((excellent / total_count) * 100, 1) if total_count > 0 else 0
return ResponseModel(
code=200,
message="success",
data={
"total_count": total_count,
"avg_score": avg_score,
"total_duration_hours": total_duration_hours,
"excellent_rate": excellent_rate
}
)
except Exception as e:
logger.error(f"获取学员陪练统计失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取学员陪练统计失败: {str(e)}", data=None)
@router.get("/{session_id}/conversation", response_model=ResponseModel)
async def get_session_conversation(
session_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取指定会话的对话记录
返回对话列表按sequence排序
"""
try:
# 权限检查
if current_user.role not in ['admin', 'manager']:
return ResponseModel(code=403, message="无权访问", data=None)
# 1. 查询会话是否存在
session_query = select(PracticeSession).where(
PracticeSession.session_id == session_id,
PracticeSession.is_deleted == False
)
session_result = await db.execute(session_query)
session = session_result.scalar_one_or_none()
if not session:
return ResponseModel(code=404, message="会话不存在", data=None)
# 2. 查询对话记录
dialogue_query = (
select(PracticeDialogue)
.where(PracticeDialogue.session_id == session_id)
.order_by(PracticeDialogue.sequence)
)
dialogue_result = await db.execute(dialogue_query)
dialogues = dialogue_result.scalars().all()
# 3. 构建响应数据
conversation = []
for dialogue in dialogues:
conversation.append({
"role": dialogue.speaker, # "user" 或 "ai"
"content": dialogue.content,
"timestamp": dialogue.timestamp.strftime('%Y-%m-%d %H:%M:%S') if dialogue.timestamp else None,
"sequence": dialogue.sequence
})
logger.info(f"获取会话对话记录: session_id={session_id}, 对话数={len(conversation)}")
return ResponseModel(
code=200,
message="success",
data={
"session_id": session_id,
"conversation": conversation,
"total_count": len(conversation)
}
)
except Exception as e:
logger.error(f"获取会话对话记录失败: {e}, session_id={session_id}", exc_info=True)
return ResponseModel(code=500, message=f"获取对话记录失败: {str(e)}", data=None)

View File

@@ -0,0 +1,447 @@
"""
管理员查看学员考试成绩API
"""
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Query
from pydantic import BaseModel
from sqlalchemy import and_, delete, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.deps import get_current_user, get_db
from app.core.logger import logger
from app.models.course import Course
from app.models.exam import Exam
from app.models.exam_mistake import ExamMistake
from app.models.position_member import PositionMember
from app.models.position import Position
from app.models.user import User
from app.schemas.base import PaginatedResponse, ResponseModel
router = APIRouter(prefix="/manager/student-scores", tags=["manager-student-scores"])
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int]
@router.get("/{exam_id}/mistakes", response_model=ResponseModel[PaginatedResponse])
async def get_exam_mistakes(
exam_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取指定考试的错题记录管理员和manager可访问
"""
try:
# 权限检查
if current_user.role not in ['admin', 'manager']:
return ResponseModel(code=403, message="无权访问", data=None)
# 查询错题记录
query = (
select(ExamMistake)
.options(selectinload(ExamMistake.question))
.where(ExamMistake.exam_id == exam_id)
.order_by(ExamMistake.created_at.desc())
)
result = await db.execute(query)
mistakes = result.scalars().all()
items = []
for mistake in mistakes:
# 获取解析优先从关联题目获取如果是AI生成的题目可能没有关联题目
analysis = ""
if mistake.question and mistake.question.explanation:
analysis = mistake.question.explanation
items.append({
"id": mistake.id,
"question_content": mistake.question_content,
"correct_answer": mistake.correct_answer,
"user_answer": mistake.user_answer,
"question_type": mistake.question_type,
"analysis": analysis,
"created_at": mistake.created_at.strftime('%Y-%m-%d %H:%M:%S') if mistake.created_at else None
})
return ResponseModel(
code=200,
message="success",
data=PaginatedResponse(
items=items,
total=len(items),
page=1,
page_size=len(items),
pages=1
)
)
except Exception as e:
logger.error(f"获取错题记录失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取错题记录失败: {str(e)}", data=None)
@router.get("/", response_model=ResponseModel[PaginatedResponse])
async def get_student_scores(
page: int = Query(1, ge=1, description="页码"),
size: int = Query(20, ge=1, le=100, description="每页数量"),
student_name: Optional[str] = Query(None, description="学员姓名搜索"),
position: Optional[str] = Query(None, description="岗位筛选"),
course_id: Optional[int] = Query(None, description="课程ID筛选"),
score_range: Optional[str] = Query(None, description="成绩范围: excellent/good/pass/fail"),
start_date: Optional[str] = Query(None, description="开始日期 YYYY-MM-DD"),
end_date: Optional[str] = Query(None, description="结束日期 YYYY-MM-DD"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取所有学员的考试成绩列表管理员和manager可访问
支持筛选:
- student_name: 按学员姓名模糊搜索
- position: 按岗位筛选
- course_id: 按课程筛选
- score_range: 按成绩范围筛选excellent>=90, good>=80, pass>=60, fail<60
- start_date/end_date: 按日期范围筛选
"""
try:
# 权限检查
if current_user.role not in ['admin', 'manager']:
return ResponseModel(code=403, message="无权访问", data=None)
# 构建基础查询
# 关联User、Course、ExamMistake来获取完整信息
query = (
select(
Exam,
User.full_name.label('student_name'),
User.id.label('student_id'),
Course.name.label('course_name'),
func.count(ExamMistake.id).label('wrong_count')
)
.join(User, Exam.user_id == User.id)
.join(Course, Exam.course_id == Course.id)
.outerjoin(ExamMistake, and_(
ExamMistake.exam_id == Exam.id,
ExamMistake.user_id == User.id
))
.where(
Exam.status.in_(['completed', 'submitted']) # 只查询已完成的考试
)
.group_by(Exam.id, User.id, User.full_name, Course.id, Course.name)
)
# 学员姓名筛选
if student_name:
query = query.where(User.full_name.contains(student_name))
# 岗位筛选
if position:
# 通过position_members关联查询
query = query.join(
PositionMember,
and_(
PositionMember.user_id == User.id,
PositionMember.is_deleted == False
)
).join(
Position,
Position.id == PositionMember.position_id
).where(
Position.name == position
)
# 课程筛选
if course_id:
query = query.where(Exam.course_id == course_id)
# 成绩范围筛选
if score_range:
score_field = Exam.round1_score # 使用第一轮成绩
if score_range == 'excellent':
query = query.where(score_field >= 90)
elif score_range == 'good':
query = query.where(and_(score_field >= 80, score_field < 90))
elif score_range == 'pass':
query = query.where(and_(score_field >= 60, score_field < 80))
elif score_range == 'fail':
query = query.where(score_field < 60)
# 日期范围筛选
if start_date:
try:
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
query = query.where(Exam.created_at >= start_dt)
except ValueError:
pass
if end_date:
try:
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
end_dt = end_dt.replace(hour=23, minute=59, second=59)
query = query.where(Exam.created_at <= end_dt)
except ValueError:
pass
# 按创建时间倒序
query = query.order_by(Exam.created_at.desc())
# 计算总数
count_query = select(func.count()).select_from(query.subquery())
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# 分页查询
offset = (page - 1) * size
results = await db.execute(query.offset(offset).limit(size))
# 构建响应数据
items = []
for exam, student_name, student_id, course_name, wrong_count in results:
# 查询该学员的所有岗位
position_query = (
select(Position.name)
.join(PositionMember, Position.id == PositionMember.position_id)
.where(
PositionMember.user_id == student_id,
PositionMember.is_deleted == False,
Position.is_deleted == False
)
)
position_result = await db.execute(position_query)
positions = position_result.scalars().all()
position_str = ', '.join(positions) if positions else None
# 计算正确率和用时
accuracy = None
correct_count = None
duration_seconds = None
if exam.question_count and exam.question_count > 0:
correct_count = exam.question_count - wrong_count
accuracy = round((correct_count / exam.question_count) * 100, 1)
if exam.start_time and exam.end_time:
duration_seconds = int((exam.end_time - exam.start_time).total_seconds())
items.append({
"id": exam.id,
"student_id": student_id,
"student_name": student_name,
"position": position_str, # 所有岗位,逗号分隔
"course_id": exam.course_id,
"course_name": course_name,
"exam_type": "assessment", # 简化处理统一为assessment
"score": float(exam.round1_score) if exam.round1_score else 0,
"round1_score": float(exam.round1_score) if exam.round1_score else None,
"round2_score": float(exam.round2_score) if exam.round2_score else None,
"round3_score": float(exam.round3_score) if exam.round3_score else None,
"total_score": float(exam.total_score) if exam.total_score else 100,
"accuracy": accuracy,
"correct_count": correct_count,
"wrong_count": wrong_count,
"total_count": exam.question_count,
"duration_seconds": duration_seconds,
"exam_date": exam.created_at.strftime('%Y-%m-%d %H:%M:%S') if exam.created_at else None
})
# 计算分页信息
pages = (total + size - 1) // size
return ResponseModel(
code=200,
message="success",
data=PaginatedResponse(
items=items,
total=total,
page=page,
page_size=size,
pages=pages
)
)
except Exception as e:
logger.error(f"获取学员考试成绩失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取学员考试成绩失败: {str(e)}", data=None)
@router.get("/statistics", response_model=ResponseModel)
async def get_student_scores_statistics(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取学员考试成绩统计数据
返回:
- total_exams: 总考试次数
- avg_score: 平均分
- pass_rate: 通过率
- excellent_rate: 优秀率
"""
try:
# 权限检查
if current_user.role not in ['admin', 'manager']:
return ResponseModel(code=403, message="无权访问", data=None)
# 查询所有用户的已完成考试
query = (
select(Exam)
.join(User, Exam.user_id == User.id)
.where(
Exam.status.in_(['completed', 'submitted']),
Exam.round1_score.isnot(None)
)
)
result = await db.execute(query)
exams = result.scalars().all()
if not exams:
return ResponseModel(
code=200,
message="success",
data={
"total_exams": 0,
"avg_score": 0,
"pass_rate": 0,
"excellent_rate": 0
}
)
total_exams = len(exams)
total_score = sum(exam.round1_score for exam in exams if exam.round1_score)
avg_score = round(total_score / total_exams, 1) if total_exams > 0 else 0
passed = sum(1 for exam in exams if exam.round1_score and exam.round1_score >= 60)
pass_rate = round((passed / total_exams) * 100, 1) if total_exams > 0 else 0
excellent = sum(1 for exam in exams if exam.round1_score and exam.round1_score >= 90)
excellent_rate = round((excellent / total_exams) * 100, 1) if total_exams > 0 else 0
return ResponseModel(
code=200,
message="success",
data={
"total_exams": total_exams,
"avg_score": avg_score,
"pass_rate": pass_rate,
"excellent_rate": excellent_rate
}
)
except Exception as e:
logger.error(f"获取学员考试成绩统计失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取学员考试成绩统计失败: {str(e)}", data=None)
@router.delete("/{exam_id}", response_model=ResponseModel)
async def delete_exam_record(
exam_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
删除单条考试记录(管理员可访问)
会同时删除关联的错题记录
"""
try:
# 权限检查 - 仅管理员可删除
if current_user.role != 'admin':
return ResponseModel(code=403, message="无权操作,仅管理员可删除考试记录", data=None)
# 查询考试记录
result = await db.execute(
select(Exam).where(Exam.id == exam_id)
)
exam = result.scalar_one_or_none()
if not exam:
return ResponseModel(code=404, message="考试记录不存在", data=None)
# 删除关联的错题记录
await db.execute(
delete(ExamMistake).where(ExamMistake.exam_id == exam_id)
)
# 删除考试记录
await db.delete(exam)
await db.commit()
logger.info(f"管理员 {current_user.username} 删除了考试记录 {exam_id}")
return ResponseModel(
code=200,
message="考试记录已删除",
data={"deleted_id": exam_id}
)
except Exception as e:
await db.rollback()
logger.error(f"删除考试记录失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"删除考试记录失败: {str(e)}", data=None)
@router.delete("/batch/delete", response_model=ResponseModel)
async def batch_delete_exam_records(
request: BatchDeleteRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
批量删除考试记录(管理员可访问)
会同时删除关联的错题记录
"""
try:
# 权限检查 - 仅管理员可删除
if current_user.role != 'admin':
return ResponseModel(code=403, message="无权操作,仅管理员可删除考试记录", data=None)
if not request.ids:
return ResponseModel(code=400, message="请选择要删除的记录", data=None)
# 查询存在的考试记录
result = await db.execute(
select(Exam.id).where(Exam.id.in_(request.ids))
)
existing_ids = [row[0] for row in result.all()]
if not existing_ids:
return ResponseModel(code=404, message="未找到要删除的记录", data=None)
# 删除关联的错题记录
await db.execute(
delete(ExamMistake).where(ExamMistake.exam_id.in_(existing_ids))
)
# 删除考试记录
await db.execute(
delete(Exam).where(Exam.id.in_(existing_ids))
)
await db.commit()
deleted_count = len(existing_ids)
logger.info(f"管理员 {current_user.username} 批量删除了 {deleted_count} 条考试记录")
return ResponseModel(
code=200,
message=f"成功删除 {deleted_count} 条考试记录",
data={
"deleted_count": deleted_count,
"deleted_ids": existing_ids
}
)
except Exception as e:
await db.rollback()
logger.error(f"批量删除考试记录失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"批量删除考试记录失败: {str(e)}", data=None)

View File

@@ -0,0 +1,255 @@
"""
站内消息通知 API
提供通知的查询、标记已读、删除等功能
"""
import logging
from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, get_current_user
from app.models.user import User
from app.schemas.base import ResponseModel
from app.schemas.notification import (
NotificationCreate,
NotificationBatchCreate,
NotificationResponse,
NotificationListResponse,
NotificationCountResponse,
MarkReadRequest,
)
from app.services.notification_service import notification_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/notifications")
@router.get("", response_model=ResponseModel[NotificationListResponse])
async def get_notifications(
is_read: Optional[bool] = Query(None, description="是否已读筛选"),
type: Optional[str] = Query(None, description="通知类型筛选"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取当前用户的通知列表
支持按已读状态和通知类型筛选
"""
try:
skip = (page - 1) * page_size
notifications, total, unread_count = await notification_service.get_user_notifications(
db=db,
user_id=current_user.id,
skip=skip,
limit=page_size,
is_read=is_read,
notification_type=type
)
response_data = NotificationListResponse(
items=notifications,
total=total,
unread_count=unread_count
)
return ResponseModel(
code=200,
message="获取通知列表成功",
data=response_data
)
except Exception as e:
logger.error(f"获取通知列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取通知列表失败: {str(e)}")
@router.get("/unread-count", response_model=ResponseModel[NotificationCountResponse])
async def get_unread_count(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取当前用户的未读通知数量
用于顶部导航栏显示未读消息数
"""
try:
unread_count, total = await notification_service.get_unread_count(
db=db,
user_id=current_user.id
)
return ResponseModel(
code=200,
message="获取未读数量成功",
data=NotificationCountResponse(
unread_count=unread_count,
total=total
)
)
except Exception as e:
logger.error(f"获取未读数量失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取未读数量失败: {str(e)}")
@router.post("/mark-read", response_model=ResponseModel)
async def mark_notifications_read(
request: MarkReadRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
标记通知为已读
- 传入 notification_ids 则标记指定通知
- 不传则标记全部未读通知为已读
"""
try:
updated_count = await notification_service.mark_as_read(
db=db,
user_id=current_user.id,
notification_ids=request.notification_ids
)
return ResponseModel(
code=200,
message=f"成功标记 {updated_count} 条通知为已读",
data={"updated_count": updated_count}
)
except Exception as e:
logger.error(f"标记已读失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"标记已读失败: {str(e)}")
@router.delete("/{notification_id}", response_model=ResponseModel)
async def delete_notification(
notification_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
删除单条通知
只能删除自己的通知
"""
try:
success = await notification_service.delete_notification(
db=db,
user_id=current_user.id,
notification_id=notification_id
)
if not success:
raise HTTPException(status_code=404, detail="通知不存在或无权删除")
return ResponseModel(
code=200,
message="删除通知成功",
data={"deleted": True}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"删除通知失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"删除通知失败: {str(e)}")
# ==================== 管理员接口 ====================
@router.post("/send", response_model=ResponseModel[NotificationResponse])
async def send_notification(
notification_in: NotificationCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
发送单条通知(管理员接口)
向指定用户发送通知
"""
try:
# 权限检查:仅管理员和管理者可发送通知
if current_user.role not in ["admin", "manager"]:
raise HTTPException(status_code=403, detail="无权限发送通知")
# 设置发送者
notification_in.sender_id = current_user.id
notification = await notification_service.create_notification(
db=db,
notification_in=notification_in
)
# 构建响应
response = NotificationResponse(
id=notification.id,
user_id=notification.user_id,
title=notification.title,
content=notification.content,
type=notification.type,
is_read=notification.is_read,
related_id=notification.related_id,
related_type=notification.related_type,
sender_id=notification.sender_id,
sender_name=current_user.full_name,
created_at=notification.created_at,
updated_at=notification.updated_at
)
return ResponseModel(
code=200,
message="发送通知成功",
data=response
)
except HTTPException:
raise
except Exception as e:
logger.error(f"发送通知失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"发送通知失败: {str(e)}")
@router.post("/send-batch", response_model=ResponseModel)
async def send_batch_notifications(
batch_in: NotificationBatchCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
批量发送通知(管理员接口)
向多个用户发送相同的通知
"""
try:
# 权限检查:仅管理员和管理者可发送通知
if current_user.role not in ["admin", "manager"]:
raise HTTPException(status_code=403, detail="无权限发送通知")
# 设置发送者
batch_in.sender_id = current_user.id
notifications = await notification_service.batch_create_notifications(
db=db,
batch_in=batch_in
)
return ResponseModel(
code=200,
message=f"成功发送 {len(notifications)} 条通知",
data={"sent_count": len(notifications)}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"批量发送通知失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"批量发送通知失败: {str(e)}")

View File

@@ -0,0 +1,658 @@
"""
岗位管理 API真实数据库
"""
from typing import Optional, List
from fastapi import APIRouter, Depends, Query, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, func
from sqlalchemy.orm import selectinload
import sqlalchemy as sa
from app.core.deps import get_current_active_user as get_current_user, get_db, require_admin, require_admin_or_manager
from app.schemas.base import ResponseModel, PaginationParams, PaginatedResponse
from app.models.position import Position
from app.models.position_member import PositionMember
from app.models.position_course import PositionCourse
from app.models.user import User
from app.models.course import Course
router = APIRouter(prefix="/admin/positions")
@router.get("")
async def list_positions(
pagination: PaginationParams = Depends(),
keyword: Optional[str] = Query(None, description="关键词"),
current_user=Depends(require_admin_or_manager),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""分页获取岗位列表(管理员或经理)。"""
stmt = select(Position).where(Position.is_deleted == False)
if keyword:
like = f"%{keyword}%"
stmt = stmt.where((Position.name.ilike(like)) | (Position.description.ilike(like)))
rows = (await db.execute(stmt)).scalars().all()
total = len(rows)
sliced = rows[pagination.offset : pagination.offset + pagination.limit]
async def to_dict(p: Position) -> dict:
"""将Position对象转换为字典并添加统计数据"""
d = p.__dict__.copy()
d.pop("_sa_instance_state", None)
# 统计岗位成员数量
member_count_result = await db.execute(
select(func.count(PositionMember.id)).where(
and_(
PositionMember.position_id == p.id,
PositionMember.is_deleted == False
)
)
)
d["memberCount"] = member_count_result.scalar() or 0
# 统计必修课程数量
required_count_result = await db.execute(
select(func.count(PositionCourse.id)).where(
and_(
PositionCourse.position_id == p.id,
PositionCourse.course_type == "required",
PositionCourse.is_deleted == False
)
)
)
d["requiredCourses"] = required_count_result.scalar() or 0
# 统计选修课程数量
optional_count_result = await db.execute(
select(func.count(PositionCourse.id)).where(
and_(
PositionCourse.position_id == p.id,
PositionCourse.course_type == "optional",
PositionCourse.is_deleted == False
)
)
)
d["optionalCourses"] = optional_count_result.scalar() or 0
return d
# 为每个岗位添加统计数据(使用异步)
items = []
for p in sliced:
item = await to_dict(p)
items.append(item)
paged = {
"items": items,
"total": total,
"page": pagination.page,
"page_size": pagination.page_size,
"pages": (total + pagination.page_size - 1) // pagination.page_size if pagination.page_size else 1,
}
return ResponseModel(message="获取岗位列表成功", data=paged)
@router.get("/tree")
async def get_position_tree(
current_user=Depends(require_admin_or_manager), db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""获取岗位树(管理员或经理)。"""
rows = (await db.execute(select(Position).where(Position.is_deleted == False))).scalars().all()
id_to_node = {p.id: {**p.__dict__, "children": []} for p in rows}
roots: List[dict] = []
for p in rows:
node = id_to_node[p.id]
parent_id = p.parent_id
if parent_id and parent_id in id_to_node:
id_to_node[parent_id]["children"].append(node)
else:
roots.append(node)
# 清理 _sa_instance_state
def clean(d: dict):
d.pop("_sa_instance_state", None)
for c in d.get("children", []):
clean(c)
for r in roots:
clean(r)
return ResponseModel(message="获取岗位树成功", data=roots)
@router.post("")
async def create_position(
payload: dict, current_user=Depends(require_admin), db: AsyncSession = Depends(get_db)
) -> ResponseModel:
obj = Position(
name=payload.get("name"),
code=payload.get("code"),
description=payload.get("description"),
parent_id=payload.get("parentId"),
status=payload.get("status", "active"),
skills=payload.get("skills"),
level=payload.get("level"),
sort_order=payload.get("sort_order", 0),
created_by=current_user.id,
)
db.add(obj)
await db.commit()
await db.refresh(obj)
return ResponseModel(message="创建岗位成功", data={"id": obj.id})
@router.put("/{position_id}")
async def update_position(
position_id: int, payload: dict, current_user=Depends(require_admin), db: AsyncSession = Depends(get_db)
) -> ResponseModel:
obj = await db.get(Position, position_id)
if not obj or obj.is_deleted:
return ResponseModel(code=404, message="岗位不存在")
obj.name = payload.get("name", obj.name)
obj.code = payload.get("code", obj.code)
obj.description = payload.get("description", obj.description)
obj.parent_id = payload.get("parentId", obj.parent_id)
obj.status = payload.get("status", obj.status)
obj.skills = payload.get("skills", obj.skills)
obj.level = payload.get("level", obj.level)
obj.sort_order = payload.get("sort_order", obj.sort_order)
obj.updated_by = current_user.id
await db.commit()
await db.refresh(obj)
# 返回更新后的完整数据
data = obj.__dict__.copy()
data.pop("_sa_instance_state", None)
return ResponseModel(message="更新岗位成功", data=data)
@router.get("/{position_id}")
async def get_position_detail(
position_id: int, current_user=Depends(require_admin), db: AsyncSession = Depends(get_db)
) -> ResponseModel:
obj = await db.get(Position, position_id)
if not obj or obj.is_deleted:
return ResponseModel(code=404, message="岗位不存在")
data = obj.__dict__.copy()
data.pop("_sa_instance_state", None)
return ResponseModel(data=data)
@router.get("/{position_id}/check-delete")
async def check_position_delete(
position_id: int, current_user=Depends(require_admin), db: AsyncSession = Depends(get_db)
) -> ResponseModel:
obj = await db.get(Position, position_id)
if not obj or obj.is_deleted:
return ResponseModel(code=404, message="岗位不存在")
# 检查是否有子岗位
child_count_result = await db.execute(
select(func.count(Position.id)).where(
and_(
Position.parent_id == position_id,
Position.is_deleted == False
)
)
)
child_count = child_count_result.scalar() or 0
if child_count > 0:
return ResponseModel(data={
"deletable": False,
"reason": f"该岗位下有 {child_count} 个子岗位,请先删除或移动子岗位"
})
# 检查是否有成员(仅作为提醒,不阻止删除)
member_count_result = await db.execute(
select(func.count(PositionMember.id)).where(
and_(
PositionMember.position_id == position_id,
PositionMember.is_deleted == False
)
)
)
member_count = member_count_result.scalar() or 0
warning = ""
if member_count > 0:
warning = f"注意:该岗位当前有 {member_count} 名成员,删除后这些成员将不再属于此岗位"
return ResponseModel(data={"deletable": True, "reason": "", "warning": warning, "member_count": member_count})
@router.delete("/{position_id}")
async def delete_position(
position_id: int, current_user=Depends(require_admin), db: AsyncSession = Depends(get_db)
) -> ResponseModel:
obj = await db.get(Position, position_id)
if not obj or obj.is_deleted:
return ResponseModel(code=404, message="岗位不存在")
# 检查是否有子岗位
child_count_result = await db.execute(
select(func.count(Position.id)).where(
and_(
Position.parent_id == position_id,
Position.is_deleted == False
)
)
)
child_count = child_count_result.scalar() or 0
if child_count > 0:
return ResponseModel(
code=400,
message=f"该岗位下有 {child_count} 个子岗位,请先删除或移动子岗位"
)
# 软删除岗位成员关联
await db.execute(
sa.update(PositionMember)
.where(PositionMember.position_id == position_id)
.values(is_deleted=True)
)
# 软删除岗位课程关联
await db.execute(
sa.update(PositionCourse)
.where(PositionCourse.position_id == position_id)
.values(is_deleted=True)
)
# 软删除岗位
obj.is_deleted = True
await db.commit()
return ResponseModel(message="岗位已删除")
# ========== 岗位成员管理 API ==========
@router.get("/{position_id}/members")
async def get_position_members(
position_id: int,
pagination: PaginationParams = Depends(),
keyword: Optional[str] = Query(None, description="搜索关键词"),
current_user=Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""获取岗位成员列表"""
# 验证岗位存在
position = await db.get(Position, position_id)
if not position or position.is_deleted:
return ResponseModel(code=404, message="岗位不存在")
# 构建查询
stmt = (
select(PositionMember, User)
.join(User, PositionMember.user_id == User.id)
.where(
and_(
PositionMember.position_id == position_id,
PositionMember.is_deleted == False,
User.is_deleted == False
)
)
)
# 关键词搜索
if keyword:
like = f"%{keyword}%"
stmt = stmt.where(
(User.username.ilike(like)) |
(User.full_name.ilike(like)) |
(User.email.ilike(like))
)
# 执行查询
result = await db.execute(stmt)
rows = result.all()
total = len(rows)
sliced = rows[pagination.offset : pagination.offset + pagination.limit]
# 格式化数据
items = []
for pm, user in sliced:
items.append({
"id": pm.id,
"user_id": user.id,
"username": user.username,
"full_name": user.full_name,
"email": user.email,
"phone": user.phone,
"role": pm.role,
"joined_at": pm.joined_at.isoformat() if pm.joined_at else None,
"user_role": user.role, # 系统角色
"is_active": user.is_active,
})
return ResponseModel(
message="获取成员列表成功",
data={
"items": items,
"total": total,
"page": pagination.page,
"page_size": pagination.page_size,
"pages": (total + pagination.page_size - 1) // pagination.page_size if pagination.page_size else 1,
}
)
@router.post("/{position_id}/members")
async def add_position_members(
position_id: int,
payload: dict,
current_user=Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""批量添加岗位成员"""
# 验证岗位存在
position = await db.get(Position, position_id)
if not position or position.is_deleted:
return ResponseModel(code=404, message="岗位不存在")
user_ids = payload.get("user_ids", [])
if not user_ids:
return ResponseModel(code=400, message="请选择要添加的用户")
# 验证用户存在
users = await db.execute(
select(User).where(
and_(
User.id.in_(user_ids),
User.is_deleted == False
)
)
)
valid_users = {u.id: u for u in users.scalars().all()}
if len(valid_users) != len(user_ids):
invalid_ids = set(user_ids) - set(valid_users.keys())
return ResponseModel(code=400, message=f"部分用户不存在: {invalid_ids}")
# 检查是否已存在
existing = await db.execute(
select(PositionMember).where(
and_(
PositionMember.position_id == position_id,
PositionMember.user_id.in_(user_ids),
PositionMember.is_deleted == False
)
)
)
existing_user_ids = {pm.user_id for pm in existing.scalars().all()}
# 添加新成员
added_count = 0
for user_id in user_ids:
if user_id not in existing_user_ids:
member = PositionMember(
position_id=position_id,
user_id=user_id,
role=payload.get("role")
)
db.add(member)
added_count += 1
await db.commit()
return ResponseModel(
message=f"成功添加 {added_count} 个成员",
data={"added_count": added_count}
)
@router.delete("/{position_id}/members/{user_id}")
async def remove_position_member(
position_id: int,
user_id: int,
current_user=Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""移除岗位成员"""
# 查找成员关系
member = await db.execute(
select(PositionMember).where(
and_(
PositionMember.position_id == position_id,
PositionMember.user_id == user_id,
PositionMember.is_deleted == False
)
)
)
member = member.scalar_one_or_none()
if not member:
return ResponseModel(code=404, message="成员关系不存在")
# 软删除
member.is_deleted = True
await db.commit()
return ResponseModel(message="成员已移除")
# ========== 岗位课程管理 API ==========
@router.get("/{position_id}/courses")
async def get_position_courses(
position_id: int,
course_type: Optional[str] = Query(None, description="课程类型required/optional"),
current_user=Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""获取岗位课程列表"""
# 验证岗位存在
position = await db.get(Position, position_id)
if not position or position.is_deleted:
return ResponseModel(code=404, message="岗位不存在")
# 构建查询
stmt = (
select(PositionCourse, Course)
.join(Course, PositionCourse.course_id == Course.id)
.where(
and_(
PositionCourse.position_id == position_id,
PositionCourse.is_deleted == False,
Course.is_deleted == False
)
)
)
# 课程类型筛选
if course_type:
stmt = stmt.where(PositionCourse.course_type == course_type)
# 按优先级排序
stmt = stmt.order_by(PositionCourse.priority, PositionCourse.id)
# 执行查询
result = await db.execute(stmt)
rows = result.all()
# 格式化数据
items = []
for pc, course in rows:
items.append({
"id": pc.id,
"course_id": course.id,
"course_name": course.name,
"course_description": course.description,
"course_category": course.category,
"course_status": course.status,
"course_duration_hours": course.duration_hours,
"course_difficulty_level": course.difficulty_level,
"course_type": pc.course_type,
"priority": pc.priority,
"created_at": pc.created_at.isoformat() if pc.created_at else None,
})
# 统计
stats = {
"total": len(items),
"required_count": sum(1 for item in items if item["course_type"] == "required"),
"optional_count": sum(1 for item in items if item["course_type"] == "optional"),
}
return ResponseModel(
message="获取课程列表成功",
data={
"items": items,
"stats": stats
}
)
@router.post("/{position_id}/courses")
async def add_position_courses(
position_id: int,
payload: dict,
current_user=Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""批量添加岗位课程"""
# 验证岗位存在
position = await db.get(Position, position_id)
if not position or position.is_deleted:
return ResponseModel(code=404, message="岗位不存在")
course_ids = payload.get("course_ids", [])
if not course_ids:
return ResponseModel(code=400, message="请选择要添加的课程")
course_type = payload.get("course_type", "required")
if course_type not in ["required", "optional"]:
return ResponseModel(code=400, message="课程类型无效")
# 验证课程存在
courses = await db.execute(
select(Course).where(
and_(
Course.id.in_(course_ids),
Course.is_deleted == False
)
)
)
valid_courses = {c.id: c for c in courses.scalars().all()}
if len(valid_courses) != len(course_ids):
invalid_ids = set(course_ids) - set(valid_courses.keys())
return ResponseModel(code=400, message=f"部分课程不存在: {invalid_ids}")
# 检查是否已存在
existing = await db.execute(
select(PositionCourse).where(
and_(
PositionCourse.position_id == position_id,
PositionCourse.course_id.in_(course_ids),
PositionCourse.is_deleted == False
)
)
)
existing_course_ids = {pc.course_id for pc in existing.scalars().all()}
# 获取当前最大优先级
max_priority_result = await db.execute(
select(sa.func.max(PositionCourse.priority)).where(
and_(
PositionCourse.position_id == position_id,
PositionCourse.is_deleted == False
)
)
)
max_priority = max_priority_result.scalar() or 0
# 添加新课程
added_count = 0
for idx, course_id in enumerate(course_ids):
if course_id not in existing_course_ids:
pc = PositionCourse(
position_id=position_id,
course_id=course_id,
course_type=course_type,
priority=max_priority + idx + 1,
)
db.add(pc)
added_count += 1
await db.commit()
return ResponseModel(
message=f"成功添加 {added_count} 门课程",
data={"added_count": added_count}
)
@router.put("/{position_id}/courses/{pc_id}")
async def update_position_course(
position_id: int,
pc_id: int,
payload: dict,
current_user=Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""更新岗位课程设置"""
# 查找课程关系
pc = await db.execute(
select(PositionCourse).where(
and_(
PositionCourse.id == pc_id,
PositionCourse.position_id == position_id,
PositionCourse.is_deleted == False
)
)
)
pc = pc.scalar_one_or_none()
if not pc:
return ResponseModel(code=404, message="课程关系不存在")
# 更新课程类型
if "course_type" in payload:
course_type = payload["course_type"]
if course_type not in ["required", "optional"]:
return ResponseModel(code=400, message="课程类型无效")
pc.course_type = course_type
# 更新优先级
if "priority" in payload:
pc.priority = payload["priority"]
# PositionCourse 未继承审计字段,避免写入不存在字段
await db.commit()
return ResponseModel(message="更新成功")
@router.delete("/{position_id}/courses/{course_id}")
async def remove_position_course(
position_id: int,
course_id: int,
current_user=Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""移除岗位课程"""
# 查找课程关系
pc = await db.execute(
select(PositionCourse).where(
and_(
PositionCourse.position_id == position_id,
PositionCourse.course_id == course_id,
PositionCourse.is_deleted == False
)
)
)
pc = pc.scalar_one_or_none()
if not pc:
return ResponseModel(code=404, message="课程关系不存在")
# 软删除
pc.is_deleted = True
# PositionCourse 未继承审计字段,避免写入不存在字段
await db.commit()
return ResponseModel(message="课程已移除")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,285 @@
"""
文件预览API
提供课程资料的在线预览功能
"""
import logging
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.core.deps import get_db, get_current_user
from app.schemas.base import ResponseModel
from app.core.config import settings
from app.models.user import User
from app.models.course import CourseMaterial
from app.services.document_converter import document_converter
logger = logging.getLogger(__name__)
router = APIRouter()
class PreviewType:
"""预览类型常量
支持格式TXT、Markdown、MDX、PDF、HTML、Excel、Word、CSV、VTT、Properties
"""
PDF = "pdf"
TEXT = "text"
HTML = "html"
EXCEL_HTML = "excel_html" # Excel转HTML预览
VIDEO = "video"
AUDIO = "audio"
IMAGE = "image"
DOWNLOAD = "download"
# 文件类型到预览类型的映射
FILE_TYPE_MAPPING = {
# PDF - 直接预览
'.pdf': PreviewType.PDF,
# 文本 - 直接显示内容
'.txt': PreviewType.TEXT,
'.md': PreviewType.TEXT,
'.mdx': PreviewType.TEXT,
'.csv': PreviewType.TEXT,
'.vtt': PreviewType.TEXT,
'.properties': PreviewType.TEXT,
# HTML - 在iframe中预览
'.html': PreviewType.HTML,
'.htm': PreviewType.HTML,
}
def get_preview_type(file_ext: str) -> str:
"""
根据文件扩展名获取预览类型
Args:
file_ext: 文件扩展名(带点,如 .pdf
Returns:
预览类型
"""
file_ext_lower = file_ext.lower()
# 直接映射的类型
if file_ext_lower in FILE_TYPE_MAPPING:
return FILE_TYPE_MAPPING[file_ext_lower]
# Excel文件使用HTML预览避免分页问题
if file_ext_lower in {'.xlsx', '.xls'}:
return PreviewType.EXCEL_HTML
# 其他Office文档需要转换为PDF预览
if document_converter.is_convertible(file_ext_lower):
return PreviewType.PDF
# 其他类型,只提供下载
return PreviewType.DOWNLOAD
def get_file_path_from_url(file_url: str) -> Optional[Path]:
"""
从文件URL获取本地文件路径
Args:
file_url: 文件URL如 /static/uploads/courses/1/xxx.pdf
Returns:
本地文件路径如果无效返回None
"""
try:
# 移除 /static/uploads/ 前缀
if file_url.startswith('/static/uploads/'):
relative_path = file_url.replace('/static/uploads/', '')
full_path = Path(settings.UPLOAD_PATH) / relative_path
return full_path
return None
except Exception:
return None
@router.get("/material/{material_id}", response_model=ResponseModel[dict])
async def get_material_preview(
material_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取资料预览信息
Args:
material_id: 资料ID
Returns:
预览信息包括预览类型、预览URL等
"""
try:
# 查询资料信息
stmt = select(CourseMaterial).where(
CourseMaterial.id == material_id,
CourseMaterial.is_deleted == False
)
result = await db.execute(stmt)
material = result.scalar_one_or_none()
if not material:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="资料不存在"
)
# TODO: 权限检查 - 确认当前用户是否有权访问该课程的资料
# 可以通过查询 position_courses 表和用户的岗位关系来判断
# 获取文件扩展名
file_ext = Path(material.name).suffix.lower()
# 确定预览类型
preview_type = get_preview_type(file_ext)
logger.info(
f"资料预览请求 - material_id: {material_id}, "
f"file_type: {file_ext}, preview_type: {preview_type}, "
f"user_id: {current_user.id}"
)
# 构建响应数据
response_data = {
"preview_type": preview_type,
"file_name": material.name,
"original_url": material.file_url,
"file_size": material.file_size,
}
# 根据预览类型处理
if preview_type == PreviewType.TEXT:
# 文本类型,读取文件内容
file_path = get_file_path_from_url(material.file_url)
if file_path and file_path.exists():
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
response_data["content"] = content
response_data["preview_url"] = None
except Exception as e:
logger.error(f"读取文本文件失败: {str(e)}")
# 读取失败,改为下载模式
response_data["preview_type"] = PreviewType.DOWNLOAD
response_data["preview_url"] = material.file_url
else:
response_data["preview_type"] = PreviewType.DOWNLOAD
response_data["preview_url"] = material.file_url
elif preview_type == PreviewType.EXCEL_HTML:
# Excel文件转换为HTML预览
file_path = get_file_path_from_url(material.file_url)
if file_path and file_path.exists():
converted_url = document_converter.convert_excel_to_html(
str(file_path),
material.course_id,
material.id
)
if converted_url:
response_data["preview_url"] = converted_url
response_data["preview_type"] = "html" # 前端使用html类型渲染
response_data["is_converted"] = True
else:
logger.warning(f"Excel转HTML失败改为下载模式 - material_id: {material_id}")
response_data["preview_type"] = PreviewType.DOWNLOAD
response_data["preview_url"] = material.file_url
response_data["is_converted"] = False
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文件不存在"
)
elif preview_type == PreviewType.PDF and document_converter.is_convertible(file_ext):
# Office文档需要转换为PDF
file_path = get_file_path_from_url(material.file_url)
if file_path and file_path.exists():
# 执行转换
converted_url = document_converter.convert_to_pdf(
str(file_path),
material.course_id,
material.id
)
if converted_url:
response_data["preview_url"] = converted_url
response_data["is_converted"] = True
else:
# 转换失败,改为下载模式
logger.warning(f"文档转换失败,改为下载模式 - material_id: {material_id}")
response_data["preview_type"] = PreviewType.DOWNLOAD
response_data["preview_url"] = material.file_url
response_data["is_converted"] = False
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文件不存在"
)
else:
# 其他类型直接返回原始URL
response_data["preview_url"] = material.file_url
return ResponseModel(data=response_data, message="获取预览信息成功")
except HTTPException:
raise
except Exception as e:
logger.error(f"获取资料预览信息失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取预览信息失败"
)
@router.get("/check-converter", response_model=ResponseModel[dict])
async def check_converter_status(
current_user: User = Depends(get_current_user),
):
"""
检查文档转换服务状态(用于调试)
Returns:
转换服务状态信息
"""
try:
import subprocess
# 检查 LibreOffice 是否安装
try:
result = subprocess.run(
['libreoffice', '--version'],
capture_output=True,
text=True,
timeout=5
)
libreoffice_installed = result.returncode == 0
libreoffice_version = result.stdout.strip() if libreoffice_installed else None
except Exception:
libreoffice_installed = False
libreoffice_version = None
return ResponseModel(
data={
"libreoffice_installed": libreoffice_installed,
"libreoffice_version": libreoffice_version,
"supported_formats": list(document_converter.SUPPORTED_FORMATS),
"converted_path": str(document_converter.converted_path),
},
message="转换服务状态检查完成"
)
except Exception as e:
logger.error(f"检查转换服务状态失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="检查转换服务状态失败"
)

311
backend/app/api/v1/scrm.py Normal file
View File

@@ -0,0 +1,311 @@
"""
SCRM 系统对接 API 路由
提供给 SCRM 系统调用的数据查询接口
认证方式Bearer Token (SCRM_API_KEY)
"""
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, verify_scrm_api_key
from app.services.scrm_service import SCRMService
from app.schemas.scrm import (
EmployeePositionResponse,
EmployeePositionData,
PositionCoursesResponse,
PositionCoursesData,
KnowledgePointSearchRequest,
KnowledgePointSearchResponse,
KnowledgePointSearchData,
KnowledgePointDetailResponse,
KnowledgePointDetailData,
SCRMErrorResponse,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/scrm", tags=["scrm"])
# ==================== 1. 获取员工岗位 ====================
@router.get(
"/employees/{userid}/position",
response_model=EmployeePositionResponse,
summary="获取员工岗位通过userid",
description="根据企微 userid 查询员工在考陪练系统中的岗位信息",
responses={
200: {"model": EmployeePositionResponse, "description": "成功"},
401: {"model": SCRMErrorResponse, "description": "认证失败"},
404: {"model": SCRMErrorResponse, "description": "员工不存在"},
}
)
async def get_employee_position_by_userid(
userid: str,
_: bool = Depends(verify_scrm_api_key),
db: AsyncSession = Depends(get_db)
):
"""
获取员工岗位通过企微userid
- **userid**: 企微员工 userid
"""
service = SCRMService(db)
result = await service.get_employee_position(userid=userid)
if result is None:
raise HTTPException(
status_code=404,
detail={
"code": 404,
"message": "员工不存在",
"data": None
}
)
# 检查是否有多个匹配结果
if result.get("multiple_matches"):
return {
"code": 0,
"message": f"找到 {result['count']} 个匹配的员工,请确认",
"data": result
}
return EmployeePositionResponse(
code=0,
message="success",
data=EmployeePositionData(**result)
)
@router.get(
"/employees/search/by-name",
summary="获取员工岗位(通过姓名搜索)",
description="根据员工姓名查询员工在考陪练系统中的岗位信息,支持精确匹配和模糊匹配",
responses={
200: {"description": "成功"},
401: {"model": SCRMErrorResponse, "description": "认证失败"},
404: {"model": SCRMErrorResponse, "description": "员工不存在"},
}
)
async def get_employee_position_by_name(
name: str = Query(..., description="员工姓名,支持精确匹配和模糊匹配"),
_: bool = Depends(verify_scrm_api_key),
db: AsyncSession = Depends(get_db)
):
"""
获取员工岗位(通过姓名搜索)
- **name**: 员工姓名(必填),优先精确匹配,无结果时模糊匹配
注意:如果有多个同名员工,会返回员工列表供确认
"""
service = SCRMService(db)
result = await service.get_employee_position(name=name)
if result is None:
raise HTTPException(
status_code=404,
detail={
"code": 404,
"message": f"未找到姓名包含 '{name}' 的员工",
"data": None
}
)
# 检查是否有多个匹配结果
if result.get("multiple_matches"):
return {
"code": 0,
"message": f"找到 {result['count']} 个匹配的员工,请确认后使用 employee_id 精确查询",
"data": result
}
return EmployeePositionResponse(
code=0,
message="success",
data=EmployeePositionData(**result)
)
@router.get(
"/employees/by-id/{employee_id}/position",
response_model=EmployeePositionResponse,
summary="获取员工岗位通过员工ID",
description="根据员工ID精确查询员工岗位信息用于多个同名员工时的精确查询",
responses={
200: {"model": EmployeePositionResponse, "description": "成功"},
401: {"model": SCRMErrorResponse, "description": "认证失败"},
404: {"model": SCRMErrorResponse, "description": "员工不存在"},
}
)
async def get_employee_position_by_id(
employee_id: int,
_: bool = Depends(verify_scrm_api_key),
db: AsyncSession = Depends(get_db)
):
"""
获取员工岗位通过员工ID精确查询
- **employee_id**: 员工ID考陪练系统用户ID
适用场景:通过姓名搜索返回多个匹配结果后,使用此接口精确查询
"""
service = SCRMService(db)
result = await service.get_employee_position_by_id(employee_id)
if result is None:
raise HTTPException(
status_code=404,
detail={
"code": 404,
"message": "员工不存在",
"data": None
}
)
return EmployeePositionResponse(
code=0,
message="success",
data=EmployeePositionData(**result)
)
# ==================== 2. 获取岗位课程列表 ====================
@router.get(
"/positions/{position_id}/courses",
response_model=PositionCoursesResponse,
summary="获取岗位课程列表",
description="获取指定岗位的必修/选修课程列表",
responses={
200: {"model": PositionCoursesResponse, "description": "成功"},
401: {"model": SCRMErrorResponse, "description": "认证失败"},
404: {"model": SCRMErrorResponse, "description": "岗位不存在"},
}
)
async def get_position_courses(
position_id: int,
course_type: Optional[str] = Query(
default="all",
description="课程类型required/optional/all",
regex="^(required|optional|all)$"
),
_: bool = Depends(verify_scrm_api_key),
db: AsyncSession = Depends(get_db)
):
"""
获取岗位课程列表
- **position_id**: 岗位ID
- **course_type**: 课程类型筛选required/optional/all默认 all
"""
service = SCRMService(db)
result = await service.get_position_courses(position_id, course_type)
if result is None:
raise HTTPException(
status_code=404,
detail={
"code": 40002,
"message": "position_id 不存在",
"data": None
}
)
return PositionCoursesResponse(
code=0,
message="success",
data=PositionCoursesData(**result)
)
# ==================== 3. 搜索知识点 ====================
@router.post(
"/knowledge-points/search",
response_model=KnowledgePointSearchResponse,
summary="搜索知识点",
description="根据关键词和岗位搜索匹配的知识点",
responses={
200: {"model": KnowledgePointSearchResponse, "description": "成功"},
401: {"model": SCRMErrorResponse, "description": "认证失败"},
400: {"model": SCRMErrorResponse, "description": "请求参数错误"},
}
)
async def search_knowledge_points(
request: KnowledgePointSearchRequest,
_: bool = Depends(verify_scrm_api_key),
db: AsyncSession = Depends(get_db)
):
"""
搜索知识点
- **keywords**: 搜索关键词列表(必填)
- **position_id**: 岗位ID用于优先排序可选
- **course_ids**: 限定课程范围(可选)
- **knowledge_type**: 知识点类型筛选(可选)
- **limit**: 返回数量默认10最大100
"""
service = SCRMService(db)
result = await service.search_knowledge_points(
keywords=request.keywords,
position_id=request.position_id,
course_ids=request.course_ids,
knowledge_type=request.knowledge_type,
limit=request.limit
)
return KnowledgePointSearchResponse(
code=0,
message="success",
data=KnowledgePointSearchData(**result)
)
# ==================== 4. 获取知识点详情 ====================
@router.get(
"/knowledge-points/{knowledge_point_id}",
response_model=KnowledgePointDetailResponse,
summary="获取知识点详情",
description="获取知识点的完整信息",
responses={
200: {"model": KnowledgePointDetailResponse, "description": "成功"},
401: {"model": SCRMErrorResponse, "description": "认证失败"},
404: {"model": SCRMErrorResponse, "description": "知识点不存在"},
}
)
async def get_knowledge_point_detail(
knowledge_point_id: int,
_: bool = Depends(verify_scrm_api_key),
db: AsyncSession = Depends(get_db)
):
"""
获取知识点详情
- **knowledge_point_id**: 知识点ID
"""
service = SCRMService(db)
result = await service.get_knowledge_point_detail(knowledge_point_id)
if result is None:
raise HTTPException(
status_code=404,
detail={
"code": 40003,
"message": "knowledge_point_id 不存在",
"data": None
}
)
return KnowledgePointDetailResponse(
code=0,
message="success",
data=KnowledgePointDetailData(**result)
)

View File

@@ -0,0 +1,363 @@
"""
SQL 执行器 API - 用于内部服务调用
支持执行查询和写入操作的 SQL 语句
"""
import json
from typing import Any, Dict, List, Optional, Union
from datetime import datetime, date
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.engine.result import Result
import structlog
from app.core.deps import get_current_user, get_db
try:
from app.core.simple_auth import get_current_user_simple
except ImportError:
get_current_user_simple = None
from app.core.config import settings
from app.models.user import User
from app.schemas.base import ResponseModel
logger = structlog.get_logger(__name__)
router = APIRouter(tags=["SQL Executor"])
class SQLExecutorRequest:
"""SQL执行请求模型"""
def __init__(self, sql: str, params: Optional[Dict[str, Any]] = None):
self.sql = sql
self.params = params or {}
class DateTimeEncoder(json.JSONEncoder):
"""处理日期时间对象的 JSON 编码器"""
def default(self, obj):
if isinstance(obj, (datetime, date)):
return obj.isoformat()
return super().default(obj)
def serialize_row(row: Any) -> Union[Dict[str, Any], Any]:
"""序列化数据库行结果"""
if hasattr(row, '_mapping'):
# 处理 SQLAlchemy Row 对象
return dict(row._mapping)
elif hasattr(row, '__dict__'):
# 处理 ORM 对象
return {k: v for k, v in row.__dict__.items() if not k.startswith('_')}
else:
# 处理单值结果
return row
@router.post("/execute", response_model=ResponseModel)
async def execute_sql(
request: Dict[str, Any],
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
执行 SQL 语句
Args:
request: 包含 sql 和可选的 params 字段
- sql: SQL 语句
- params: 参数字典(可选)
Returns:
执行结果,包括:
- 查询操作:返回数据行
- 写入操作:返回影响的行数
安全说明:
- 需要用户身份验证
- 所有操作都会记录日志
- 建议在生产环境中限制可执行的 SQL 类型
"""
try:
# 提取参数
sql = request.get('sql', '').strip()
params = request.get('params', {})
if not sql:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="SQL 语句不能为空"
)
# 记录 SQL 执行日志
logger.info(
"sql_execution_request",
user_id=current_user.id,
username=current_user.username,
sql_type=sql.split()[0].upper() if sql else "UNKNOWN",
sql_length=len(sql),
has_params=bool(params)
)
# 判断 SQL 类型
sql_upper = sql.upper().strip()
is_select = sql_upper.startswith('SELECT')
is_show = sql_upper.startswith('SHOW')
is_describe = sql_upper.startswith(('DESCRIBE', 'DESC'))
is_query = is_select or is_show or is_describe
# 执行 SQL
try:
result = await db.execute(text(sql), params)
if is_query:
# 查询操作
rows = result.fetchall()
columns = list(result.keys()) if result.keys() else []
# 序列化结果
data = []
for row in rows:
serialized_row = serialize_row(row)
if isinstance(serialized_row, dict):
data.append(serialized_row)
else:
# 单列结果
data.append({columns[0] if columns else 'value': serialized_row})
# 使用自定义编码器处理日期时间
response_data = {
"type": "query",
"columns": columns,
"rows": json.loads(json.dumps(data, cls=DateTimeEncoder)),
"row_count": len(data)
}
logger.info(
"sql_query_success",
user_id=current_user.id,
row_count=len(data),
column_count=len(columns)
)
else:
# 写入操作
await db.commit()
affected_rows = result.rowcount
response_data = {
"type": "execute",
"affected_rows": affected_rows,
"success": True
}
logger.info(
"sql_execute_success",
user_id=current_user.id,
affected_rows=affected_rows
)
return ResponseModel(
code=200,
message="SQL 执行成功",
data=response_data
)
except Exception as e:
# 回滚事务
await db.rollback()
logger.error(
"sql_execution_error",
user_id=current_user.id,
sql_type=sql.split()[0].upper() if sql else "UNKNOWN",
error=str(e),
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"SQL 执行失败: {str(e)}"
)
except HTTPException:
raise
except Exception as e:
logger.error(
"sql_executor_error",
user_id=current_user.id,
error=str(e),
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"处理请求时发生错误: {str(e)}"
)
@router.post("/validate", response_model=ResponseModel)
async def validate_sql(
request: Dict[str, Any],
current_user: User = Depends(get_current_user)
) -> ResponseModel:
"""
验证 SQL 语句的语法(不执行)
Args:
request: 包含 sql 字段的请求
Returns:
验证结果
"""
try:
sql = request.get('sql', '').strip()
if not sql:
return ResponseModel(
code=400,
message="SQL 语句不能为空",
data={"valid": False, "error": "SQL 语句不能为空"}
)
# 基本的 SQL 验证
sql_upper = sql.upper().strip()
# 检查危险操作(可根据需要调整)
dangerous_keywords = ['DROP', 'TRUNCATE', 'DELETE FROM', 'UPDATE']
warnings = []
for keyword in dangerous_keywords:
if keyword in sql_upper:
warnings.append(f"包含危险操作: {keyword}")
return ResponseModel(
code=200,
message="SQL 验证完成",
data={
"valid": True,
"warnings": warnings,
"sql_type": sql_upper.split()[0] if sql_upper else "UNKNOWN"
}
)
except Exception as e:
logger.error(
"sql_validation_error",
user_id=current_user.id,
error=str(e)
)
return ResponseModel(
code=500,
message="SQL 验证失败",
data={"valid": False, "error": str(e)}
)
@router.get("/tables", response_model=ResponseModel)
async def get_tables(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
获取数据库中的所有表
Returns:
数据库表列表
"""
try:
result = await db.execute(text("SHOW TABLES"))
tables = [row[0] for row in result.fetchall()]
return ResponseModel(
code=200,
message="获取表列表成功",
data={
"tables": tables,
"count": len(tables)
}
)
except Exception as e:
logger.error(
"get_tables_error",
user_id=current_user.id,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取表列表失败: {str(e)}"
)
@router.get("/table/{table_name}/schema", response_model=ResponseModel)
async def get_table_schema(
table_name: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
获取指定表的结构信息
Args:
table_name: 表名
Returns:
表结构信息
"""
try:
# MySQL 的 DESCRIBE 不支持参数化,需要直接拼接
# 但为了安全,先验证表名
if not table_name.replace('_', '').isalnum():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的表名"
)
result = await db.execute(text(f"DESCRIBE {table_name}"))
columns = []
for row in result.fetchall():
columns.append({
"field": row[0],
"type": row[1],
"null": row[2],
"key": row[3],
"default": row[4],
"extra": row[5]
})
return ResponseModel(
code=200,
message="获取表结构成功",
data={
"table_name": table_name,
"columns": columns,
"column_count": len(columns)
}
)
except Exception as e:
logger.error(
"get_table_schema_error",
user_id=current_user.id,
table_name=table_name,
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取表结构失败: {str(e)}"
)
# 简化认证版本的端点(如果启用)
if get_current_user_simple:
@router.post("/execute-simple", response_model=ResponseModel)
async def execute_sql_simple(
request: Dict[str, Any],
current_user: User = Depends(get_current_user_simple),
db: AsyncSession = Depends(get_db)
) -> ResponseModel:
"""
执行 SQL 语句(简化认证版本)
支持 API Key 和 Token 两种认证方式,专为内部服务设计。
"""
return await execute_sql(request, current_user, db)

View File

@@ -0,0 +1,5 @@
"""
SQL 执行器 API - 简化认证版本(已删除,功能已整合到主文件)
"""
# 此文件的功能已经整合到 sql_executor.py 中
# 请使用 /api/v1/sql/execute-simple 端点

View File

@@ -0,0 +1,238 @@
"""
统计分析API路由
"""
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, get_current_user
from app.models.user import User
from app.schemas.base import ResponseModel
from app.services.statistics_service import StatisticsService
from app.core.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/statistics", tags=["statistics"])
@router.get("/key-metrics", response_model=ResponseModel)
async def get_key_metrics(
course_id: Optional[int] = Query(None, description="课程ID不传则统计全部课程"),
period: str = Query("month", description="时间范围: week/month/quarter/halfYear/year"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取关键指标
返回:
- learningEfficiency: 学习效率
- knowledgeCoverage: 知识覆盖率
- avgTimePerQuestion: 平均用时
- progressSpeed: 进步速度
"""
try:
metrics = await StatisticsService.get_key_metrics(
db=db,
user_id=current_user.id,
course_id=course_id,
period=period
)
return ResponseModel(
code=200,
message="获取关键指标成功",
data=metrics
)
except Exception as e:
logger.error(f"获取关键指标失败: {e}", exc_info=True)
return ResponseModel(
code=500,
message=f"获取关键指标失败: {str(e)}"
)
@router.get("/score-distribution", response_model=ResponseModel)
async def get_score_distribution(
course_id: Optional[int] = Query(None, description="课程ID不传则统计全部课程"),
period: str = Query("month", description="时间范围: week/month/quarter/halfYear/year"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取成绩分布统计
返回各分数段的考试数量:
- excellent: 优秀(90-100)
- good: 良好(80-89)
- medium: 中等(70-79)
- pass: 及格(60-69)
- fail: 不及格(<60)
"""
try:
distribution = await StatisticsService.get_score_distribution(
db=db,
user_id=current_user.id,
course_id=course_id,
period=period
)
return ResponseModel(
code=200,
message="获取成绩分布成功",
data=distribution
)
except Exception as e:
logger.error(f"获取成绩分布失败: {e}", exc_info=True)
return ResponseModel(
code=500,
message=f"获取成绩分布失败: {str(e)}"
)
@router.get("/difficulty-analysis", response_model=ResponseModel)
async def get_difficulty_analysis(
course_id: Optional[int] = Query(None, description="课程ID不传则统计全部课程"),
period: str = Query("month", description="时间范围: week/month/quarter/halfYear/year"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取题目难度分析
返回各难度题目的正确率:
- 简单题
- 中等题
- 困难题
- 综合题
- 应用题
"""
try:
analysis = await StatisticsService.get_difficulty_analysis(
db=db,
user_id=current_user.id,
course_id=course_id,
period=period
)
return ResponseModel(
code=200,
message="获取难度分析成功",
data=analysis
)
except Exception as e:
logger.error(f"获取难度分析失败: {e}", exc_info=True)
return ResponseModel(
code=500,
message=f"获取难度分析失败: {str(e)}"
)
@router.get("/knowledge-mastery", response_model=ResponseModel)
async def get_knowledge_mastery(
course_id: Optional[int] = Query(None, description="课程ID不传则统计全部课程"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取知识点掌握度
返回知识点列表及其掌握度:
- name: 知识点名称
- mastery: 掌握度0-100
"""
try:
mastery = await StatisticsService.get_knowledge_mastery(
db=db,
user_id=current_user.id,
course_id=course_id
)
return ResponseModel(
code=200,
message="获取知识点掌握度成功",
data=mastery
)
except Exception as e:
logger.error(f"获取知识点掌握度失败: {e}", exc_info=True)
return ResponseModel(
code=500,
message=f"获取知识点掌握度失败: {str(e)}"
)
@router.get("/study-time", response_model=ResponseModel)
async def get_study_time_stats(
course_id: Optional[int] = Query(None, description="课程ID不传则统计全部课程"),
period: str = Query("month", description="时间范围: week/month/quarter/halfYear/year"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取学习时长统计
返回学习时长和练习时长的日期分布:
- labels: 日期标签列表
- studyTime: 学习时长列表(小时)
- practiceTime: 练习时长列表(小时)
"""
try:
time_stats = await StatisticsService.get_study_time_stats(
db=db,
user_id=current_user.id,
course_id=course_id,
period=period
)
return ResponseModel(
code=200,
message="获取学习时长统计成功",
data=time_stats
)
except Exception as e:
logger.error(f"获取学习时长统计失败: {e}", exc_info=True)
return ResponseModel(
code=500,
message=f"获取学习时长统计失败: {str(e)}"
)
@router.get("/detail", response_model=ResponseModel)
async def get_detail_data(
course_id: Optional[int] = Query(None, description="课程ID不传则统计全部课程"),
period: str = Query("month", description="时间范围: week/month/quarter/halfYear/year"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取详细统计数据(按日期)
返回每日详细统计数据:
- date: 日期
- examCount: 考试次数
- avgScore: 平均分
- studyTime: 学习时长(小时)
- questionCount: 练习题数
- accuracy: 正确率
- improvement: 进步指数
"""
try:
detail = await StatisticsService.get_detail_data(
db=db,
user_id=current_user.id,
course_id=course_id,
period=period
)
return ResponseModel(
code=200,
message="获取详细数据成功",
data=detail
)
except Exception as e:
logger.error(f"获取详细数据失败: {e}", exc_info=True)
return ResponseModel(
code=500,
message=f"获取详细数据失败: {str(e)}"
)

View File

@@ -0,0 +1,139 @@
"""
系统API - 供外部服务回调使用
"""
import logging
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Header
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel, Field
from app.core.deps import get_db
from app.schemas.base import ResponseModel
from app.schemas.course import KnowledgePointCreate
from app.services.course_service import knowledge_point_service, course_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/system")
class KnowledgePointData(BaseModel):
"""知识点数据模型"""
name: str = Field(..., description="知识点名称")
description: str = Field(default="", description="知识点描述")
type: str = Field(default="理论知识", description="知识点类型")
source: int = Field(default=1, description="来源0=手动1=AI分析")
topic_relation: Optional[str] = Field(None, description="与主题的关系描述")
class KnowledgeCallbackRequest(BaseModel):
"""知识点回调请求模型(已弃用,保留向后兼容)"""
course_id: int = Field(..., description="课程ID")
material_id: int = Field(..., description="资料ID")
knowledge_points: List[KnowledgePointData] = Field(..., description="知识点列表")
@router.post("/knowledge", response_model=ResponseModel[Dict[str, Any]])
async def create_knowledge_points_callback(
request: KnowledgeCallbackRequest,
authorization: str = Header(None),
db: AsyncSession = Depends(get_db),
):
"""
创建知识点回调接口(已弃用)
注意:此接口已弃用,知识点分析现使用 Python 原生实现。
保留此接口仅为向后兼容。
"""
try:
# API密钥验证已弃用的接口保留向后兼容
expected_token = "Bearer callback-token-2025"
if authorization != expected_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的授权令牌"
)
# 验证课程是否存在
course = await course_service.get_by_id(db, request.course_id)
if not course:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"课程 {request.course_id} 不存在"
)
# 验证资料是否存在
materials = await course_service.get_course_materials(db, course_id=request.course_id)
material = next((m for m in materials if m.id == request.material_id), None)
if not material:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"资料 {request.material_id} 不存在"
)
# 创建知识点
created_points = []
for kp_data in request.knowledge_points:
try:
knowledge_point_create = KnowledgePointCreate(
name=kp_data.name,
description=kp_data.description,
type=kp_data.type,
source=kp_data.source, # AI分析来源=1
topic_relation=kp_data.topic_relation,
material_id=request.material_id # 关联资料ID
)
# 使用系统用户ID (假设为1或者可以配置)
system_user_id = 1
knowledge_point = await knowledge_point_service.create_knowledge_point(
db=db,
course_id=request.course_id,
point_in=knowledge_point_create,
created_by=system_user_id
)
created_points.append({
"id": knowledge_point.id,
"name": knowledge_point.name,
"description": knowledge_point.description,
"type": knowledge_point.type,
"source": knowledge_point.source,
"material_id": knowledge_point.material_id
})
except Exception as e:
logger.error(
f"创建知识点失败 - name: {kp_data.name}, error: {str(e)}"
)
# 继续处理其他知识点,不因为单个失败而中断
continue
logger.info(
f"知识点回调成功 - course_id: {request.course_id}, material_id: {request.material_id}, created_points: {len(created_points)}"
)
return ResponseModel(
data={
"course_id": request.course_id,
"material_id": request.material_id,
"knowledge_points_count": len(created_points),
"knowledge_points": created_points
},
message=f"成功创建 {len(created_points)} 个知识点"
)
except HTTPException:
raise
except Exception as e:
logger.error(
f"知识点回调处理失败 - course_id: {request.course_id}, material_id: {request.material_id}, error: {str(e)}",
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="知识点创建失败"
)

View File

@@ -0,0 +1,184 @@
"""
系统日志 API
提供日志查询、筛选、详情查看等功能
"""
import logging
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, get_current_user
from app.models.user import User
from app.schemas.base import ResponseModel
from app.schemas.system_log import (
SystemLogCreate,
SystemLogResponse,
SystemLogQuery,
SystemLogListResponse
)
from app.services.system_log_service import system_log_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/logs")
@router.get("", response_model=ResponseModel[SystemLogListResponse])
async def get_system_logs(
level: Optional[str] = Query(None, description="日志级别筛选"),
type: Optional[str] = Query(None, description="日志类型筛选"),
user: Optional[str] = Query(None, description="用户筛选"),
keyword: Optional[str] = Query(None, description="关键词搜索"),
start_date: Optional[datetime] = Query(None, description="开始日期"),
end_date: Optional[datetime] = Query(None, description="结束日期"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取系统日志列表
支持按级别、类型、用户、关键词、日期范围筛选
仅管理员可访问
"""
try:
# 权限检查:仅管理员可查看系统日志
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="无权限访问系统日志")
# 构建查询参数
query_params = SystemLogQuery(
level=level,
type=type,
user=user,
keyword=keyword,
start_date=start_date,
end_date=end_date,
page=page,
page_size=page_size
)
# 查询日志
logs, total = await system_log_service.get_logs(db, query_params)
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 转换为响应格式
log_responses = [SystemLogResponse.model_validate(log) for log in logs]
response_data = SystemLogListResponse(
items=log_responses,
total=total,
page=page,
page_size=page_size,
total_pages=total_pages
)
return ResponseModel(
code=200,
message="获取系统日志成功",
data=response_data
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取系统日志失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取系统日志失败: {str(e)}")
@router.get("/{log_id}", response_model=ResponseModel[SystemLogResponse])
async def get_log_detail(
log_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取日志详情
仅管理员可访问
"""
try:
# 权限检查
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="无权限访问系统日志")
# 查询日志
log = await system_log_service.get_log_by_id(db, log_id)
if not log:
raise HTTPException(status_code=404, detail="日志不存在")
return ResponseModel(
code=200,
message="获取日志详情成功",
data=SystemLogResponse.model_validate(log)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"获取日志详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取日志详情失败: {str(e)}")
@router.post("", response_model=ResponseModel[SystemLogResponse])
async def create_system_log(
log_data: SystemLogCreate,
db: AsyncSession = Depends(get_db)
):
"""
创建系统日志内部API供系统各模块调用
注意:此接口不需要用户认证,但应该只供内部调用
"""
try:
log = await system_log_service.create_log(db, log_data)
return ResponseModel(
code=200,
message="创建日志成功",
data=SystemLogResponse.model_validate(log)
)
except Exception as e:
logger.error(f"创建日志失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"创建日志失败: {str(e)}")
@router.delete("/cleanup")
async def cleanup_old_logs(
before_days: int = Query(90, ge=1, description="删除多少天之前的日志"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
清理旧日志
仅管理员可访问
"""
try:
# 权限检查
if current_user.role != "admin":
raise HTTPException(status_code=403, detail="无权限执行此操作")
# 计算截止日期
from datetime import timedelta
before_date = datetime.now() - timedelta(days=before_days)
# 删除旧日志
deleted_count = await system_log_service.delete_logs_before_date(db, before_date)
return ResponseModel(
code=200,
message=f"成功清理 {deleted_count} 条日志",
data={"deleted_count": deleted_count}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"清理日志失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"清理日志失败: {str(e)}")

228
backend/app/api/v1/tasks.py Normal file
View File

@@ -0,0 +1,228 @@
"""
任务管理API
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, get_current_user, require_admin_or_manager
from app.schemas.base import ResponseModel, PaginatedResponse
from app.schemas.task import TaskCreate, TaskUpdate, TaskResponse, TaskStatsResponse
from app.services.task_service import task_service
from app.services.system_log_service import system_log_service
from app.schemas.system_log import SystemLogCreate
from app.models.user import User
router = APIRouter(prefix="/manager/tasks", tags=["Tasks"], redirect_slashes=False)
@router.post("", response_model=ResponseModel[TaskResponse], summary="创建任务")
async def create_task(
task_in: TaskCreate,
request: Request,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_admin_or_manager)
):
"""创建新任务"""
task = await task_service.create_task(db, task_in, current_user.id)
# 记录任务创建日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="api",
message=f"创建任务: {task.title}",
user_id=current_user.id,
user=current_user.username,
ip=request.client.host if request.client else None,
path="/api/v1/manager/tasks",
method="POST",
user_agent=request.headers.get("user-agent")
)
)
# 构建响应
courses = [link.course.name for link in task.course_links]
return ResponseModel(
data=TaskResponse(
id=task.id,
title=task.title,
description=task.description,
priority=task.priority.value,
status=task.status.value,
creator_id=task.creator_id,
deadline=task.deadline,
requirements=task.requirements,
progress=task.progress,
created_at=task.created_at,
updated_at=task.updated_at,
courses=courses,
assigned_count=len(task.assignments),
completed_count=sum(1 for a in task.assignments if a.status.value == "completed")
)
)
@router.get("", response_model=ResponseModel[PaginatedResponse[TaskResponse]], summary="获取任务列表")
async def get_tasks(
status: Optional[str] = Query(None, description="任务状态筛选"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_admin_or_manager)
):
"""获取任务列表"""
tasks, total = await task_service.get_tasks(db, status, page, page_size)
# 构建响应
items = []
for task in tasks:
# 加载关联数据
task_detail = await task_service.get_task_detail(db, task.id)
if task_detail:
courses = [link.course.name for link in task_detail.course_links]
items.append(TaskResponse(
id=task.id,
title=task.title,
description=task.description,
priority=task.priority.value,
status=task.status.value,
creator_id=task.creator_id,
deadline=task.deadline,
requirements=task.requirements,
progress=task.progress,
created_at=task.created_at,
updated_at=task.updated_at,
courses=courses,
assigned_count=len(task_detail.assignments),
completed_count=sum(1 for a in task_detail.assignments if a.status.value == "completed")
))
return ResponseModel(
data=PaginatedResponse.create(
items=items,
total=total,
page=page,
page_size=page_size
)
)
@router.get("/stats", response_model=ResponseModel[TaskStatsResponse], summary="获取任务统计")
async def get_task_stats(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_admin_or_manager)
):
"""获取任务统计数据"""
stats = await task_service.get_task_stats(db)
return ResponseModel(data=stats)
@router.get("/{task_id}", response_model=ResponseModel[TaskResponse], summary="获取任务详情")
async def get_task(
task_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_admin_or_manager)
):
"""获取任务详情"""
task = await task_service.get_task_detail(db, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
courses = [link.course.name for link in task.course_links]
return ResponseModel(
data=TaskResponse(
id=task.id,
title=task.title,
description=task.description,
priority=task.priority.value,
status=task.status.value,
creator_id=task.creator_id,
deadline=task.deadline,
requirements=task.requirements,
progress=task.progress,
created_at=task.created_at,
updated_at=task.updated_at,
courses=courses,
assigned_count=len(task.assignments),
completed_count=sum(1 for a in task.assignments if a.status.value == "completed")
)
)
@router.put("/{task_id}", response_model=ResponseModel[TaskResponse], summary="更新任务")
async def update_task(
task_id: int,
task_in: TaskUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_admin_or_manager)
):
"""更新任务"""
task = await task_service.update_task(db, task_id, task_in)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
# 自动更新任务进度和状态
await task_service.update_task_status(db, task_id)
# 重新加载详情
task_detail = await task_service.get_task_detail(db, task.id)
courses = [link.course.name for link in task_detail.course_links] if task_detail else []
return ResponseModel(
data=TaskResponse(
id=task.id,
title=task.title,
description=task.description,
priority=task.priority.value,
status=task.status.value,
creator_id=task.creator_id,
deadline=task.deadline,
requirements=task.requirements,
progress=task.progress,
created_at=task.created_at,
updated_at=task.updated_at,
courses=courses,
assigned_count=len(task_detail.assignments) if task_detail else 0,
completed_count=sum(1 for a in task_detail.assignments if a.status.value == "completed") if task_detail else 0
)
)
@router.delete("/{task_id}", response_model=ResponseModel, summary="删除任务")
async def delete_task(
task_id: int,
request: Request,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_admin_or_manager)
):
"""删除任务"""
# 先获取任务信息用于日志
task_detail = await task_service.get_task_detail(db, task_id)
task_title = task_detail.title if task_detail else f"ID:{task_id}"
success = await task_service.delete_task(db, task_id)
if not success:
raise HTTPException(status_code=404, detail="任务不存在")
# 记录任务删除日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="api",
message=f"删除任务: {task_title}",
user_id=current_user.id,
user=current_user.username,
ip=request.client.host if request.client else None,
path=f"/api/v1/manager/tasks/{task_id}",
method="DELETE",
user_agent=request.headers.get("user-agent")
)
)
return ResponseModel(message="任务已删除")

View File

@@ -0,0 +1,750 @@
"""
团队看板 API 路由
提供团队概览、学习进度、排行榜、动态等数据
"""
import json
from datetime import datetime, timedelta
from typing import Any, Dict, List
from fastapi import APIRouter, Depends
from sqlalchemy import and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_current_active_user as get_current_user, get_db
from app.core.logger import logger
from app.models.course import Course
from app.models.exam import Exam
from app.models.position import Position
from app.models.position_member import PositionMember
from app.models.practice import PracticeReport, PracticeSession
from app.models.user import Team, User, UserTeam
from app.schemas.base import ResponseModel
router = APIRouter(prefix="/team/dashboard", tags=["team-dashboard"])
async def get_accessible_teams(
current_user: User,
db: AsyncSession
) -> List[int]:
"""获取用户可访问的团队ID列表"""
if current_user.role in ['admin', 'manager']:
# 管理员查看所有团队
stmt = select(Team.id).where(Team.is_deleted == False) # noqa: E712
result = await db.execute(stmt)
return [row[0] for row in result.all()]
else:
# 普通用户只查看自己的团队
stmt = select(UserTeam.team_id).where(UserTeam.user_id == current_user.id)
result = await db.execute(stmt)
return [row[0] for row in result.all()]
async def get_team_member_ids(
team_ids: List[int],
db: AsyncSession
) -> List[int]:
"""获取团队成员ID列表"""
if not team_ids:
return []
stmt = select(UserTeam.user_id).where(
UserTeam.team_id.in_(team_ids)
).distinct()
result = await db.execute(stmt)
return [row[0] for row in result.all()]
@router.get("/overview", response_model=ResponseModel)
async def get_team_overview(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取团队概览统计
返回团队总数、成员数、平均学习进度、平均成绩、课程完成率等
"""
try:
# 获取可访问的团队
team_ids = await get_accessible_teams(current_user, db)
# 获取团队成员ID
member_ids = await get_team_member_ids(team_ids, db)
# 统计团队数
team_count = len(team_ids)
# 统计成员数
member_count = len(member_ids)
# 计算平均考试成绩使用round1_score
avg_score = 0.0
if member_ids:
stmt = select(func.avg(Exam.round1_score)).where(
and_(
Exam.user_id.in_(member_ids),
Exam.round1_score.isnot(None),
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(stmt)
avg_score_value = result.scalar()
avg_score = float(avg_score_value) if avg_score_value else 0.0
# 计算平均学习进度(基于考试完成情况)
avg_progress = 0.0
if member_ids:
# 统计每个成员完成的考试数
stmt = select(func.count(Exam.id)).where(
and_(
Exam.user_id.in_(member_ids),
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(stmt)
completed_exams = result.scalar() or 0
# 假设每个成员应完成10个考试计算完成率作为进度
total_expected = member_count * 10
if total_expected > 0:
avg_progress = (completed_exams / total_expected) * 100
# 计算课程完成率
course_completion_rate = 0.0
if member_ids:
# 统计已完成的课程数(有考试记录且成绩>=60
stmt = select(func.count(func.distinct(Exam.course_id))).where(
and_(
Exam.user_id.in_(member_ids),
Exam.round1_score >= 60,
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(stmt)
completed_courses = result.scalar() or 0
# 统计总课程数
stmt = select(func.count(Course.id)).where(
and_(
Course.is_deleted == False, # noqa: E712
Course.status == 'published'
)
)
result = await db.execute(stmt)
total_courses = result.scalar() or 0
if total_courses > 0:
course_completion_rate = (completed_courses / total_courses) * 100
# 趋势数据(暂时返回固定值,后续可实现真实趋势计算)
trends = {
"member_trend": 0,
"progress_trend": 12.3 if avg_progress > 0 else 0,
"score_trend": 5.8 if avg_score > 0 else 0,
"completion_trend": -3.2 if course_completion_rate > 0 else 0
}
data = {
"team_count": team_count,
"member_count": member_count,
"avg_progress": round(avg_progress, 1),
"avg_score": round(avg_score, 1),
"course_completion_rate": round(course_completion_rate, 1),
"trends": trends
}
return ResponseModel(code=200, message="success", data=data)
except Exception as e:
logger.error(f"获取团队概览失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取团队概览失败: {str(e)}", data=None)
@router.get("/progress", response_model=ResponseModel)
async def get_progress_data(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取学习进度数据
返回Top 5成员的8周学习进度数据
"""
try:
# 获取可访问的团队
team_ids = await get_accessible_teams(current_user, db)
member_ids = await get_team_member_ids(team_ids, db)
if not member_ids:
return ResponseModel(
code=200,
message="success",
data={"members": [], "weeks": [], "data": []}
)
# 获取Top 5学习时长最高的成员
stmt = (
select(
User.id,
User.full_name,
func.sum(PracticeSession.duration_seconds).label('total_duration')
)
.join(PracticeSession, PracticeSession.user_id == User.id)
.where(
and_(
User.id.in_(member_ids),
PracticeSession.status == 'completed'
)
)
.group_by(User.id, User.full_name)
.order_by(func.sum(PracticeSession.duration_seconds).desc())
.limit(5)
)
result = await db.execute(stmt)
top_members = result.all()
if not top_members:
# 如果没有陪练记录按考试成绩选择Top 5
stmt = (
select(
User.id,
User.full_name,
func.avg(Exam.round1_score).label('avg_score')
)
.join(Exam, Exam.user_id == User.id)
.where(
and_(
User.id.in_(member_ids),
Exam.round1_score.isnot(None),
Exam.status.in_(['completed', 'submitted'])
)
)
.group_by(User.id, User.full_name)
.order_by(func.avg(Exam.round1_score).desc())
.limit(5)
)
result = await db.execute(stmt)
top_members = result.all()
# 生成周标签
weeks = [f"{i+1}" for i in range(8)]
# 为每个成员生成进度数据
members = []
data = []
for member in top_members:
member_name = member.full_name or f"用户{member.id}"
members.append(member_name)
# 查询该成员8周内的考试完成情况
eight_weeks_ago = datetime.now() - timedelta(weeks=8)
stmt = select(Exam).where(
and_(
Exam.user_id == member.id,
Exam.created_at >= eight_weeks_ago,
Exam.status.in_(['completed', 'submitted'])
)
).order_by(Exam.created_at)
result = await db.execute(stmt)
exams = result.scalars().all()
# 计算每周的进度0-100
values = []
for week in range(8):
week_start = datetime.now() - timedelta(weeks=8-week)
week_end = week_start + timedelta(weeks=1)
# 统计该周完成的考试数
week_exams = [
e for e in exams
if week_start <= e.created_at < week_end
]
# 进度 = 累计完成考试数 * 10假设每个考试代表10%进度)
cumulative_exams = len([e for e in exams if e.created_at < week_end])
progress = min(cumulative_exams * 10, 100)
values.append(progress)
data.append({"name": member_name, "values": values})
return ResponseModel(
code=200,
message="success",
data={"members": members, "weeks": weeks, "data": data}
)
except Exception as e:
logger.error(f"获取学习进度数据失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取学习进度数据失败: {str(e)}", data=None)
@router.get("/course-distribution", response_model=ResponseModel)
async def get_course_distribution(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取课程完成分布
返回已完成、进行中、未开始的课程数量
"""
try:
# 获取可访问的团队
team_ids = await get_accessible_teams(current_user, db)
member_ids = await get_team_member_ids(team_ids, db)
# 统计所有已发布的课程
stmt = select(func.count(Course.id)).where(
and_(
Course.is_deleted == False, # noqa: E712
Course.status == 'published'
)
)
result = await db.execute(stmt)
total_courses = result.scalar() or 0
if not member_ids or total_courses == 0:
return ResponseModel(
code=200,
message="success",
data={"completed": 0, "in_progress": 0, "not_started": 0}
)
# 统计已完成的课程(有及格成绩)
stmt = select(func.count(func.distinct(Exam.course_id))).where(
and_(
Exam.user_id.in_(member_ids),
Exam.round1_score >= 60,
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(stmt)
completed = result.scalar() or 0
# 统计进行中的课程(有考试记录但未及格)
stmt = select(func.count(func.distinct(Exam.course_id))).where(
and_(
Exam.user_id.in_(member_ids),
or_(
Exam.round1_score < 60,
Exam.status == 'started'
)
)
)
result = await db.execute(stmt)
in_progress = result.scalar() or 0
# 未开始 = 总数 - 已完成 - 进行中
not_started = max(0, total_courses - completed - in_progress)
data = {
"completed": completed,
"in_progress": in_progress,
"not_started": not_started
}
return ResponseModel(code=200, message="success", data=data)
except Exception as e:
logger.error(f"获取课程分布失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取课程分布失败: {str(e)}", data=None)
@router.get("/ability-analysis", response_model=ResponseModel)
async def get_ability_analysis(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取能力分析数据
返回团队能力雷达图数据和短板列表
"""
try:
# 获取可访问的团队
team_ids = await get_accessible_teams(current_user, db)
member_ids = await get_team_member_ids(team_ids, db)
if not member_ids:
return ResponseModel(
code=200,
message="success",
data={
"radar_data": {
"dimensions": [],
"values": []
},
"weaknesses": []
}
)
# 查询所有陪练报告的能力维度数据
# 需要通过PracticeSession关联因为PracticeReport没有user_id
stmt = (
select(PracticeReport.ability_dimensions)
.join(PracticeSession, PracticeSession.session_id == PracticeReport.session_id)
.where(PracticeSession.user_id.in_(member_ids))
)
result = await db.execute(stmt)
all_dimensions = result.scalars().all()
if not all_dimensions:
# 如果没有陪练报告,返回默认能力维度
default_dimensions = ["沟通表达", "倾听理解", "需求挖掘", "异议处理", "成交技巧", "客户维护"]
return ResponseModel(
code=200,
message="success",
data={
"radar_data": {
"dimensions": default_dimensions,
"values": [0] * len(default_dimensions)
},
"weaknesses": []
}
)
# 聚合能力数据
ability_scores: Dict[str, List[float]] = {}
# 能力维度名称映射
dimension_name_map = {
"sales_ability": "销售能力",
"service_attitude": "服务态度",
"technical_skills": "技术能力",
"沟通表达": "沟通表达",
"倾听理解": "倾听理解",
"需求挖掘": "需求挖掘",
"异议处理": "异议处理",
"成交技巧": "成交技巧",
"客户维护": "客户维护"
}
for dimensions in all_dimensions:
if dimensions:
# 如果是字符串进行JSON反序列化
if isinstance(dimensions, str):
try:
dimensions = json.loads(dimensions)
except json.JSONDecodeError:
logger.warning(f"无法解析能力维度数据: {dimensions}")
continue
# 处理字典格式:{"sales_ability": 79.0, ...}
if isinstance(dimensions, dict):
for key, score in dimensions.items():
name = dimension_name_map.get(key, key)
if name not in ability_scores:
ability_scores[name] = []
ability_scores[name].append(float(score))
# 处理列表格式:[{"name": "沟通表达", "score": 85}, ...]
elif isinstance(dimensions, list):
for dim in dimensions:
if not isinstance(dim, dict):
logger.warning(f"能力维度项格式错误: {type(dim)}")
continue
name = dim.get('name', '')
score = dim.get('score', 0)
if name:
mapped_name = dimension_name_map.get(name, name)
if mapped_name not in ability_scores:
ability_scores[mapped_name] = []
ability_scores[mapped_name].append(float(score))
else:
logger.warning(f"能力维度数据格式错误: {type(dimensions)}")
# 计算平均分
avg_scores = {
name: sum(scores) / len(scores)
for name, scores in ability_scores.items()
}
# 按固定顺序排列维度(支持多种维度组合)
# 优先使用六维度,如果没有则使用三维度
standard_dimensions_six = ["沟通表达", "倾听理解", "需求挖掘", "异议处理", "成交技巧", "客户维护"]
standard_dimensions_three = ["销售能力", "服务态度", "技术能力"]
# 判断使用哪种维度标准
has_six_dimensions = any(dim in avg_scores for dim in standard_dimensions_six)
has_three_dimensions = any(dim in avg_scores for dim in standard_dimensions_three)
if has_six_dimensions:
standard_dimensions = standard_dimensions_six
elif has_three_dimensions:
standard_dimensions = standard_dimensions_three
else:
# 如果都没有,使用实际数据的维度
standard_dimensions = list(avg_scores.keys())
dimensions = []
values = []
for dim in standard_dimensions:
if dim in avg_scores:
dimensions.append(dim)
values.append(round(avg_scores[dim], 1))
# 找出短板(平均分<80
weaknesses = []
weakness_suggestions = {
# 六维度建议
"异议处理": "建议加强异议处理专项训练,增加实战演练",
"成交技巧": "需要系统学习成交话术和时机把握",
"需求挖掘": "提升提问技巧,深入了解客户需求",
"沟通表达": "加强沟通技巧训练,提升表达能力",
"倾听理解": "培养同理心,提高倾听和理解能力",
"客户维护": "学习客户关系管理,提升服务质量",
# 三维度建议
"销售能力": "建议加强销售技巧训练,提升成交率",
"服务态度": "需要改善服务态度,提高客户满意度",
"技术能力": "建议学习产品知识,提升专业能力"
}
for name, score in avg_scores.items():
if score < 80:
weaknesses.append({
"name": name,
"avg_score": int(score),
"suggestion": weakness_suggestions.get(name, f"建议加强{name}专项训练")
})
# 按分数升序排列
weaknesses.sort(key=lambda x: x['avg_score'])
data = {
"radar_data": {
"dimensions": dimensions,
"values": values
},
"weaknesses": weaknesses
}
return ResponseModel(code=200, message="success", data=data)
except Exception as e:
logger.error(f"获取能力分析失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取能力分析失败: {str(e)}", data=None)
@router.get("/rankings", response_model=ResponseModel)
async def get_rankings(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取排行榜数据
返回学习时长排行和成绩排行Top 5
"""
try:
# 获取可访问的团队
team_ids = await get_accessible_teams(current_user, db)
member_ids = await get_team_member_ids(team_ids, db)
if not member_ids:
return ResponseModel(
code=200,
message="success",
data={
"study_time_ranking": [],
"score_ranking": []
}
)
# 学习时长排行(基于陪练会话)
stmt = (
select(
User.id,
User.full_name,
User.avatar_url,
Position.name.label('position_name'),
func.sum(PracticeSession.duration_seconds).label('total_duration')
)
.join(PracticeSession, PracticeSession.user_id == User.id)
.outerjoin(PositionMember, and_(
PositionMember.user_id == User.id,
PositionMember.is_deleted == False # noqa: E712
))
.outerjoin(Position, Position.id == PositionMember.position_id)
.where(
and_(
User.id.in_(member_ids),
PracticeSession.status == 'completed'
)
)
.group_by(User.id, User.full_name, User.avatar_url, Position.name)
.order_by(func.sum(PracticeSession.duration_seconds).desc())
.limit(5)
)
result = await db.execute(stmt)
study_time_data = result.all()
study_time_ranking = []
for row in study_time_data:
study_time_ranking.append({
"id": row.id,
"name": row.full_name or f"用户{row.id}",
"position": row.position_name or "未分配岗位",
"avatar": row.avatar_url or "",
"study_time": round(row.total_duration / 3600, 1) # 转换为小时
})
# 成绩排行基于考试round1_score
stmt = (
select(
User.id,
User.full_name,
User.avatar_url,
Position.name.label('position_name'),
func.avg(Exam.round1_score).label('avg_score')
)
.join(Exam, Exam.user_id == User.id)
.outerjoin(PositionMember, and_(
PositionMember.user_id == User.id,
PositionMember.is_deleted == False # noqa: E712
))
.outerjoin(Position, Position.id == PositionMember.position_id)
.where(
and_(
User.id.in_(member_ids),
Exam.round1_score.isnot(None),
Exam.status.in_(['completed', 'submitted'])
)
)
.group_by(User.id, User.full_name, User.avatar_url, Position.name)
.order_by(func.avg(Exam.round1_score).desc())
.limit(5)
)
result = await db.execute(stmt)
score_data = result.all()
score_ranking = []
for row in score_data:
score_ranking.append({
"id": row.id,
"name": row.full_name or f"用户{row.id}",
"position": row.position_name or "未分配岗位",
"avatar": row.avatar_url or "",
"avg_score": round(row.avg_score, 1)
})
data = {
"study_time_ranking": study_time_ranking,
"score_ranking": score_ranking
}
return ResponseModel(code=200, message="success", data=data)
except Exception as e:
logger.error(f"获取排行榜失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取排行榜失败: {str(e)}", data=None)
@router.get("/activities", response_model=ResponseModel)
async def get_activities(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取团队学习动态
返回最近20条活动记录考试、陪练等
"""
try:
# 获取可访问的团队
team_ids = await get_accessible_teams(current_user, db)
member_ids = await get_team_member_ids(team_ids, db)
if not member_ids:
return ResponseModel(
code=200,
message="success",
data={"activities": []}
)
activities = []
# 获取最近的考试记录
stmt = (
select(Exam, User.full_name, Course.name.label('course_name'))
.join(User, User.id == Exam.user_id)
.join(Course, Course.id == Exam.course_id)
.where(
and_(
Exam.user_id.in_(member_ids),
Exam.status.in_(['completed', 'submitted'])
)
)
.order_by(Exam.updated_at.desc())
.limit(10)
)
result = await db.execute(stmt)
exam_records = result.all()
for exam, user_name, course_name in exam_records:
score = exam.round1_score or 0
activity_type = "success" if score >= 60 else "danger"
result_type = "success" if score >= 60 else "danger"
result_text = f"成绩:{int(score)}" if score >= 60 else "未通过"
activities.append({
"id": f"exam_{exam.id}",
"user_name": user_name or f"用户{exam.user_id}",
"action": "完成了" if score >= 60 else "参加了",
"target": f"{course_name}》课程考试",
"time": exam.updated_at.strftime("%Y-%m-%d %H:%M"),
"type": activity_type,
"result": {"type": result_type, "text": result_text}
})
# 获取最近的陪练记录
stmt = (
select(PracticeSession, User.full_name, PracticeReport.total_score)
.join(User, User.id == PracticeSession.user_id)
.outerjoin(PracticeReport, PracticeReport.session_id == PracticeSession.session_id)
.where(
and_(
PracticeSession.user_id.in_(member_ids),
PracticeSession.status == 'completed'
)
)
.order_by(PracticeSession.end_time.desc())
.limit(10)
)
result = await db.execute(stmt)
practice_records = result.all()
for session, user_name, total_score in practice_records:
activity_type = "primary"
result_data = None
if total_score:
result_data = {"type": "", "text": f"评分:{int(total_score)}"}
activities.append({
"id": f"practice_{session.id}",
"user_name": user_name or f"用户{session.user_id}",
"action": "参加了",
"target": "AI陪练训练",
"time": session.end_time.strftime("%Y-%m-%d %H:%M") if session.end_time else "",
"type": activity_type,
"result": result_data
})
# 按时间倒序排列取前20条
activities.sort(key=lambda x: x['time'], reverse=True)
activities = activities[:20]
return ResponseModel(
code=200,
message="success",
data={"activities": activities}
)
except Exception as e:
logger.error(f"获取团队动态失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取团队动态失败: {str(e)}", data=None)

View File

@@ -0,0 +1,896 @@
"""
团队成员管理 API 路由
提供团队统计、成员列表、成员详情、学习报告等功能
"""
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import and_, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_current_active_user as get_current_user, get_db
from app.core.logger import logger
from app.models.course import Course
from app.models.exam import Exam
from app.models.position import Position
from app.models.position_course import PositionCourse
from app.models.position_member import PositionMember
from app.models.practice import PracticeReport, PracticeSession
from app.models.user import User, UserTeam
from app.schemas.base import PaginatedResponse, ResponseModel
router = APIRouter(prefix="/team/management", tags=["team-management"])
async def get_accessible_team_member_ids(
current_user: User,
db: AsyncSession
) -> List[int]:
"""获取用户可访问的团队成员ID列表"""
if current_user.role in ['admin', 'manager']:
# 管理员查看所有团队成员
stmt = select(UserTeam.user_id).distinct()
result = await db.execute(stmt)
return [row[0] for row in result.all()]
else:
# 普通用户只查看自己团队的成员
# 1. 先查询用户所在的团队
stmt = select(UserTeam.team_id).where(UserTeam.user_id == current_user.id)
result = await db.execute(stmt)
team_ids = [row[0] for row in result.all()]
if not team_ids:
return []
# 2. 查询这些团队的所有成员
stmt = select(UserTeam.user_id).where(
UserTeam.team_id.in_(team_ids)
).distinct()
result = await db.execute(stmt)
return [row[0] for row in result.all()]
def calculate_member_status(
last_login: Optional[datetime],
last_exam: Optional[datetime],
last_practice: Optional[datetime],
has_ongoing: bool
) -> str:
"""
计算成员活跃状态
Args:
last_login: 最后登录时间
last_exam: 最后考试时间
last_practice: 最后陪练时间
has_ongoing: 是否有进行中的活动
Returns:
状态: active(活跃), learning(学习中), rest(休息)
"""
# 获取最近活跃时间
times = [t for t in [last_login, last_exam, last_practice] if t is not None]
if not times:
return 'rest'
last_active = max(times)
thirty_days_ago = datetime.now() - timedelta(days=30)
# 判断状态
if last_active >= thirty_days_ago:
if has_ongoing:
return 'learning'
else:
return 'active'
else:
return 'rest'
@router.get("/statistics", response_model=ResponseModel)
async def get_team_statistics(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取团队统计数据
返回:团队总人数、活跃成员数、平均学习进度、团队平均分
"""
try:
# 获取可访问的团队成员ID
member_ids = await get_accessible_team_member_ids(current_user, db)
# 团队总人数
team_count = len(member_ids)
if team_count == 0:
return ResponseModel(
code=200,
message="success",
data={
"teamCount": 0,
"activeMembers": 0,
"avgProgress": 0,
"avgScore": 0
}
)
# 统计活跃成员数最近30天有活动
thirty_days_ago = datetime.now() - timedelta(days=30)
# 统计最近30天有登录或有考试或有陪练的用户
active_users_stmt = select(func.count(func.distinct(User.id))).where(
and_(
User.id.in_(member_ids),
or_(
User.last_login_at >= thirty_days_ago,
User.id.in_(
select(Exam.user_id).where(
and_(
Exam.user_id.in_(member_ids),
Exam.created_at >= thirty_days_ago
)
)
),
User.id.in_(
select(PracticeSession.user_id).where(
and_(
PracticeSession.user_id.in_(member_ids),
PracticeSession.start_time >= thirty_days_ago
)
)
)
)
)
)
result = await db.execute(active_users_stmt)
active_members = result.scalar() or 0
# 计算平均学习进度(每个成员的完成课程/应完成课程的平均值)
# 统计每个成员的进度,然后计算平均值
total_progress = 0.0
members_with_courses = 0
for member_id in member_ids:
# 获取该成员岗位分配的课程数
member_courses_stmt = select(
func.count(func.distinct(PositionCourse.course_id))
).select_from(PositionMember).join(
PositionCourse,
PositionCourse.position_id == PositionMember.position_id
).where(
and_(
PositionMember.user_id == member_id,
PositionMember.is_deleted == False # noqa: E712
)
)
result = await db.execute(member_courses_stmt)
member_total_courses = result.scalar() or 0
if member_total_courses > 0:
# 获取该成员已完成(及格)的课程数
member_completed_stmt = select(
func.count(func.distinct(Exam.course_id))
).where(
and_(
Exam.user_id == member_id,
Exam.round1_score >= 60,
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(member_completed_stmt)
member_completed = result.scalar() or 0
# 计算该成员的进度最大100%
member_progress = min((member_completed / member_total_courses) * 100, 100)
total_progress += member_progress
members_with_courses += 1
avg_progress = round(total_progress / members_with_courses, 1) if members_with_courses > 0 else 0.0
# 计算团队平均分使用round1_score
avg_score_stmt = select(func.avg(Exam.round1_score)).where(
and_(
Exam.user_id.in_(member_ids),
Exam.round1_score.isnot(None),
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(avg_score_stmt)
avg_score_value = result.scalar()
avg_score = round(float(avg_score_value), 1) if avg_score_value else 0.0
data = {
"teamCount": team_count,
"activeMembers": active_members,
"avgProgress": avg_progress,
"avgScore": avg_score
}
return ResponseModel(code=200, message="success", data=data)
except Exception as e:
logger.error(f"获取团队统计失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取团队统计失败: {str(e)}", data=None)
@router.get("/members", response_model=ResponseModel[PaginatedResponse])
async def get_team_members(
page: int = Query(1, ge=1, description="页码"),
size: int = Query(20, ge=1, le=100, description="每页数量"),
search_text: Optional[str] = Query(None, description="搜索姓名、岗位"),
status: Optional[str] = Query(None, description="筛选状态: active/learning/rest"),
position: Optional[str] = Query(None, description="筛选岗位"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取团队成员列表(带筛选、搜索、分页)
返回成员基本信息、学习进度、成绩、学习时长等
"""
try:
# 获取可访问的团队成员ID
member_ids = await get_accessible_team_member_ids(current_user, db)
if not member_ids:
return ResponseModel(
code=200,
message="success",
data=PaginatedResponse(
items=[],
total=0,
page=page,
page_size=size,
pages=0
)
)
# 构建基础查询
stmt = select(User).where(
and_(
User.id.in_(member_ids),
User.is_deleted == False # noqa: E712
)
)
# 搜索条件(姓名)
if search_text:
like_pattern = f"%{search_text}%"
stmt = stmt.where(
or_(
User.full_name.ilike(like_pattern),
User.username.ilike(like_pattern)
)
)
# 先获取所有符合条件的用户然后在Python中过滤状态和岗位
result = await db.execute(stmt)
all_users = result.scalars().all()
# 为每个用户计算详细信息
member_list = []
thirty_days_ago = datetime.now() - timedelta(days=30)
for user in all_users:
# 获取用户岗位
position_stmt = select(Position.name).select_from(PositionMember).join(
Position,
Position.id == PositionMember.position_id
).where(
and_(
PositionMember.user_id == user.id,
PositionMember.is_deleted == False # noqa: E712
)
).limit(1)
result = await db.execute(position_stmt)
position_name = result.scalar()
# 如果有岗位筛选且不匹配,跳过
if position and position_name != position:
continue
# 获取最近考试时间
last_exam_stmt = select(func.max(Exam.created_at)).where(
Exam.user_id == user.id
)
result = await db.execute(last_exam_stmt)
last_exam = result.scalar()
# 获取最近陪练时间
last_practice_stmt = select(func.max(PracticeSession.start_time)).where(
PracticeSession.user_id == user.id
)
result = await db.execute(last_practice_stmt)
last_practice = result.scalar()
# 检查是否有进行中的活动
has_ongoing_stmt = select(func.count(Exam.id)).where(
and_(
Exam.user_id == user.id,
Exam.status == 'started'
)
)
result = await db.execute(has_ongoing_stmt)
has_ongoing = (result.scalar() or 0) > 0
# 计算状态
member_status = calculate_member_status(
user.last_login_at,
last_exam,
last_practice,
has_ongoing
)
# 如果有状态筛选且不匹配,跳过
if status and member_status != status:
continue
# 统计学习进度
# 1. 获取岗位分配的课程总数
total_courses_stmt = select(
func.count(func.distinct(PositionCourse.course_id))
).select_from(PositionMember).join(
PositionCourse,
PositionCourse.position_id == PositionMember.position_id
).where(
and_(
PositionMember.user_id == user.id,
PositionMember.is_deleted == False # noqa: E712
)
)
result = await db.execute(total_courses_stmt)
total_courses = result.scalar() or 0
# 2. 统计已完成的考试(及格)
completed_courses_stmt = select(
func.count(func.distinct(Exam.course_id))
).where(
and_(
Exam.user_id == user.id,
Exam.round1_score >= 60,
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(completed_courses_stmt)
completed_courses = result.scalar() or 0
# 3. 计算进度
progress = 0
if total_courses > 0:
progress = int((completed_courses / total_courses) * 100)
# 统计平均成绩
avg_score_stmt = select(func.avg(Exam.round1_score)).where(
and_(
Exam.user_id == user.id,
Exam.round1_score.isnot(None),
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(avg_score_stmt)
avg_score_value = result.scalar()
avg_score = round(float(avg_score_value), 1) if avg_score_value else 0.0
# 统计学习时长(考试时长+陪练时长)
exam_time_stmt = select(
func.coalesce(func.sum(Exam.duration_minutes), 0)
).where(Exam.user_id == user.id)
result = await db.execute(exam_time_stmt)
exam_minutes = float(result.scalar() or 0)
practice_time_stmt = select(
func.coalesce(func.sum(PracticeSession.duration_seconds), 0)
).where(
and_(
PracticeSession.user_id == user.id,
PracticeSession.status == 'completed'
)
)
result = await db.execute(practice_time_stmt)
practice_seconds = float(result.scalar() or 0)
total_hours = round(exam_minutes / 60 + practice_seconds / 3600, 1)
# 获取最近活跃时间
active_times = [t for t in [user.last_login_at, last_exam, last_practice] if t is not None]
last_active = max(active_times).strftime("%Y-%m-%d %H:%M") if active_times else "-"
member_list.append({
"id": user.id,
"name": user.full_name or user.username,
"avatar": user.avatar_url or "",
"position": position_name or "未分配岗位",
"status": member_status,
"progress": progress,
"completedCourses": completed_courses,
"totalCourses": total_courses,
"avgScore": avg_score,
"studyTime": total_hours,
"lastActive": last_active,
"joinTime": user.created_at.strftime("%Y-%m-%d") if user.created_at else "-",
"email": user.email or "",
"phone": user.phone or "",
"passRate": 100 if completed_courses > 0 else 0 # 简化计算
})
# 分页
total = len(member_list)
pages = (total + size - 1) // size if size > 0 else 0
start = (page - 1) * size
end = start + size
items = member_list[start:end]
return ResponseModel(
code=200,
message="success",
data=PaginatedResponse(
items=items,
total=total,
page=page,
page_size=size,
pages=pages
)
)
except Exception as e:
logger.error(f"获取团队成员列表失败: {e}", exc_info=True)
return ResponseModel(
code=500,
message=f"获取团队成员列表失败: {str(e)}",
data=None
)
@router.get("/members/{member_id}/detail", response_model=ResponseModel)
async def get_member_detail(
member_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取成员详情
返回完整的成员信息和最近学习记录
"""
try:
# 权限检查确保member_id在可访问范围内
accessible_ids = await get_accessible_team_member_ids(current_user, db)
if member_id not in accessible_ids:
return ResponseModel(
code=403,
message="无权访问该成员信息",
data=None
)
# 获取用户基本信息
stmt = select(User).where(
and_(
User.id == member_id,
User.is_deleted == False # noqa: E712
)
)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return ResponseModel(code=404, message="成员不存在", data=None)
# 获取岗位
position_stmt = select(Position.name).select_from(PositionMember).join(
Position,
Position.id == PositionMember.position_id
).where(
and_(
PositionMember.user_id == user.id,
PositionMember.is_deleted == False # noqa: E712
)
).limit(1)
result = await db.execute(position_stmt)
position_name = result.scalar() or "未分配岗位"
# 计算状态
last_exam_stmt = select(func.max(Exam.created_at)).where(Exam.user_id == user.id)
result = await db.execute(last_exam_stmt)
last_exam = result.scalar()
last_practice_stmt = select(func.max(PracticeSession.start_time)).where(
PracticeSession.user_id == user.id
)
result = await db.execute(last_practice_stmt)
last_practice = result.scalar()
has_ongoing_stmt = select(func.count(Exam.id)).where(
and_(
Exam.user_id == user.id,
Exam.status == 'started'
)
)
result = await db.execute(has_ongoing_stmt)
has_ongoing = (result.scalar() or 0) > 0
member_status = calculate_member_status(
user.last_login_at,
last_exam,
last_practice,
has_ongoing
)
# 统计学习数据
# 学习时长
exam_time_stmt = select(func.coalesce(func.sum(Exam.duration_minutes), 0)).where(
Exam.user_id == user.id
)
result = await db.execute(exam_time_stmt)
exam_minutes = result.scalar() or 0
practice_time_stmt = select(
func.coalesce(func.sum(PracticeSession.duration_seconds), 0)
).where(
and_(
PracticeSession.user_id == user.id,
PracticeSession.status == 'completed'
)
)
result = await db.execute(practice_time_stmt)
practice_seconds = result.scalar() or 0
study_time = round(exam_minutes / 60 + practice_seconds / 3600, 1)
# 完成课程数
completed_courses_stmt = select(
func.count(func.distinct(Exam.course_id))
).where(
and_(
Exam.user_id == user.id,
Exam.round1_score >= 60,
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(completed_courses_stmt)
completed_courses = result.scalar() or 0
# 平均成绩
avg_score_stmt = select(func.avg(Exam.round1_score)).where(
and_(
Exam.user_id == user.id,
Exam.round1_score.isnot(None),
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(avg_score_stmt)
avg_score_value = result.scalar()
avg_score = round(float(avg_score_value), 1) if avg_score_value else 0.0
# 通过率
total_exams_stmt = select(func.count(Exam.id)).where(
and_(
Exam.user_id == user.id,
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(total_exams_stmt)
total_exams = result.scalar() or 0
passed_exams_stmt = select(func.count(Exam.id)).where(
and_(
Exam.user_id == user.id,
Exam.round1_score >= 60,
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(passed_exams_stmt)
passed_exams = result.scalar() or 0
pass_rate = round((passed_exams / total_exams) * 100) if total_exams > 0 else 0
# 获取最近学习记录最近10条考试和陪练
recent_records = []
# 考试记录
exam_records_stmt = (
select(Exam, Course.name.label('course_name'))
.join(Course, Course.id == Exam.course_id)
.where(
and_(
Exam.user_id == user.id,
Exam.status.in_(['completed', 'submitted'])
)
)
.order_by(Exam.updated_at.desc())
.limit(10)
)
result = await db.execute(exam_records_stmt)
exam_records = result.all()
for exam, course_name in exam_records:
score = exam.round1_score or 0
record_type = "success" if score >= 60 else "danger"
recent_records.append({
"id": f"exam_{exam.id}",
"time": exam.updated_at.strftime("%Y-%m-%d %H:%M"),
"content": f"完成《{course_name}》课程考试,成绩:{int(score)}",
"type": record_type
})
# 陪练记录
practice_records_stmt = (
select(PracticeSession)
.where(
and_(
PracticeSession.user_id == user.id,
PracticeSession.status == 'completed'
)
)
.order_by(PracticeSession.end_time.desc())
.limit(5)
)
result = await db.execute(practice_records_stmt)
practice_records = result.scalars().all()
for session in practice_records:
recent_records.append({
"id": f"practice_{session.id}",
"time": session.end_time.strftime("%Y-%m-%d %H:%M") if session.end_time else "",
"content": "参加AI陪练训练",
"type": "primary"
})
# 按时间排序
recent_records.sort(key=lambda x: x['time'], reverse=True)
recent_records = recent_records[:10]
data = {
"id": user.id,
"name": user.full_name or user.username,
"avatar": user.avatar_url or "",
"position": position_name,
"status": member_status,
"joinTime": user.created_at.strftime("%Y-%m-%d") if user.created_at else "-",
"email": user.email or "",
"phone": user.phone or "",
"studyTime": study_time,
"completedCourses": completed_courses,
"avgScore": avg_score,
"passRate": pass_rate,
"recentRecords": recent_records
}
return ResponseModel(code=200, message="success", data=data)
except Exception as e:
logger.error(f"获取成员详情失败: {e}", exc_info=True)
return ResponseModel(
code=500,
message=f"获取成员详情失败: {str(e)}",
data=None
)
@router.get("/members/{member_id}/report", response_model=ResponseModel)
async def get_member_report(
member_id: int,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取成员学习报告
返回学习概览、30天进度趋势、能力评估、详细学习记录
"""
try:
# 权限检查
accessible_ids = await get_accessible_team_member_ids(current_user, db)
if member_id not in accessible_ids:
return ResponseModel(code=403, message="无权访问该成员信息", data=None)
# 获取用户信息
stmt = select(User).where(
and_(
User.id == member_id,
User.is_deleted == False # noqa: E712
)
)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return ResponseModel(code=404, message="成员不存在", data=None)
# 1. 报告概览
# 学习总时长
exam_time_stmt = select(func.coalesce(func.sum(Exam.duration_minutes), 0)).where(
Exam.user_id == user.id
)
result = await db.execute(exam_time_stmt)
exam_minutes = result.scalar() or 0
practice_time_stmt = select(
func.coalesce(func.sum(PracticeSession.duration_seconds), 0)
).where(
and_(
PracticeSession.user_id == user.id,
PracticeSession.status == 'completed'
)
)
result = await db.execute(practice_time_stmt)
practice_seconds = result.scalar() or 0
total_hours = round(exam_minutes / 60 + practice_seconds / 3600, 1)
# 完成课程数
completed_courses_stmt = select(
func.count(func.distinct(Exam.course_id))
).where(
and_(
Exam.user_id == user.id,
Exam.round1_score >= 60,
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(completed_courses_stmt)
completed_courses = result.scalar() or 0
# 平均成绩
avg_score_stmt = select(func.avg(Exam.round1_score)).where(
and_(
Exam.user_id == user.id,
Exam.round1_score.isnot(None),
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(avg_score_stmt)
avg_score_value = result.scalar()
avg_score = round(float(avg_score_value), 1) if avg_score_value else 0.0
# 学习排名(简化:在团队中的排名)
# TODO: 实现真实排名计算
ranking = "第5名"
overview = [
{
"label": "学习总时长",
"value": f"{total_hours}小时",
"icon": "Clock",
"color": "#667eea",
"bgColor": "rgba(102, 126, 234, 0.1)"
},
{
"label": "完成课程",
"value": f"{completed_courses}",
"icon": "CircleCheck",
"color": "#67c23a",
"bgColor": "rgba(103, 194, 58, 0.1)"
},
{
"label": "平均成绩",
"value": f"{avg_score}",
"icon": "Trophy",
"color": "#e6a23c",
"bgColor": "rgba(230, 162, 60, 0.1)"
},
{
"label": "学习排名",
"value": ranking,
"icon": "Medal",
"color": "#f56c6c",
"bgColor": "rgba(245, 108, 108, 0.1)"
}
]
# 2. 30天学习进度趋势
thirty_days_ago = datetime.now() - timedelta(days=30)
dates = []
progress_data = []
for i in range(30):
date = thirty_days_ago + timedelta(days=i)
dates.append(date.strftime("%m-%d"))
# 统计该日期之前完成的考试数
cumulative_exams_stmt = select(func.count(Exam.id)).where(
and_(
Exam.user_id == user.id,
Exam.created_at <= date,
Exam.status.in_(['completed', 'submitted'])
)
)
result = await db.execute(cumulative_exams_stmt)
cumulative = result.scalar() or 0
# 进度 = 累计考试数 * 10简化计算
progress = min(cumulative * 10, 100)
progress_data.append(progress)
# 3. 能力评估(从陪练报告聚合)
ability_stmt = select(PracticeReport.ability_dimensions).where(
PracticeReport.user_id == user.id
)
result = await db.execute(ability_stmt)
all_dimensions = result.scalars().all()
abilities = []
if all_dimensions:
# 聚合能力数据
ability_scores: Dict[str, List[float]] = {}
for dimensions in all_dimensions:
if dimensions:
for dim in dimensions:
name = dim.get('name', '')
score = dim.get('score', 0)
if name:
if name not in ability_scores:
ability_scores[name] = []
ability_scores[name].append(float(score))
# 计算平均分
for name, scores in ability_scores.items():
avg = sum(scores) / len(scores)
description = "表现良好" if avg >= 80 else "需要加强"
abilities.append({
"name": name,
"score": int(avg),
"description": description
})
else:
# 默认能力评估
default_abilities = [
{"name": "沟通表达", "score": 0, "description": "暂无数据"},
{"name": "需求挖掘", "score": 0, "description": "暂无数据"},
{"name": "产品知识", "score": 0, "description": "暂无数据"},
{"name": "成交技巧", "score": 0, "description": "暂无数据"}
]
abilities = default_abilities
# 4. 详细学习记录最近20条
records = []
# 考试记录
exam_records_stmt = (
select(Exam, Course.name.label('course_name'))
.join(Course, Course.id == Exam.course_id)
.where(
and_(
Exam.user_id == user.id,
Exam.status.in_(['completed', 'submitted'])
)
)
.order_by(Exam.updated_at.desc())
.limit(20)
)
result = await db.execute(exam_records_stmt)
exam_records = result.all()
for exam, course_name in exam_records:
score = exam.round1_score or 0
records.append({
"date": exam.updated_at.strftime("%Y-%m-%d"),
"course": course_name,
"duration": exam.duration_minutes or 0,
"score": int(score),
"status": "completed"
})
data = {
"overview": overview,
"progressTrend": {
"dates": dates,
"data": progress_data
},
"abilities": abilities,
"records": records[:20]
}
return ResponseModel(code=200, message="success", data=data)
except Exception as e:
logger.error(f"获取成员学习报告失败: {e}", exc_info=True)
return ResponseModel(
code=500,
message=f"获取成员学习报告失败: {str(e)}",
data=None
)

View File

@@ -0,0 +1,55 @@
"""
团队相关 API 路由
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_current_active_user as get_current_user, get_db
from app.core.logger import logger
from app.models.user import Team
from app.schemas.base import ResponseModel
router = APIRouter(prefix="/teams", tags=["teams"])
@router.get("/", response_model=ResponseModel)
async def list_teams(
keyword: Optional[str] = Query(None, description="按名称或编码模糊搜索"),
current_user=Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取团队列表
任何登录用户均可查询团队列表,用于前端下拉选择。
"""
try:
stmt = select(Team).where(Team.is_deleted == False) # noqa: E712
if keyword:
like = f"%{keyword}%"
stmt = stmt.where(or_(Team.name.ilike(like), Team.code.ilike(like)))
rows: List[Team] = (await db.execute(stmt)).scalars().all()
data = [
{
"id": t.id,
"name": t.name,
"code": t.code,
"team_type": t.team_type,
}
for t in rows
]
return ResponseModel(code=200, message="OK", data=data)
except Exception:
logger.error("查询团队列表失败", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="查询团队列表失败",
)

View File

@@ -0,0 +1,507 @@
"""陪练模块API路由"""
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_db, get_current_user, require_admin
from app.schemas.base import ResponseModel
from app.schemas.training import (
TrainingSceneCreate,
TrainingSceneUpdate,
TrainingSceneResponse,
TrainingSessionResponse,
TrainingMessageResponse,
TrainingReportResponse,
StartTrainingRequest,
StartTrainingResponse,
EndTrainingRequest,
EndTrainingResponse,
TrainingSceneListQuery,
TrainingSessionListQuery,
PaginatedResponse,
)
from app.services.training_service import (
TrainingSceneService,
TrainingSessionService,
TrainingMessageService,
TrainingReportService,
)
from app.models.training import TrainingSceneStatus, TrainingSessionStatus
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/training", tags=["陪练模块"])
# 服务实例
scene_service = TrainingSceneService()
session_service = TrainingSessionService()
message_service = TrainingMessageService()
report_service = TrainingReportService()
# ========== 陪练场景管理 ==========
@router.get(
"/scenes", response_model=ResponseModel[PaginatedResponse[TrainingSceneResponse]]
)
async def get_training_scenes(
category: Optional[str] = Query(None, description="场景分类"),
status: Optional[TrainingSceneStatus] = Query(None, description="场景状态"),
is_public: Optional[bool] = Query(None, description="是否公开"),
search: Optional[str] = Query(None, description="搜索关键词"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取陪练场景列表
- 支持按分类、状态、是否公开筛选
- 支持关键词搜索
- 支持分页
"""
try:
# 计算分页参数
skip = (page - 1) * page_size
# 获取用户等级TODO: 从User服务获取
user_level = 1
# 获取场景列表
scenes = await scene_service.get_active_scenes(
db,
category=category,
is_public=is_public,
user_level=user_level,
skip=skip,
limit=page_size,
)
# 获取总数
from sqlalchemy import select, func, and_
from app.models.training import TrainingScene
count_query = (
select(func.count())
.select_from(TrainingScene)
.where(
and_(
TrainingScene.status == TrainingSceneStatus.ACTIVE,
TrainingScene.is_deleted == False,
)
)
)
if category:
count_query = count_query.where(TrainingScene.category == category)
if is_public is not None:
count_query = count_query.where(TrainingScene.is_public == is_public)
result = await db.execute(count_query)
total = result.scalar_one()
# 计算总页数
pages = (total + page_size - 1) // page_size
return ResponseModel(
data=PaginatedResponse(
items=scenes, total=total, page=page, page_size=page_size, pages=pages
),
message="获取陪练场景列表成功",
)
except Exception as e:
logger.error(f"获取陪练场景列表失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取陪练场景列表失败"
)
@router.get("/scenes/{scene_id}", response_model=ResponseModel[TrainingSceneResponse])
async def get_training_scene(
scene_id: int,
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取陪练场景详情"""
scene = await scene_service.get(db, scene_id)
if not scene or scene.is_deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在")
# 检查访问权限
if not scene.is_public and current_user.get("role") != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此场景")
return ResponseModel(data=scene, message="获取陪练场景成功")
@router.post("/scenes", response_model=ResponseModel[TrainingSceneResponse])
async def create_training_scene(
scene_in: TrainingSceneCreate,
current_user: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""
创建陪练场景(管理员)
- 需要管理员权限
- 场景默认为草稿状态
"""
try:
scene = await scene_service.create_scene(
db, scene_in=scene_in, created_by=current_user["id"]
)
logger.info(f"管理员 {current_user['id']} 创建了陪练场景: {scene.id}")
return ResponseModel(data=scene, message="创建陪练场景成功")
except Exception as e:
logger.error(f"创建陪练场景失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建陪练场景失败"
)
@router.put("/scenes/{scene_id}", response_model=ResponseModel[TrainingSceneResponse])
async def update_training_scene(
scene_id: int,
scene_in: TrainingSceneUpdate,
current_user: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""更新陪练场景(管理员)"""
scene = await scene_service.update_scene(
db, scene_id=scene_id, scene_in=scene_in, updated_by=current_user["id"]
)
if not scene:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在")
logger.info(f"管理员 {current_user['id']} 更新了陪练场景: {scene_id}")
return ResponseModel(data=scene, message="更新陪练场景成功")
@router.delete("/scenes/{scene_id}", response_model=ResponseModel[bool])
async def delete_training_scene(
scene_id: int,
current_user: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
):
"""删除陪练场景(管理员)"""
success = await scene_service.soft_delete(db, id=scene_id)
if not success:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练场景不存在")
logger.info(f"管理员 {current_user['id']} 删除了陪练场景: {scene_id}")
return ResponseModel(data=True, message="删除陪练场景成功")
# ========== 陪练会话管理 ==========
@router.post("/sessions", response_model=ResponseModel[StartTrainingResponse])
async def start_training(
request: StartTrainingRequest,
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
开始陪练会话
- 需要登录
- 创建会话记录
- 初始化Coze对话如果配置了Bot
- 返回会话信息和WebSocket连接地址如果支持
"""
try:
response = await session_service.start_training(
db, request=request, user_id=current_user["id"]
)
logger.info(f"用户 {current_user['id']} 开始陪练会话: {response.session_id}")
return ResponseModel(data=response, message="开始陪练成功")
except HTTPException:
raise
except Exception as e:
logger.error(f"开始陪练失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="开始陪练失败"
)
@router.post(
"/sessions/{session_id}/end", response_model=ResponseModel[EndTrainingResponse]
)
async def end_training(
session_id: int,
request: EndTrainingRequest,
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
结束陪练会话
- 需要登录且是会话创建者
- 更新会话状态
- 可选生成陪练报告
"""
try:
response = await session_service.end_training(
db, session_id=session_id, request=request, user_id=current_user["id"]
)
logger.info(f"用户 {current_user['id']} 结束陪练会话: {session_id}")
return ResponseModel(data=response, message="结束陪练成功")
except HTTPException:
raise
except Exception as e:
logger.error(f"结束陪练失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="结束陪练失败"
)
@router.get(
"/sessions",
response_model=ResponseModel[PaginatedResponse[TrainingSessionResponse]],
)
async def get_training_sessions(
scene_id: Optional[int] = Query(None, description="场景ID"),
status: Optional[TrainingSessionStatus] = Query(None, description="会话状态"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取用户的陪练会话列表"""
try:
skip = (page - 1) * page_size
sessions = await session_service.get_user_sessions(
db,
user_id=current_user["id"],
scene_id=scene_id,
status=status,
skip=skip,
limit=page_size,
)
# 获取总数
from sqlalchemy import select, func
from app.models.training import TrainingSession
count_query = (
select(func.count())
.select_from(TrainingSession)
.where(TrainingSession.user_id == current_user["id"])
)
if scene_id:
count_query = count_query.where(TrainingSession.scene_id == scene_id)
if status:
count_query = count_query.where(TrainingSession.status == status)
result = await db.execute(count_query)
total = result.scalar_one()
pages = (total + page_size - 1) // page_size
# 加载关联的场景信息
for session in sessions:
await db.refresh(session, ["scene"])
return ResponseModel(
data=PaginatedResponse(
items=sessions, total=total, page=page, page_size=page_size, pages=pages
),
message="获取陪练会话列表成功",
)
except Exception as e:
logger.error(f"获取陪练会话列表失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取陪练会话列表失败"
)
@router.get(
"/sessions/{session_id}", response_model=ResponseModel[TrainingSessionResponse]
)
async def get_training_session(
session_id: int,
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取陪练会话详情"""
session = await session_service.get(db, session_id)
if not session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练会话不存在")
# 检查访问权限
if session.user_id != current_user["id"] and current_user.get("role") != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此会话")
# 加载关联数据
await db.refresh(session, ["scene"])
# 获取消息数量
messages = await message_service.get_session_messages(db, session_id=session_id)
session.message_count = len(messages)
return ResponseModel(data=session, message="获取陪练会话成功")
# ========== 消息管理 ==========
@router.get(
"/sessions/{session_id}/messages",
response_model=ResponseModel[List[TrainingMessageResponse]],
)
async def get_training_messages(
session_id: int,
skip: int = Query(0, ge=0, description="跳过数量"),
limit: int = Query(100, ge=1, le=500, description="返回数量"),
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取陪练会话的消息列表"""
# 验证会话访问权限
session = await session_service.get(db, session_id)
if not session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练会话不存在")
if session.user_id != current_user["id"] and current_user.get("role") != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此会话消息")
messages = await message_service.get_session_messages(
db, session_id=session_id, skip=skip, limit=limit
)
return ResponseModel(data=messages, message="获取消息列表成功")
# ========== 报告管理 ==========
@router.get(
"/reports", response_model=ResponseModel[PaginatedResponse[TrainingReportResponse]]
)
async def get_training_reports(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取用户的陪练报告列表"""
try:
skip = (page - 1) * page_size
reports = await report_service.get_user_reports(
db, user_id=current_user["id"], skip=skip, limit=page_size
)
# 获取总数
from sqlalchemy import select, func
from app.models.training import TrainingReport
count_query = (
select(func.count())
.select_from(TrainingReport)
.where(TrainingReport.user_id == current_user["id"])
)
result = await db.execute(count_query)
total = result.scalar_one()
pages = (total + page_size - 1) // page_size
# 加载关联的会话信息
for report in reports:
await db.refresh(report, ["session"])
if report.session:
await db.refresh(report.session, ["scene"])
return ResponseModel(
data=PaginatedResponse(
items=reports, total=total, page=page, page_size=page_size, pages=pages
),
message="获取陪练报告列表成功",
)
except Exception as e:
logger.error(f"获取陪练报告列表失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="获取陪练报告列表失败"
)
@router.get(
"/reports/{report_id}", response_model=ResponseModel[TrainingReportResponse]
)
async def get_training_report(
report_id: int,
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""获取陪练报告详情"""
report = await report_service.get(db, report_id)
if not report:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练报告不存在")
# 检查访问权限
if report.user_id != current_user["id"] and current_user.get("role") != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此报告")
# 加载关联数据
await db.refresh(report, ["session"])
if report.session:
await db.refresh(report.session, ["scene"])
return ResponseModel(data=report, message="获取陪练报告成功")
@router.get(
"/sessions/{session_id}/report",
response_model=ResponseModel[TrainingReportResponse],
)
async def get_session_report(
session_id: int,
current_user: dict = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""根据会话ID获取陪练报告"""
# 验证会话访问权限
session = await session_service.get(db, session_id)
if not session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="陪练会话不存在")
if session.user_id != current_user["id"] and current_user.get("role") != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="无权访问此会话报告")
# 获取报告
report = await report_service.get_by_session(db, session_id=session_id)
if not report:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="该会话暂无报告")
# 加载关联数据
await db.refresh(report, ["session"])
if report.session:
await db.refresh(report.session, ["scene"])
return ResponseModel(data=report, message="获取会话报告成功")

View File

@@ -0,0 +1,854 @@
openapi: 3.0.0
info:
title: Training Module API
description: 考培练系统陪练模块API契约
version: 1.0.0
servers:
- url: http://localhost:8000/api/v1
description: 本地开发服务器
paths:
/training/scenes:
get:
summary: 获取陪练场景列表
tags:
- 陪练场景
security:
- bearerAuth: []
parameters:
- name: category
in: query
description: 场景分类
schema:
type: string
- name: status
in: query
description: 场景状态
schema:
type: string
enum: [draft, active, inactive]
- name: is_public
in: query
description: 是否公开
schema:
type: boolean
- name: search
in: query
description: 搜索关键词
schema:
type: string
- name: page
in: query
description: 页码
schema:
type: integer
minimum: 1
default: 1
- name: page_size
in: query
description: 每页数量
schema:
type: integer
minimum: 1
maximum: 100
default: 20
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/PaginatedScenesResponse'
'401':
$ref: '#/components/responses/Unauthorized'
post:
summary: 创建陪练场景(管理员)
tags:
- 陪练场景
security:
- bearerAuth: []
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/TrainingSceneCreate'
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/TrainingSceneResponse'
'401':
$ref: '#/components/responses/Unauthorized'
'403':
$ref: '#/components/responses/Forbidden'
/training/scenes/{scene_id}:
get:
summary: 获取陪练场景详情
tags:
- 陪练场景
security:
- bearerAuth: []
parameters:
- name: scene_id
in: path
required: true
description: 场景ID
schema:
type: integer
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/TrainingSceneResponse'
'401':
$ref: '#/components/responses/Unauthorized'
'404':
$ref: '#/components/responses/NotFound'
put:
summary: 更新陪练场景(管理员)
tags:
- 陪练场景
security:
- bearerAuth: []
parameters:
- name: scene_id
in: path
required: true
description: 场景ID
schema:
type: integer
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/TrainingSceneUpdate'
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/TrainingSceneResponse'
'401':
$ref: '#/components/responses/Unauthorized'
'403':
$ref: '#/components/responses/Forbidden'
'404':
$ref: '#/components/responses/NotFound'
delete:
summary: 删除陪练场景(管理员)
tags:
- 陪练场景
security:
- bearerAuth: []
parameters:
- name: scene_id
in: path
required: true
description: 场景ID
schema:
type: integer
responses:
'200':
description: 成功
content:
application/json:
schema:
type: object
properties:
code:
type: integer
example: 200
message:
type: string
example: "删除陪练场景成功"
data:
type: boolean
example: true
'401':
$ref: '#/components/responses/Unauthorized'
'403':
$ref: '#/components/responses/Forbidden'
'404':
$ref: '#/components/responses/NotFound'
/training/sessions:
post:
summary: 开始陪练会话
tags:
- 陪练会话
security:
- bearerAuth: []
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/StartTrainingRequest'
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/StartTrainingResponse'
'401':
$ref: '#/components/responses/Unauthorized'
'404':
description: 场景不存在
get:
summary: 获取用户的陪练会话列表
tags:
- 陪练会话
security:
- bearerAuth: []
parameters:
- name: scene_id
in: query
description: 场景ID
schema:
type: integer
- name: status
in: query
description: 会话状态
schema:
type: string
enum: [created, in_progress, completed, cancelled, error]
- name: page
in: query
schema:
type: integer
minimum: 1
default: 1
- name: page_size
in: query
schema:
type: integer
minimum: 1
maximum: 100
default: 20
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/PaginatedSessionsResponse'
'401':
$ref: '#/components/responses/Unauthorized'
/training/sessions/{session_id}:
get:
summary: 获取陪练会话详情
tags:
- 陪练会话
security:
- bearerAuth: []
parameters:
- name: session_id
in: path
required: true
description: 会话ID
schema:
type: integer
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/TrainingSessionResponse'
'401':
$ref: '#/components/responses/Unauthorized'
'403':
$ref: '#/components/responses/Forbidden'
'404':
$ref: '#/components/responses/NotFound'
/training/sessions/{session_id}/end:
post:
summary: 结束陪练会话
tags:
- 陪练会话
security:
- bearerAuth: []
parameters:
- name: session_id
in: path
required: true
description: 会话ID
schema:
type: integer
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/EndTrainingRequest'
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/EndTrainingResponse'
'401':
$ref: '#/components/responses/Unauthorized'
'403':
$ref: '#/components/responses/Forbidden'
'404':
$ref: '#/components/responses/NotFound'
/training/sessions/{session_id}/messages:
get:
summary: 获取陪练会话的消息列表
tags:
- 陪练消息
security:
- bearerAuth: []
parameters:
- name: session_id
in: path
required: true
description: 会话ID
schema:
type: integer
- name: skip
in: query
description: 跳过数量
schema:
type: integer
minimum: 0
default: 0
- name: limit
in: query
description: 返回数量
schema:
type: integer
minimum: 1
maximum: 500
default: 100
responses:
'200':
description: 成功
content:
application/json:
schema:
type: object
properties:
code:
type: integer
message:
type: string
data:
type: array
items:
$ref: '#/components/schemas/TrainingMessage'
'401':
$ref: '#/components/responses/Unauthorized'
'403':
$ref: '#/components/responses/Forbidden'
'404':
$ref: '#/components/responses/NotFound'
/training/reports:
get:
summary: 获取用户的陪练报告列表
tags:
- 陪练报告
security:
- bearerAuth: []
parameters:
- name: page
in: query
schema:
type: integer
minimum: 1
default: 1
- name: page_size
in: query
schema:
type: integer
minimum: 1
maximum: 100
default: 20
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/PaginatedReportsResponse'
'401':
$ref: '#/components/responses/Unauthorized'
/training/reports/{report_id}:
get:
summary: 获取陪练报告详情
tags:
- 陪练报告
security:
- bearerAuth: []
parameters:
- name: report_id
in: path
required: true
description: 报告ID
schema:
type: integer
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/TrainingReportResponse'
'401':
$ref: '#/components/responses/Unauthorized'
'403':
$ref: '#/components/responses/Forbidden'
'404':
$ref: '#/components/responses/NotFound'
/training/sessions/{session_id}/report:
get:
summary: 根据会话ID获取陪练报告
tags:
- 陪练报告
security:
- bearerAuth: []
parameters:
- name: session_id
in: path
required: true
description: 会话ID
schema:
type: integer
responses:
'200':
description: 成功
content:
application/json:
schema:
$ref: '#/components/schemas/TrainingReportResponse'
'401':
$ref: '#/components/responses/Unauthorized'
'403':
$ref: '#/components/responses/Forbidden'
'404':
$ref: '#/components/responses/NotFound'
components:
securitySchemes:
bearerAuth:
type: http
scheme: bearer
bearerFormat: JWT
responses:
Unauthorized:
description: 未授权
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
Forbidden:
description: 禁止访问
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
NotFound:
description: 资源未找到
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
schemas:
ErrorResponse:
type: object
properties:
code:
type: integer
message:
type: string
detail:
type: object
BaseResponse:
type: object
properties:
code:
type: integer
example: 200
message:
type: string
example: "success"
request_id:
type: string
PaginationMeta:
type: object
properties:
total:
type: integer
page:
type: integer
page_size:
type: integer
pages:
type: integer
TrainingSceneCreate:
type: object
required:
- name
- category
properties:
name:
type: string
maxLength: 100
description: 场景名称
description:
type: string
description: 场景描述
category:
type: string
maxLength: 50
description: 场景分类
ai_config:
type: object
description: AI配置
prompt_template:
type: string
description: 提示词模板
evaluation_criteria:
type: object
description: 评估标准
is_public:
type: boolean
default: true
description: 是否公开
required_level:
type: integer
description: 所需用户等级
status:
type: string
enum: [draft, active, inactive]
default: draft
TrainingSceneUpdate:
type: object
properties:
name:
type: string
maxLength: 100
description:
type: string
category:
type: string
maxLength: 50
ai_config:
type: object
prompt_template:
type: string
evaluation_criteria:
type: object
status:
type: string
enum: [draft, active, inactive]
is_public:
type: boolean
required_level:
type: integer
TrainingScene:
type: object
properties:
id:
type: integer
name:
type: string
description:
type: string
category:
type: string
ai_config:
type: object
prompt_template:
type: string
evaluation_criteria:
type: object
status:
type: string
enum: [draft, active, inactive]
is_public:
type: boolean
required_level:
type: integer
created_at:
type: string
format: date-time
updated_at:
type: string
format: date-time
TrainingSceneResponse:
allOf:
- $ref: '#/components/schemas/BaseResponse'
- type: object
properties:
data:
$ref: '#/components/schemas/TrainingScene'
PaginatedScenesResponse:
allOf:
- $ref: '#/components/schemas/BaseResponse'
- type: object
properties:
data:
type: object
properties:
items:
type: array
items:
$ref: '#/components/schemas/TrainingScene'
total:
type: integer
page:
type: integer
page_size:
type: integer
pages:
type: integer
StartTrainingRequest:
type: object
required:
- scene_id
properties:
scene_id:
type: integer
description: 场景ID
config:
type: object
description: 会话配置
StartTrainingResponse:
allOf:
- $ref: '#/components/schemas/BaseResponse'
- type: object
properties:
data:
type: object
properties:
session_id:
type: integer
coze_conversation_id:
type: string
scene:
$ref: '#/components/schemas/TrainingScene'
websocket_url:
type: string
EndTrainingRequest:
type: object
properties:
generate_report:
type: boolean
default: true
description: 是否生成报告
EndTrainingResponse:
allOf:
- $ref: '#/components/schemas/BaseResponse'
- type: object
properties:
data:
type: object
properties:
session:
$ref: '#/components/schemas/TrainingSession'
report:
$ref: '#/components/schemas/TrainingReport'
TrainingSession:
type: object
properties:
id:
type: integer
user_id:
type: integer
scene_id:
type: integer
coze_conversation_id:
type: string
start_time:
type: string
format: date-time
end_time:
type: string
format: date-time
duration_seconds:
type: integer
status:
type: string
enum: [created, in_progress, completed, cancelled, error]
session_config:
type: object
total_score:
type: number
evaluation_result:
type: object
scene:
$ref: '#/components/schemas/TrainingScene'
message_count:
type: integer
created_at:
type: string
format: date-time
updated_at:
type: string
format: date-time
TrainingSessionResponse:
allOf:
- $ref: '#/components/schemas/BaseResponse'
- type: object
properties:
data:
$ref: '#/components/schemas/TrainingSession'
PaginatedSessionsResponse:
allOf:
- $ref: '#/components/schemas/BaseResponse'
- type: object
properties:
data:
type: object
properties:
items:
type: array
items:
$ref: '#/components/schemas/TrainingSession'
total:
type: integer
page:
type: integer
page_size:
type: integer
pages:
type: integer
TrainingMessage:
type: object
properties:
id:
type: integer
session_id:
type: integer
role:
type: string
enum: [user, assistant, system]
type:
type: string
enum: [text, voice, system]
content:
type: string
voice_url:
type: string
voice_duration:
type: number
metadata:
type: object
coze_message_id:
type: string
created_at:
type: string
format: date-time
TrainingReport:
type: object
properties:
id:
type: integer
session_id:
type: integer
user_id:
type: integer
overall_score:
type: number
dimension_scores:
type: object
additionalProperties:
type: number
strengths:
type: array
items:
type: string
weaknesses:
type: array
items:
type: string
suggestions:
type: array
items:
type: string
detailed_analysis:
type: string
transcript:
type: string
statistics:
type: object
session:
$ref: '#/components/schemas/TrainingSession'
created_at:
type: string
format: date-time
updated_at:
type: string
format: date-time
TrainingReportResponse:
allOf:
- $ref: '#/components/schemas/BaseResponse'
- type: object
properties:
data:
$ref: '#/components/schemas/TrainingReport'
PaginatedReportsResponse:
allOf:
- $ref: '#/components/schemas/BaseResponse'
- type: object
properties:
data:
type: object
properties:
items:
type: array
items:
$ref: '#/components/schemas/TrainingReport'
total:
type: integer
page:
type: integer
page_size:
type: integer
pages:
type: integer

View File

@@ -0,0 +1,275 @@
"""
文件上传API接口
"""
import os
import shutil
from pathlib import Path
from typing import List, Optional
from datetime import datetime
import hashlib
from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.deps import get_current_user, get_db
from app.models.user import User
from app.models.course import Course
from app.schemas.base import ResponseModel
from app.core.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/upload")
# 支持的文件类型和大小限制
# 支持格式TXT、Markdown、MDX、PDF、HTML、Excel、Word、CSV、VTT、Properties
ALLOWED_EXTENSIONS = {
'txt', 'md', 'mdx', 'pdf', 'html', 'htm',
'xlsx', 'xls', 'docx', 'doc', 'csv', 'vtt', 'properties'
}
MAX_FILE_SIZE = 15 * 1024 * 1024 # 15MB
def get_file_extension(filename: str) -> str:
"""获取文件扩展名"""
return filename.rsplit('.', 1)[1].lower() if '.' in filename else ''
def generate_unique_filename(original_filename: str) -> str:
"""生成唯一的文件名"""
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
random_str = hashlib.md5(f"{original_filename}{timestamp}".encode()).hexdigest()[:8]
ext = get_file_extension(original_filename)
return f"{timestamp}_{random_str}.{ext}"
def get_upload_path(file_type: str = "general") -> Path:
"""获取上传路径"""
base_path = Path(settings.UPLOAD_PATH)
upload_path = base_path / file_type
upload_path.mkdir(parents=True, exist_ok=True)
return upload_path
@router.post("/file", response_model=ResponseModel[dict])
async def upload_file(
file: UploadFile = File(...),
file_type: str = "general",
current_user: User = Depends(get_current_user),
):
"""
上传单个文件
- **file**: 要上传的文件
- **file_type**: 文件类型分类general, course, avatar等
返回:
- **file_url**: 文件访问URL
- **file_name**: 原始文件名
- **file_size**: 文件大小
- **file_type**: 文件类型
"""
try:
# 检查文件扩展名
file_ext = get_file_extension(file.filename)
if file_ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"不支持的文件类型: {file_ext}"
)
# 读取文件内容
contents = await file.read()
file_size = len(contents)
# 检查文件大小
if file_size > MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"文件大小超过限制,最大允许 {MAX_FILE_SIZE // 1024 // 1024}MB"
)
# 生成唯一文件名
unique_filename = generate_unique_filename(file.filename)
# 获取上传路径
upload_path = get_upload_path(file_type)
file_path = upload_path / unique_filename
# 保存文件
with open(file_path, "wb") as f:
f.write(contents)
# 生成文件访问URL
file_url = f"/static/uploads/{file_type}/{unique_filename}"
logger.info(
"文件上传成功",
user_id=current_user.id,
original_filename=file.filename,
saved_filename=unique_filename,
file_size=file_size,
file_type=file_type,
)
return ResponseModel(
data={
"file_url": file_url,
"file_name": file.filename,
"file_size": file_size,
"file_type": file_ext,
},
message="文件上传成功"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"文件上传失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="文件上传失败"
)
@router.post("/course/{course_id}/materials", response_model=ResponseModel[dict])
async def upload_course_material(
course_id: int,
file: UploadFile = File(...),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
上传课程资料
- **course_id**: 课程ID
- **file**: 要上传的文件
返回上传结果包含文件URL等信息
"""
try:
# 验证课程是否存在
from sqlalchemy import select
from app.models.course import Course
stmt = select(Course).where(Course.id == course_id, Course.is_deleted == False)
result = await db.execute(stmt)
course = result.scalar_one_or_none()
if not course:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"课程 {course_id} 不存在"
)
# 检查文件扩展名
file_ext = get_file_extension(file.filename)
if file_ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"不支持的文件类型: {file_ext}"
)
# 读取文件内容
contents = await file.read()
file_size = len(contents)
# 检查文件大小
if file_size > MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"文件大小超过限制,最大允许 {MAX_FILE_SIZE // 1024 // 1024}MB"
)
# 生成唯一文件名
unique_filename = generate_unique_filename(file.filename)
# 创建课程专属目录
course_upload_path = Path(settings.UPLOAD_PATH) / "courses" / str(course_id)
course_upload_path.mkdir(parents=True, exist_ok=True)
# 保存文件
file_path = course_upload_path / unique_filename
with open(file_path, "wb") as f:
f.write(contents)
# 生成文件访问URL
file_url = f"/static/uploads/courses/{course_id}/{unique_filename}"
logger.info(
"课程资料上传成功",
user_id=current_user.id,
course_id=course_id,
original_filename=file.filename,
saved_filename=unique_filename,
file_size=file_size,
)
return ResponseModel(
data={
"file_url": file_url,
"file_name": file.filename,
"file_size": file_size,
"file_type": file_ext,
},
message="课程资料上传成功"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"课程资料上传失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="课程资料上传失败"
)
@router.delete("/file", response_model=ResponseModel[bool])
async def delete_file(
file_url: str,
current_user: User = Depends(get_current_user),
):
"""
删除已上传的文件
- **file_url**: 文件URL路径
"""
try:
# 解析文件路径
if not file_url.startswith("/static/uploads/"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的文件URL"
)
# 转换为实际文件路径
relative_path = file_url.replace("/static/uploads/", "")
file_path = Path(settings.UPLOAD_PATH) / relative_path
# 检查文件是否存在
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文件不存在"
)
# 删除文件
os.remove(file_path)
logger.info(
"文件删除成功",
user_id=current_user.id,
file_url=file_url,
)
return ResponseModel(data=True, message="文件删除成功")
except HTTPException:
raise
except Exception as e:
logger.error(f"文件删除失败: {str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="文件删除失败"
)

474
backend/app/api/v1/users.py Normal file
View File

@@ -0,0 +1,474 @@
"""
用户管理 API
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query, status, Request
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_current_active_user, get_db, require_admin
from app.core.logger import logger
from app.models.user import User
from app.schemas.base import PaginatedResponse, PaginationParams, ResponseModel
from app.schemas.user import User as UserSchema
from app.schemas.user import UserCreate, UserFilter, UserPasswordUpdate, UserUpdate
from app.services.user_service import UserService
from app.services.system_log_service import system_log_service
from app.schemas.system_log import SystemLogCreate
from app.models.exam import Exam, ExamResult
from app.models.training import TrainingSession
from app.models.position_member import PositionMember
from app.models.position import Position
from app.models.course import Course
router = APIRouter()
@router.get("/me", response_model=ResponseModel)
async def get_current_user_info(
current_user: dict = Depends(get_current_active_user),
) -> ResponseModel:
"""
获取当前用户信息
权限:需要登录
"""
return ResponseModel(data=UserSchema.model_validate(current_user))
@router.get("/me/statistics", response_model=ResponseModel)
async def get_current_user_statistics(
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取当前用户学习统计
返回字段:
- learningDays: 学习天数(按陪练会话开始日期去重)
- totalHours: 学习总时长小时取整到1位小数
- practiceQuestions: 练习题数(答题记录条数汇总)
- averageScore: 平均成绩已提交考试的平均分保留1位小数
- examsCompleted: 已完成考试数量
"""
try:
user_id = current_user.id
# 学习天数:按会话开始日期去重
learning_days_stmt = select(func.count(func.distinct(func.date(TrainingSession.start_time)))).where(
TrainingSession.user_id == user_id
)
learning_days = (await db.scalar(learning_days_stmt)) or 0
# 总时长(小时)
total_seconds_stmt = select(func.coalesce(func.sum(TrainingSession.duration_seconds), 0)).where(
TrainingSession.user_id == user_id
)
total_seconds = (await db.scalar(total_seconds_stmt)) or 0
total_hours = round(float(total_seconds) / 3600.0, 1) if total_seconds else 0.0
# 练习题数:用户所有考试的题目总数
practice_questions_stmt = (
select(func.coalesce(func.sum(Exam.question_count), 0))
.where(Exam.user_id == user_id, Exam.status == "completed")
)
practice_questions = (await db.scalar(practice_questions_stmt)) or 0
# 平均成绩:用户已完成考试的平均分
avg_score_stmt = select(func.avg(Exam.score)).where(
Exam.user_id == user_id, Exam.status == "completed"
)
avg_score_val = await db.scalar(avg_score_stmt)
average_score = round(float(avg_score_val), 1) if avg_score_val is not None else 0.0
# 已完成考试数量
exams_completed_stmt = select(func.count(Exam.id)).where(
Exam.user_id == user_id,
Exam.status == "completed"
)
exams_completed = (await db.scalar(exams_completed_stmt)) or 0
return ResponseModel(
data={
"learningDays": int(learning_days),
"totalHours": total_hours,
"practiceQuestions": int(practice_questions),
"averageScore": average_score,
"examsCompleted": int(exams_completed),
}
)
except Exception as e:
logger.error("获取用户学习统计失败", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取用户学习统计失败: {str(e)}")
@router.get("/me/recent-exams", response_model=ResponseModel)
async def get_recent_exams(
limit: int = Query(5, ge=1, le=20, description="返回数量"),
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取当前用户最近的考试记录
返回最近的考试列表,按创建时间降序排列
只返回已完成或已提交的考试不包括started状态
"""
try:
user_id = current_user.id
# 查询最近的考试记录,关联课程表获取课程名称
stmt = (
select(Exam, Course.name.label("course_name"))
.join(Course, Exam.course_id == Course.id)
.where(
Exam.user_id == user_id,
Exam.status.in_(["completed", "submitted"])
)
.order_by(Exam.created_at.desc())
.limit(limit)
)
results = await db.execute(stmt)
rows = results.all()
# 构建返回数据
exams_list = []
for exam, course_name in rows:
exams_list.append({
"id": exam.id,
"title": exam.exam_name,
"courseName": course_name,
"courseId": exam.course_id,
"time": exam.created_at.strftime("%Y-%m-%d %H:%M") if exam.created_at else "",
"questions": exam.question_count or 0,
"status": exam.status,
"score": exam.score
})
return ResponseModel(data=exams_list)
except Exception as e:
logger.error("获取最近考试记录失败", exc_info=True)
raise HTTPException(status_code=500, detail=f"获取最近考试记录失败: {str(e)}")
@router.put("/me", response_model=ResponseModel)
async def update_current_user(
user_in: UserUpdate,
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
更新当前用户信息
权限:需要登录
"""
user_service = UserService(db)
user = await user_service.update_user(
user_id=current_user.id,
obj_in=user_in,
updated_by=current_user.id,
)
return ResponseModel(data=UserSchema.model_validate(user))
@router.put("/me/password", response_model=ResponseModel)
async def update_current_user_password(
password_in: UserPasswordUpdate,
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
更新当前用户密码
权限:需要登录
"""
user_service = UserService(db)
user = await user_service.update_password(
user_id=current_user.id,
old_password=password_in.old_password,
new_password=password_in.new_password,
)
return ResponseModel(message="密码更新成功", data=UserSchema.model_validate(user))
@router.get("/", response_model=ResponseModel)
async def get_users(
pagination: PaginationParams = Depends(),
role: str = Query(None, description="用户角色"),
is_active: bool = Query(None, description="是否激活"),
team_id: int = Query(None, description="团队ID"),
keyword: str = Query(None, description="搜索关键词"),
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取用户列表
权限:需要登录
- 普通用户只能看到激活的用户
- 管理员可以看到所有用户
"""
# 构建筛选条件
filter_params = UserFilter(
role=role,
is_active=is_active,
team_id=team_id,
keyword=keyword,
)
# 普通用户只能看到激活的用户
if current_user.role == "trainee":
filter_params.is_active = True
# 获取用户列表
user_service = UserService(db)
users, total = await user_service.get_users_with_filter(
skip=pagination.offset,
limit=pagination.limit,
filter_params=filter_params,
)
# 构建分页响应
paginated = PaginatedResponse.create(
items=[UserSchema.model_validate(user) for user in users],
total=total,
page=pagination.page,
page_size=pagination.page_size,
)
return ResponseModel(data=paginated.model_dump())
@router.post("/", response_model=ResponseModel, status_code=status.HTTP_201_CREATED)
async def create_user(
user_in: UserCreate,
request: Request,
current_user: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
创建用户
权限:需要管理员权限
"""
user_service = UserService(db)
user = await user_service.create_user(
obj_in=user_in,
created_by=current_user.id,
)
logger.info(
"管理员创建用户",
admin_id=current_user.id,
admin_username=current_user.username,
new_user_id=user.id,
new_username=user.username,
)
# 记录用户创建日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="user",
message=f"管理员 {current_user.username} 创建用户: {user.username}",
user_id=current_user.id,
user=current_user.username,
ip=request.client.host if request.client else None,
path="/api/v1/users/",
method="POST",
user_agent=request.headers.get("user-agent")
)
)
return ResponseModel(message="用户创建成功", data=UserSchema.model_validate(user))
@router.get("/{user_id}", response_model=ResponseModel)
async def get_user(
user_id: int,
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取用户详情
权限:需要登录
- 普通用户只能查看自己的信息
- 管理员和经理可以查看所有用户信息
"""
# 权限检查
if current_user.role == "trainee" and current_user.id != user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="没有权限查看其他用户信息"
)
# 获取用户
user_service = UserService(db)
user = await user_service.get_by_id(user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
return ResponseModel(data=UserSchema.model_validate(user))
@router.put("/{user_id}", response_model=ResponseModel)
async def update_user(
user_id: int,
user_in: UserUpdate,
current_user: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
更新用户信息
权限:需要管理员权限
"""
user_service = UserService(db)
user = await user_service.update_user(
user_id=user_id,
obj_in=user_in,
updated_by=current_user.id,
)
logger.info(
"管理员更新用户",
admin_id=current_user.id,
admin_username=current_user.username,
updated_user_id=user.id,
updated_username=user.username,
)
return ResponseModel(data=UserSchema.model_validate(user))
@router.delete("/{user_id}", response_model=ResponseModel)
async def delete_user(
user_id: int,
request: Request,
current_user: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
删除用户(软删除)
权限:需要管理员权限
"""
# 不能删除自己
if user_id == current_user.id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="不能删除自己")
# 获取用户
user_service = UserService(db)
user = await user_service.get_by_id(user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
# 软删除
await user_service.soft_delete(db_obj=user)
logger.info(
"管理员删除用户",
admin_id=current_user.id,
admin_username=current_user.username,
deleted_user_id=user.id,
deleted_username=user.username,
)
# 记录用户删除日志
await system_log_service.create_log(
db,
SystemLogCreate(
level="INFO",
type="user",
message=f"管理员 {current_user.username} 删除用户: {user.username}",
user_id=current_user.id,
user=current_user.username,
ip=request.client.host if request.client else None,
path=f"/api/v1/users/{user_id}",
method="DELETE",
user_agent=request.headers.get("user-agent")
)
)
return ResponseModel(message="用户删除成功")
@router.post("/{user_id}/teams/{team_id}", response_model=ResponseModel)
async def add_user_to_team(
user_id: int,
team_id: int,
role: str = Query("member", regex="^(member|leader)$"),
current_user: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
将用户添加到团队
权限:需要管理员权限
"""
user_service = UserService(db)
await user_service.add_user_to_team(
user_id=user_id,
team_id=team_id,
role=role,
)
return ResponseModel(message="用户已添加到团队")
@router.delete("/{user_id}/teams/{team_id}", response_model=ResponseModel)
async def remove_user_from_team(
user_id: int,
team_id: int,
current_user: dict = Depends(require_admin),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
从团队中移除用户
权限:需要管理员权限
"""
user_service = UserService(db)
await user_service.remove_user_from_team(
user_id=user_id,
team_id=team_id,
)
return ResponseModel(message="用户已从团队中移除")
@router.get("/{user_id}/positions", response_model=ResponseModel)
async def get_user_positions(
user_id: int,
current_user: dict = Depends(get_current_active_user),
db: AsyncSession = Depends(get_db),
) -> ResponseModel:
"""
获取用户所属岗位列表(用于前端展示与编辑)
权限:登录即可;普通用户仅能查看自己的信息
返回:[{id,name,code}]
"""
# 权限检查
if current_user.role == "trainee" and current_user.id != user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="没有权限查看其他用户信息")
stmt = (
select(Position)
.join(PositionMember, PositionMember.position_id == Position.id)
.where(PositionMember.user_id == user_id, PositionMember.is_deleted == False, Position.is_deleted == False)
.order_by(Position.id)
)
rows = (await db.execute(stmt)).scalars().all()
data = [
{"id": p.id, "name": p.name, "code": p.code}
for p in rows
]
return ResponseModel(data=data)

120
backend/app/api/v1/yanji.py Normal file
View File

@@ -0,0 +1,120 @@
"""
言迹智能工牌API接口
"""
import logging
from typing import List
from fastapi import APIRouter, Depends, Query
from app.core.deps import get_current_user
from app.models.user import User
from app.schemas.base import ResponseModel
from app.schemas.yanji import (
GetConversationsByVisitIdsResponse,
GetConversationsResponse,
YanjiConversation,
)
from app.services.yanji_service import YanjiService
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/conversations/by-visit-ids", response_model=ResponseModel[GetConversationsByVisitIdsResponse])
async def get_conversations_by_visit_ids(
external_visit_ids: List[str] = Query(
...,
min_length=1,
max_length=10,
description="三方来访单ID列表最多10个",
),
current_user: User = Depends(get_current_user),
):
"""
根据来访单ID获取对话记录ASR转写文字
这是获取对话记录的主要接口,适用于:
1. 已知来访单ID的场景
2. 获取特定对话记录用于AI评分
3. 批量获取多个对话记录
"""
try:
yanji_service = YanjiService()
conversations = await yanji_service.get_conversations_by_visit_ids(
external_visit_ids=external_visit_ids
)
return ResponseModel(
code=200,
message="获取成功",
data=GetConversationsByVisitIdsResponse(
conversations=conversations, total=len(conversations)
),
)
except Exception as e:
logger.error(f"获取对话记录失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取失败: {str(e)}", data=None)
@router.get("/conversations", response_model=ResponseModel[GetConversationsResponse])
async def get_employee_conversations(
consultant_phone: str = Query(..., description="员工手机号"),
limit: int = Query(10, ge=1, le=100, description="获取数量"),
current_user: User = Depends(get_current_user),
):
"""
获取员工最近的对话记录
注意目前此接口功能有限因为言迹API没有直接通过员工手机号查询录音的接口。
推荐使用 /conversations/by-visit-ids 接口。
后续可扩展:
1. 先查询员工的来访单列表
2. 再获取这些来访单的对话记录
"""
try:
yanji_service = YanjiService()
conversations = await yanji_service.get_recent_conversations(
consultant_phone=consultant_phone, limit=limit
)
return ResponseModel(
code=200,
message="获取成功",
data=GetConversationsResponse(
conversations=conversations, total=len(conversations)
),
)
except Exception as e:
logger.error(f"获取员工对话记录失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"获取失败: {str(e)}", data=None)
@router.get("/test-auth")
async def test_yanji_auth(current_user: User = Depends(get_current_user)):
"""
测试言迹API认证
用于验证OAuth2.0认证是否正常工作
"""
try:
yanji_service = YanjiService()
access_token = await yanji_service.get_access_token()
return ResponseModel(
code=200,
message="认证成功",
data={
"access_token": access_token[:20] + "...", # 只显示前20个字符
"base_url": yanji_service.base_url,
},
)
except Exception as e:
logger.error(f"言迹API认证失败: {e}", exc_info=True)
return ResponseModel(code=500, message=f"认证失败: {str(e)}", data=None)

View File

View File

@@ -0,0 +1,49 @@
"""数据库配置"""
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.pool import NullPool
from app.core.config import get_settings
settings = get_settings()
# 创建异步引擎
if settings.DEBUG:
# 开发环境使用 NullPool不需要连接池参数
engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
pool_pre_ping=True,
poolclass=NullPool,
# 确保 MySQL 连接使用 UTF-8 字符集
connect_args={
"charset": "utf8mb4",
"use_unicode": True,
"autocommit": False,
"init_command": "SET character_set_client=utf8mb4, character_set_connection=utf8mb4, character_set_results=utf8mb4, collation_connection=utf8mb4_unicode_ci",
} if "mysql" in settings.DATABASE_URL else {},
)
else:
# 生产环境使用连接池
engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
pool_size=20,
max_overflow=0,
pool_pre_ping=True,
# 确保 MySQL 连接使用 UTF-8 字符集
connect_args={
"charset": "utf8mb4",
"use_unicode": True,
"autocommit": False,
"init_command": "SET character_set_client=utf8mb4, character_set_connection=utf8mb4, character_set_results=utf8mb4, collation_connection=utf8mb4_unicode_ci",
} if "mysql" in settings.DATABASE_URL else {},
)
# 创建异步会话工厂
SessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)

View File

@@ -0,0 +1,3 @@
"""
核心功能模块
"""

323
backend/app/core/config.py Normal file
View File

@@ -0,0 +1,323 @@
"""
系统配置
支持两种配置来源:
1. 环境变量 / .env 文件(传统方式,向后兼容)
2. 数据库 tenant_configs 表(新方式,支持热更新)
配置优先级:数据库 > 环境变量 > 默认值
"""
import os
import json
from functools import lru_cache
from typing import Optional, Any
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""系统配置"""
# 应用基础配置
APP_NAME: str = "KaoPeiLian"
APP_VERSION: str = "1.0.0"
DEBUG: bool = Field(default=True)
# 租户配置(用于多租户部署)
TENANT_CODE: str = Field(default="demo", description="租户编码,如 hua, yy, hl")
# 服务器配置
HOST: str = Field(default="0.0.0.0")
PORT: int = Field(default=8000)
# 数据库配置
DATABASE_URL: Optional[str] = Field(default=None)
MYSQL_HOST: str = Field(default="localhost")
MYSQL_PORT: int = Field(default=3306)
MYSQL_USER: str = Field(default="root")
MYSQL_PASSWORD: str = Field(default="password")
MYSQL_DATABASE: str = Field(default="kaopeilian")
@property
def database_url(self) -> str:
"""构建数据库连接URL"""
if self.DATABASE_URL:
return self.DATABASE_URL
# 使用urllib.parse.quote_plus来正确编码特殊字符
import urllib.parse
password = urllib.parse.quote_plus(self.MYSQL_PASSWORD)
return f"mysql+aiomysql://{self.MYSQL_USER}:{password}@{self.MYSQL_HOST}:{self.MYSQL_PORT}/{self.MYSQL_DATABASE}?charset=utf8mb4"
# Redis配置
REDIS_URL: str = Field(default="redis://localhost:6379/0")
# JWT配置
SECRET_KEY: str = Field(default="your-secret-key-here")
ALGORITHM: str = Field(default="HS256")
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=30)
REFRESH_TOKEN_EXPIRE_DAYS: int = Field(default=7)
# 跨域配置
CORS_ORIGINS: list[str] = Field(
default=[
"http://localhost:3000",
"http://localhost:3001",
"http://localhost:5173",
"http://127.0.0.1:3000",
"http://127.0.0.1:3001",
"http://127.0.0.1:5173",
]
)
@field_validator('CORS_ORIGINS', mode='before')
@classmethod
def parse_cors_origins(cls, v):
"""解析 CORS_ORIGINS 环境变量(支持 JSON 格式字符串)"""
if isinstance(v, str):
try:
return json.loads(v)
except json.JSONDecodeError:
# 如果不是 JSON 格式,尝试按逗号分割
return [origin.strip() for origin in v.split(',')]
return v
# 日志配置
LOG_LEVEL: str = Field(default="INFO")
LOG_FORMAT: str = Field(default="text") # text 或 json
LOG_DIR: str = Field(default="logs")
# 上传配置
UPLOAD_DIR: str = Field(default="uploads")
MAX_UPLOAD_SIZE: int = Field(default=15 * 1024 * 1024) # 15MB
@property
def UPLOAD_PATH(self) -> str:
"""获取上传文件的完整路径"""
import os
return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.UPLOAD_DIR)
# Coze 平台配置(陪练对话、播课等)
COZE_API_BASE: Optional[str] = Field(default="https://api.coze.cn")
COZE_WORKSPACE_ID: Optional[str] = Field(default=None)
COZE_API_TOKEN: Optional[str] = Field(default="pat_Sa5OiuUl0gDflnKstQTToIz0sSMshBV06diX0owOeuI1ZK1xDLH5YZH9fSeuKLIi")
COZE_TRAINING_BOT_ID: Optional[str] = Field(default=None)
COZE_CHAT_BOT_ID: Optional[str] = Field(default=None)
COZE_PRACTICE_BOT_ID: Optional[str] = Field(default="7560643598174683145") # 陪练专用Bot ID
# 播课工作流配置(多租户需在环境变量中覆盖,参见:应用配置清单.md
COZE_BROADCAST_WORKFLOW_ID: str = Field(default="7577983042284486666") # 默认:演示版播课工作流
COZE_BROADCAST_SPACE_ID: str = Field(default="7474971491470688296") # 播课工作流空间ID
COZE_BROADCAST_BOT_ID: Optional[str] = Field(default=None) # 播课工作流专用Bot ID
# OAuth配置可选
COZE_OAUTH_CLIENT_ID: Optional[str] = Field(default=None)
COZE_OAUTH_PUBLIC_KEY_ID: Optional[str] = Field(default=None)
COZE_OAUTH_PRIVATE_KEY_PATH: Optional[str] = Field(default=None)
# WebSocket语音配置
COZE_WS_BASE_URL: str = Field(default="wss://ws.coze.cn")
COZE_AUDIO_FORMAT: str = Field(default="pcm") # 音频格式
COZE_SAMPLE_RATE: int = Field(default=16000) # 采样率Hz
COZE_AUDIO_CHANNELS: int = Field(default=1) # 声道数(单声道)
COZE_AUDIO_BIT_DEPTH: int = Field(default=16) # 位深度
# 服务器公开访问域名
PUBLIC_DOMAIN: str = Field(default="http://aiedu.ireborn.com.cn")
# 言迹智能工牌API配置
YANJI_API_BASE: str = Field(default="https://open.yanjiai.com") # 正式环境
YANJI_CLIENT_ID: str = Field(default="1Fld4LCWt2vpJNG5")
YANJI_CLIENT_SECRET: str = Field(default="XE8w413qNtJBOdWc2aCezV0yMIHpUuTZ")
YANJI_TENANT_ID: str = Field(default="516799409476866048")
YANJI_ESTATE_ID: str = Field(default="516799468310364162")
# SCRM 系统对接 API Key用于内部服务间调用
SCRM_API_KEY: str = Field(default="scrm-kpl-api-key-2026-ruixiaomei")
# AI 服务配置(知识点分析 V2 使用)
# 首选服务商4sapi.com国内优化
AI_PRIMARY_API_KEY: str = Field(default="sk-9yMCXjRGANbacz20kJY8doSNy6Rf446aYwmgGIuIXQ7DAyBw") # 测试阶段 Key
AI_PRIMARY_BASE_URL: str = Field(default="https://4sapi.com/v1")
# 备选服务商OpenRouter模型全稳定性好
AI_FALLBACK_API_KEY: str = Field(default="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") # 测试阶段 Key
AI_FALLBACK_BASE_URL: str = Field(default="https://openrouter.ai/api/v1")
# 默认模型
AI_DEFAULT_MODEL: str = Field(default="gemini-3-flash-preview")
# 请求超时(秒)
AI_TIMEOUT: float = Field(default=120.0)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": True,
"extra": "allow", # 允许额外的环境变量
}
@lru_cache()
def get_settings() -> Settings:
"""获取系统配置(缓存)"""
return Settings()
settings = get_settings()
# ============================================
# 动态配置获取(支持从数据库读取)
# ============================================
class DynamicConfig:
"""
动态配置管理器
用于在运行时从数据库获取配置,支持热更新。
向后兼容:如果数据库不可用,回退到环境变量配置。
"""
_tenant_loader = None
_initialized = False
@classmethod
async def init(cls, redis_url: Optional[str] = None):
"""
初始化动态配置管理器
Args:
redis_url: Redis URL可选用于缓存
"""
if cls._initialized:
return
try:
from app.core.tenant_config import TenantConfigManager
if redis_url:
await TenantConfigManager.init_redis(redis_url)
cls._initialized = True
except Exception as e:
import logging
logging.getLogger(__name__).warning(f"动态配置初始化失败: {e}")
@classmethod
async def get(cls, key: str, default: Any = None, tenant_code: Optional[str] = None) -> Any:
"""
获取配置值
Args:
key: 配置键(如 AI_PRIMARY_API_KEY
default: 默认值
tenant_code: 租户编码(可选,默认使用环境变量中的 TENANT_CODE
Returns:
配置值
"""
# 确定租户编码
if tenant_code is None:
tenant_code = settings.TENANT_CODE
# 配置键到分组的映射
config_mapping = {
# 数据库
"MYSQL_HOST": ("database", "MYSQL_HOST"),
"MYSQL_PORT": ("database", "MYSQL_PORT"),
"MYSQL_USER": ("database", "MYSQL_USER"),
"MYSQL_PASSWORD": ("database", "MYSQL_PASSWORD"),
"MYSQL_DATABASE": ("database", "MYSQL_DATABASE"),
# Redis
"REDIS_HOST": ("redis", "REDIS_HOST"),
"REDIS_PORT": ("redis", "REDIS_PORT"),
"REDIS_DB": ("redis", "REDIS_DB"),
# 安全
"SECRET_KEY": ("security", "SECRET_KEY"),
"CORS_ORIGINS": ("security", "CORS_ORIGINS"),
# Coze
"COZE_PRACTICE_BOT_ID": ("coze", "COZE_PRACTICE_BOT_ID"),
"COZE_BROADCAST_WORKFLOW_ID": ("coze", "COZE_BROADCAST_WORKFLOW_ID"),
"COZE_BROADCAST_SPACE_ID": ("coze", "COZE_BROADCAST_SPACE_ID"),
"COZE_OAUTH_CLIENT_ID": ("coze", "COZE_OAUTH_CLIENT_ID"),
"COZE_OAUTH_PUBLIC_KEY_ID": ("coze", "COZE_OAUTH_PUBLIC_KEY_ID"),
# AI
"AI_PRIMARY_API_KEY": ("ai", "AI_PRIMARY_API_KEY"),
"AI_PRIMARY_BASE_URL": ("ai", "AI_PRIMARY_BASE_URL"),
"AI_FALLBACK_API_KEY": ("ai", "AI_FALLBACK_API_KEY"),
"AI_FALLBACK_BASE_URL": ("ai", "AI_FALLBACK_BASE_URL"),
"AI_DEFAULT_MODEL": ("ai", "AI_DEFAULT_MODEL"),
"AI_TIMEOUT": ("ai", "AI_TIMEOUT"),
# 言迹
"YANJI_CLIENT_ID": ("yanji", "YANJI_CLIENT_ID"),
"YANJI_CLIENT_SECRET": ("yanji", "YANJI_CLIENT_SECRET"),
"YANJI_TENANT_ID": ("yanji", "YANJI_TENANT_ID"),
"YANJI_ESTATE_ID": ("yanji", "YANJI_ESTATE_ID"),
}
# 尝试从数据库获取
if cls._initialized and key in config_mapping:
try:
from app.core.tenant_config import TenantConfigManager
config_group, config_key = config_mapping[key]
loader = TenantConfigManager.get_loader(tenant_code)
value = await loader.get_config(config_group, config_key)
if value is not None:
return value
except Exception:
pass
# 回退到环境变量 / Settings
env_value = getattr(settings, key, None)
if env_value is not None:
return env_value
return default
@classmethod
async def is_feature_enabled(cls, feature_code: str, tenant_code: Optional[str] = None) -> bool:
"""
检查功能是否启用
Args:
feature_code: 功能编码
tenant_code: 租户编码
Returns:
是否启用
"""
if tenant_code is None:
tenant_code = settings.TENANT_CODE
if cls._initialized:
try:
from app.core.tenant_config import TenantConfigManager
loader = TenantConfigManager.get_loader(tenant_code)
return await loader.is_feature_enabled(feature_code)
except Exception:
pass
return True # 默认启用
@classmethod
async def refresh_cache(cls, tenant_code: Optional[str] = None):
"""
刷新配置缓存
Args:
tenant_code: 租户编码(为空则刷新所有)
"""
if not cls._initialized:
return
try:
from app.core.tenant_config import TenantConfigManager
if tenant_code:
await TenantConfigManager.refresh_tenant_cache(tenant_code)
else:
await TenantConfigManager.refresh_all_cache()
except Exception:
pass

View File

@@ -0,0 +1,31 @@
"""
数据库配置
"""
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from .config import settings
# 创建异步引擎
engine = create_async_engine(
settings.database_url,
echo=settings.DEBUG,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
# 确保 MySQL 连接使用 UTF-8 字符集
connect_args={
"charset": "utf8mb4",
"use_unicode": True,
"autocommit": False,
"init_command": "SET character_set_client=utf8mb4, character_set_connection=utf8mb4, character_set_results=utf8mb4, collation_connection=utf8mb4_unicode_ci",
} if "mysql" in settings.database_url else {},
)
# 创建异步会话工厂
AsyncSessionLocal = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)

166
backend/app/core/deps.py Normal file
View File

@@ -0,0 +1,166 @@
"""依赖注入模块"""
from typing import AsyncGenerator, Optional
from sqlalchemy import select
import redis.asyncio as redis
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import AsyncSessionLocal
from app.core.config import get_settings
from app.models.user import User
# JWT Bearer认证
security = HTTPBearer()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
获取数据库会话
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
) -> User:
"""
获取当前用户基于JWT
- 从 Authorization Bearer Token 中解析用户ID
- 查询数据库返回完整的 User 对象
- 失败时抛出 401 未授权
"""
from app.core.security import decode_token # 延迟导入避免循环依赖
if not credentials or not credentials.scheme or not credentials.credentials:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证信息")
if credentials.scheme.lower() != "bearer":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="认证方式不支持")
token = credentials.credentials
try:
payload = decode_token(token)
user_id = int(payload.get("sub"))
except Exception:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的令牌")
result = await db.execute(
select(User).where(User.id == user_id, User.is_deleted == False)
)
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已被禁用"
)
return user
async def require_admin(current_user: User = Depends(get_current_user)) -> User:
"""
需要管理员权限
"""
if getattr(current_user, "role", None) != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
return current_user
async def require_admin_or_manager(current_user: User = Depends(get_current_user)) -> User:
"""
需要管理者或管理员权限
"""
if getattr(current_user, "role", None) not in ("admin", "manager"):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理者或管理员权限")
return current_user
async def get_optional_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
db: AsyncSession = Depends(get_db),
) -> Optional[User]:
"""
获取可选的当前用户(不强制登录)
"""
if not credentials:
return None
try:
return await get_current_user(credentials, db)
except:
return None
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
"""
获取当前活跃用户
"""
# TODO: 检查用户是否被禁用
return current_user
async def verify_scrm_api_key(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> bool:
"""
验证 SCRM 系统 API Key
用于内部服务间调用认证SCRM 系统通过固定 API Key 访问考陪练数据查询接口
请求头格式: Authorization: Bearer {SCRM_API_KEY}
"""
settings = get_settings()
if not credentials or not credentials.credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证信息"
)
if credentials.scheme.lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="认证方式不支持,需要 Bearer Token"
)
if credentials.credentials != settings.SCRM_API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的 API Key"
)
return True
# Redis 连接池
_redis_pool: Optional[redis.ConnectionPool] = None
async def get_redis() -> AsyncGenerator[redis.Redis, None]:
"""
获取 Redis 连接
"""
global _redis_pool
if _redis_pool is None:
settings = get_settings()
_redis_pool = redis.ConnectionPool.from_url(
settings.REDIS_URL, encoding="utf-8", decode_responses=True
)
client = redis.Redis(connection_pool=_redis_pool)
try:
yield client
finally:
await client.close()

View File

@@ -0,0 +1,28 @@
"""
应用生命周期事件处理
"""
from app.core.logger import logger
async def startup_handler():
"""应用启动时执行的任务"""
logger.info("执行启动任务...")
# TODO: 初始化数据库连接池
# TODO: 初始化Redis连接
# TODO: 初始化AI平台客户端
# TODO: 加载缓存数据
logger.info("启动任务完成")
async def shutdown_handler():
"""应用关闭时执行的任务"""
logger.info("执行关闭任务...")
# TODO: 关闭数据库连接池
# TODO: 关闭Redis连接
# TODO: 清理临时文件
# TODO: 保存应用状态
logger.info("关闭任务完成")

View File

@@ -0,0 +1,89 @@
"""统一异常定义"""
from typing import Optional, Dict, Any
from fastapi import HTTPException, status
class BusinessError(HTTPException):
"""业务异常基类"""
def __init__(
self,
message: str,
code: int = status.HTTP_400_BAD_REQUEST,
error_code: Optional[str] = None,
detail: Optional[Dict[str, Any]] = None,
):
super().__init__(
status_code=code,
detail={
"message": message,
"error_code": error_code or f"ERR_{code}",
"detail": detail,
},
)
self.message = message
self.code = code
self.error_code = error_code
class BadRequestError(BusinessError):
"""400 错误请求"""
def __init__(self, message: str = "错误的请求", **kwargs):
super().__init__(message, status.HTTP_400_BAD_REQUEST, **kwargs)
class UnauthorizedError(BusinessError):
"""401 未授权"""
def __init__(self, message: str = "未授权", **kwargs):
super().__init__(message, status.HTTP_401_UNAUTHORIZED, **kwargs)
class ForbiddenError(BusinessError):
"""403 禁止访问"""
def __init__(self, message: str = "禁止访问", **kwargs):
super().__init__(message, status.HTTP_403_FORBIDDEN, **kwargs)
class NotFoundError(BusinessError):
"""404 未找到"""
def __init__(self, message: str = "资源未找到", **kwargs):
super().__init__(message, status.HTTP_404_NOT_FOUND, **kwargs)
class ConflictError(BusinessError):
"""409 冲突"""
def __init__(self, message: str = "资源冲突", **kwargs):
super().__init__(message, status.HTTP_409_CONFLICT, **kwargs)
class ValidationError(BusinessError):
"""422 验证错误"""
def __init__(self, message: str = "验证失败", **kwargs):
super().__init__(message, status.HTTP_422_UNPROCESSABLE_ENTITY, **kwargs)
class InternalServerError(BusinessError):
"""500 内部服务器错误"""
def __init__(self, message: str = "内部服务器错误", **kwargs):
super().__init__(message, status.HTTP_500_INTERNAL_SERVER_ERROR, **kwargs)
class InsufficientPermissionsError(ForbiddenError):
"""权限不足"""
def __init__(self, message: str = "权限不足", **kwargs):
super().__init__(message, error_code="INSUFFICIENT_PERMISSIONS", **kwargs)
class ExternalServiceError(BusinessError):
"""外部服务错误"""
def __init__(self, message: str = "外部服务异常", **kwargs):
super().__init__(message, status.HTTP_502_BAD_GATEWAY, error_code="EXTERNAL_SERVICE_ERROR", **kwargs)

View File

@@ -0,0 +1,76 @@
"""
日志配置
"""
import logging
import sys
from typing import Any
import structlog
from structlog.stdlib import LoggerFactory
from app.core.config import get_settings
settings = get_settings()
def setup_logging():
"""
配置日志系统
"""
# 设置日志级别
log_level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)
# 配置标准库日志
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=log_level,
)
# 配置处理器
processors = [
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
]
# 根据配置选择输出格式
if getattr(settings, "LOG_FORMAT", "text") == "json":
processors.append(structlog.processors.JSONRenderer())
else:
processors.append(structlog.dev.ConsoleRenderer())
# 配置 structlog
structlog.configure(
processors=processors,
context_class=dict,
logger_factory=LoggerFactory(),
cache_logger_on_first_use=True,
)
# 设置日志
setup_logging()
# 获取日志器
def get_logger(name: str = __name__) -> Any:
"""
获取日志器
Args:
name: 日志器名称
Returns:
日志器实例
"""
return structlog.get_logger(name)
# 默认日志器
logger = get_logger("app")

View File

@@ -0,0 +1,64 @@
"""
中间件定义
"""
import time
import uuid
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.logger import logger
class RequestIDMiddleware(BaseHTTPMiddleware):
"""请求ID中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 生成请求ID
request_id = str(uuid.uuid4())
# 将请求ID添加到request状态
request.state.request_id = request_id
# 记录请求开始
start_time = time.time()
# 处理请求
response = await call_next(request)
# 计算处理时间
process_time = time.time() - start_time
# 添加响应头
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = str(process_time)
# 记录请求日志
logger.info(
"HTTP请求",
method=request.method,
url=str(request.url),
status_code=response.status_code,
process_time=process_time,
request_id=request_id,
)
return response
class GlobalContextMiddleware(BaseHTTPMiddleware):
"""全局上下文中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 设置追踪ID用于分布式追踪
trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4()))
request.state.trace_id = trace_id
# 处理请求
response = await call_next(request)
# 添加追踪ID到响应头
response.headers["X-Trace-ID"] = trace_id
return response

44
backend/app/core/redis.py Normal file
View File

@@ -0,0 +1,44 @@
"""
Redis连接管理
"""
from typing import Optional
from redis import asyncio as aioredis
from app.core.config import settings
from app.core.logger import logger
# 全局Redis连接实例
redis_client: Optional[aioredis.Redis] = None
async def init_redis() -> aioredis.Redis:
"""初始化Redis连接"""
global redis_client
try:
redis_client = await aioredis.from_url(
settings.REDIS_URL, encoding="utf-8", decode_responses=True
)
# 测试连接
await redis_client.ping()
logger.info("Redis连接成功", url=settings.REDIS_URL)
return redis_client
except Exception as e:
logger.error("Redis连接失败", error=str(e), url=settings.REDIS_URL)
raise
async def close_redis():
"""关闭Redis连接"""
global redis_client
if redis_client:
await redis_client.close()
logger.info("Redis连接已关闭")
redis_client = None
def get_redis_client() -> aioredis.Redis:
"""获取Redis客户端实例"""
if not redis_client:
raise RuntimeError("Redis client not initialized")
return redis_client

View File

@@ -0,0 +1,72 @@
"""
安全相关功能
"""
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Union
import bcrypt
from jose import JWTError, jwt
from .config import settings
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
) -> str:
"""创建访问令牌"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
) -> str:
"""创建刷新令牌"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str) -> Dict[str, Any]:
"""解码令牌"""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
return payload
except JWTError:
raise ValueError("Invalid token")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return bcrypt.checkpw(
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
)
def get_password_hash(password: str) -> str:
"""生成密码哈希"""
salt = bcrypt.gensalt()
hashed_password = bcrypt.hashpw(password.encode("utf-8"), salt)
return hashed_password.decode("utf-8")

View File

@@ -0,0 +1,81 @@
"""
简化认证中间件 - 支持 API Key 和长期 Token
用于内部服务间调用
"""
from typing import Optional
from fastapi import HTTPException, Header, status
from app.models.user import User
# 配置 API Keys用于内部服务调用
API_KEYS = {
"internal-service-2025-kaopeilian": {
"service": "internal",
"user_id": 1,
"username": "internal_service",
"role": "admin"
}
}
# 长期有效的 Token用于内部服务调用
LONG_TERM_TOKENS = {
"permanent-token-for-internal-2025": {
"service": "internal",
"user_id": 1,
"username": "internal_service",
"role": "admin"
}
}
def get_current_user_by_api_key(
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None)
) -> Optional[User]:
"""
通过 API Key 或长期 Token 获取用户
支持两种方式:
1. X-API-Key: internal-service-2025-kaopeilian
2. Authorization: Bearer permanent-token-for-internal-2025
"""
# 方式1检查 API Key
if x_api_key and x_api_key in API_KEYS:
api_key_info = API_KEYS[x_api_key]
# 创建一个虚拟用户对象
user = User()
user.id = api_key_info["user_id"]
user.username = api_key_info["username"]
user.role = api_key_info["role"]
return user
# 方式2检查长期 Token
if authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if token in LONG_TERM_TOKENS:
token_info = LONG_TERM_TOKENS[token]
user = User()
user.id = token_info["user_id"]
user.username = token_info["username"]
user.role = token_info["role"]
return user
return None
def get_current_user_simple(
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None)
) -> User:
"""
简化的用户认证依赖项
"""
# 尝试 API Key 或长期 Token 认证
user = get_current_user_by_api_key(x_api_key, authorization)
if user:
return user
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)

View File

@@ -0,0 +1,421 @@
"""
租户配置加载器
功能:
1. 从数据库 tenant_configs 表加载租户配置
2. 支持 Redis 缓存
3. 数据库不可用时回退到环境变量
4. 支持配置热更新
"""
import os
import json
import logging
from typing import Optional, Dict, Any
from functools import lru_cache
import aiomysql
import redis.asyncio as redis
logger = logging.getLogger(__name__)
# ============================================
# 平台管理库连接配置
#
# 注意:敏感信息必须通过环境变量传递,禁止硬编码
# 参考:瑞小美系统技术栈标准与字符标准.md - 敏感信息管理
# ============================================
ADMIN_DB_CONFIG = {
"host": os.getenv("ADMIN_DB_HOST", "prod-mysql"),
"port": int(os.getenv("ADMIN_DB_PORT", "3306")),
"user": os.getenv("ADMIN_DB_USER", "root"),
"password": os.getenv("ADMIN_DB_PASSWORD"), # 必须从环境变量获取
"db": os.getenv("ADMIN_DB_NAME", "kaopeilian_admin"),
"charset": "utf8mb4",
}
# 校验必填环境变量
if not ADMIN_DB_CONFIG["password"]:
logger.warning(
"ADMIN_DB_PASSWORD 环境变量未设置,租户配置加载功能将不可用。"
"请在 .env.admin 文件中配置此变量。"
)
# Redis 缓存配置
CACHE_PREFIX = "tenant_config:"
CACHE_TTL = 300 # 5分钟缓存
class TenantConfigLoader:
"""租户配置加载器"""
def __init__(self, tenant_code: str, redis_client: Optional[redis.Redis] = None):
"""
初始化租户配置加载器
Args:
tenant_code: 租户编码(如 hua, yy, hl
redis_client: Redis 客户端(可选)
"""
self.tenant_code = tenant_code
self.redis_client = redis_client
self._config_cache: Dict[str, Any] = {}
self._tenant_id: Optional[int] = None
async def get_config(self, config_group: str, config_key: str, default: Any = None) -> Any:
"""
获取配置项
优先级:
1. 内存缓存
2. Redis 缓存
3. 数据库
4. 环境变量
5. 默认值
Args:
config_group: 配置分组database, redis, coze, ai, yanji, security
config_key: 配置键
default: 默认值
Returns:
配置值
"""
cache_key = f"{config_group}.{config_key}"
# 1. 内存缓存
if cache_key in self._config_cache:
return self._config_cache[cache_key]
# 2. Redis 缓存
if self.redis_client:
try:
redis_key = f"{CACHE_PREFIX}{self.tenant_code}:{cache_key}"
cached_value = await self.redis_client.get(redis_key)
if cached_value:
value = json.loads(cached_value)
self._config_cache[cache_key] = value
return value
except Exception as e:
logger.warning(f"Redis 缓存读取失败: {e}")
# 3. 数据库
try:
value = await self._get_from_database(config_group, config_key)
if value is not None:
self._config_cache[cache_key] = value
# 写入 Redis 缓存
if self.redis_client:
try:
redis_key = f"{CACHE_PREFIX}{self.tenant_code}:{cache_key}"
await self.redis_client.setex(
redis_key,
CACHE_TTL,
json.dumps(value)
)
except Exception as e:
logger.warning(f"Redis 缓存写入失败: {e}")
return value
except Exception as e:
logger.warning(f"数据库配置读取失败: {e}")
# 4. 环境变量
env_value = os.getenv(config_key)
if env_value is not None:
return env_value
# 5. 默认值
return default
async def _get_from_database(self, config_group: str, config_key: str) -> Optional[Any]:
"""从数据库获取配置"""
conn = None
try:
conn = await aiomysql.connect(**ADMIN_DB_CONFIG)
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取租户 ID
if self._tenant_id is None:
await cursor.execute(
"SELECT id FROM tenants WHERE code = %s AND status = 'active'",
(self.tenant_code,)
)
row = await cursor.fetchone()
if row:
self._tenant_id = row['id']
else:
return None
# 获取配置值
await cursor.execute(
"""
SELECT config_value, value_type, is_encrypted
FROM tenant_configs
WHERE tenant_id = %s AND config_group = %s AND config_key = %s
""",
(self._tenant_id, config_group, config_key)
)
row = await cursor.fetchone()
if row:
return self._parse_value(row['config_value'], row['value_type'], row['is_encrypted'])
# 如果租户没有配置,获取默认值
await cursor.execute(
"""
SELECT default_value, value_type
FROM config_templates
WHERE config_group = %s AND config_key = %s
""",
(config_group, config_key)
)
row = await cursor.fetchone()
if row and row['default_value']:
return self._parse_value(row['default_value'], row['value_type'], False)
return None
finally:
if conn:
conn.close()
def _parse_value(self, value: str, value_type: str, is_encrypted: bool) -> Any:
"""解析配置值"""
if value is None:
return None
# TODO: 如果是加密值,先解密
if is_encrypted:
# 这里可以实现解密逻辑
pass
if value_type == 'int':
return int(value)
elif value_type == 'bool':
return value.lower() in ('true', '1', 'yes')
elif value_type == 'json':
return json.loads(value)
elif value_type == 'float':
return float(value)
else:
return value
async def get_all_configs(self) -> Dict[str, Any]:
"""获取租户的所有配置"""
configs = {}
conn = None
try:
conn = await aiomysql.connect(**ADMIN_DB_CONFIG)
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取租户 ID
await cursor.execute(
"SELECT id FROM tenants WHERE code = %s AND status = 'active'",
(self.tenant_code,)
)
row = await cursor.fetchone()
if not row:
return configs
tenant_id = row['id']
# 获取所有配置
await cursor.execute(
"""
SELECT config_group, config_key, config_value, value_type, is_encrypted
FROM tenant_configs
WHERE tenant_id = %s
""",
(tenant_id,)
)
rows = await cursor.fetchall()
for row in rows:
key = f"{row['config_group']}.{row['config_key']}"
configs[key] = self._parse_value(
row['config_value'],
row['value_type'],
row['is_encrypted']
)
return configs
finally:
if conn:
conn.close()
async def refresh_cache(self):
"""刷新缓存"""
self._config_cache.clear()
if self.redis_client:
try:
# 删除该租户的所有缓存
pattern = f"{CACHE_PREFIX}{self.tenant_code}:*"
cursor = 0
while True:
cursor, keys = await self.redis_client.scan(cursor, match=pattern, count=100)
if keys:
await self.redis_client.delete(*keys)
if cursor == 0:
break
except Exception as e:
logger.warning(f"Redis 缓存刷新失败: {e}")
async def is_feature_enabled(self, feature_code: str) -> bool:
"""
检查功能是否启用
Args:
feature_code: 功能编码
Returns:
是否启用
"""
conn = None
try:
conn = await aiomysql.connect(**ADMIN_DB_CONFIG)
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取租户 ID
if self._tenant_id is None:
await cursor.execute(
"SELECT id FROM tenants WHERE code = %s AND status = 'active'",
(self.tenant_code,)
)
row = await cursor.fetchone()
if row:
self._tenant_id = row['id']
# 先查租户级别的配置
if self._tenant_id:
await cursor.execute(
"""
SELECT is_enabled FROM feature_switches
WHERE tenant_id = %s AND feature_code = %s
""",
(self._tenant_id, feature_code)
)
row = await cursor.fetchone()
if row:
return bool(row['is_enabled'])
# 再查全局默认配置
await cursor.execute(
"""
SELECT is_enabled FROM feature_switches
WHERE tenant_id IS NULL AND feature_code = %s
""",
(feature_code,)
)
row = await cursor.fetchone()
if row:
return bool(row['is_enabled'])
return True # 默认启用
except Exception as e:
logger.warning(f"功能开关查询失败: {e}, 默认启用")
return True
finally:
if conn:
conn.close()
class TenantConfigManager:
"""租户配置管理器(单例)"""
_instances: Dict[str, TenantConfigLoader] = {}
_redis_client: Optional[redis.Redis] = None
@classmethod
async def init_redis(cls, redis_url: str):
"""初始化 Redis 连接"""
try:
cls._redis_client = redis.from_url(redis_url)
await cls._redis_client.ping()
logger.info("TenantConfigManager Redis 连接成功")
except Exception as e:
logger.warning(f"TenantConfigManager Redis 连接失败: {e}")
cls._redis_client = None
@classmethod
def get_loader(cls, tenant_code: str) -> TenantConfigLoader:
"""获取租户配置加载器"""
if tenant_code not in cls._instances:
cls._instances[tenant_code] = TenantConfigLoader(
tenant_code,
cls._redis_client
)
return cls._instances[tenant_code]
@classmethod
async def refresh_tenant_cache(cls, tenant_code: str):
"""刷新指定租户的缓存"""
if tenant_code in cls._instances:
await cls._instances[tenant_code].refresh_cache()
@classmethod
async def refresh_all_cache(cls):
"""刷新所有租户的缓存"""
for loader in cls._instances.values():
await loader.refresh_cache()
# ============================================
# 辅助函数
# ============================================
def get_tenant_code_from_domain(domain: str) -> str:
"""
从域名提取租户编码
Examples:
hua.ireborn.com.cn -> hua
yy.ireborn.com.cn -> yy
aiedu.ireborn.com.cn -> demo
"""
if not domain:
return "demo"
# 移除 https:// 或 http://
domain = domain.replace("https://", "").replace("http://", "")
# 获取子域名
parts = domain.split(".")
if len(parts) >= 3:
subdomain = parts[0]
# 特殊处理
if subdomain == "aiedu":
return "demo"
return subdomain
return "demo"
async def get_tenant_config(tenant_code: str, config_group: str, config_key: str, default: Any = None) -> Any:
"""
快捷函数:获取租户配置
Args:
tenant_code: 租户编码
config_group: 配置分组
config_key: 配置键
default: 默认值
Returns:
配置值
"""
loader = TenantConfigManager.get_loader(tenant_code)
return await loader.get_config(config_group, config_key, default)
async def is_tenant_feature_enabled(tenant_code: str, feature_code: str) -> bool:
"""
快捷函数:检查租户功能是否启用
Args:
tenant_code: 租户编码
feature_code: 功能编码
Returns:
是否启用
"""
loader = TenantConfigManager.get_loader(tenant_code)
return await loader.is_feature_enabled(feature_code)

140
backend/app/main.py Normal file
View File

@@ -0,0 +1,140 @@
"""考培练系统后端主应用"""
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
import json
import os
from app.core.config import get_settings
from app.api.v1 import api_router
# 配置日志
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
settings = get_settings()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时执行
logger.info(f"启动 {settings.APP_NAME} v{settings.APP_VERSION}")
# 初始化 Redis
try:
from app.core.redis import init_redis, close_redis
await init_redis()
logger.info("Redis 初始化成功")
except Exception as e:
logger.warning(f"Redis 初始化失败(非致命): {e}")
yield
# 关闭时执行
try:
from app.core.redis import close_redis
await close_redis()
logger.info("Redis 连接已关闭")
except Exception as e:
logger.warning(f"关闭 Redis 连接失败: {e}")
logger.info("应用关闭")
# 自定义 JSON 响应类,确保中文正确编码
class UTF8JSONResponse(JSONResponse):
def render(self, content) -> bytes:
return json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
).encode("utf-8")
# 创建FastAPI应用
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
description="考培练系统后端API",
lifespan=lifespan,
# 确保响应正确的 UTF-8 编码
default_response_class=UTF8JSONResponse,
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 健康检查端点
@app.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"service": settings.APP_NAME,
"version": settings.APP_VERSION,
}
# 根路径
@app.get("/")
async def root():
"""根路径"""
return {
"message": f"欢迎使用{settings.APP_NAME}",
"version": settings.APP_VERSION,
"docs": "/docs",
}
# 注册路由
app.include_router(api_router, prefix="/api/v1")
# 挂载静态文件目录
# 创建上传目录(如果不存在)
upload_path = settings.UPLOAD_PATH
os.makedirs(upload_path, exist_ok=True)
# 挂载上传文件目录为静态文件服务
app.mount("/static/uploads", StaticFiles(directory=upload_path), name="uploads")
# 全局异常处理
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""全局异常处理"""
logger.error(f"未处理的异常: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"code": 500,
"message": "内部服务器错误",
"detail": str(exc) if settings.DEBUG else None,
},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.HOST,
port=settings.PORT,
reload=settings.DEBUG,
log_level=settings.LOG_LEVEL.lower(),
)
# 测试热重载 - Fri Sep 26 03:37:07 CST 2025

View File

@@ -0,0 +1,49 @@
"""数据库模型包"""
from app.models.base import Base, BaseModel
from app.models.user import User
from app.models.course import Course, CourseMaterial, KnowledgePoint, GrowthPath
from app.models.training import (
TrainingScene,
TrainingSession,
TrainingMessage,
TrainingReport,
)
from app.models.exam import Exam, Question, ExamResult
from app.models.exam_mistake import ExamMistake
from app.models.position import Position
from app.models.position_member import PositionMember
from app.models.position_course import PositionCourse
from app.models.practice import PracticeScene, PracticeSession, PracticeDialogue, PracticeReport
from app.models.system_log import SystemLog
from app.models.task import Task, TaskCourse, TaskAssignment
from app.models.notification import Notification
__all__ = [
"Base",
"BaseModel",
"User",
"Course",
"CourseMaterial",
"KnowledgePoint",
"GrowthPath",
"TrainingScene",
"TrainingSession",
"TrainingMessage",
"TrainingReport",
"Exam",
"Question",
"ExamResult",
"ExamMistake",
"Position",
"PositionMember",
"PositionCourse",
"PracticeScene",
"PracticeSession",
"PracticeDialogue",
"PracticeReport",
"SystemLog",
"Task",
"TaskCourse",
"TaskAssignment",
"Notification",
]

View File

@@ -0,0 +1,64 @@
"""
能力评估模型
用于存储智能工牌数据分析、练习报告等产生的能力评估结果
"""
from sqlalchemy import Column, Integer, String, DateTime, JSON, ForeignKey, Text
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.models.base import Base
class AbilityAssessment(Base):
"""能力评估历史表"""
__tablename__ = "ability_assessments"
id = Column(Integer, primary_key=True, index=True, comment='主键ID')
user_id = Column(
Integer,
ForeignKey('users.id', ondelete='CASCADE'),
nullable=False,
comment='用户ID'
)
source_type = Column(
String(50),
nullable=False,
comment='数据来源: yanji_badge(智能工牌), practice_report(练习报告), manual(手动评估)'
)
source_id = Column(
String(100),
comment='来源记录ID如录音ID列表逗号分隔'
)
total_score = Column(
Integer,
comment='综合评分(0-100)'
)
ability_dimensions = Column(
JSON,
nullable=False,
comment='6个能力维度评分JSON数组'
)
recommended_courses = Column(
JSON,
comment='推荐课程列表JSON数组'
)
conversation_count = Column(
Integer,
comment='分析的对话数量'
)
analyzed_at = Column(
DateTime,
server_default=func.now(),
comment='分析时间'
)
created_at = Column(
DateTime,
server_default=func.now(),
comment='创建时间'
)
# 关系
# user = relationship("User", back_populates="ability_assessments")
def __repr__(self):
return f"<AbilityAssessment(id={self.id}, user_id={self.user_id}, total_score={self.total_score})>"

View File

@@ -0,0 +1,47 @@
"""基础模型定义"""
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, DateTime, Integer, Boolean, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
# 创建基础模型类
Base = declarative_base()
class BaseModel(Base):
"""
基础模型类,所有模型都应继承此类
包含通用字段id, created_at, updated_at
时区使用北京时间Asia/Shanghai, UTC+8
"""
__abstract__ = True
__allow_unmapped__ = True # SQLAlchemy 2.0 兼容性
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), nullable=False, comment="创建时间(北京时间)"
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now(), nullable=False, comment="更新时间(北京时间)"
)
class SoftDeleteMixin:
"""软删除混入类"""
__allow_unmapped__ = True
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
deleted_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
class AuditMixin:
"""审计字段混入类"""
__allow_unmapped__ = True
created_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
updated_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)

View File

@@ -0,0 +1,270 @@
"""
课程相关数据库模型
"""
from enum import Enum
from typing import List, Optional
from datetime import datetime
from sqlalchemy import (
String,
Text,
Integer,
Boolean,
ForeignKey,
Enum as SQLEnum,
Float,
JSON,
DateTime,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel, SoftDeleteMixin, AuditMixin
class CourseStatus(str, Enum):
"""课程状态枚举"""
DRAFT = "draft" # 草稿
PUBLISHED = "published" # 已发布
ARCHIVED = "archived" # 已归档
class CourseCategory(str, Enum):
"""课程分类枚举"""
TECHNOLOGY = "technology" # 技术
MANAGEMENT = "management" # 管理
BUSINESS = "business" # 业务
GENERAL = "general" # 通用
class Course(BaseModel, SoftDeleteMixin, AuditMixin):
"""
课程表
"""
__tablename__ = "courses"
# 基本信息
name: Mapped[str] = mapped_column(String(200), nullable=False, comment="课程名称")
description: Mapped[Optional[str]] = mapped_column(
Text, nullable=True, comment="课程描述"
)
category: Mapped[CourseCategory] = mapped_column(
SQLEnum(
CourseCategory,
values_callable=lambda enum_cls: [e.value for e in enum_cls],
validate_strings=True,
),
default=CourseCategory.GENERAL,
nullable=False,
comment="课程分类",
)
status: Mapped[CourseStatus] = mapped_column(
SQLEnum(
CourseStatus,
values_callable=lambda enum_cls: [e.value for e in enum_cls],
validate_strings=True,
),
default=CourseStatus.DRAFT,
nullable=False,
comment="课程状态",
)
# 课程详情
cover_image: Mapped[Optional[str]] = mapped_column(
String(500), nullable=True, comment="封面图片URL"
)
duration_hours: Mapped[Optional[float]] = mapped_column(
Float, nullable=True, comment="课程时长(小时)"
)
difficulty_level: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True, comment="难度等级(1-5)"
)
tags: Mapped[Optional[List[str]]] = mapped_column(
JSON, nullable=True, comment="标签列表"
)
# 发布信息
published_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True, comment="发布时间"
)
publisher_id: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True, comment="发布人ID"
)
# 播课信息
# 播课功能Coze工作流直接写数据库
broadcast_audio_url: Mapped[Optional[str]] = mapped_column(
String(500), nullable=True, comment="播课音频URL"
)
broadcast_generated_at: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True, comment="播课生成时间"
)
# 排序和权重
sort_order: Mapped[int] = mapped_column(
Integer, default=0, nullable=False, comment="排序顺序"
)
is_featured: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False, comment="是否推荐"
)
# 统计信息
student_count: Mapped[int] = mapped_column(
Integer, default=0, nullable=False, comment="学习人数"
)
is_new: Mapped[bool] = mapped_column(
Boolean, default=True, nullable=False, comment="是否新课程"
)
# 资料下载设置
allow_download: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False, comment="是否允许下载资料"
)
# 关联关系
materials: Mapped[List["CourseMaterial"]] = relationship(
"CourseMaterial", back_populates="course"
)
knowledge_points: Mapped[List["KnowledgePoint"]] = relationship(
"KnowledgePoint", back_populates="course"
)
# 岗位分配关系(通过关联表)
position_assignments = relationship("PositionCourse", back_populates="course", cascade="all, delete-orphan")
exams = relationship("Exam", back_populates="course")
questions = relationship("Question", back_populates="course")
class CourseMaterial(BaseModel, SoftDeleteMixin, AuditMixin):
"""
课程资料表
"""
__tablename__ = "course_materials"
course_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("courses.id", ondelete="CASCADE"),
nullable=False,
comment="课程ID",
)
name: Mapped[str] = mapped_column(String(200), nullable=False, comment="资料名称")
description: Mapped[Optional[str]] = mapped_column(
Text, nullable=True, comment="资料描述"
)
file_url: Mapped[str] = mapped_column(String(500), nullable=False, comment="文件URL")
file_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="文件类型")
file_size: Mapped[int] = mapped_column(Integer, nullable=False, comment="文件大小(字节)")
# 排序
sort_order: Mapped[int] = mapped_column(
Integer, default=0, nullable=False, comment="排序顺序"
)
# 关联关系
course: Mapped["Course"] = relationship("Course", back_populates="materials")
# 关联的知识点(直接关联)
knowledge_points: Mapped[List["KnowledgePoint"]] = relationship(
"KnowledgePoint", back_populates="material"
)
class KnowledgePoint(BaseModel, SoftDeleteMixin, AuditMixin):
"""
知识点表
"""
__tablename__ = "knowledge_points"
course_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("courses.id", ondelete="CASCADE"),
nullable=False,
comment="课程ID",
)
material_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("course_materials.id", ondelete="CASCADE"),
nullable=False,
comment="关联资料ID",
)
name: Mapped[str] = mapped_column(String(200), nullable=False, comment="知识点名称")
description: Mapped[Optional[str]] = mapped_column(
Text, nullable=True, comment="知识点描述"
)
type: Mapped[str] = mapped_column(
String(50), default="概念定义", nullable=False, comment="知识点类型"
)
source: Mapped[int] = mapped_column(
Integer, default=0, nullable=False, comment="来源0=手动1=AI分析"
)
topic_relation: Mapped[Optional[str]] = mapped_column(
String(200), nullable=True, comment="与主题的关系描述"
)
# 关联关系
course: Mapped["Course"] = relationship("Course", back_populates="knowledge_points")
material: Mapped["CourseMaterial"] = relationship("CourseMaterial")
class GrowthPath(BaseModel, SoftDeleteMixin):
"""
成长路径表
"""
__tablename__ = "growth_paths"
name: Mapped[str] = mapped_column(String(200), nullable=False, comment="路径名称")
description: Mapped[Optional[str]] = mapped_column(
Text, nullable=True, comment="路径描述"
)
target_role: Mapped[Optional[str]] = mapped_column(
String(100), nullable=True, comment="目标角色"
)
# 路径配置
courses: Mapped[Optional[List[dict]]] = mapped_column(
JSON, nullable=True, comment="课程列表[{course_id, order, is_required}]"
)
# 预计时长
estimated_duration_days: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True, comment="预计完成天数"
)
# 状态
is_active: Mapped[bool] = mapped_column(
Boolean, default=True, nullable=False, comment="是否启用"
)
sort_order: Mapped[int] = mapped_column(
Integer, default=0, nullable=False, comment="排序顺序"
)
class MaterialKnowledgePoint(BaseModel, SoftDeleteMixin):
"""
资料知识点关联表
"""
__tablename__ = "material_knowledge_points"
material_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("course_materials.id", ondelete="CASCADE"),
nullable=False,
comment="资料ID",
)
knowledge_point_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("knowledge_points.id", ondelete="CASCADE"),
nullable=False,
comment="知识点ID",
)
sort_order: Mapped[int] = mapped_column(
Integer, default=0, nullable=False, comment="排序顺序"
)
is_ai_generated: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False, comment="是否AI生成"
)

View File

@@ -0,0 +1,34 @@
"""
课程考试设置模型
"""
from sqlalchemy import Column, Integer, ForeignKey, Boolean
from sqlalchemy.orm import relationship
from app.models.base import BaseModel, SoftDeleteMixin, AuditMixin
class CourseExamSettings(BaseModel, SoftDeleteMixin, AuditMixin):
"""课程考试设置表"""
__tablename__ = "course_exam_settings"
course_id = Column(Integer, ForeignKey("courses.id"), unique=True, nullable=False, comment="课程ID")
# 题型数量设置
single_choice_count = Column(Integer, default=4, nullable=False, comment="单选题数量")
multiple_choice_count = Column(Integer, default=2, nullable=False, comment="多选题数量")
true_false_count = Column(Integer, default=1, nullable=False, comment="判断题数量")
fill_blank_count = Column(Integer, default=2, nullable=False, comment="填空题数量")
essay_count = Column(Integer, default=1, nullable=False, comment="问答题数量")
# 考试参数设置
duration_minutes = Column(Integer, default=10, nullable=False, comment="考试时长(分钟)")
difficulty_level = Column(Integer, default=3, nullable=False, comment="难度系数(1-5)")
passing_score = Column(Integer, default=60, nullable=False, comment="及格分数")
# 其他设置
is_enabled = Column(Boolean, default=True, nullable=False, comment="是否启用")
show_answer_immediately = Column(Boolean, default=False, nullable=False, comment="是否立即显示答案")
allow_retake = Column(Boolean, default=True, nullable=False, comment="是否允许重考")
max_retake_times = Column(Integer, default=3, nullable=True, comment="最大重考次数")
# 关系
course = relationship("Course", backref="exam_settings", uselist=False)

153
backend/app/models/exam.py Normal file
View File

@@ -0,0 +1,153 @@
"""
考试相关模型定义
"""
from datetime import datetime
from typing import List, Optional
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON, Float, func
from sqlalchemy.orm import relationship, Mapped, mapped_column
from app.models.base import BaseModel
class Exam(BaseModel):
"""考试记录模型"""
__tablename__ = "exams"
__allow_unmapped__ = True
id: Mapped[int] = mapped_column(Integer, primary_key=True)
user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
course_id: Mapped[int] = mapped_column(
Integer, ForeignKey("courses.id"), nullable=False, index=True
)
# 考试信息
exam_name: Mapped[str] = mapped_column(String(255), nullable=False)
question_count: Mapped[int] = mapped_column(Integer, default=10)
total_score: Mapped[float] = mapped_column(Float, default=100.0)
pass_score: Mapped[float] = mapped_column(Float, default=60.0)
# 考试时间
start_time: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), comment="开始时间(北京时间)")
end_time: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True, comment="结束时间(北京时间)")
duration_minutes: Mapped[int] = mapped_column(Integer, default=60) # 考试时长(分钟)
# 考试结果
score: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
# 三轮考试得分
round1_score: Mapped[Optional[float]] = mapped_column(Float, nullable=True, comment="第一轮得分")
round2_score: Mapped[Optional[float]] = mapped_column(Float, nullable=True, comment="第二轮得分")
round3_score: Mapped[Optional[float]] = mapped_column(Float, nullable=True, comment="第三轮得分")
is_passed: Mapped[Optional[bool]] = mapped_column(nullable=True)
# 考试状态: started, submitted, timeout
status: Mapped[str] = mapped_column(String(20), default="started", index=True)
# 考试数据JSON格式存储题目和答案
questions: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
answers: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
# 关系
user = relationship("User", back_populates="exams")
course = relationship("Course", back_populates="exams")
results = relationship("ExamResult", back_populates="exam")
def __repr__(self):
return f"<Exam(id={self.id}, user_id={self.user_id}, course_id={self.course_id}, status={self.status})>"
class Question(BaseModel):
"""题目模型"""
__tablename__ = "questions"
__allow_unmapped__ = True
id: Mapped[int] = mapped_column(Integer, primary_key=True)
course_id: Mapped[int] = mapped_column(
Integer, ForeignKey("courses.id"), nullable=False, index=True
)
# 题目类型: single_choice, multiple_choice, true_false, fill_blank, essay
question_type: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
# 题目内容
title: Mapped[str] = mapped_column(Text, nullable=False)
content: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# 选项JSON格式适用于选择题
options: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
# 答案
correct_answer: Mapped[str] = mapped_column(Text, nullable=False)
explanation: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# 分值
score: Mapped[float] = mapped_column(Float, default=10.0)
# 难度等级: easy, medium, hard
difficulty: Mapped[str] = mapped_column(String(10), default="medium", index=True)
# 标签JSON格式
tags: Mapped[Optional[list]] = mapped_column(JSON, nullable=True)
# 使用统计
usage_count: Mapped[int] = mapped_column(Integer, default=0)
correct_count: Mapped[int] = mapped_column(Integer, default=0)
# 状态
is_active: Mapped[bool] = mapped_column(default=True, index=True)
# 关系
course = relationship("Course", back_populates="questions")
def __repr__(self):
return f"<Question(id={self.id}, course_id={self.course_id}, type={self.question_type})>"
class ExamResult(BaseModel):
"""考试结果详情模型"""
__tablename__ = "exam_results"
__allow_unmapped__ = True
id: Mapped[int] = mapped_column(Integer, primary_key=True)
exam_id: Mapped[int] = mapped_column(
Integer, ForeignKey("exams.id"), nullable=False, index=True
)
question_id: Mapped[int] = mapped_column(
Integer, ForeignKey("questions.id"), nullable=False
)
# 用户答案
user_answer: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# 是否正确
is_correct: Mapped[bool] = mapped_column(default=False)
# 得分
score: Mapped[float] = mapped_column(Float, default=0.0)
# 答题时长(秒)
answer_time: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
# 关系
exam = relationship("Exam", back_populates="results")
question = relationship("Question")
def __repr__(self):
return f"<ExamResult(id={self.id}, exam_id={self.exam_id}, question_id={self.question_id}, is_correct={self.is_correct})>"
# 在模型文件末尾添加关系定义
# 需要在User模型中添加
# exams = relationship("Exam", back_populates="user")
# 需要在Course模型中添加
# exams = relationship("Exam", back_populates="course")
# questions = relationship("Question", back_populates="course")
# 需要在Exam模型中添加
# results = relationship("ExamResult", back_populates="exam")

View File

@@ -0,0 +1,43 @@
"""
错题记录模型
"""
from sqlalchemy import Column, Integer, ForeignKey, Text, DateTime, func
from sqlalchemy.orm import relationship
from datetime import datetime
from app.models.base import BaseModel
class ExamMistake(BaseModel):
"""错题记录表"""
__tablename__ = "exam_mistakes"
# 核心关联字段
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True, comment="用户ID")
exam_id = Column(Integer, ForeignKey("exams.id", ondelete="CASCADE"), nullable=False, index=True, comment="考试ID")
question_id = Column(Integer, ForeignKey("questions.id", ondelete="SET NULL"), nullable=True, index=True, comment="题目IDAI生成的题目可能为空")
knowledge_point_id = Column(Integer, ForeignKey("knowledge_points.id", ondelete="SET NULL"), nullable=True, index=True, comment="关联的知识点ID")
# 题目核心信息
question_content = Column(Text, nullable=False, comment="题目内容")
correct_answer = Column(Text, nullable=False, comment="正确答案")
user_answer = Column(Text, nullable=True, comment="用户答案")
question_type = Column(Text, nullable=True, index=True, comment="题型(single/multiple/judge/blank/essay)")
# 掌握状态和统计字段
mastery_status = Column(Text, nullable=False, default='unmastered', index=True, comment="掌握状态: unmastered-未掌握, mastered-已掌握")
difficulty = Column(Text, nullable=False, default='medium', index=True, comment="题目难度: easy-简单, medium-中等, hard-困难")
wrong_count = Column(Integer, nullable=False, default=1, comment="错误次数统计")
mastered_at = Column(DateTime, nullable=True, comment="标记掌握时间")
# 审计字段继承自BaseModel但这里重写以匹配数据库实际结构
created_at = Column(DateTime, nullable=False, server_default=func.now(), comment="创建时间(北京时间)")
updated_at = Column(DateTime, nullable=False, server_default=func.now(), onupdate=func.now(), comment="更新时间(北京时间)")
# 关系
user = relationship("User", backref="exam_mistakes")
exam = relationship("Exam", backref="mistakes")
question = relationship("Question", backref="mistake_records")
knowledge_point = relationship("KnowledgePoint", backref="mistake_records")
def __repr__(self):
return f"<ExamMistake(id={self.id}, user_id={self.user_id}, exam_id={self.exam_id})>"

View File

@@ -0,0 +1,106 @@
"""
站内消息通知模型
用于记录用户的站内消息通知
"""
from datetime import datetime
from typing import Optional
from sqlalchemy import String, Text, Integer, Boolean, Index, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
class Notification(BaseModel):
"""
站内消息通知模型
用于存储发送给用户的各类站内通知消息,如:
- 岗位分配通知
- 课程分配通知
- 考试提醒通知
- 系统公告通知
"""
__tablename__ = "notifications"
# 接收用户ID外键关联到users表
user_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="接收用户ID"
)
# 通知标题
title: Mapped[str] = mapped_column(
String(200),
nullable=False,
comment="通知标题"
)
# 通知内容
content: Mapped[Optional[str]] = mapped_column(
Text,
nullable=True,
comment="通知内容"
)
# 通知类型
# position_assign: 岗位分配
# course_assign: 课程分配
# exam_remind: 考试提醒
# task_assign: 任务分配
# system: 系统通知
type: Mapped[str] = mapped_column(
String(50),
nullable=False,
default="system",
index=True,
comment="通知类型position_assign/course_assign/exam_remind/task_assign/system"
)
# 是否已读
is_read: Mapped[bool] = mapped_column(
Boolean,
default=False,
nullable=False,
index=True,
comment="是否已读"
)
# 关联数据ID可选如岗位ID、课程ID等
related_id: Mapped[Optional[int]] = mapped_column(
Integer,
nullable=True,
comment="关联数据ID岗位ID/课程ID等"
)
# 关联数据类型可选如position、course等
related_type: Mapped[Optional[str]] = mapped_column(
String(50),
nullable=True,
comment="关联数据类型"
)
# 发送者ID可选系统通知时为空
sender_id: Mapped[Optional[int]] = mapped_column(
Integer,
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
comment="发送者用户ID"
)
# 关联关系
user = relationship("User", foreign_keys=[user_id], backref="notifications")
sender = relationship("User", foreign_keys=[sender_id])
# 创建索引以优化查询性能
__table_args__ = (
Index('idx_notifications_user_read', 'user_id', 'is_read'),
Index('idx_notifications_user_created', 'user_id', 'created_at'),
Index('idx_notifications_type', 'type'),
)
def __repr__(self):
return f"<Notification(id={self.id}, user_id={self.user_id}, title={self.title}, is_read={self.is_read})>"

View File

@@ -0,0 +1,54 @@
"""
岗位Position数据模型
"""
from typing import Optional
from sqlalchemy import String, Integer, Text, ForeignKey, Boolean, JSON
from sqlalchemy.orm import Mapped, mapped_column, relationship
from typing import Optional, List
from .base import BaseModel, SoftDeleteMixin, AuditMixin
class Position(BaseModel, SoftDeleteMixin, AuditMixin):
"""
岗位表
字段说明:
- name: 岗位名称
- code: 岗位编码(唯一),用于稳定引用
- description: 岗位描述
- parent_id: 上级岗位ID支持树形结构
- status: 状态active/inactive
"""
__tablename__ = "positions"
__allow_unmapped__ = True
name: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
code: Mapped[str] = mapped_column(String(100), nullable=False, unique=True, index=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
parent_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey("positions.id", ondelete="SET NULL"))
status: Mapped[str] = mapped_column(String(20), default="active", nullable=False)
# 新增字段
skills: Mapped[Optional[List]] = mapped_column(JSON, nullable=True, comment="核心技能")
level: Mapped[Optional[str]] = mapped_column(String(20), nullable=True, comment="岗位等级")
sort_order: Mapped[Optional[int]] = mapped_column(Integer, default=0, nullable=True, comment="排序")
# 关系
parent: Mapped[Optional["Position"]] = relationship(
"Position", remote_side="Position.id", backref="children", lazy="selectin"
)
# 成员关系(通过关联表)
members = relationship("PositionMember", back_populates="position", cascade="all, delete-orphan")
# 课程关系(通过关联表)
courses = relationship("PositionCourse", back_populates="position", cascade="all, delete-orphan")
def __repr__(self) -> str:
return f"<Position(id={self.id}, name={self.name}, code={self.code}, status={self.status})>"

View File

@@ -0,0 +1,28 @@
"""
岗位课程关联模型
"""
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean, Enum, UniqueConstraint
from sqlalchemy.orm import relationship
from app.models.base import BaseModel, SoftDeleteMixin
class PositionCourse(BaseModel, SoftDeleteMixin):
"""岗位课程关联表"""
__tablename__ = "position_courses"
# 添加唯一约束:同一岗位下同一课程只能有一条有效记录
__table_args__ = (
UniqueConstraint('position_id', 'course_id', 'is_deleted', name='uix_position_course'),
)
position_id = Column(Integer, ForeignKey("positions.id"), nullable=False, comment="岗位ID")
course_id = Column(Integer, ForeignKey("courses.id"), nullable=False, comment="课程ID")
# 课程类型required必修、optional选修
course_type = Column(String(20), default="required", nullable=False, comment="课程类型")
priority = Column(Integer, default=0, comment="优先级/排序")
# 关系
position = relationship("Position", back_populates="courses")
course = relationship("Course", back_populates="position_assignments")

View File

@@ -0,0 +1,26 @@
"""
岗位成员关联模型
"""
from datetime import datetime
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean, UniqueConstraint, func
from sqlalchemy.orm import relationship
from app.models.base import BaseModel, SoftDeleteMixin
class PositionMember(BaseModel, SoftDeleteMixin):
"""岗位成员关联表"""
__tablename__ = "position_members"
# 添加唯一约束:同一岗位下同一用户只能有一条有效记录
__table_args__ = (
UniqueConstraint('position_id', 'user_id', 'is_deleted', name='uix_position_user'),
)
position_id = Column(Integer, ForeignKey("positions.id"), nullable=False, comment="岗位ID")
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, comment="用户ID")
role = Column(String(50), comment="成员角色(预留字段)")
joined_at = Column(DateTime, server_default=func.now(), comment="加入时间(北京时间)")
# 关系
position = relationship("Position", back_populates="members")
user = relationship("User", back_populates="position_memberships")

View File

@@ -0,0 +1,109 @@
"""
陪练场景模型
"""
from sqlalchemy import Column, Integer, String, Text, JSON, DECIMAL, Boolean, DateTime, ForeignKey
from sqlalchemy.sql import func
from app.models.base import Base
class PracticeScene(Base):
"""陪练场景模型"""
__tablename__ = "practice_scenes"
id = Column(Integer, primary_key=True, index=True, comment="场景ID")
name = Column(String(200), nullable=False, comment="场景名称")
description = Column(Text, comment="场景描述")
type = Column(String(50), nullable=False, index=True, comment="场景类型: phone/face/complaint/after-sales/product-intro")
difficulty = Column(String(50), nullable=False, index=True, comment="难度等级: beginner/junior/intermediate/senior/expert")
status = Column(String(20), default="active", index=True, comment="状态: active/inactive")
background = Column(Text, comment="场景背景设定")
ai_role = Column(Text, comment="AI角色描述")
objectives = Column(JSON, comment="练习目标数组")
keywords = Column(JSON, comment="关键词数组")
duration = Column(Integer, default=10, comment="预计时长(分钟)")
usage_count = Column(Integer, default=0, comment="使用次数")
rating = Column(DECIMAL(3, 1), default=0.0, comment="评分")
# 审计字段
created_by = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), comment="创建人ID")
updated_by = Column(Integer, ForeignKey("users.id", ondelete="SET NULL"), comment="更新人ID")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
# 软删除字段
is_deleted = Column(Boolean, default=False, index=True, comment="是否删除")
deleted_at = Column(DateTime, comment="删除时间")
def __repr__(self):
return f"<PracticeScene(id={self.id}, name='{self.name}', type='{self.type}', difficulty='{self.difficulty}')>"
class PracticeSession(Base):
"""陪练会话模型"""
__tablename__ = "practice_sessions"
id = Column(Integer, primary_key=True, index=True, comment="会话ID")
session_id = Column(String(50), unique=True, nullable=False, index=True, comment="会话唯一标识")
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True, comment="学员ID")
scene_id = Column(Integer, ForeignKey("practice_scenes.id", ondelete="SET NULL"), comment="场景ID")
scene_name = Column(String(200), comment="场景名称")
scene_type = Column(String(50), comment="场景类型")
conversation_id = Column(String(100), comment="Coze对话ID")
# 会话时间信息
start_time = Column(DateTime, nullable=False, index=True, comment="开始时间")
end_time = Column(DateTime, comment="结束时间")
duration_seconds = Column(Integer, default=0, comment="时长(秒)")
turns = Column(Integer, default=0, comment="对话轮次")
status = Column(String(20), default="in_progress", index=True, comment="状态: in_progress/completed/canceled")
# 审计字段
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
is_deleted = Column(Boolean, default=False, comment="是否删除")
def __repr__(self):
return f"<PracticeSession(session_id='{self.session_id}', scene='{self.scene_name}', status='{self.status}')>"
class PracticeDialogue(Base):
"""陪练对话记录模型"""
__tablename__ = "practice_dialogues"
id = Column(Integer, primary_key=True, index=True, comment="对话ID")
session_id = Column(String(50), nullable=False, index=True, comment="会话ID")
speaker = Column(String(20), nullable=False, comment="说话人: user/ai")
content = Column(Text, nullable=False, comment="对话内容")
timestamp = Column(DateTime, nullable=False, comment="时间戳")
sequence = Column(Integer, nullable=False, comment="顺序号")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
def __repr__(self):
return f"<PracticeDialogue(session_id='{self.session_id}', speaker='{self.speaker}', seq={self.sequence})>"
class PracticeReport(Base):
"""陪练分析报告模型"""
__tablename__ = "practice_reports"
id = Column(Integer, primary_key=True, index=True, comment="报告ID")
session_id = Column(String(50), unique=True, nullable=False, index=True, comment="会话ID")
# AI分析结果
total_score = Column(Integer, comment="综合得分0-100")
score_breakdown = Column(JSON, comment="分数细分")
ability_dimensions = Column(JSON, comment="能力维度")
dialogue_review = Column(JSON, comment="对话复盘")
suggestions = Column(JSON, comment="改进建议")
# AI分析元数据
workflow_run_id = Column(String(100), comment="AI分析运行ID")
task_id = Column(String(100), comment="AI分析任务ID")
created_at = Column(DateTime, server_default=func.now(), comment="创建时间")
updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now(), comment="更新时间")
def __repr__(self):
return f"<PracticeReport(session_id='{self.session_id}', total_score={self.total_score})>"

View File

@@ -0,0 +1,60 @@
"""
系统日志模型
用于记录系统操作、错误、安全事件等日志信息
"""
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
from sqlalchemy.orm import Mapped, mapped_column
from app.models.base import BaseModel
class SystemLog(BaseModel):
"""
系统日志模型
记录系统各类操作日志
"""
__tablename__ = "system_logs"
# 日志级别: debug, info, warning, error
level: Mapped[str] = mapped_column(String(20), nullable=False, index=True)
# 日志类型: system, user, api, error, security
type: Mapped[str] = mapped_column(String(50), nullable=False, index=True)
# 操作用户(可能为空,如系统自动操作)
user: Mapped[str] = mapped_column(String(100), nullable=True, index=True)
# 用户ID可能为空
user_id: Mapped[int] = mapped_column(Integer, nullable=True, index=True)
# IP地址
ip: Mapped[str] = mapped_column(String(100), nullable=True)
# 日志消息
message: Mapped[str] = mapped_column(Text, nullable=False)
# User Agent
user_agent: Mapped[str] = mapped_column(String(500), nullable=True)
# 请求路径API路径
path: Mapped[str] = mapped_column(String(500), nullable=True, index=True)
# 请求方法
method: Mapped[str] = mapped_column(String(10), nullable=True)
# 额外数据JSON格式可存储详细信息
extra_data: Mapped[str] = mapped_column(Text, nullable=True)
# 创建索引以优化查询性能
__table_args__ = (
Index('idx_system_logs_created_at', 'created_at'),
Index('idx_system_logs_level_type', 'level', 'type'),
Index('idx_system_logs_user_created', 'user', 'created_at'),
)
def __repr__(self):
return f"<SystemLog(id={self.id}, level={self.level}, type={self.type}, user={self.user})>"

100
backend/app/models/task.py Normal file
View File

@@ -0,0 +1,100 @@
"""
任务相关模型
"""
from datetime import datetime
from typing import List, Optional
from sqlalchemy import Column, Integer, String, Text, DateTime, Enum as SQLEnum, JSON, Boolean, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
from enum import Enum
class TaskPriority(str, Enum):
"""任务优先级"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
class TaskStatus(str, Enum):
"""任务状态"""
PENDING = "pending" # 待开始
ONGOING = "ongoing" # 进行中
COMPLETED = "completed" # 已完成
EXPIRED = "expired" # 已过期
class AssignmentStatus(str, Enum):
"""分配状态"""
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
class Task(BaseModel):
"""任务表"""
__tablename__ = "tasks"
title: Mapped[str] = mapped_column(String(200), nullable=False, comment="任务标题")
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True, comment="任务描述")
priority: Mapped[TaskPriority] = mapped_column(
SQLEnum(TaskPriority, values_callable=lambda x: [e.value for e in x]),
default=TaskPriority.MEDIUM,
nullable=False,
comment="优先级"
)
status: Mapped[TaskStatus] = mapped_column(
SQLEnum(TaskStatus, values_callable=lambda x: [e.value for e in x]),
default=TaskStatus.PENDING,
nullable=False,
comment="任务状态"
)
creator_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="创建人ID")
deadline: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True, comment="截止时间")
requirements: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, comment="任务要求配置")
progress: Mapped[int] = mapped_column(Integer, default=0, nullable=False, comment="完成进度")
is_deleted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
# 关系
creator = relationship("User", backref="created_tasks", foreign_keys=[creator_id])
course_links = relationship("TaskCourse", back_populates="task", cascade="all, delete-orphan")
assignments = relationship("TaskAssignment", back_populates="task", cascade="all, delete-orphan")
class TaskCourse(BaseModel):
"""任务课程关联表"""
__tablename__ = "task_courses"
task_id: Mapped[int] = mapped_column(Integer, ForeignKey("tasks.id"), nullable=False, comment="任务ID")
course_id: Mapped[int] = mapped_column(Integer, ForeignKey("courses.id"), nullable=False, comment="课程ID")
# 关系
task = relationship("Task", back_populates="course_links")
course = relationship("Course")
class TaskAssignment(BaseModel):
"""任务分配表"""
__tablename__ = "task_assignments"
task_id: Mapped[int] = mapped_column(Integer, ForeignKey("tasks.id"), nullable=False, comment="任务ID")
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, comment="分配用户ID")
team_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, comment="团队ID")
status: Mapped[AssignmentStatus] = mapped_column(
SQLEnum(AssignmentStatus, values_callable=lambda x: [e.value for e in x]),
default=AssignmentStatus.NOT_STARTED,
nullable=False,
comment="完成状态"
)
progress: Mapped[int] = mapped_column(Integer, default=0, nullable=False, comment="个人完成进度")
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True, comment="完成时间")
# 关系
task = relationship("Task", back_populates="assignments")
user = relationship("User")
__all__ = ["Task", "TaskCourse", "TaskAssignment", "TaskPriority", "TaskStatus", "AssignmentStatus"]

View File

@@ -0,0 +1,263 @@
"""陪练模块数据模型"""
from datetime import datetime
from typing import Optional
from enum import Enum
from sqlalchemy import (
Column,
String,
Integer,
ForeignKey,
Text,
JSON,
Enum as SQLEnum,
Float,
Boolean,
DateTime,
func,
)
from sqlalchemy.orm import relationship, Mapped, mapped_column
from app.models.base import BaseModel, SoftDeleteMixin, AuditMixin
class TrainingSceneStatus(str, Enum):
"""陪练场景状态枚举"""
DRAFT = "draft" # 草稿
ACTIVE = "active" # 已激活
INACTIVE = "inactive" # 已停用
class TrainingSessionStatus(str, Enum):
"""陪练会话状态枚举"""
CREATED = "created" # 已创建
IN_PROGRESS = "in_progress" # 进行中
COMPLETED = "completed" # 已完成
CANCELLED = "cancelled" # 已取消
ERROR = "error" # 异常结束
class MessageType(str, Enum):
"""消息类型枚举"""
TEXT = "text" # 文本消息
VOICE = "voice" # 语音消息
SYSTEM = "system" # 系统消息
class MessageRole(str, Enum):
"""消息角色枚举"""
USER = "user" # 用户
ASSISTANT = "assistant" # AI助手
SYSTEM = "system" # 系统
class TrainingScene(BaseModel, SoftDeleteMixin, AuditMixin):
"""
陪练场景模型
定义不同的陪练场景,如面试训练、演讲训练等
"""
__tablename__ = "training_scenes"
__allow_unmapped__ = True
# 基础信息
name: Mapped[str] = mapped_column(String(100), nullable=False, comment="场景名称")
description: Mapped[Optional[str]] = mapped_column(
Text, nullable=True, comment="场景描述"
)
category: Mapped[str] = mapped_column(String(50), nullable=False, comment="场景分类")
# 配置信息
ai_config: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, comment="AI配置如Coze Bot ID等"
)
prompt_template: Mapped[Optional[str]] = mapped_column(
Text, nullable=True, comment="提示词模板"
)
evaluation_criteria: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, comment="评估标准"
)
# 状态和权限
status: Mapped[TrainingSceneStatus] = mapped_column(
SQLEnum(TrainingSceneStatus),
default=TrainingSceneStatus.DRAFT,
nullable=False,
comment="场景状态",
)
is_public: Mapped[bool] = mapped_column(
Boolean, default=True, nullable=False, comment="是否公开"
)
required_level: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True, comment="所需用户等级"
)
# 关联
sessions: Mapped[list["TrainingSession"]] = relationship(
"TrainingSession", back_populates="scene", cascade="all, delete-orphan"
)
class TrainingSession(BaseModel, AuditMixin):
"""
陪练会话模型
记录每次陪练会话的信息
"""
__tablename__ = "training_sessions"
__allow_unmapped__ = True
# 基础信息
user_id: Mapped[int] = mapped_column(
Integer, nullable=False, index=True, comment="用户ID"
)
scene_id: Mapped[int] = mapped_column(
Integer, ForeignKey("training_scenes.id"), nullable=False, comment="场景ID"
)
# 会话信息
coze_conversation_id: Mapped[Optional[str]] = mapped_column(
String(100), nullable=True, comment="Coze会话ID"
)
start_time: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), nullable=False, comment="开始时间(北京时间)"
)
end_time: Mapped[Optional[datetime]] = mapped_column(
DateTime, nullable=True, comment="结束时间(北京时间)"
)
duration_seconds: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True, comment="持续时长(秒)"
)
# 状态和配置
status: Mapped[TrainingSessionStatus] = mapped_column(
SQLEnum(TrainingSessionStatus),
default=TrainingSessionStatus.CREATED,
nullable=False,
comment="会话状态",
)
session_config: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, comment="会话配置"
)
# 评估信息
total_score: Mapped[Optional[float]] = mapped_column(
Float, nullable=True, comment="总分"
)
evaluation_result: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, comment="评估结果详情"
)
# 关联
scene: Mapped["TrainingScene"] = relationship(
"TrainingScene", back_populates="sessions"
)
messages: Mapped[list["TrainingMessage"]] = relationship(
"TrainingMessage",
back_populates="session",
cascade="all, delete-orphan",
order_by="TrainingMessage.created_at",
)
report: Mapped[Optional["TrainingReport"]] = relationship(
"TrainingReport", back_populates="session", uselist=False
)
class TrainingMessage(BaseModel):
"""
陪练消息模型
记录会话中的每条消息
"""
__tablename__ = "training_messages"
__allow_unmapped__ = True
# 基础信息
session_id: Mapped[int] = mapped_column(
Integer, ForeignKey("training_sessions.id"), nullable=False, comment="会话ID"
)
# 消息内容
role: Mapped[MessageRole] = mapped_column(
SQLEnum(MessageRole), nullable=False, comment="消息角色"
)
type: Mapped[MessageType] = mapped_column(
SQLEnum(MessageType), nullable=False, comment="消息类型"
)
content: Mapped[str] = mapped_column(Text, nullable=False, comment="消息内容")
# 语音消息相关
voice_url: Mapped[Optional[str]] = mapped_column(
String(500), nullable=True, comment="语音文件URL"
)
voice_duration: Mapped[Optional[float]] = mapped_column(
Float, nullable=True, comment="语音时长(秒)"
)
# 元数据
message_metadata: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, comment="消息元数据"
)
coze_message_id: Mapped[Optional[str]] = mapped_column(
String(100), nullable=True, comment="Coze消息ID"
)
# 关联
session: Mapped["TrainingSession"] = relationship(
"TrainingSession", back_populates="messages"
)
class TrainingReport(BaseModel, AuditMixin):
"""
陪练报告模型
存储陪练会话的分析报告
"""
__tablename__ = "training_reports"
__allow_unmapped__ = True
# 基础信息
session_id: Mapped[int] = mapped_column(
Integer,
ForeignKey("training_sessions.id"),
unique=True,
nullable=False,
comment="会话ID",
)
user_id: Mapped[int] = mapped_column(
Integer, nullable=False, index=True, comment="用户ID"
)
# 评分信息
overall_score: Mapped[float] = mapped_column(Float, nullable=False, comment="总体得分")
dimension_scores: Mapped[dict] = mapped_column(
JSON, nullable=False, comment="各维度得分"
)
# 分析内容
strengths: Mapped[list[str]] = mapped_column(JSON, nullable=False, comment="优势点")
weaknesses: Mapped[list[str]] = mapped_column(JSON, nullable=False, comment="待改进点")
suggestions: Mapped[list[str]] = mapped_column(JSON, nullable=False, comment="改进建议")
# 详细内容
detailed_analysis: Mapped[Optional[str]] = mapped_column(
Text, nullable=True, comment="详细分析"
)
transcript: Mapped[Optional[str]] = mapped_column(
Text, nullable=True, comment="对话文本记录"
)
# 统计信息
statistics: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, comment="统计数据"
)
# 关联
session: Mapped["TrainingSession"] = relationship(
"TrainingSession", back_populates="report"
)

171
backend/app/models/user.py Normal file
View File

@@ -0,0 +1,171 @@
"""
用户相关数据模型
"""
from datetime import datetime
from typing import List, Optional
from sqlalchemy import (
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
Table,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from .base import Base, BaseModel, SoftDeleteMixin
# 用户-团队关联表(用于多对多关系)
user_teams = Table(
"user_teams",
BaseModel.metadata,
Column(
"user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
),
Column(
"team_id", Integer, ForeignKey("teams.id", ondelete="CASCADE"), primary_key=True
),
Column("role", String(50), default="member", nullable=False), # member, leader
Column("joined_at", DateTime, server_default=func.now(), nullable=False),
UniqueConstraint("user_id", "team_id", name="uq_user_team"),
)
class UserTeam(Base):
"""用户团队关联模型(用于直接查询关联表)"""
__allow_unmapped__ = True
__table__ = user_teams # 重用已定义的表
# 定义列映射不需要id因为使用复合主键
user_id: Mapped[int]
team_id: Mapped[int]
role: Mapped[str]
joined_at: Mapped[datetime]
def __repr__(self) -> str:
return f"<UserTeam(user_id={self.user_id}, team_id={self.team_id}, role={self.role})>"
class User(BaseModel, SoftDeleteMixin):
"""用户模型"""
__allow_unmapped__ = True
__tablename__ = "users"
# 基础信息
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
email: Mapped[Optional[str]] = mapped_column(String(100), unique=True, nullable=True)
phone: Mapped[Optional[str]] = mapped_column(String(20), unique=True, nullable=True)
hashed_password: Mapped[str] = mapped_column(
"password_hash", String(200), nullable=False
)
# 个人信息
full_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
avatar_url: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
bio: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# 性别: male/female可扩展
gender: Mapped[Optional[str]] = mapped_column(String(10), nullable=True)
# 学校
school: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
# 专业
major: Mapped[Optional[str]] = mapped_column(String(100), nullable=True)
# 企微员工userid用于SCRM系统对接
wework_userid: Mapped[Optional[str]] = mapped_column(String(64), unique=True, nullable=True, comment="企微员工userid")
# 系统角色admin, manager, trainee
role: Mapped[str] = mapped_column(String(20), default="trainee", nullable=False)
# 账号状态
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
is_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
# 时间记录
last_login_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
password_changed_at: Mapped[Optional[datetime]] = mapped_column(
DateTime, nullable=True
)
# 关联关系
teams: Mapped[List["Team"]] = relationship(
"Team",
secondary=user_teams,
back_populates="members",
lazy="selectin",
)
exams = relationship("Exam", back_populates="user")
# 岗位关系(通过关联表)
position_memberships = relationship("PositionMember", back_populates="user", cascade="all, delete-orphan")
def __repr__(self) -> str:
return f"<User(id={self.id}, username={self.username}, role={self.role})>"
class Team(BaseModel, SoftDeleteMixin):
"""团队模型"""
__allow_unmapped__ = True
__tablename__ = "teams"
# 基础信息
name: Mapped[str] = mapped_column(String(100), unique=True, nullable=False)
code: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# 团队类型department, project, study_group
team_type: Mapped[str] = mapped_column(
String(50), default="department", nullable=False
)
# 状态
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
# 团队负责人
leader_id: Mapped[Optional[int]] = mapped_column(
Integer, ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
# 父团队(支持层级结构)
parent_id: Mapped[Optional[int]] = mapped_column(
Integer, ForeignKey("teams.id", ondelete="CASCADE"), nullable=True
)
# 关联关系
members: Mapped[List["User"]] = relationship(
"User",
secondary=user_teams,
back_populates="teams",
lazy="selectin",
)
leader: Mapped[Optional["User"]] = relationship(
"User",
foreign_keys=[leader_id],
lazy="selectin",
)
parent: Mapped[Optional["Team"]] = relationship(
"Team",
remote_side="Team.id",
foreign_keys=[parent_id],
lazy="selectin",
)
children: Mapped[List["Team"]] = relationship(
"Team",
back_populates="parent",
lazy="selectin",
)
def __repr__(self) -> str:
return f"<Team(id={self.id}, name={self.name}, code={self.code})>"

View File

@@ -0,0 +1 @@
"""Pydantic模式包"""

View File

@@ -0,0 +1,50 @@
"""
能力评估相关的Pydantic Schema
"""
from pydantic import BaseModel, Field
from typing import List, Optional
from datetime import datetime
class AbilityDimension(BaseModel):
"""能力维度评分"""
name: str = Field(..., description="能力维度名称")
score: int = Field(..., ge=0, le=100, description="评分(0-100)")
feedback: str = Field(..., description="反馈建议")
class CourseRecommendation(BaseModel):
"""课程推荐"""
course_id: int = Field(..., description="课程ID")
course_name: str = Field(..., description="课程名称")
recommendation_reason: str = Field(..., description="推荐理由")
priority: str = Field(..., description="优先级: high/medium/low")
match_score: int = Field(..., ge=0, le=100, description="匹配度(0-100)")
class AbilityAssessmentResponse(BaseModel):
"""能力评估响应"""
assessment_id: int = Field(..., description="评估记录ID")
total_score: int = Field(..., ge=0, le=100, description="综合评分")
dimensions: List[AbilityDimension] = Field(..., description="能力维度列表")
recommended_courses: List[CourseRecommendation] = Field(..., description="推荐课程列表")
conversation_count: int = Field(..., description="分析的对话数量")
analyzed_at: Optional[datetime] = Field(None, description="分析时间")
class AbilityAssessmentHistory(BaseModel):
"""能力评估历史记录"""
id: int
user_id: int
source_type: str
source_id: Optional[str]
total_score: Optional[int]
ability_dimensions: List[AbilityDimension]
recommended_courses: Optional[List[CourseRecommendation]]
conversation_count: Optional[int]
analyzed_at: datetime
created_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,35 @@
"""
认证相关 Schema
"""
from pydantic import EmailStr, Field
from .base import BaseSchema
class LoginRequest(BaseSchema):
"""登录请求"""
username: str = Field(..., description="用户名/邮箱/手机号")
password: str = Field(..., min_length=6)
class Token(BaseSchema):
"""令牌响应"""
access_token: str
refresh_token: str
token_type: str = "bearer"
class TokenPayload(BaseSchema):
"""令牌载荷"""
sub: str # 用户ID
type: str # access 或 refresh
exp: int # 过期时间
class RefreshTokenRequest(BaseSchema):
"""刷新令牌请求"""
refresh_token: str

View File

@@ -0,0 +1,73 @@
"""基础响应模式"""
from typing import Generic, TypeVar, Optional, Any, List
from pydantic import BaseModel, Field
from datetime import datetime
DataT = TypeVar("DataT")
class ResponseModel(BaseModel, Generic[DataT]):
"""
统一响应格式模型
"""
code: int = Field(default=200, description="响应状态码")
message: str = Field(default="success", description="响应消息")
data: Optional[DataT] = Field(default=None, description="响应数据")
request_id: Optional[str] = Field(default=None, description="请求ID")
class BaseSchema(BaseModel):
"""基础模式"""
class Config:
from_attributes = True # Pydantic V2
json_encoders = {datetime: lambda v: v.isoformat()}
class TimestampMixin(BaseModel):
"""时间戳混入"""
created_at: datetime
updated_at: datetime
class IDMixin(BaseModel):
"""ID混入"""
id: int
class PaginationParams(BaseModel):
"""分页参数"""
page: int = Field(default=1, ge=1, description="页码")
page_size: int = Field(default=20, ge=1, le=100, description="每页数量")
@property
def offset(self) -> int:
"""计算偏移量"""
return (self.page - 1) * self.page_size
@property
def limit(self) -> int:
"""计算限制数量"""
return self.page_size
class PaginatedResponse(BaseModel, Generic[DataT]):
"""分页响应模型"""
items: list[DataT] = Field(default_factory=list, description="数据列表")
total: int = Field(default=0, description="总数量")
page: int = Field(default=1, description="当前页码")
page_size: int = Field(default=20, description="每页数量")
pages: int = Field(default=1, description="总页数")
@classmethod
def create(cls, items: list[DataT], total: int, page: int, page_size: int):
"""创建分页响应"""
pages = (total + page_size - 1) // page_size if page_size > 0 else 1
return cls(
items=items, total=total, page=page, page_size=page_size, pages=pages
)

View File

@@ -0,0 +1,364 @@
"""
课程相关的数据验证模型
"""
from typing import Optional, List
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field, ConfigDict, field_validator
from app.models.course import CourseStatus, CourseCategory
class CourseBase(BaseModel):
"""
课程基础模型
"""
name: str = Field(..., min_length=1, max_length=200, description="课程名称")
description: Optional[str] = Field(None, description="课程描述")
category: CourseCategory = Field(default=CourseCategory.GENERAL, description="课程分类")
cover_image: Optional[str] = Field(None, max_length=500, description="封面图片URL")
duration_hours: Optional[float] = Field(None, ge=0, description="课程时长(小时)")
difficulty_level: Optional[int] = Field(None, ge=1, le=5, description="难度等级(1-5)")
tags: Optional[List[str]] = Field(default_factory=list, description="标签列表")
sort_order: int = Field(default=0, description="排序顺序")
is_featured: bool = Field(default=False, description="是否推荐")
allow_download: bool = Field(default=False, description="是否允许下载资料")
@field_validator("category", mode="before")
@classmethod
def normalize_category(cls, v):
"""允许使用枚举的名称或值(忽略大小写)。空字符串使用默认值。"""
if isinstance(v, CourseCategory):
return v
if isinstance(v, str):
s = v.strip()
# 空字符串使用默认值
if not s:
return CourseCategory.GENERAL
# 优先按值匹配technology 等)
try:
return CourseCategory(s.lower())
except Exception:
pass
# 再按名称匹配TECHNOLOGY 等)
try:
return CourseCategory[s.upper()]
except Exception:
pass
return v
class CourseCreate(CourseBase):
"""
创建课程模型
"""
status: CourseStatus = Field(default=CourseStatus.DRAFT, description="课程状态")
class CourseUpdate(BaseModel):
"""
更新课程模型
"""
name: Optional[str] = Field(None, min_length=1, max_length=200, description="课程名称")
description: Optional[str] = Field(None, description="课程描述")
category: Optional[CourseCategory] = Field(None, description="课程分类")
status: Optional[CourseStatus] = Field(None, description="课程状态")
cover_image: Optional[str] = Field(None, max_length=500, description="封面图片URL")
duration_hours: Optional[float] = Field(None, ge=0, description="课程时长(小时)")
difficulty_level: Optional[int] = Field(None, ge=1, le=5, description="难度等级(1-5)")
tags: Optional[List[str]] = Field(None, description="标签列表")
sort_order: Optional[int] = Field(None, description="排序顺序")
is_featured: Optional[bool] = Field(None, description="是否推荐")
allow_download: Optional[bool] = Field(None, description="是否允许下载资料")
@field_validator("category", mode="before")
@classmethod
def normalize_category_update(cls, v):
if v is None:
return v
if isinstance(v, CourseCategory):
return v
if isinstance(v, str):
s = v.strip()
if not s: # 空字符串视为None不更新
return None
try:
return CourseCategory(s.lower())
except Exception:
pass
try:
return CourseCategory[s.upper()]
except Exception:
pass
return v
class CourseInDB(CourseBase):
"""
数据库中的课程模型
"""
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., description="课程ID")
status: CourseStatus = Field(..., description="课程状态")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
published_at: Optional[datetime] = Field(None, description="发布时间")
publisher_id: Optional[int] = Field(None, description="发布人ID")
created_by: Optional[int] = Field(None, description="创建人ID")
updated_by: Optional[int] = Field(None, description="更新人ID")
# 用户岗位相关的课程类型(必修/选修非数据库字段由API动态计算
course_type: Optional[str] = Field(None, description="课程类型required=必修, optional=选修")
class CourseList(BaseModel):
"""
课程列表查询参数
"""
status: Optional[CourseStatus] = Field(None, description="课程状态")
category: Optional[CourseCategory] = Field(None, description="课程分类")
is_featured: Optional[bool] = Field(None, description="是否推荐")
keyword: Optional[str] = Field(None, description="搜索关键词")
# 课程资料相关模型
class CourseMaterialBase(BaseModel):
"""
课程资料基础模型
"""
name: str = Field(..., min_length=1, max_length=200, description="资料名称")
description: Optional[str] = Field(None, description="资料描述")
sort_order: int = Field(default=0, description="排序顺序")
class CourseMaterialCreate(CourseMaterialBase):
"""
创建课程资料模型
"""
file_url: str = Field(..., max_length=500, description="文件URL")
file_type: str = Field(..., max_length=50, description="文件类型")
file_size: int = Field(..., gt=0, description="文件大小(字节)")
@field_validator("file_type")
def validate_file_type(cls, v):
"""验证文件类型
支持格式TXT、Markdown、MDX、PDF、HTML、Excel、Word、CSV、VTT、Properties
"""
allowed_types = [
"txt", "md", "mdx", "pdf", "html", "htm",
"xlsx", "xls", "docx", "doc", "csv", "vtt", "properties"
]
file_ext = v.lower()
if file_ext not in allowed_types:
raise ValueError(f"不支持的文件类型: {v}。允许的类型: TXT、Markdown、MDX、PDF、HTML、Excel、Word、CSV、VTT、Properties")
return file_ext
class CourseMaterialInDB(CourseMaterialBase):
"""
数据库中的课程资料模型
"""
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., description="资料ID")
course_id: int = Field(..., description="课程ID")
file_url: str = Field(..., description="文件URL")
file_type: str = Field(..., description="文件类型")
file_size: int = Field(..., description="文件大小(字节)")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
# 知识点相关模型
class KnowledgePointBase(BaseModel):
"""
知识点基础模型
"""
name: str = Field(..., min_length=1, max_length=200, description="知识点名称")
description: Optional[str] = Field(None, description="知识点描述")
type: str = Field(default="理论知识", description="知识点类型")
source: int = Field(default=0, description="来源0=手动1=AI分析")
topic_relation: Optional[str] = Field(None, description="与主题的关系描述")
class KnowledgePointCreate(KnowledgePointBase):
"""
创建知识点模型
"""
material_id: int = Field(..., description="关联资料ID必填")
class KnowledgePointUpdate(BaseModel):
"""
更新知识点模型
"""
name: Optional[str] = Field(None, min_length=1, max_length=200, description="知识点名称")
description: Optional[str] = Field(None, description="知识点描述")
type: Optional[str] = Field(None, description="知识点类型")
source: Optional[int] = Field(None, description="来源0=手动1=AI分析")
topic_relation: Optional[str] = Field(None, description="与主题的关系描述")
material_id: int = Field(..., description="关联资料ID必填")
class KnowledgePointInDB(KnowledgePointBase):
"""
数据库中的知识点模型
"""
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., description="知识点ID")
course_id: int = Field(..., description="课程ID")
material_id: int = Field(..., description="关联资料ID")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
class KnowledgePointTree(KnowledgePointInDB):
"""
知识点树形结构
"""
children: List["KnowledgePointTree"] = Field(
default_factory=list, description="子知识点"
)
# 成长路径相关模型
class GrowthPathCourse(BaseModel):
"""
成长路径中的课程
"""
course_id: int = Field(..., description="课程ID")
order: int = Field(..., ge=0, description="排序")
is_required: bool = Field(default=True, description="是否必修")
class GrowthPathBase(BaseModel):
"""
成长路径基础模型
"""
name: str = Field(..., min_length=1, max_length=200, description="路径名称")
description: Optional[str] = Field(None, description="路径描述")
target_role: Optional[str] = Field(None, max_length=100, description="目标角色")
courses: List[GrowthPathCourse] = Field(default_factory=list, description="课程列表")
estimated_duration_days: Optional[int] = Field(None, ge=1, description="预计完成天数")
is_active: bool = Field(default=True, description="是否启用")
sort_order: int = Field(default=0, description="排序顺序")
class GrowthPathCreate(GrowthPathBase):
"""
创建成长路径模型
"""
pass
class GrowthPathInDB(GrowthPathBase):
"""
数据库中的成长路径模型
"""
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., description="路径ID")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
# 课程考试设置相关Schema
class CourseExamSettingsBase(BaseModel):
"""
课程考试设置基础模型
"""
single_choice_count: int = Field(default=4, ge=0, le=50, description="单选题数量")
multiple_choice_count: int = Field(default=2, ge=0, le=30, description="多选题数量")
true_false_count: int = Field(default=1, ge=0, le=20, description="判断题数量")
fill_blank_count: int = Field(default=2, ge=0, le=10, description="填空题数量")
essay_count: int = Field(default=1, ge=0, le=10, description="问答题数量")
duration_minutes: int = Field(default=10, ge=10, le=180, description="考试时长(分钟)")
difficulty_level: int = Field(default=3, ge=1, le=5, description="难度系数(1-5)")
passing_score: int = Field(default=60, ge=0, le=100, description="及格分数")
is_enabled: bool = Field(default=True, description="是否启用")
show_answer_immediately: bool = Field(default=False, description="是否立即显示答案")
allow_retake: bool = Field(default=True, description="是否允许重考")
max_retake_times: Optional[int] = Field(None, ge=1, le=10, description="最大重考次数")
class CourseExamSettingsCreate(CourseExamSettingsBase):
"""
创建课程考试设置模型
"""
pass
class CourseExamSettingsUpdate(BaseModel):
"""
更新课程考试设置模型
"""
single_choice_count: Optional[int] = Field(None, ge=0, le=50, description="单选题数量")
multiple_choice_count: Optional[int] = Field(None, ge=0, le=30, description="多选题数量")
true_false_count: Optional[int] = Field(None, ge=0, le=20, description="判断题数量")
fill_blank_count: Optional[int] = Field(None, ge=0, le=10, description="填空题数量")
essay_count: Optional[int] = Field(None, ge=0, le=10, description="问答题数量")
duration_minutes: Optional[int] = Field(None, ge=10, le=180, description="考试时长(分钟)")
difficulty_level: Optional[int] = Field(None, ge=1, le=5, description="难度系数(1-5)")
passing_score: Optional[int] = Field(None, ge=0, le=100, description="及格分数")
is_enabled: Optional[bool] = Field(None, description="是否启用")
show_answer_immediately: Optional[bool] = Field(None, description="是否立即显示答案")
allow_retake: Optional[bool] = Field(None, description="是否允许重考")
max_retake_times: Optional[int] = Field(None, ge=1, le=10, description="最大重考次数")
class CourseExamSettingsInDB(CourseExamSettingsBase):
"""
数据库中的课程考试设置模型
"""
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., description="设置ID")
course_id: int = Field(..., description="课程ID")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
# 岗位分配相关Schema
class CoursePositionAssignment(BaseModel):
"""
课程岗位分配模型
"""
position_id: int = Field(..., description="岗位ID")
course_type: str = Field(default="required", pattern="^(required|optional)$", description="课程类型required必修/optional选修")
priority: int = Field(default=0, description="优先级/排序")
class CoursePositionAssignmentInDB(CoursePositionAssignment):
"""
数据库中的课程岗位分配模型
"""
model_config = ConfigDict(from_attributes=True)
id: int = Field(..., description="分配ID")
course_id: int = Field(..., description="课程ID")
position_name: Optional[str] = Field(None, description="岗位名称")
position_description: Optional[str] = Field(None, description="岗位描述")
member_count: Optional[int] = Field(None, description="岗位成员数")

316
backend/app/schemas/exam.py Normal file
View File

@@ -0,0 +1,316 @@
"""
考试相关的Schema定义
"""
from typing import List, Optional, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field
class StartExamRequest(BaseModel):
"""开始考试请求"""
course_id: int = Field(..., description="课程ID")
count: int = Field(10, ge=1, le=100, description="题目数量")
class StartExamResponse(BaseModel):
"""开始考试响应"""
exam_id: int = Field(..., description="考试ID")
class ExamAnswer(BaseModel):
"""考试答案"""
question_id: str = Field(..., description="题目ID")
answer: str = Field(..., description="答案")
class SubmitExamRequest(BaseModel):
"""提交考试请求"""
exam_id: int = Field(..., description="考试ID")
answers: List[ExamAnswer] = Field(..., description="答案列表")
class SubmitExamResponse(BaseModel):
"""提交考试响应"""
exam_id: int = Field(..., description="考试ID")
total_score: float = Field(..., description="总分")
pass_score: float = Field(..., description="及格分")
is_passed: bool = Field(..., description="是否通过")
correct_count: int = Field(..., description="正确题数")
total_count: int = Field(..., description="总题数")
accuracy: float = Field(..., description="正确率")
class QuestionInfo(BaseModel):
"""题目信息"""
id: str = Field(..., description="题目ID")
type: str = Field(..., description="题目类型")
title: str = Field(..., description="题目标题")
content: Optional[str] = Field(None, description="题目内容")
options: Optional[Dict[str, Any]] = Field(None, description="选项")
score: float = Field(..., description="分值")
class ExamResultInfo(BaseModel):
"""答题结果信息"""
question_id: int = Field(..., description="题目ID")
user_answer: Optional[str] = Field(None, description="用户答案")
is_correct: bool = Field(..., description="是否正确")
score: float = Field(..., description="得分")
class ExamDetailResponse(BaseModel):
"""考试详情响应"""
id: int = Field(..., description="考试ID")
course_id: int = Field(..., description="课程ID")
exam_name: str = Field(..., description="考试名称")
question_count: int = Field(..., description="题目数量")
total_score: float = Field(..., description="总分")
pass_score: float = Field(..., description="及格分")
start_time: Optional[str] = Field(None, description="开始时间")
end_time: Optional[str] = Field(None, description="结束时间")
duration_minutes: int = Field(..., description="考试时长(分钟)")
status: str = Field(..., description="考试状态")
score: Optional[float] = Field(None, description="得分")
is_passed: Optional[bool] = Field(None, description="是否通过")
questions: Optional[Dict[str, Any]] = Field(None, description="题目数据")
results: Optional[List[ExamResultInfo]] = Field(None, description="答题结果")
answers: Optional[Dict[str, Any]] = Field(None, description="用户答案")
class ExamRecordInfo(BaseModel):
"""考试记录信息"""
id: int = Field(..., description="考试ID")
course_id: int = Field(..., description="课程ID")
exam_name: str = Field(..., description="考试名称")
question_count: int = Field(..., description="题目数量")
total_score: float = Field(..., description="总分")
score: Optional[float] = Field(None, description="得分")
is_passed: Optional[bool] = Field(None, description="是否通过")
status: str = Field(..., description="考试状态")
start_time: Optional[str] = Field(None, description="开始时间")
end_time: Optional[str] = Field(None, description="结束时间")
created_at: str = Field(..., description="创建时间")
# 新增统计字段
accuracy: Optional[float] = Field(None, description="正确率(%)")
correct_count: Optional[int] = Field(None, description="正确题数")
wrong_count: Optional[int] = Field(None, description="错题数")
duration_seconds: Optional[int] = Field(None, description="考试用时(秒)")
course_name: Optional[str] = Field(None, description="课程名称")
question_type_stats: Optional[List[Dict[str, Any]]] = Field(None, description="分题型统计")
class ExamRecordResponse(BaseModel):
"""考试记录列表响应"""
items: List[ExamRecordInfo] = Field(..., description="考试记录列表")
total: int = Field(..., description="总数")
page: int = Field(..., description="当前页")
size: int = Field(..., description="每页数量")
pages: int = Field(..., description="总页数")
# ==================== AI服务响应Schema ====================
class MistakeRecord(BaseModel):
"""错题记录详情"""
question_id: Optional[int] = Field(None, description="题目ID")
knowledge_point_id: Optional[int] = Field(None, description="知识点ID")
question_content: str = Field(..., description="题目内容")
correct_answer: str = Field(..., description="正确答案")
user_answer: str = Field(..., description="用户答案")
class GenerateExamRequest(BaseModel):
"""生成考试试题请求"""
course_id: int = Field(..., description="课程ID")
position_id: Optional[int] = Field(None, description="岗位ID,如果不提供则从用户信息中自动获取")
current_round: int = Field(1, ge=1, le=3, description="当前轮次(1/2/3)")
exam_id: Optional[int] = Field(None, description="已存在的exam_id(第2、3轮传入)")
mistake_records: Optional[str] = Field(None, description="错题记录JSON字符串,第一轮不传此参数,第二三轮传入上一轮错题的JSON字符串")
single_choice_count: int = Field(4, ge=0, le=50, description="单选题数量")
multiple_choice_count: int = Field(2, ge=0, le=30, description="多选题数量")
true_false_count: int = Field(1, ge=0, le=20, description="判断题数量")
fill_blank_count: int = Field(2, ge=0, le=10, description="填空题数量")
essay_count: int = Field(1, ge=0, le=10, description="问答题数量")
difficulty_level: int = Field(3, ge=1, le=5, description="难度系数(1-5)")
class GenerateExamResponse(BaseModel):
"""生成考试试题响应"""
result: str = Field(..., description="试题JSON数组(字符串格式)")
workflow_run_id: Optional[str] = Field(None, description="AI服务调用ID")
task_id: Optional[str] = Field(None, description="任务ID")
exam_id: int = Field(..., description="考试ID真实的数据库ID")
class JudgeAnswerRequest(BaseModel):
"""判断主观题答案请求"""
question: str = Field(..., description="题目内容")
correct_answer: str = Field(..., description="标准答案")
user_answer: str = Field(..., description="用户提交的答案")
analysis: str = Field(..., description="正确答案的解析(来源于试题生成器)")
class JudgeAnswerResponse(BaseModel):
"""判断主观题答案响应"""
is_correct: bool = Field(..., description="是否正确")
correct_answer: str = Field(..., description="标准答案")
feedback: Optional[str] = Field(None, description="判断反馈信息")
class RecordMistakeRequest(BaseModel):
"""记录错题请求"""
exam_id: int = Field(..., description="考试ID")
question_id: Optional[int] = Field(None, description="题目ID(AI生成的题目可能为空)")
knowledge_point_id: Optional[int] = Field(None, description="知识点ID")
question_content: str = Field(..., description="题目内容")
correct_answer: str = Field(..., description="正确答案")
user_answer: str = Field(..., description="用户答案")
question_type: Optional[str] = Field(None, description="题型(single/multiple/judge/blank/essay)")
class RecordMistakeResponse(BaseModel):
"""记录错题响应"""
id: int = Field(..., description="错题记录ID")
created_at: datetime = Field(..., description="创建时间")
class MistakeRecordItem(BaseModel):
"""错题记录项"""
id: int = Field(..., description="错题记录ID")
question_id: Optional[int] = Field(None, description="题目ID")
knowledge_point_id: Optional[int] = Field(None, description="知识点ID")
question_content: str = Field(..., description="题目内容")
correct_answer: str = Field(..., description="正确答案")
user_answer: str = Field(..., description="用户答案")
created_at: datetime = Field(..., description="创建时间")
class GetMistakesResponse(BaseModel):
"""获取错题记录响应"""
mistakes: List[MistakeRecordItem] = Field(..., description="错题列表")
# ==================== 成绩报告和错题本相关Schema ====================
class RoundScores(BaseModel):
"""三轮得分"""
round1: Optional[float] = Field(None, description="第一轮得分")
round2: Optional[float] = Field(None, description="第二轮得分")
round3: Optional[float] = Field(None, description="第三轮得分")
class ExamReportOverview(BaseModel):
"""成绩报告概览"""
avg_score: float = Field(..., description="平均成绩(基于round1_score)")
total_exams: int = Field(..., description="考试总数")
pass_rate: float = Field(..., description="及格率")
total_questions: int = Field(..., description="答题总数")
class ExamTrendItem(BaseModel):
"""成绩趋势项"""
date: str = Field(..., description="日期(YYYY-MM-DD)")
avg_score: float = Field(..., description="平均分")
class SubjectStatItem(BaseModel):
"""科目统计项"""
course_id: int = Field(..., description="课程ID")
course_name: str = Field(..., description="课程名称")
avg_score: float = Field(..., description="平均分")
exam_count: int = Field(..., description="考试次数")
max_score: float = Field(..., description="最高分")
min_score: float = Field(..., description="最低分")
pass_rate: float = Field(..., description="及格率")
class RecentExamItem(BaseModel):
"""最近考试记录项"""
id: int = Field(..., description="考试ID")
course_id: int = Field(..., description="课程ID")
course_name: str = Field(..., description="课程名称")
score: Optional[float] = Field(None, description="最终得分")
total_score: float = Field(..., description="总分")
is_passed: Optional[bool] = Field(None, description="是否通过")
duration_seconds: Optional[int] = Field(None, description="考试用时(秒)")
start_time: str = Field(..., description="开始时间")
end_time: Optional[str] = Field(None, description="结束时间")
round_scores: RoundScores = Field(..., description="三轮得分")
class ExamReportResponse(BaseModel):
"""成绩报告响应"""
overview: ExamReportOverview = Field(..., description="概览数据")
trends: List[ExamTrendItem] = Field(..., description="趋势数据")
subjects: List[SubjectStatItem] = Field(..., description="科目分析")
recent_exams: List[RecentExamItem] = Field(..., description="最近考试记录")
class MistakeListItem(BaseModel):
"""错题列表项"""
id: int = Field(..., description="错题记录ID")
exam_id: int = Field(..., description="考试ID")
course_id: int = Field(..., description="课程ID")
course_name: str = Field(..., description="课程名称")
question_content: str = Field(..., description="题目内容")
correct_answer: str = Field(..., description="正确答案")
user_answer: str = Field(..., description="用户答案")
question_type: Optional[str] = Field(None, description="题型")
knowledge_point_id: Optional[int] = Field(None, description="知识点ID")
knowledge_point_name: Optional[str] = Field(None, description="知识点名称")
created_at: datetime = Field(..., description="创建时间")
class MistakeListResponse(BaseModel):
"""错题列表响应"""
items: List[MistakeListItem] = Field(..., description="错题列表")
total: int = Field(..., description="总数")
page: int = Field(..., description="当前页")
size: int = Field(..., description="每页数量")
pages: int = Field(..., description="总页数")
class MistakeByCourse(BaseModel):
"""按课程统计错题"""
course_id: int = Field(..., description="课程ID")
course_name: str = Field(..., description="课程名称")
count: int = Field(..., description="错题数量")
class MistakeByType(BaseModel):
"""按题型统计错题"""
type: str = Field(..., description="题型代码")
type_name: str = Field(..., description="题型名称")
count: int = Field(..., description="错题数量")
class MistakeByTime(BaseModel):
"""按时间统计错题"""
week: int = Field(..., description="最近一周")
month: int = Field(..., description="最近一月")
quarter: int = Field(..., description="最近三月")
class MistakesStatisticsResponse(BaseModel):
"""错题统计响应"""
total: int = Field(..., description="错题总数")
by_course: List[MistakeByCourse] = Field(..., description="按课程统计")
by_type: List[MistakeByType] = Field(..., description="按题型统计")
by_time: MistakeByTime = Field(..., description="按时间统计")
class UpdateRoundScoreRequest(BaseModel):
"""更新轮次得分请求"""
round: int = Field(..., ge=1, le=3, description="轮次(1/2/3)")
score: float = Field(..., ge=0, le=100, description="得分")
is_final: bool = Field(False, description="是否为最终轮次(如果是,则同时更新总分和状态)")

View File

@@ -0,0 +1,102 @@
"""
站内消息通知相关的数据验证模型
"""
from typing import Optional, List
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field, ConfigDict
class NotificationType(str, Enum):
"""通知类型枚举"""
POSITION_ASSIGN = "position_assign" # 岗位分配
COURSE_ASSIGN = "course_assign" # 课程分配
EXAM_REMIND = "exam_remind" # 考试提醒
TASK_ASSIGN = "task_assign" # 任务分配
SYSTEM = "system" # 系统通知
class NotificationBase(BaseModel):
"""
通知基础模型
"""
title: str = Field(..., min_length=1, max_length=200, description="通知标题")
content: Optional[str] = Field(None, description="通知内容")
type: NotificationType = Field(default=NotificationType.SYSTEM, description="通知类型")
related_id: Optional[int] = Field(None, description="关联数据ID")
related_type: Optional[str] = Field(None, max_length=50, description="关联数据类型")
class NotificationCreate(NotificationBase):
"""
创建通知模型
"""
user_id: int = Field(..., description="接收用户ID")
sender_id: Optional[int] = Field(None, description="发送者用户ID")
class NotificationBatchCreate(BaseModel):
"""
批量创建通知模型(发送给多个用户)
"""
user_ids: List[int] = Field(..., min_length=1, description="接收用户ID列表")
title: str = Field(..., min_length=1, max_length=200, description="通知标题")
content: Optional[str] = Field(None, description="通知内容")
type: NotificationType = Field(default=NotificationType.SYSTEM, description="通知类型")
related_id: Optional[int] = Field(None, description="关联数据ID")
related_type: Optional[str] = Field(None, max_length=50, description="关联数据类型")
sender_id: Optional[int] = Field(None, description="发送者用户ID")
class NotificationUpdate(BaseModel):
"""
更新通知模型
"""
is_read: Optional[bool] = Field(None, description="是否已读")
class NotificationInDB(NotificationBase):
"""
数据库中的通知模型
"""
model_config = ConfigDict(from_attributes=True)
id: int
user_id: int
is_read: bool
sender_id: Optional[int] = None
created_at: datetime
updated_at: datetime
class NotificationResponse(NotificationInDB):
"""
通知响应模型(可扩展发送者信息)
"""
sender_name: Optional[str] = Field(None, description="发送者姓名")
class NotificationListResponse(BaseModel):
"""
通知列表响应模型
"""
items: List[NotificationResponse]
total: int
unread_count: int
class NotificationCountResponse(BaseModel):
"""
未读通知数量响应模型
"""
unread_count: int
total: int
class MarkReadRequest(BaseModel):
"""
标记已读请求模型
"""
notification_ids: Optional[List[int]] = Field(None, description="通知ID列表为空则标记全部已读")

View File

@@ -0,0 +1,318 @@
"""
陪练功能相关Schema定义
"""
from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel, Field, field_validator
# ==================== 枚举类型 ====================
class SceneType:
"""场景类型枚举"""
PHONE = "phone" # 电话销售
FACE = "face" # 面对面销售
COMPLAINT = "complaint" # 客户投诉
AFTER_SALES = "after-sales" # 售后服务
PRODUCT_INTRO = "product-intro" # 产品介绍
class Difficulty:
"""难度等级枚举"""
BEGINNER = "beginner" # 入门
JUNIOR = "junior" # 初级
INTERMEDIATE = "intermediate" # 中级
SENIOR = "senior" # 高级
EXPERT = "expert" # 专家
class SceneStatus:
"""场景状态枚举"""
ACTIVE = "active" # 启用
INACTIVE = "inactive" # 禁用
# ==================== 场景Schema ====================
class PracticeSceneBase(BaseModel):
"""陪练场景基础Schema"""
name: str = Field(..., max_length=200, description="场景名称")
description: Optional[str] = Field(None, description="场景描述")
type: str = Field(..., description="场景类型: phone/face/complaint/after-sales/product-intro")
difficulty: str = Field(..., description="难度等级: beginner/junior/intermediate/senior/expert")
status: str = Field(default="active", description="状态: active/inactive")
background: str = Field(..., description="场景背景设定")
ai_role: str = Field(..., description="AI角色描述")
objectives: List[str] = Field(..., description="练习目标数组")
keywords: Optional[List[str]] = Field(default=None, description="关键词数组")
duration: int = Field(default=10, ge=1, le=120, description="预计时长(分钟)")
@field_validator('type')
@classmethod
def validate_type(cls, v):
"""验证场景类型"""
valid_types = ['phone', 'face', 'complaint', 'after-sales', 'product-intro']
if v not in valid_types:
raise ValueError(f"场景类型必须是: {', '.join(valid_types)}")
return v
@field_validator('difficulty')
@classmethod
def validate_difficulty(cls, v):
"""验证难度等级"""
valid_difficulties = ['beginner', 'junior', 'intermediate', 'senior', 'expert']
if v not in valid_difficulties:
raise ValueError(f"难度等级必须是: {', '.join(valid_difficulties)}")
return v
@field_validator('status')
@classmethod
def validate_status(cls, v):
"""验证状态"""
valid_statuses = ['active', 'inactive']
if v not in valid_statuses:
raise ValueError(f"状态必须是: {', '.join(valid_statuses)}")
return v
@field_validator('objectives')
@classmethod
def validate_objectives(cls, v):
"""验证练习目标"""
if not v or len(v) < 1:
raise ValueError("至少需要1个练习目标")
if len(v) > 10:
raise ValueError("练习目标不能超过10个")
return v
class PracticeSceneCreate(PracticeSceneBase):
"""创建陪练场景Schema"""
pass
class PracticeSceneUpdate(BaseModel):
"""更新陪练场景Schema所有字段可选"""
name: Optional[str] = Field(None, max_length=200, description="场景名称")
description: Optional[str] = Field(None, description="场景描述")
type: Optional[str] = Field(None, description="场景类型")
difficulty: Optional[str] = Field(None, description="难度等级")
status: Optional[str] = Field(None, description="状态")
background: Optional[str] = Field(None, description="场景背景设定")
ai_role: Optional[str] = Field(None, description="AI角色描述")
objectives: Optional[List[str]] = Field(None, description="练习目标数组")
keywords: Optional[List[str]] = Field(None, description="关键词数组")
duration: Optional[int] = Field(None, ge=1, le=120, description="预计时长(分钟)")
class PracticeSceneResponse(PracticeSceneBase):
"""陪练场景响应Schema"""
id: int
usage_count: int
rating: float
created_by: Optional[int] = None
updated_by: Optional[int] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ==================== 对话Schema ====================
class StartPracticeRequest(BaseModel):
"""开始陪练对话请求Schema"""
# 场景信息(首次消息必填,后续消息可选)
scene_id: Optional[int] = Field(None, description="场景ID可选")
scene_name: Optional[str] = Field(None, description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
scene_background: Optional[str] = Field(None, description="场景背景")
scene_ai_role: Optional[str] = Field(None, description="AI角色")
scene_objectives: Optional[List[str]] = Field(None, description="练习目标")
scene_keywords: Optional[List[str]] = Field(None, description="关键词")
# 对话信息
user_message: str = Field(..., description="用户消息")
conversation_id: Optional[str] = Field(None, description="对话ID续接对话时必填")
is_first: bool = Field(..., description="是否首次消息")
@field_validator('scene_name')
@classmethod
def validate_scene_name_for_first(cls, v, info):
"""首次消息时场景名称必填"""
if info.data.get('is_first') and not v:
raise ValueError("首次消息时场景名称必填")
return v
@field_validator('scene_background')
@classmethod
def validate_scene_background_for_first(cls, v, info):
"""首次消息时场景背景必填"""
if info.data.get('is_first') and not v:
raise ValueError("首次消息时场景背景必填")
return v
@field_validator('scene_ai_role')
@classmethod
def validate_scene_ai_role_for_first(cls, v, info):
"""首次消息时AI角色必填"""
if info.data.get('is_first') and not v:
raise ValueError("首次消息时AI角色必填")
return v
@field_validator('scene_objectives')
@classmethod
def validate_scene_objectives_for_first(cls, v, info):
"""首次消息时练习目标必填"""
if info.data.get('is_first') and (not v or len(v) == 0):
raise ValueError("首次消息时练习目标必填")
return v
class InterruptPracticeRequest(BaseModel):
"""中断对话请求Schema"""
conversation_id: str = Field(..., description="对话ID")
chat_id: str = Field(..., description="聊天ID")
class ConversationInfo(BaseModel):
"""对话信息Schema"""
id: str = Field(..., description="对话ID")
name: str = Field(..., description="对话名称")
created_at: int = Field(..., description="创建时间(时间戳)")
class ConversationsResponse(BaseModel):
"""对话列表响应Schema"""
items: List[ConversationInfo]
has_more: bool
page: int
size: int
# ==================== 场景提取Schema ====================
class ExtractSceneRequest(BaseModel):
"""提取场景请求Schema"""
course_id: int = Field(..., description="课程ID")
class ExtractedSceneData(BaseModel):
"""提取的场景数据Schema"""
name: str = Field(..., description="场景名称")
description: str = Field(..., description="场景描述")
type: str = Field(..., description="场景类型")
difficulty: str = Field(..., description="难度等级")
background: str = Field(..., description="场景背景")
ai_role: str = Field(..., description="AI角色描述")
objectives: List[str] = Field(..., description="练习目标数组")
keywords: Optional[List[str]] = Field(default=[], description="关键词数组")
class ExtractSceneResponse(BaseModel):
"""提取场景响应Schema"""
scene: ExtractedSceneData = Field(..., description="场景数据")
workflow_run_id: str = Field(..., description="工作流运行ID")
task_id: str = Field(..., description="任务ID")
# ==================== 陪练会话Schema ====================
class PracticeSessionCreate(BaseModel):
"""创建陪练会话请求Schema"""
scene_id: Optional[int] = Field(None, description="场景ID")
scene_name: str = Field(..., description="场景名称")
scene_type: Optional[str] = Field(None, description="场景类型")
conversation_id: Optional[str] = Field(None, description="Coze对话ID")
class PracticeSessionResponse(BaseModel):
"""陪练会话响应Schema"""
id: int
session_id: str
user_id: int
scene_id: Optional[int]
scene_name: str
scene_type: Optional[str]
conversation_id: Optional[str]
start_time: datetime
end_time: Optional[datetime]
duration_seconds: int
turns: int
status: str
created_at: datetime
class Config:
from_attributes = True
class SaveDialogueRequest(BaseModel):
"""保存对话记录请求Schema"""
session_id: str = Field(..., description="会话ID")
speaker: str = Field(..., description="说话人: user/ai")
content: str = Field(..., description="对话内容")
sequence: int = Field(..., ge=1, description="顺序号从1开始")
class PracticeDialogueResponse(BaseModel):
"""对话记录响应Schema"""
id: int
session_id: str
speaker: str
content: str
timestamp: datetime
sequence: int
class Config:
from_attributes = True
# ==================== 分析报告Schema ====================
class ScoreBreakdownItem(BaseModel):
"""分数细分项"""
name: str
score: int = Field(..., ge=0, le=100)
description: str
class AbilityDimensionItem(BaseModel):
"""能力维度项"""
name: str
score: int = Field(..., ge=0, le=100)
feedback: str
class DialogueReviewItem(BaseModel):
"""对话复盘项"""
speaker: str
time: str
content: str
tags: List[str] = Field(default_factory=list)
comment: str = Field(default="")
class SuggestionItem(BaseModel):
"""改进建议项"""
title: str
content: str
example: Optional[str] = None
class PracticeAnalysisResult(BaseModel):
"""陪练分析结果Schema"""
total_score: int = Field(..., ge=0, le=100, description="综合得分")
score_breakdown: List[ScoreBreakdownItem] = Field(..., description="分数细分")
ability_dimensions: List[AbilityDimensionItem] = Field(..., description="能力维度")
dialogue_review: List[DialogueReviewItem] = Field(..., description="对话复盘")
suggestions: List[SuggestionItem] = Field(..., description="改进建议")
class PracticeReportResponse(BaseModel):
"""陪练报告响应Schema"""
session_info: PracticeSessionResponse
analysis: PracticeAnalysisResult
class Config:
from_attributes = True

128
backend/app/schemas/scrm.py Normal file
View File

@@ -0,0 +1,128 @@
"""
SCRM 系统对接 API Schema 定义
用于 SCRM 系统调用考陪练系统的数据查询接口
"""
from typing import List, Optional
from pydantic import BaseModel, Field
from datetime import datetime
# ==================== 通用响应 ====================
class SCRMBaseResponse(BaseModel):
"""SCRM API 通用响应基类"""
code: int = Field(default=0, description="响应码0=成功")
message: str = Field(default="success", description="响应消息")
# ==================== 1. 获取员工岗位 ====================
class PositionInfo(BaseModel):
"""岗位信息"""
position_id: int = Field(..., description="岗位ID")
position_name: str = Field(..., description="岗位名称")
is_primary: bool = Field(default=True, description="是否主岗位")
joined_at: Optional[str] = Field(None, description="加入时间")
class EmployeePositionData(BaseModel):
"""员工岗位数据"""
employee_id: int = Field(..., description="员工ID")
userid: Optional[str] = Field(None, description="企微员工userid可能为空")
name: str = Field(..., description="员工姓名")
positions: List[PositionInfo] = Field(default=[], description="岗位列表")
class EmployeePositionResponse(SCRMBaseResponse):
"""获取员工岗位响应"""
data: Optional[EmployeePositionData] = None
# ==================== 2. 获取岗位课程 ====================
class CourseInfo(BaseModel):
"""课程信息"""
course_id: int = Field(..., description="课程ID")
course_name: str = Field(..., description="课程名称")
course_type: str = Field(..., description="课程类型required/optional")
priority: int = Field(default=0, description="优先级")
knowledge_point_count: int = Field(default=0, description="知识点数量")
class PositionCoursesData(BaseModel):
"""岗位课程数据"""
position_id: int = Field(..., description="岗位ID")
position_name: str = Field(..., description="岗位名称")
courses: List[CourseInfo] = Field(default=[], description="课程列表")
class PositionCoursesResponse(SCRMBaseResponse):
"""获取岗位课程响应"""
data: Optional[PositionCoursesData] = None
# ==================== 3. 搜索知识点 ====================
class KnowledgePointSearchRequest(BaseModel):
"""搜索知识点请求"""
keywords: List[str] = Field(..., min_length=1, description="搜索关键词列表")
position_id: Optional[int] = Field(None, description="岗位ID用于优先排序")
course_ids: Optional[List[int]] = Field(None, description="限定课程范围")
knowledge_type: Optional[str] = Field(None, description="知识点类型筛选")
limit: int = Field(default=10, ge=1, le=100, description="返回数量")
class KnowledgePointBrief(BaseModel):
"""知识点简要信息"""
knowledge_point_id: int = Field(..., description="知识点ID")
name: str = Field(..., description="知识点名称")
course_id: int = Field(..., description="课程ID")
course_name: str = Field(..., description="课程名称")
type: str = Field(..., description="知识点类型")
relevance_score: float = Field(default=1.0, description="相关度分数")
class KnowledgePointSearchData(BaseModel):
"""知识点搜索结果数据"""
total: int = Field(..., description="匹配总数")
items: List[KnowledgePointBrief] = Field(default=[], description="知识点列表")
class KnowledgePointSearchResponse(SCRMBaseResponse):
"""搜索知识点响应"""
data: Optional[KnowledgePointSearchData] = None
# ==================== 4. 获取知识点详情 ====================
class KnowledgePointDetailData(BaseModel):
"""知识点详情数据"""
knowledge_point_id: int = Field(..., description="知识点ID")
name: str = Field(..., description="知识点名称")
course_id: int = Field(..., description="课程ID")
course_name: str = Field(..., description="课程名称")
type: str = Field(..., description="知识点类型")
content: str = Field(..., description="知识点完整内容description")
material_id: Optional[int] = Field(None, description="关联的课程资料ID")
material_type: Optional[str] = Field(None, description="资料文件类型")
material_url: Optional[str] = Field(None, description="资料文件URL")
topic_relation: Optional[str] = Field(None, description="与主题的关系描述")
source: int = Field(default=0, description="来源0=手动创建1=AI分析生成")
created_at: Optional[str] = Field(None, description="创建时间")
class KnowledgePointDetailResponse(SCRMBaseResponse):
"""获取知识点详情响应"""
data: Optional[KnowledgePointDetailData] = None
# ==================== 错误响应 ====================
class SCRMErrorResponse(SCRMBaseResponse):
"""错误响应"""
code: int = Field(..., description="错误码")
message: str = Field(..., description="错误消息")
data: None = None

View File

@@ -0,0 +1,59 @@
"""
系统日志 Schema
"""
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
class SystemLogBase(BaseModel):
"""系统日志基础Schema"""
level: str = Field(..., description="日志级别: debug, info, warning, error")
type: str = Field(..., description="日志类型: system, user, api, error, security")
user: Optional[str] = Field(None, description="操作用户")
user_id: Optional[int] = Field(None, description="用户ID")
ip: Optional[str] = Field(None, description="IP地址")
message: str = Field(..., description="日志消息")
user_agent: Optional[str] = Field(None, description="User Agent")
path: Optional[str] = Field(None, description="请求路径")
method: Optional[str] = Field(None, description="请求方法")
extra_data: Optional[str] = Field(None, description="额外数据JSON格式")
class SystemLogCreate(SystemLogBase):
"""创建系统日志Schema"""
pass
class SystemLogResponse(SystemLogBase):
"""系统日志响应Schema"""
id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class SystemLogQuery(BaseModel):
"""系统日志查询参数"""
level: Optional[str] = Field(None, description="日志级别筛选")
type: Optional[str] = Field(None, description="日志类型筛选")
user: Optional[str] = Field(None, description="用户筛选")
keyword: Optional[str] = Field(None, description="关键词搜索搜索message字段")
start_date: Optional[datetime] = Field(None, description="开始日期")
end_date: Optional[datetime] = Field(None, description="结束日期")
page: int = Field(1, ge=1, description="页码")
page_size: int = Field(20, ge=1, le=100, description="每页数量")
class SystemLogListResponse(BaseModel):
"""系统日志列表响应"""
items: list[SystemLogResponse]
total: int
page: int
page_size: int
total_pages: int

View File

@@ -0,0 +1,67 @@
"""
任务相关Schema
"""
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
class TaskBase(BaseModel):
"""任务基础Schema"""
title: str = Field(..., description="任务标题")
description: Optional[str] = Field(None, description="任务描述")
priority: str = Field("medium", description="优先级(low/medium/high)")
deadline: Optional[datetime] = Field(None, description="截止时间")
requirements: Optional[dict] = Field(None, description="任务要求配置")
course_ids: List[int] = Field(default_factory=list, description="关联课程ID列表")
user_ids: List[int] = Field(default_factory=list, description="分配用户ID列表")
class TaskCreate(TaskBase):
"""创建任务"""
pass
class TaskUpdate(BaseModel):
"""更新任务"""
title: Optional[str] = None
description: Optional[str] = None
priority: Optional[str] = None
status: Optional[str] = None
deadline: Optional[datetime] = None
requirements: Optional[dict] = None
progress: Optional[int] = None
class TaskResponse(BaseModel):
"""任务响应"""
id: int
title: str
description: Optional[str]
priority: str
status: str
creator_id: int
deadline: Optional[datetime]
requirements: Optional[dict]
progress: int
created_at: datetime
updated_at: datetime
# 扩展字段
courses: List[str] = Field(default_factory=list, description="课程名称列表")
assigned_count: int = Field(0, description="分配人数")
completed_count: int = Field(0, description="完成人数")
class Config:
from_attributes = True
class TaskStatsResponse(BaseModel):
"""任务统计响应"""
total: int = Field(0, description="总任务数")
ongoing: int = Field(0, description="进行中")
completed: int = Field(0, description="已完成")
expired: int = Field(0, description="已过期")
avg_completion_rate: float = Field(0.0, description="平均完成率")

View File

@@ -0,0 +1,260 @@
"""陪练模块Pydantic模式"""
from typing import Optional, List, Dict, Any, Generic, TypeVar
from datetime import datetime
from pydantic import BaseModel, Field, ConfigDict
# 定义泛型类型变量
DataT = TypeVar("DataT")
from app.models.training import (
TrainingSceneStatus,
TrainingSessionStatus,
MessageType,
MessageRole,
)
from app.schemas.base import BaseSchema, TimestampMixin, IDMixin
# ========== 陪练场景相关 ==========
class TrainingSceneBase(BaseSchema):
"""陪练场景基础模式"""
name: str = Field(..., max_length=100, description="场景名称")
description: Optional[str] = Field(None, description="场景描述")
category: str = Field(..., max_length=50, description="场景分类")
ai_config: Optional[Dict[str, Any]] = Field(None, description="AI配置")
prompt_template: Optional[str] = Field(None, description="提示词模板")
evaluation_criteria: Optional[Dict[str, Any]] = Field(None, description="评估标准")
is_public: bool = Field(True, description="是否公开")
required_level: Optional[int] = Field(None, description="所需用户等级")
class TrainingSceneCreate(TrainingSceneBase):
"""创建陪练场景模式"""
status: TrainingSceneStatus = Field(
default=TrainingSceneStatus.DRAFT, description="场景状态"
)
class TrainingSceneUpdate(BaseSchema):
"""更新陪练场景模式"""
name: Optional[str] = Field(None, max_length=100)
description: Optional[str] = None
category: Optional[str] = Field(None, max_length=50)
ai_config: Optional[Dict[str, Any]] = None
prompt_template: Optional[str] = None
evaluation_criteria: Optional[Dict[str, Any]] = None
status: Optional[TrainingSceneStatus] = None
is_public: Optional[bool] = None
required_level: Optional[int] = None
class TrainingSceneInDB(TrainingSceneBase, IDMixin, TimestampMixin):
"""数据库中的陪练场景模式"""
status: TrainingSceneStatus
is_deleted: bool = False
created_by: Optional[int] = None
updated_by: Optional[int] = None
class TrainingSceneResponse(TrainingSceneInDB):
"""陪练场景响应模式"""
pass
# ========== 陪练会话相关 ==========
class TrainingSessionBase(BaseSchema):
"""陪练会话基础模式"""
scene_id: int = Field(..., description="场景ID")
session_config: Optional[Dict[str, Any]] = Field(None, description="会话配置")
class TrainingSessionCreate(TrainingSessionBase):
"""创建陪练会话模式"""
pass
class TrainingSessionUpdate(BaseSchema):
"""更新陪练会话模式"""
status: Optional[TrainingSessionStatus] = None
end_time: Optional[datetime] = None
duration_seconds: Optional[int] = None
total_score: Optional[float] = None
evaluation_result: Optional[Dict[str, Any]] = None
class TrainingSessionInDB(TrainingSessionBase, IDMixin, TimestampMixin):
"""数据库中的陪练会话模式"""
user_id: int
coze_conversation_id: Optional[str] = None
start_time: datetime
end_time: Optional[datetime] = None
duration_seconds: Optional[int] = None
status: TrainingSessionStatus
total_score: Optional[float] = None
evaluation_result: Optional[Dict[str, Any]] = None
created_by: Optional[int] = None
updated_by: Optional[int] = None
class TrainingSessionResponse(TrainingSessionInDB):
"""陪练会话响应模式"""
scene: Optional["TrainingSceneResponse"] = None
message_count: Optional[int] = Field(None, description="消息数量")
# ========== 消息相关 ==========
class TrainingMessageBase(BaseSchema):
"""陪练消息基础模式"""
role: MessageRole = Field(..., description="消息角色")
type: MessageType = Field(..., description="消息类型")
content: str = Field(..., description="消息内容")
voice_url: Optional[str] = Field(None, max_length=500, description="语音文件URL")
voice_duration: Optional[float] = Field(None, description="语音时长(秒)")
metadata: Optional[Dict[str, Any]] = Field(None, description="消息元数据")
class TrainingMessageCreate(TrainingMessageBase):
"""创建陪练消息模式"""
session_id: int = Field(..., description="会话ID")
coze_message_id: Optional[str] = Field(None, max_length=100, description="Coze消息ID")
class TrainingMessageInDB(TrainingMessageBase, IDMixin, TimestampMixin):
"""数据库中的陪练消息模式"""
session_id: int
coze_message_id: Optional[str] = None
class TrainingMessageResponse(TrainingMessageInDB):
"""陪练消息响应模式"""
pass
# ========== 报告相关 ==========
class TrainingReportBase(BaseSchema):
"""陪练报告基础模式"""
overall_score: float = Field(..., ge=0, le=100, description="总体得分")
dimension_scores: Dict[str, float] = Field(..., description="各维度得分")
strengths: List[str] = Field(..., description="优势点")
weaknesses: List[str] = Field(..., description="待改进点")
suggestions: List[str] = Field(..., description="改进建议")
detailed_analysis: Optional[str] = Field(None, description="详细分析")
transcript: Optional[str] = Field(None, description="对话文本记录")
statistics: Optional[Dict[str, Any]] = Field(None, description="统计数据")
class TrainingReportCreate(TrainingReportBase):
"""创建陪练报告模式"""
session_id: int = Field(..., description="会话ID")
user_id: int = Field(..., description="用户ID")
class TrainingReportInDB(TrainingReportBase, IDMixin, TimestampMixin):
"""数据库中的陪练报告模式"""
session_id: int
user_id: int
created_by: Optional[int] = None
updated_by: Optional[int] = None
class TrainingReportResponse(TrainingReportInDB):
"""陪练报告响应模式"""
session: Optional[TrainingSessionResponse] = None
# ========== 会话操作相关 ==========
class StartTrainingRequest(BaseSchema):
"""开始陪练请求"""
scene_id: int = Field(..., description="场景ID")
config: Optional[Dict[str, Any]] = Field(None, description="会话配置")
class StartTrainingResponse(BaseSchema):
"""开始陪练响应"""
session_id: int = Field(..., description="会话ID")
coze_conversation_id: Optional[str] = Field(None, description="Coze会话ID")
scene: TrainingSceneResponse = Field(..., description="场景信息")
websocket_url: Optional[str] = Field(None, description="WebSocket连接URL")
class EndTrainingRequest(BaseSchema):
"""结束陪练请求"""
generate_report: bool = Field(True, description="是否生成报告")
class EndTrainingResponse(BaseSchema):
"""结束陪练响应"""
session: TrainingSessionResponse = Field(..., description="会话信息")
report: Optional[TrainingReportResponse] = Field(None, description="陪练报告")
# ========== 列表查询相关 ==========
class TrainingSceneListQuery(BaseSchema):
"""陪练场景列表查询参数"""
category: Optional[str] = Field(None, description="场景分类")
status: Optional[TrainingSceneStatus] = Field(None, description="场景状态")
is_public: Optional[bool] = Field(None, description="是否公开")
search: Optional[str] = Field(None, description="搜索关键词")
page: int = Field(1, ge=1, description="页码")
page_size: int = Field(20, ge=1, le=100, description="每页数量")
class TrainingSessionListQuery(BaseSchema):
"""陪练会话列表查询参数"""
scene_id: Optional[int] = Field(None, description="场景ID")
status: Optional[TrainingSessionStatus] = Field(None, description="会话状态")
start_date: Optional[datetime] = Field(None, description="开始日期")
end_date: Optional[datetime] = Field(None, description="结束日期")
page: int = Field(1, ge=1, description="页码")
page_size: int = Field(20, ge=1, le=100, description="每页数量")
class PaginatedResponse(BaseModel, Generic[DataT]):
"""分页响应模式"""
items: List[DataT] = Field(..., description="数据列表")
total: int = Field(..., description="总数量")
page: int = Field(..., description="当前页码")
page_size: int = Field(..., description="每页数量")
pages: int = Field(..., description="总页数")
# 更新前向引用
TrainingSessionResponse.model_rebuild()
TrainingReportResponse.model_rebuild()

154
backend/app/schemas/user.py Normal file
View File

@@ -0,0 +1,154 @@
"""
用户相关 Schema
"""
from datetime import datetime
from typing import List, Optional
from pydantic import EmailStr, Field, field_validator
from .base import BaseSchema
class UserBase(BaseSchema):
"""用户基础信息"""
username: str = Field(..., min_length=3, max_length=50)
email: Optional[EmailStr] = None
phone: Optional[str] = Field(None, pattern=r"^1[3-9]\d{9}$")
full_name: Optional[str] = Field(None, max_length=100)
avatar_url: Optional[str] = None
bio: Optional[str] = None
role: str = Field(default="trainee", pattern="^(admin|manager|trainee)$")
gender: Optional[str] = Field(None, pattern="^(male|female)$")
school: Optional[str] = Field(None, max_length=100)
major: Optional[str] = Field(None, max_length=100)
class UserCreate(UserBase):
"""创建用户"""
password: str = Field(..., min_length=6, max_length=100)
@field_validator("password")
def validate_password(cls, v):
if len(v) < 6:
raise ValueError("密码长度至少为6位")
return v
class UserUpdate(BaseSchema):
"""更新用户"""
email: Optional[EmailStr] = None
phone: Optional[str] = Field(None, pattern=r"^1[3-9]\d{9}$")
full_name: Optional[str] = Field(None, max_length=100)
avatar_url: Optional[str] = None
bio: Optional[str] = None
role: Optional[str] = Field(None, pattern="^(admin|manager|trainee)$")
is_active: Optional[bool] = None
gender: Optional[str] = Field(None, pattern="^(male|female)$")
school: Optional[str] = Field(None, max_length=100)
major: Optional[str] = Field(None, max_length=100)
class UserPasswordUpdate(BaseSchema):
"""更新密码"""
old_password: str
new_password: str = Field(..., min_length=6, max_length=100)
class UserInDBBase(UserBase):
"""数据库中的用户基础信息"""
id: int
is_active: bool
is_verified: bool
created_at: datetime
updated_at: datetime
last_login_at: Optional[datetime] = None
class User(UserInDBBase):
"""用户信息(不含敏感数据)"""
teams: List["TeamBasic"] = []
class UserWithPassword(UserInDBBase):
"""用户信息(含密码)"""
hashed_password: str
# Team Schemas
class TeamBase(BaseSchema):
"""团队基础信息"""
name: str = Field(..., min_length=2, max_length=100)
code: str = Field(..., min_length=2, max_length=50)
description: Optional[str] = None
team_type: str = Field(
default="department", pattern="^(department|project|study_group)$"
)
class TeamCreate(TeamBase):
"""创建团队"""
leader_id: Optional[int] = None
parent_id: Optional[int] = None
class TeamUpdate(BaseSchema):
"""更新团队"""
name: Optional[str] = Field(None, min_length=2, max_length=100)
description: Optional[str] = None
leader_id: Optional[int] = None
is_active: Optional[bool] = None
class TeamBasic(BaseSchema):
"""团队基本信息"""
id: int
name: str
code: str
team_type: str
class Team(TeamBase):
"""团队完整信息"""
id: int
is_active: bool
leader_id: Optional[int] = None
parent_id: Optional[int] = None
created_at: datetime
updated_at: datetime
member_count: Optional[int] = 0
class TeamWithMembers(Team):
"""团队信息(含成员)"""
members: List[User] = []
leader: Optional[User] = None
# 避免循环引用
UserBase.model_rebuild()
User.model_rebuild()
Team.model_rebuild()
# Filter schemas
class UserFilter(BaseSchema):
"""用户筛选条件"""
role: Optional[str] = Field(None, pattern="^(admin|manager|trainee)$")
is_active: Optional[bool] = None
team_id: Optional[int] = None
keyword: Optional[str] = None # 搜索用户名、邮箱、姓名

View File

@@ -0,0 +1,61 @@
"""
言迹智能工牌相关Schema定义
"""
from typing import List, Optional
from pydantic import BaseModel, Field
class ConversationMessage(BaseModel):
"""单条对话消息"""
role: str = Field(..., description="角色consultant=销售人员customer=客户")
text: str = Field(..., description="对话文本内容")
begin_time: Optional[str] = Field(None, description="开始时间偏移量(毫秒)")
end_time: Optional[str] = Field(None, description="结束时间偏移量(毫秒)")
class YanjiConversation(BaseModel):
"""完整的对话记录"""
audio_id: int = Field(..., description="录音ID")
visit_id: str = Field(..., description="来访单ID")
start_time: str = Field(..., description="录音开始时间")
duration: int = Field(..., description="录音时长(毫秒)")
consultant_name: str = Field(..., description="销售人员姓名")
consultant_phone: str = Field(..., description="销售人员手机号")
conversation: List[ConversationMessage] = Field(..., description="对话内容列表")
class GetConversationsByVisitIdsRequest(BaseModel):
"""根据来访单ID获取对话记录请求"""
external_visit_ids: List[str] = Field(
...,
min_length=1,
max_length=10,
description="三方来访单ID列表最多10个",
)
class GetConversationsByVisitIdsResponse(BaseModel):
"""获取对话记录响应"""
conversations: List[YanjiConversation] = Field(..., description="对话记录列表")
total: int = Field(..., description="总数量")
class GetConversationsRequest(BaseModel):
"""获取员工对话记录请求"""
consultant_phone: str = Field(..., description="员工手机号")
limit: int = Field(default=10, ge=1, le=100, description="获取数量")
class GetConversationsResponse(BaseModel):
"""获取员工对话记录响应"""
conversations: List[YanjiConversation] = Field(..., description="对话记录列表")
total: int = Field(..., description="总数量")

View File

@@ -0,0 +1 @@
"""业务逻辑服务包"""

View File

@@ -0,0 +1,272 @@
"""
能力评估服务
用于分析用户对话数据,生成能力评估报告和课程推荐
使用 Python 原生实现
"""
import json
import logging
from typing import Dict, Any, List, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.ability import AbilityAssessment
from app.services.ai import ability_analysis_service
logger = logging.getLogger(__name__)
class AbilityAssessmentService:
"""能力评估服务类"""
async def analyze_yanji_conversations(
self,
user_id: int,
phone: str,
db: AsyncSession,
yanji_service,
engine: Literal["v2"] = "v2"
) -> Dict[str, Any]:
"""
分析言迹对话并生成能力评估及课程推荐
Args:
user_id: 用户ID
phone: 用户手机号(用于获取言迹数据)
db: 数据库会话
yanji_service: 言迹服务实例
engine: 引擎类型v2=Python原生
Returns:
评估结果字典,包含:
- assessment_id: 评估记录ID
- total_score: 综合评分
- dimensions: 能力维度列表
- recommended_courses: 推荐课程列表
- conversation_count: 分析的对话数量
Raises:
ValueError: 未找到员工的录音记录
Exception: API调用失败或其他错误
"""
logger.info(f"开始分析言迹对话: user_id={user_id}, phone={phone}, engine={engine}")
# 1. 获取员工对话数据最多10条录音
conversations = await yanji_service.get_employee_conversations_for_analysis(
phone=phone,
limit=10
)
if not conversations:
logger.warning(f"未找到员工的录音记录: user_id={user_id}, phone={phone}")
raise ValueError("未找到该员工的录音记录")
# 2. 合并所有对话历史
all_dialogues = []
for conv in conversations:
all_dialogues.extend(conv['dialogue_history'])
logger.info(
f"准备分析: user_id={user_id}, "
f"对话数={len(conversations)}, "
f"总轮次={len(all_dialogues)}"
)
used_engine = "v2"
# Python 原生实现
logger.info(f"调用原生能力分析服务")
# 将对话历史格式化为文本
dialogue_text = self._format_dialogues_for_analysis(all_dialogues)
# 调用原生服务
result = await ability_analysis_service.analyze(
db=db,
user_id=user_id,
dialogue_history=dialogue_text
)
if not result.success:
raise Exception(f"能力分析失败: {result.error}")
# 转换为兼容格式
analysis_result = {
"analysis": {
"total_score": result.total_score,
"ability_dimensions": [
{"name": d.name, "score": d.score, "feedback": d.feedback}
for d in result.ability_dimensions
],
"course_recommendations": [
{
"course_id": c.course_id,
"course_name": c.course_name,
"recommendation_reason": c.recommendation_reason,
"priority": c.priority,
"match_score": c.match_score,
}
for c in result.course_recommendations
]
}
}
logger.info(
f"能力分析完成 - total_score: {result.total_score}, "
f"provider: {result.ai_provider}, latency: {result.ai_latency_ms}ms"
)
# 4. 提取结果
analysis = analysis_result.get('analysis', {})
ability_dims = analysis.get('ability_dimensions', [])
course_recs = analysis.get('course_recommendations', [])
total_score = analysis.get('total_score')
logger.info(
f"分析完成 (engine={used_engine}): total_score={total_score}, "
f"dimensions={len(ability_dims)}, courses={len(course_recs)}"
)
# 5. 保存能力评估记录到数据库
assessment = AbilityAssessment(
user_id=user_id,
source_type='yanji_badge',
source_id=','.join([str(c['audio_id']) for c in conversations]),
total_score=total_score,
ability_dimensions=ability_dims,
recommended_courses=course_recs,
conversation_count=len(conversations)
)
db.add(assessment)
await db.commit()
await db.refresh(assessment)
logger.info(
f"评估记录已保存: assessment_id={assessment.id}, "
f"user_id={user_id}, total_score={total_score}"
)
# 6. 返回评估结果
return {
"assessment_id": assessment.id,
"total_score": total_score,
"dimensions": ability_dims,
"recommended_courses": course_recs,
"conversation_count": len(conversations),
"analyzed_at": assessment.analyzed_at,
"engine": used_engine,
}
def _format_dialogues_for_analysis(self, dialogues: List[Dict[str, Any]]) -> str:
"""
将对话历史列表格式化为文本
Args:
dialogues: 对话历史列表,每项包含 speaker, content 等字段
Returns:
格式化后的对话文本
"""
lines = []
for i, d in enumerate(dialogues, 1):
speaker = d.get('speaker', 'unknown')
content = d.get('content', '')
# 统一说话者标识
if speaker in ['consultant', 'employee', 'user', '员工']:
speaker_label = '员工'
elif speaker in ['customer', 'client', '顾客', '客户']:
speaker_label = '顾客'
else:
speaker_label = speaker
lines.append(f"[{i}] {speaker_label}: {content}")
return '\n'.join(lines)
async def get_user_assessment_history(
self,
user_id: int,
db: AsyncSession,
limit: int = 10
) -> List[Dict[str, Any]]:
"""
获取用户的能力评估历史记录
Args:
user_id: 用户ID
db: 数据库会话
limit: 返回记录数量限制
Returns:
评估历史记录列表
"""
stmt = (
select(AbilityAssessment)
.where(AbilityAssessment.user_id == user_id)
.order_by(AbilityAssessment.analyzed_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
assessments = result.scalars().all()
history = []
for assessment in assessments:
history.append({
"id": assessment.id,
"source_type": assessment.source_type,
"total_score": assessment.total_score,
"ability_dimensions": assessment.ability_dimensions,
"recommended_courses": assessment.recommended_courses,
"conversation_count": assessment.conversation_count,
"analyzed_at": assessment.analyzed_at.isoformat() if assessment.analyzed_at else None,
"created_at": assessment.created_at.isoformat() if assessment.created_at else None
})
logger.info(f"获取评估历史: user_id={user_id}, count={len(history)}")
return history
async def get_assessment_detail(
self,
assessment_id: int,
db: AsyncSession
) -> Dict[str, Any]:
"""
获取单个评估记录的详细信息
Args:
assessment_id: 评估记录ID
db: 数据库会话
Returns:
评估详细信息
Raises:
ValueError: 评估记录不存在
"""
stmt = select(AbilityAssessment).where(AbilityAssessment.id == assessment_id)
result = await db.execute(stmt)
assessment = result.scalar_one_or_none()
if not assessment:
raise ValueError(f"评估记录不存在: assessment_id={assessment_id}")
return {
"id": assessment.id,
"user_id": assessment.user_id,
"source_type": assessment.source_type,
"source_id": assessment.source_id,
"total_score": assessment.total_score,
"ability_dimensions": assessment.ability_dimensions,
"recommended_courses": assessment.recommended_courses,
"conversation_count": assessment.conversation_count,
"analyzed_at": assessment.analyzed_at.isoformat() if assessment.analyzed_at else None,
"created_at": assessment.created_at.isoformat() if assessment.created_at else None
}
def get_ability_assessment_service() -> AbilityAssessmentService:
"""获取能力评估服务实例(依赖注入)"""
return AbilityAssessmentService()

View File

@@ -0,0 +1,151 @@
"""
AI 服务模块
包含:
- AIService: 本地 AI 服务(支持 4sapi + OpenRouter 降级)
- LLM JSON Parser: 大模型 JSON 输出解析器
- KnowledgeAnalysisServiceV2: 知识点分析服务Python 原生实现)
- ExamGeneratorService: 试题生成服务Python 原生实现)
- CourseChatServiceV2: 课程对话服务Python 原生实现)
- PracticeSceneService: 陪练场景准备服务Python 原生实现)
- AbilityAnalysisService: 智能工牌能力分析服务Python 原生实现)
- AnswerJudgeService: 答案判断服务Python 原生实现)
- PracticeAnalysisService: 陪练分析报告服务Python 原生实现)
"""
from .ai_service import (
AIService,
AIResponse,
AIConfig,
AIServiceError,
AIProvider,
DEFAULT_MODEL,
MODEL_ANALYSIS,
MODEL_CREATIVE,
MODEL_IMAGE_GEN,
quick_chat,
)
from .llm_json_parser import (
parse_llm_json,
parse_with_fallback,
safe_json_loads,
clean_llm_output,
diagnose_json_error,
validate_json_schema,
ParseResult,
JSONParseError,
JSONUnrecoverableError,
)
from .knowledge_analysis_v2 import (
KnowledgeAnalysisServiceV2,
knowledge_analysis_service_v2,
)
from .exam_generator_service import (
ExamGeneratorService,
ExamGeneratorConfig,
exam_generator_service,
generate_exam,
)
from .course_chat_service import (
CourseChatServiceV2,
course_chat_service_v2,
)
from .practice_scene_service import (
PracticeSceneService,
PracticeScene,
PracticeSceneResult,
practice_scene_service,
prepare_practice_knowledge,
)
from .ability_analysis_service import (
AbilityAnalysisService,
AbilityAnalysisResult,
AbilityDimension,
CourseRecommendation,
ability_analysis_service,
)
from .answer_judge_service import (
AnswerJudgeService,
JudgeResult,
answer_judge_service,
judge_answer,
)
from .practice_analysis_service import (
PracticeAnalysisService,
PracticeAnalysisResult,
ScoreBreakdownItem,
AbilityDimensionItem,
DialogueAnnotation,
Suggestion,
practice_analysis_service,
analyze_practice_session,
)
__all__ = [
# AI Service
"AIService",
"AIResponse",
"AIConfig",
"AIServiceError",
"AIProvider",
"DEFAULT_MODEL",
"MODEL_ANALYSIS",
"MODEL_CREATIVE",
"MODEL_IMAGE_GEN",
"quick_chat",
# JSON Parser
"parse_llm_json",
"parse_with_fallback",
"safe_json_loads",
"clean_llm_output",
"diagnose_json_error",
"validate_json_schema",
"ParseResult",
"JSONParseError",
"JSONUnrecoverableError",
# Knowledge Analysis V2
"KnowledgeAnalysisServiceV2",
"knowledge_analysis_service_v2",
# Exam Generator V2
"ExamGeneratorService",
"ExamGeneratorConfig",
"exam_generator_service",
"generate_exam",
# Course Chat V2
"CourseChatServiceV2",
"course_chat_service_v2",
# Practice Scene V2
"PracticeSceneService",
"PracticeScene",
"PracticeSceneResult",
"practice_scene_service",
"prepare_practice_knowledge",
# Ability Analysis V2
"AbilityAnalysisService",
"AbilityAnalysisResult",
"AbilityDimension",
"CourseRecommendation",
"ability_analysis_service",
# Answer Judge V2
"AnswerJudgeService",
"JudgeResult",
"answer_judge_service",
"judge_answer",
# Practice Analysis V2
"PracticeAnalysisService",
"PracticeAnalysisResult",
"ScoreBreakdownItem",
"AbilityDimensionItem",
"DialogueAnnotation",
"Suggestion",
"practice_analysis_service",
"analyze_practice_session",
]

View File

@@ -0,0 +1,479 @@
"""
智能工牌能力分析与课程推荐服务 - Python 原生实现
功能:
- 分析员工与顾客的对话记录
- 评估多维度能力得分
- 基于能力短板推荐课程
提供稳定可靠的能力分析和课程推荐能力。
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ExternalServiceError
from .ai_service import AIService, AIResponse
from .llm_json_parser import parse_with_fallback, clean_llm_output
from .prompts.ability_analysis_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
ABILITY_ANALYSIS_SCHEMA,
ABILITY_DIMENSIONS,
)
logger = logging.getLogger(__name__)
# ==================== 数据结构 ====================
@dataclass
class AbilityDimension:
"""能力维度评分"""
name: str
score: float
feedback: str
@dataclass
class CourseRecommendation:
"""课程推荐"""
course_id: int
course_name: str
recommendation_reason: str
priority: str # high, medium, low
match_score: float
@dataclass
class AbilityAnalysisResult:
"""能力分析结果"""
success: bool
total_score: float = 0.0
ability_dimensions: List[AbilityDimension] = field(default_factory=list)
course_recommendations: List[CourseRecommendation] = field(default_factory=list)
ai_provider: str = ""
ai_model: str = ""
ai_tokens: int = 0
ai_latency_ms: int = 0
error: str = ""
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"success": self.success,
"total_score": self.total_score,
"ability_dimensions": [
{"name": d.name, "score": d.score, "feedback": d.feedback}
for d in self.ability_dimensions
],
"course_recommendations": [
{
"course_id": c.course_id,
"course_name": c.course_name,
"recommendation_reason": c.recommendation_reason,
"priority": c.priority,
"match_score": c.match_score,
}
for c in self.course_recommendations
],
"ai_provider": self.ai_provider,
"ai_model": self.ai_model,
"ai_tokens": self.ai_tokens,
"ai_latency_ms": self.ai_latency_ms,
"error": self.error,
}
@dataclass
class UserPositionInfo:
"""用户岗位信息"""
position_id: int
position_name: str
code: str
description: str
skills: Optional[Dict[str, Any]]
level: str
status: str
@dataclass
class CourseInfo:
"""课程信息"""
id: int
name: str
description: str
category: str
tags: Optional[List[str]]
difficulty_level: int
duration_hours: float
# ==================== 服务类 ====================
class AbilityAnalysisService:
"""
智能工牌能力分析服务
使用 Python 原生实现。
使用示例:
```python
service = AbilityAnalysisService()
result = await service.analyze(
db=db_session,
user_id=1,
dialogue_history="顾客:你好,我想了解一下你们的服务..."
)
print(result.total_score)
print(result.course_recommendations)
```
"""
def __init__(self):
"""初始化服务"""
self.ai_service = AIService(module_code="ability_analysis")
async def analyze(
self,
db: AsyncSession,
user_id: int,
dialogue_history: str
) -> AbilityAnalysisResult:
"""
分析员工能力并推荐课程
Args:
db: 数据库会话(支持多租户,每个租户传入各自的会话)
user_id: 用户ID
dialogue_history: 对话记录
Returns:
AbilityAnalysisResult 分析结果
"""
try:
logger.info(f"开始能力分析 - user_id: {user_id}")
# 1. 验证输入
if not dialogue_history or not dialogue_history.strip():
return AbilityAnalysisResult(
success=False,
error="对话记录不能为空"
)
# 2. 查询用户岗位信息
user_positions = await self._get_user_positions(db, user_id)
user_info_str = self._format_user_info(user_positions)
logger.info(f"用户岗位信息: {len(user_positions)} 个岗位")
# 3. 查询所有可选课程
courses = await self._get_published_courses(db)
courses_str = self._format_courses(courses)
logger.info(f"可选课程: {len(courses)}")
# 4. 调用 AI 分析
ai_response = await self._call_ai_analysis(
dialogue_history=dialogue_history,
user_info=user_info_str,
courses=courses_str
)
logger.info(
f"AI 分析完成 - provider: {ai_response.provider}, "
f"tokens: {ai_response.total_tokens}, latency: {ai_response.latency_ms}ms"
)
# 5. 解析 JSON 结果
analysis_data = self._parse_analysis_result(ai_response.content, courses)
# 6. 构建返回结果
result = AbilityAnalysisResult(
success=True,
total_score=analysis_data.get("total_score", 0),
ability_dimensions=[
AbilityDimension(
name=d.get("name", ""),
score=d.get("score", 0),
feedback=d.get("feedback", "")
)
for d in analysis_data.get("ability_dimensions", [])
],
course_recommendations=[
CourseRecommendation(
course_id=c.get("course_id", 0),
course_name=c.get("course_name", ""),
recommendation_reason=c.get("recommendation_reason", ""),
priority=c.get("priority", "medium"),
match_score=c.get("match_score", 0)
)
for c in analysis_data.get("course_recommendations", [])
],
ai_provider=ai_response.provider,
ai_model=ai_response.model,
ai_tokens=ai_response.total_tokens,
ai_latency_ms=ai_response.latency_ms,
)
logger.info(
f"能力分析完成 - user_id: {user_id}, total_score: {result.total_score}, "
f"recommendations: {len(result.course_recommendations)}"
)
return result
except Exception as e:
logger.error(
f"能力分析失败 - user_id: {user_id}, error: {e}",
exc_info=True
)
return AbilityAnalysisResult(
success=False,
error=str(e)
)
async def _get_user_positions(
self,
db: AsyncSession,
user_id: int
) -> List[UserPositionInfo]:
"""
查询用户的岗位信息
获取用户基本信息
"""
query = text("""
SELECT
p.id as position_id,
p.name as position_name,
p.code,
p.description,
p.skills,
p.level,
p.status
FROM positions p
INNER JOIN position_members pm ON p.id = pm.position_id
WHERE pm.user_id = :user_id
AND pm.is_deleted = 0
AND p.is_deleted = 0
""")
result = await db.execute(query, {"user_id": user_id})
rows = result.fetchall()
positions = []
for row in rows:
# 解析 skills JSON
skills = None
if row.skills:
if isinstance(row.skills, str):
try:
skills = json.loads(row.skills)
except json.JSONDecodeError:
skills = None
else:
skills = row.skills
positions.append(UserPositionInfo(
position_id=row.position_id,
position_name=row.position_name,
code=row.code or "",
description=row.description or "",
skills=skills,
level=row.level or "",
status=row.status or ""
))
return positions
async def _get_published_courses(self, db: AsyncSession) -> List[CourseInfo]:
"""
查询所有已发布的课程
获取所有课程列表
"""
query = text("""
SELECT
id,
name,
description,
category,
tags,
difficulty_level,
duration_hours
FROM courses
WHERE status = 'published'
AND is_deleted = FALSE
ORDER BY sort_order
""")
result = await db.execute(query)
rows = result.fetchall()
courses = []
for row in rows:
# 解析 tags JSON
tags = None
if row.tags:
if isinstance(row.tags, str):
try:
tags = json.loads(row.tags)
except json.JSONDecodeError:
tags = None
else:
tags = row.tags
courses.append(CourseInfo(
id=row.id,
name=row.name,
description=row.description or "",
category=row.category or "",
tags=tags,
difficulty_level=row.difficulty_level or 3,
duration_hours=row.duration_hours or 0
))
return courses
def _format_user_info(self, positions: List[UserPositionInfo]) -> str:
"""格式化用户岗位信息为文本"""
if not positions:
return "暂无岗位信息"
lines = []
for p in positions:
info = f"- 岗位:{p.position_name}{p.code}"
if p.level:
info += f",级别:{p.level}"
if p.description:
info += f"\n 描述:{p.description}"
if p.skills:
skills_str = json.dumps(p.skills, ensure_ascii=False)
info += f"\n 核心技能:{skills_str}"
lines.append(info)
return "\n".join(lines)
def _format_courses(self, courses: List[CourseInfo]) -> str:
"""格式化课程列表为文本"""
if not courses:
return "暂无可选课程"
lines = []
for c in courses:
info = f"- ID: {c.id}, 课程名称: {c.name}"
if c.category:
info += f", 分类: {c.category}"
if c.difficulty_level:
info += f", 难度: {c.difficulty_level}"
if c.duration_hours:
info += f", 时长: {c.duration_hours}小时"
if c.description:
# 截断过长的描述
desc = c.description[:100] + "..." if len(c.description) > 100 else c.description
info += f"\n 描述: {desc}"
lines.append(info)
return "\n".join(lines)
async def _call_ai_analysis(
self,
dialogue_history: str,
user_info: str,
courses: str
) -> AIResponse:
"""调用 AI 进行能力分析"""
# 构建用户消息
user_message = USER_PROMPT.format(
dialogue_history=dialogue_history,
user_info=user_info,
courses=courses
)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_message}
]
# 调用 AI自动支持 4sapi → OpenRouter 降级)
response = await self.ai_service.chat(
messages=messages,
temperature=0.7, # 保持一定创意性
prompt_name="ability_analysis"
)
return response
def _parse_analysis_result(
self,
ai_output: str,
courses: List[CourseInfo]
) -> Dict[str, Any]:
"""
解析 AI 输出的分析结果 JSON
使用 LLM JSON Parser 进行多层兜底解析
"""
# 先清洗输出
cleaned_output, rules = clean_llm_output(ai_output)
if rules:
logger.debug(f"AI 输出已清洗: {rules}")
# 使用带 Schema 校验的解析
parsed = parse_with_fallback(
cleaned_output,
schema=ABILITY_ANALYSIS_SCHEMA,
default={"analysis": {}},
validate_schema=True,
on_error="default"
)
# 提取 analysis 部分
analysis = parsed.get("analysis", {})
# 后处理:验证课程推荐的有效性
valid_course_ids = {c.id for c in courses}
valid_recommendations = []
for rec in analysis.get("course_recommendations", []):
course_id = rec.get("course_id")
if course_id in valid_course_ids:
valid_recommendations.append(rec)
else:
logger.warning(f"推荐的课程ID不存在: {course_id}")
analysis["course_recommendations"] = valid_recommendations
# 确保能力维度完整
existing_dims = {d.get("name") for d in analysis.get("ability_dimensions", [])}
for dim_name in ABILITY_DIMENSIONS:
if dim_name not in existing_dims:
logger.warning(f"缺少能力维度: {dim_name},使用默认值")
analysis.setdefault("ability_dimensions", []).append({
"name": dim_name,
"score": 70,
"feedback": "暂无具体评价"
})
return analysis
# ==================== 全局实例 ====================
ability_analysis_service = AbilityAnalysisService()

View File

@@ -0,0 +1,747 @@
"""
本地 AI 服务 - 遵循瑞小美 AI 接入规范
功能:
- 支持 4sapi.com首选和 OpenRouter备选自动降级
- 统一的请求/响应格式
- 调用日志记录
"""
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from enum import Enum
import httpx
logger = logging.getLogger(__name__)
class AIProvider(Enum):
"""AI 服务商"""
PRIMARY = "4sapi" # 首选4sapi.com
FALLBACK = "openrouter" # 备选OpenRouter
@dataclass
class AIResponse:
"""AI 响应结果"""
content: str # AI 回复内容
model: str = "" # 使用的模型
provider: str = "" # 实际使用的服务商
input_tokens: int = 0 # 输入 token 数
output_tokens: int = 0 # 输出 token 数
total_tokens: int = 0 # 总 token 数
cost: float = 0.0 # 费用(美元)
latency_ms: int = 0 # 响应延迟(毫秒)
raw_response: Dict[str, Any] = field(default_factory=dict) # 原始响应
images: List[str] = field(default_factory=list) # 图像生成结果
annotations: Dict[str, Any] = field(default_factory=dict) # PDF 解析注释
@dataclass
class AIConfig:
"""AI 服务配置"""
primary_api_key: str # 通用 KeyGemini/DeepSeek 等)
anthropic_api_key: str = "" # Claude 专属 Key
primary_base_url: str = "https://4sapi.com/v1"
fallback_api_key: str = ""
fallback_base_url: str = "https://openrouter.ai/api/v1"
default_model: str = "claude-opus-4-5-20251101-thinking" # 默认使用最强模型
timeout: float = 120.0
max_retries: int = 2
# Claude 模型列表(需要使用 anthropic_api_key
CLAUDE_MODELS = [
"claude-opus-4-5-20251101-thinking",
"claude-opus-4-5-20251101",
"claude-sonnet-4-20250514",
"claude-3-opus",
"claude-3-sonnet",
"claude-3-haiku",
]
def is_claude_model(model: str) -> bool:
"""判断是否为 Claude 模型"""
model_lower = model.lower()
return any(claude in model_lower for claude in ["claude", "anthropic"])
# 模型名称映射4sapi -> OpenRouter
MODEL_MAPPING = {
# 4sapi 使用简短名称OpenRouter 使用完整路径
"gemini-3-flash-preview": "google/gemini-3-flash-preview",
"gemini-3-pro-preview": "google/gemini-3-pro-preview",
"claude-opus-4-5-20251101-thinking": "anthropic/claude-opus-4.5",
"gemini-2.5-flash-image-preview": "google/gemini-2.0-flash-exp:free",
}
# 反向映射OpenRouter -> 4sapi
MODEL_MAPPING_REVERSE = {v: k for k, v in MODEL_MAPPING.items()}
class AIServiceError(Exception):
"""AI 服务错误"""
def __init__(self, message: str, provider: str = "", status_code: int = 0):
super().__init__(message)
self.provider = provider
self.status_code = status_code
class AIService:
"""
本地 AI 服务
遵循瑞小美 AI 接入规范:
- 首选 4sapi.com失败自动降级到 OpenRouter
- 统一的响应格式
- 自动模型名称转换
使用示例:
```python
ai = AIService(module_code="knowledge_analysis")
response = await ai.chat(
messages=[
{"role": "system", "content": "你是助手"},
{"role": "user", "content": "你好"}
],
prompt_name="greeting"
)
print(response.content)
```
"""
def __init__(
self,
module_code: str = "default",
config: Optional[AIConfig] = None,
db_session: Any = None
):
"""
初始化 AI 服务
配置加载优先级(遵循瑞小美 AI 接入规范):
1. 显式传入的 config 参数
2. 数据库 ai_config 表(推荐)
3. 环境变量fallback
Args:
module_code: 模块标识,用于统计
config: AI 配置None 则从数据库/环境变量读取
db_session: 数据库会话,用于记录调用日志和读取配置
"""
self.module_code = module_code
self.db_session = db_session
self.config = config or self._load_config(db_session)
logger.info(f"AIService 初始化: module={module_code}, primary={self.config.primary_base_url}")
def _load_config(self, db_session: Any) -> AIConfig:
"""
加载配置
配置加载优先级(遵循瑞小美 AI 接入规范):
1. 管理库 tenant_configs 表(推荐,通过 DynamicConfig
2. 环境变量fallback
Args:
db_session: 数据库会话(可选,用于日志记录)
Returns:
AIConfig 配置对象
"""
# 优先从管理库加载(同步方式)
try:
config = self._load_config_from_admin_db()
if config:
logger.info("✅ AI 配置已从管理库tenant_configs加载")
return config
except Exception as e:
logger.debug(f"从管理库加载 AI 配置失败: {e}")
# Fallback 到环境变量
logger.info("AI 配置从环境变量加载")
return self._load_config_from_env()
def _load_config_from_admin_db(self) -> Optional[AIConfig]:
"""
从管理库 tenant_configs 表加载配置
使用同步方式直接查询 kaopeilian_admin.tenant_configs 表
Returns:
AIConfig 配置对象,如果无数据则返回 None
"""
import os
# 获取当前租户编码
tenant_code = os.getenv("TENANT_CODE", "demo")
# 获取管理库连接信息
admin_db_host = os.getenv("ADMIN_DB_HOST", "prod-mysql")
admin_db_port = int(os.getenv("ADMIN_DB_PORT", "3306"))
admin_db_user = os.getenv("ADMIN_DB_USER", "root")
admin_db_password = os.getenv("ADMIN_DB_PASSWORD", "")
admin_db_name = os.getenv("ADMIN_DB_NAME", "kaopeilian_admin")
if not admin_db_password:
logger.debug("ADMIN_DB_PASSWORD 未配置,跳过管理库配置加载")
return None
try:
from sqlalchemy import create_engine, text
import urllib.parse
# 构建连接 URL
encoded_password = urllib.parse.quote_plus(admin_db_password)
admin_db_url = f"mysql+pymysql://{admin_db_user}:{encoded_password}@{admin_db_host}:{admin_db_port}/{admin_db_name}?charset=utf8mb4"
engine = create_engine(admin_db_url, pool_pre_ping=True)
with engine.connect() as conn:
# 1. 获取租户 ID
result = conn.execute(
text("SELECT id FROM tenants WHERE code = :code AND status = 'active'"),
{"code": tenant_code}
)
row = result.fetchone()
if not row:
logger.debug(f"租户 {tenant_code} 不存在或未激活")
engine.dispose()
return None
tenant_id = row[0]
# 2. 获取 AI 配置
result = conn.execute(
text("""
SELECT config_key, config_value
FROM tenant_configs
WHERE tenant_id = :tenant_id AND config_group = 'ai'
"""),
{"tenant_id": tenant_id}
)
rows = result.fetchall()
engine.dispose()
if not rows:
logger.debug(f"租户 {tenant_code} 无 AI 配置")
return None
# 转换为字典
config_dict = {row[0]: row[1] for row in rows}
# 检查必要的配置是否存在
primary_key = config_dict.get("AI_PRIMARY_API_KEY", "")
if not primary_key:
logger.warning(f"租户 {tenant_code} 的 AI_PRIMARY_API_KEY 为空")
return None
logger.info(f"✅ 从管理库加载租户 {tenant_code} 的 AI 配置成功")
return AIConfig(
primary_api_key=primary_key,
anthropic_api_key=config_dict.get("AI_ANTHROPIC_API_KEY", ""),
primary_base_url=config_dict.get("AI_PRIMARY_BASE_URL", "https://4sapi.com/v1"),
fallback_api_key=config_dict.get("AI_FALLBACK_API_KEY", ""),
fallback_base_url=config_dict.get("AI_FALLBACK_BASE_URL", "https://openrouter.ai/api/v1"),
default_model=config_dict.get("AI_DEFAULT_MODEL", "claude-opus-4-5-20251101-thinking"),
timeout=float(config_dict.get("AI_TIMEOUT", "120")),
)
except Exception as e:
logger.debug(f"从管理库读取 AI 配置异常: {e}")
return None
def _load_config_from_env(self) -> AIConfig:
"""
从环境变量加载配置
⚠️ 强制要求(遵循瑞小美 AI 接入规范):
- 禁止在代码中硬编码 API Key
- 必须通过环境变量配置 Key
必须配置的环境变量:
- AI_PRIMARY_API_KEY: 通用 Key用于 Gemini/DeepSeek 等)
- AI_ANTHROPIC_API_KEY: Claude 专属 Key
"""
import os
primary_api_key = os.getenv("AI_PRIMARY_API_KEY", "")
anthropic_api_key = os.getenv("AI_ANTHROPIC_API_KEY", "")
# 检查必要的 Key 是否已配置
if not primary_api_key:
logger.warning("⚠️ AI_PRIMARY_API_KEY 未配置AI 服务可能无法正常工作")
if not anthropic_api_key:
logger.warning("⚠️ AI_ANTHROPIC_API_KEY 未配置Claude 模型调用将失败")
return AIConfig(
# 通用 KeyGemini/DeepSeek 等非 Anthropic 模型)
primary_api_key=primary_api_key,
# Claude 专属 Key
anthropic_api_key=anthropic_api_key,
primary_base_url=os.getenv("AI_PRIMARY_BASE_URL", "https://4sapi.com/v1"),
fallback_api_key=os.getenv("AI_FALLBACK_API_KEY", ""),
fallback_base_url=os.getenv("AI_FALLBACK_BASE_URL", "https://openrouter.ai/api/v1"),
# 默认模型:遵循"优先最强"原则,使用 Claude Opus 4.5
default_model=os.getenv("AI_DEFAULT_MODEL", "claude-opus-4-5-20251101-thinking"),
timeout=float(os.getenv("AI_TIMEOUT", "120")),
)
def _convert_model_name(self, model: str, provider: AIProvider) -> str:
"""
转换模型名称以匹配服务商格式
Args:
model: 原始模型名称
provider: 目标服务商
Returns:
转换后的模型名称
"""
if provider == AIProvider.FALLBACK:
# 4sapi -> OpenRouter
return MODEL_MAPPING.get(model, f"google/{model}" if "/" not in model else model)
else:
# OpenRouter -> 4sapi
return MODEL_MAPPING_REVERSE.get(model, model.split("/")[-1] if "/" in model else model)
async def chat(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
prompt_name: str = "default",
**kwargs
) -> AIResponse:
"""
文本聊天
Args:
messages: 消息列表 [{"role": "system/user/assistant", "content": "..."}]
model: 模型名称None 使用默认模型
temperature: 温度参数
max_tokens: 最大输出 token 数
prompt_name: 提示词名称,用于统计
**kwargs: 其他参数
Returns:
AIResponse 响应对象
"""
model = model or self.config.default_model
# 构建请求体
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
}
if max_tokens:
payload["max_tokens"] = max_tokens
# 首选服务商
try:
return await self._call_provider(
provider=AIProvider.PRIMARY,
endpoint="/chat/completions",
payload=payload,
prompt_name=prompt_name
)
except AIServiceError as e:
logger.warning(f"首选服务商调用失败: {e}, 尝试降级到备选服务商")
# 如果没有备选 API Key直接抛出异常
if not self.config.fallback_api_key:
raise
# 降级到备选服务商
# 转换模型名称
fallback_model = self._convert_model_name(model, AIProvider.FALLBACK)
payload["model"] = fallback_model
return await self._call_provider(
provider=AIProvider.FALLBACK,
endpoint="/chat/completions",
payload=payload,
prompt_name=prompt_name
)
async def chat_stream(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: Optional[int] = None,
prompt_name: str = "default",
**kwargs
) -> AsyncGenerator[str, None]:
"""
流式文本聊天
Args:
messages: 消息列表 [{"role": "system/user/assistant", "content": "..."}]
model: 模型名称None 使用默认模型
temperature: 温度参数
max_tokens: 最大输出 token 数
prompt_name: 提示词名称,用于统计
**kwargs: 其他参数
Yields:
str: 文本块(逐字返回)
"""
model = model or self.config.default_model
# 构建请求体
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"stream": True,
}
if max_tokens:
payload["max_tokens"] = max_tokens
# 首选服务商
try:
async for chunk in self._call_provider_stream(
provider=AIProvider.PRIMARY,
endpoint="/chat/completions",
payload=payload,
prompt_name=prompt_name
):
yield chunk
return
except AIServiceError as e:
logger.warning(f"首选服务商流式调用失败: {e}, 尝试降级到备选服务商")
# 如果没有备选 API Key直接抛出异常
if not self.config.fallback_api_key:
raise
# 降级到备选服务商
# 转换模型名称
fallback_model = self._convert_model_name(model, AIProvider.FALLBACK)
payload["model"] = fallback_model
async for chunk in self._call_provider_stream(
provider=AIProvider.FALLBACK,
endpoint="/chat/completions",
payload=payload,
prompt_name=prompt_name
):
yield chunk
async def _call_provider_stream(
self,
provider: AIProvider,
endpoint: str,
payload: Dict[str, Any],
prompt_name: str
) -> AsyncGenerator[str, None]:
"""
流式调用指定服务商
Args:
provider: 服务商
endpoint: API 端点
payload: 请求体
prompt_name: 提示词名称
Yields:
str: 文本块
"""
# 获取配置
if provider == AIProvider.PRIMARY:
base_url = self.config.primary_base_url
# 根据模型选择 API KeyClaude 用专属 Key其他用通用 Key
model = payload.get("model", "")
if is_claude_model(model) and self.config.anthropic_api_key:
api_key = self.config.anthropic_api_key
logger.debug(f"[Stream] 使用 Claude 专属 Key 调用模型: {model}")
else:
api_key = self.config.primary_api_key
else:
api_key = self.config.fallback_api_key
base_url = self.config.fallback_base_url
url = f"{base_url.rstrip('/')}{endpoint}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# OpenRouter 需要额外的 header
if provider == AIProvider.FALLBACK:
headers["HTTP-Referer"] = "https://kaopeilian.ireborn.com.cn"
headers["X-Title"] = "KaoPeiLian"
start_time = time.time()
try:
timeout = httpx.Timeout(self.config.timeout, connect=10.0)
async with httpx.AsyncClient(timeout=timeout) as client:
logger.info(f"流式调用 AI 服务: provider={provider.value}, model={payload.get('model')}")
async with client.stream("POST", url, json=payload, headers=headers) as response:
# 检查响应状态
if response.status_code != 200:
error_text = await response.aread()
logger.error(f"AI 服务流式返回错误: status={response.status_code}, body={error_text[:500]}")
raise AIServiceError(
f"API 流式请求失败: HTTP {response.status_code}",
provider=provider.value,
status_code=response.status_code
)
# 处理 SSE 流
async for line in response.aiter_lines():
if not line or not line.strip():
continue
# 解析 SSE 数据行
if line.startswith("data: "):
data_str = line[6:] # 移除 "data: " 前缀
# 检查是否是结束标记
if data_str.strip() == "[DONE]":
logger.info(f"流式响应完成: provider={provider.value}")
return
try:
event_data = json.loads(data_str)
# 提取 delta 内容
choices = event_data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError as e:
logger.debug(f"解析流式数据失败: {e} - 数据: {data_str[:100]}")
continue
latency_ms = int((time.time() - start_time) * 1000)
logger.info(f"流式调用完成: provider={provider.value}, latency={latency_ms}ms")
except httpx.TimeoutException:
latency_ms = int((time.time() - start_time) * 1000)
logger.error(f"AI 服务流式超时: provider={provider.value}, latency={latency_ms}ms")
raise AIServiceError(f"流式请求超时({self.config.timeout}秒)", provider=provider.value)
except httpx.RequestError as e:
logger.error(f"AI 服务流式网络错误: provider={provider.value}, error={e}")
raise AIServiceError(f"流式网络错误: {e}", provider=provider.value)
async def _call_provider(
self,
provider: AIProvider,
endpoint: str,
payload: Dict[str, Any],
prompt_name: str
) -> AIResponse:
"""
调用指定服务商
Args:
provider: 服务商
endpoint: API 端点
payload: 请求体
prompt_name: 提示词名称
Returns:
AIResponse 响应对象
"""
# 获取配置
if provider == AIProvider.PRIMARY:
base_url = self.config.primary_base_url
# 根据模型选择 API KeyClaude 用专属 Key其他用通用 Key
model = payload.get("model", "")
if is_claude_model(model) and self.config.anthropic_api_key:
api_key = self.config.anthropic_api_key
logger.debug(f"使用 Claude 专属 Key 调用模型: {model}")
else:
api_key = self.config.primary_api_key
else:
api_key = self.config.fallback_api_key
base_url = self.config.fallback_base_url
url = f"{base_url.rstrip('/')}{endpoint}"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# OpenRouter 需要额外的 header
if provider == AIProvider.FALLBACK:
headers["HTTP-Referer"] = "https://kaopeilian.ireborn.com.cn"
headers["X-Title"] = "KaoPeiLian"
start_time = time.time()
try:
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
logger.info(f"调用 AI 服务: provider={provider.value}, model={payload.get('model')}")
response = await client.post(url, json=payload, headers=headers)
latency_ms = int((time.time() - start_time) * 1000)
# 检查响应状态
if response.status_code != 200:
error_text = response.text
logger.error(f"AI 服务返回错误: status={response.status_code}, body={error_text[:500]}")
raise AIServiceError(
f"API 请求失败: HTTP {response.status_code}",
provider=provider.value,
status_code=response.status_code
)
data = response.json()
# 解析响应
ai_response = self._parse_response(data, provider, latency_ms)
# 记录日志
logger.info(
f"AI 调用成功: provider={provider.value}, model={ai_response.model}, "
f"tokens={ai_response.total_tokens}, latency={latency_ms}ms"
)
# 保存到数据库(如果有 session
await self._log_call(prompt_name, ai_response)
return ai_response
except httpx.TimeoutException:
latency_ms = int((time.time() - start_time) * 1000)
logger.error(f"AI 服务超时: provider={provider.value}, latency={latency_ms}ms")
raise AIServiceError(f"请求超时({self.config.timeout}秒)", provider=provider.value)
except httpx.RequestError as e:
logger.error(f"AI 服务网络错误: provider={provider.value}, error={e}")
raise AIServiceError(f"网络错误: {e}", provider=provider.value)
def _parse_response(
self,
data: Dict[str, Any],
provider: AIProvider,
latency_ms: int
) -> AIResponse:
"""解析 API 响应"""
# 提取内容
choices = data.get("choices", [])
if not choices:
raise AIServiceError("响应中没有 choices")
message = choices[0].get("message", {})
content = message.get("content", "")
# 提取 usage
usage = data.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", input_tokens + output_tokens)
# 提取费用(如果有)
cost = usage.get("total_cost", 0.0)
return AIResponse(
content=content,
model=data.get("model", ""),
provider=provider.value,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
cost=cost,
latency_ms=latency_ms,
raw_response=data
)
async def _log_call(self, prompt_name: str, response: AIResponse) -> None:
"""记录调用日志到数据库"""
if not self.db_session:
return
try:
# TODO: 实现调用日志记录
# 可以参考 ai_call_logs 表结构
pass
except Exception as e:
logger.warning(f"记录 AI 调用日志失败: {e}")
async def analyze_document(
self,
content: str,
prompt: str,
model: Optional[str] = None,
prompt_name: str = "document_analysis"
) -> AIResponse:
"""
分析文档内容
Args:
content: 文档内容
prompt: 分析提示词
model: 模型名称
prompt_name: 提示词名称
Returns:
AIResponse 响应对象
"""
messages = [
{"role": "user", "content": f"{prompt}\n\n文档内容:\n{content}"}
]
return await self.chat(
messages=messages,
model=model,
temperature=0.1, # 文档分析使用低温度
prompt_name=prompt_name
)
# 便捷函数
async def quick_chat(
messages: List[Dict[str, str]],
model: Optional[str] = None,
module_code: str = "quick"
) -> str:
"""
快速聊天,返回纯文本
Args:
messages: 消息列表
model: 模型名称
module_code: 模块标识
Returns:
AI 回复的文本内容
"""
ai = AIService(module_code=module_code)
response = await ai.chat(messages, model=model)
return response.content
# 模型常量(遵循瑞小美 AI 接入规范)
# 按优先级排序:首选 > 标准 > 快速
MODEL_PRIMARY = "claude-opus-4-5-20251101-thinking" # 🥇 首选:所有任务首先尝试
MODEL_STANDARD = "gemini-3-pro-preview" # 🥈 标准Claude 失败后降级
MODEL_FAST = "gemini-3-flash-preview" # 🥉 快速:最终保底
MODEL_IMAGE = "gemini-2.5-flash-image-preview" # 🖼️ 图像生成专用
MODEL_VIDEO = "veo3.1-pro" # 🎬 视频生成专用
# 兼容旧代码的别名
DEFAULT_MODEL = MODEL_PRIMARY # 默认使用最强模型
MODEL_ANALYSIS = MODEL_PRIMARY
MODEL_CREATIVE = MODEL_STANDARD
MODEL_IMAGE_GEN = MODEL_IMAGE

View File

@@ -0,0 +1,197 @@
"""
答案判断服务 - Python 原生实现
功能:
- 判断填空题与问答题的答案是否正确
- 通过 AI 语义理解比对用户答案与标准答案
提供稳定可靠的答案判断能力。
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional
from .ai_service import AIService, AIResponse
from .prompts.answer_judge_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
CORRECT_KEYWORDS,
INCORRECT_KEYWORDS,
)
logger = logging.getLogger(__name__)
@dataclass
class JudgeResult:
"""判断结果"""
is_correct: bool
raw_response: str
ai_provider: str = ""
ai_model: str = ""
ai_tokens: int = 0
ai_latency_ms: int = 0
class AnswerJudgeService:
"""
答案判断服务
使用 Python 原生实现。
使用示例:
```python
service = AnswerJudgeService()
result = await service.judge(
db=db_session, # 传入 db_session 用于记录调用日志
question="玻尿酸的主要作用是什么?",
correct_answer="补水保湿、填充塑形",
user_answer="保湿和塑形",
analysis="玻尿酸具有补水保湿和填充塑形两大功能"
)
print(result.is_correct) # True
```
"""
MODULE_CODE = "answer_judge"
async def judge(
self,
question: str,
correct_answer: str,
user_answer: str,
analysis: str = "",
db: Any = None # 数据库会话,用于记录 AI 调用日志
) -> JudgeResult:
"""
判断答案是否正确
Args:
question: 题目内容
correct_answer: 标准答案
user_answer: 用户答案
analysis: 答案解析(可选)
db: 数据库会话,用于记录调用日志(符合 AI 接入规范)
Returns:
JudgeResult 判断结果
"""
try:
logger.info(
f"开始判断答案 - question: {question[:50]}..., "
f"user_answer: {user_answer[:50]}..."
)
# 创建 AIService 实例(传入 db_session 用于记录调用日志)
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
# 构建提示词
user_prompt = USER_PROMPT.format(
question=question,
correct_answer=correct_answer,
user_answer=user_answer,
analysis=analysis or ""
)
# 调用 AI
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt}
]
ai_response = await ai_service.chat(
messages=messages,
temperature=0.1, # 低温度,确保输出稳定
prompt_name="answer_judge"
)
logger.info(
f"AI 判断完成 - provider: {ai_response.provider}, "
f"response: {ai_response.content}, "
f"latency: {ai_response.latency_ms}ms"
)
# 解析 AI 输出
is_correct = self._parse_judge_result(ai_response.content)
logger.info(f"答案判断结果: {is_correct}")
return JudgeResult(
is_correct=is_correct,
raw_response=ai_response.content,
ai_provider=ai_response.provider,
ai_model=ai_response.model,
ai_tokens=ai_response.total_tokens,
ai_latency_ms=ai_response.latency_ms,
)
except Exception as e:
logger.error(f"答案判断失败: {e}", exc_info=True)
# 出错时默认返回错误,保守处理
return JudgeResult(
is_correct=False,
raw_response=f"判断失败: {e}",
)
def _parse_judge_result(self, ai_output: str) -> bool:
"""
解析 AI 输出的判断结果
Args:
ai_output: AI 返回的文本
Returns:
bool: True 表示正确False 表示错误
"""
# 清洗输出
output = ai_output.strip().lower()
# 检查是否包含正确关键词
for keyword in CORRECT_KEYWORDS:
if keyword.lower() in output:
return True
# 检查是否包含错误关键词
for keyword in INCORRECT_KEYWORDS:
if keyword.lower() in output:
return False
# 无法识别时,默认返回错误(保守处理)
logger.warning(f"无法解析判断结果,默认返回错误: {ai_output}")
return False
# ==================== 全局实例 ====================
answer_judge_service = AnswerJudgeService()
# ==================== 便捷函数 ====================
async def judge_answer(
question: str,
correct_answer: str,
user_answer: str,
analysis: str = ""
) -> bool:
"""
便捷函数:判断答案是否正确
Args:
question: 题目内容
correct_answer: 标准答案
user_answer: 用户答案
analysis: 答案解析
Returns:
bool: True 表示正确False 表示错误
"""
result = await answer_judge_service.judge(
question=question,
correct_answer=correct_answer,
user_answer=user_answer,
analysis=analysis
)
return result.is_correct

View File

@@ -0,0 +1,757 @@
"""
课程对话服务 V2 - Python 原生实现
功能:
- 查询课程知识点作为知识库
- 调用 AI 进行对话
- 支持流式输出
- 多轮对话历史管理Redis 缓存)
提供稳定可靠的课程对话能力。
"""
import json
import logging
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import ExternalServiceError
from .ai_service import AIService
from .prompts.course_chat_prompts import (
SYSTEM_PROMPT,
USER_PROMPT,
KNOWLEDGE_ITEM_TEMPLATE,
CONVERSATION_WINDOW_SIZE,
CONVERSATION_TTL,
MAX_KNOWLEDGE_POINTS,
MAX_KNOWLEDGE_BASE_LENGTH,
DEFAULT_CHAT_MODEL,
DEFAULT_TEMPERATURE,
)
logger = logging.getLogger(__name__)
# 会话索引 Redis key 前缀/后缀
CONVERSATION_INDEX_PREFIX = "course_chat:user:"
CONVERSATION_INDEX_SUFFIX = ":conversations"
# 会话元数据 key 前缀
CONVERSATION_META_PREFIX = "course_chat:meta:"
# 会话索引过期时间(与会话数据一致)
CONVERSATION_INDEX_TTL = CONVERSATION_TTL
class CourseChatServiceV2:
"""
课程对话服务 V2
使用 Python 原生实现。
使用示例:
```python
service = CourseChatServiceV2()
# 非流式对话
response = await service.chat(
db=db_session,
course_id=1,
query="什么是玻尿酸?",
user_id=1,
conversation_id=None
)
# 流式对话
async for chunk in service.chat_stream(
db=db_session,
course_id=1,
query="什么是玻尿酸?",
user_id=1,
conversation_id=None
):
print(chunk, end="", flush=True)
```
"""
# Redis key 前缀
CONVERSATION_KEY_PREFIX = "course_chat:conversation:"
# 模块标识
MODULE_CODE = "course_chat"
def __init__(self):
"""初始化服务AIService 在方法中动态创建,以传入 db_session"""
pass
async def chat(
self,
db: AsyncSession,
course_id: int,
query: str,
user_id: int,
conversation_id: Optional[str] = None
) -> Dict[str, Any]:
"""
与课程对话(非流式)
Args:
db: 数据库会话
course_id: 课程ID
query: 用户问题
user_id: 用户ID
conversation_id: 会话ID续接对话时传入
Returns:
包含 answer、conversation_id 等字段的字典
"""
try:
logger.info(
f"开始课程对话 V2 - course_id: {course_id}, user_id: {user_id}, "
f"conversation_id: {conversation_id}"
)
# 1. 获取课程知识点
knowledge_base = await self._get_course_knowledge(db, course_id)
if not knowledge_base:
logger.warning(f"课程 {course_id} 没有知识点,使用空知识库")
knowledge_base = "(该课程暂无知识点内容)"
# 2. 获取或创建会话ID
is_new_conversation = False
if not conversation_id:
conversation_id = self._generate_conversation_id(user_id, course_id)
is_new_conversation = True
logger.info(f"创建新会话: {conversation_id}")
# 3. 构建消息列表
messages = await self._build_messages(
knowledge_base=knowledge_base,
query=query,
user_id=user_id,
conversation_id=conversation_id
)
# 4. 创建 AIService 并调用(传入 db_session 以记录调用日志)
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
response = await ai_service.chat(
messages=messages,
model=DEFAULT_CHAT_MODEL,
temperature=DEFAULT_TEMPERATURE,
prompt_name="course_chat"
)
answer = response.content
# 5. 保存对话历史
await self._save_conversation_history(
conversation_id=conversation_id,
user_message=query,
assistant_message=answer
)
# 6. 更新会话索引
if is_new_conversation:
await self._add_to_conversation_index(user_id, conversation_id, course_id)
else:
await self._update_conversation_index(user_id, conversation_id)
logger.info(
f"课程对话完成 - course_id: {course_id}, conversation_id: {conversation_id}, "
f"provider: {response.provider}, tokens: {response.total_tokens}"
)
return {
"success": True,
"answer": answer,
"conversation_id": conversation_id,
"ai_provider": response.provider,
"ai_model": response.model,
"ai_tokens": response.total_tokens,
"ai_latency_ms": response.latency_ms,
}
except Exception as e:
logger.error(
f"课程对话失败 - course_id: {course_id}, user_id: {user_id}, error: {e}",
exc_info=True
)
raise ExternalServiceError(f"课程对话失败: {e}")
async def chat_stream(
self,
db: AsyncSession,
course_id: int,
query: str,
user_id: int,
conversation_id: Optional[str] = None
) -> AsyncGenerator[Tuple[str, Optional[str]], None]:
"""
与课程对话(流式输出)
Args:
db: 数据库会话
course_id: 课程ID
query: 用户问题
user_id: 用户ID
conversation_id: 会话ID续接对话时传入
Yields:
Tuple[str, Optional[str]]: (事件类型, 数据)
- ("conversation_started", conversation_id): 会话开始
- ("chunk", text): 文本块
- ("end", None): 结束
- ("error", message): 错误
"""
full_answer = ""
try:
logger.info(
f"开始流式课程对话 V2 - course_id: {course_id}, user_id: {user_id}, "
f"conversation_id: {conversation_id}"
)
# 1. 获取课程知识点
knowledge_base = await self._get_course_knowledge(db, course_id)
if not knowledge_base:
logger.warning(f"课程 {course_id} 没有知识点,使用空知识库")
knowledge_base = "(该课程暂无知识点内容)"
# 2. 获取或创建会话ID
is_new_conversation = False
if not conversation_id:
conversation_id = self._generate_conversation_id(user_id, course_id)
is_new_conversation = True
logger.info(f"创建新会话: {conversation_id}")
# 3. 发送会话开始事件(如果是新会话)
if is_new_conversation:
yield ("conversation_started", conversation_id)
# 4. 构建消息列表
messages = await self._build_messages(
knowledge_base=knowledge_base,
query=query,
user_id=user_id,
conversation_id=conversation_id
)
# 5. 创建 AIService 并流式调用(传入 db_session 以记录调用日志)
ai_service = AIService(module_code=self.MODULE_CODE, db_session=db)
async for chunk in ai_service.chat_stream(
messages=messages,
model=DEFAULT_CHAT_MODEL,
temperature=DEFAULT_TEMPERATURE,
prompt_name="course_chat"
):
full_answer += chunk
yield ("chunk", chunk)
# 6. 发送结束事件
yield ("end", None)
# 7. 保存对话历史
await self._save_conversation_history(
conversation_id=conversation_id,
user_message=query,
assistant_message=full_answer
)
# 8. 更新会话索引
if is_new_conversation:
await self._add_to_conversation_index(user_id, conversation_id, course_id)
else:
await self._update_conversation_index(user_id, conversation_id)
logger.info(
f"流式课程对话完成 - course_id: {course_id}, conversation_id: {conversation_id}, "
f"answer_length: {len(full_answer)}"
)
except Exception as e:
logger.error(
f"流式课程对话失败 - course_id: {course_id}, user_id: {user_id}, error: {e}",
exc_info=True
)
yield ("error", str(e))
async def _get_course_knowledge(
self,
db: AsyncSession,
course_id: int
) -> str:
"""
获取课程知识点,构建知识库文本
Args:
db: 数据库会话
course_id: 课程ID
Returns:
知识库文本
"""
try:
# 查询知识点(课程知识点查询)
query = text("""
SELECT kp.name, kp.description
FROM knowledge_points kp
INNER JOIN course_materials cm ON kp.material_id = cm.id
WHERE kp.course_id = :course_id
AND kp.is_deleted = 0
AND cm.is_deleted = 0
ORDER BY kp.id
LIMIT :limit
""")
result = await db.execute(
query,
{"course_id": course_id, "limit": MAX_KNOWLEDGE_POINTS}
)
rows = result.fetchall()
if not rows:
logger.warning(f"课程 {course_id} 没有关联的知识点")
return ""
# 构建知识库文本
knowledge_items = []
total_length = 0
for row in rows:
name = row[0] or ""
description = row[1] or ""
item = KNOWLEDGE_ITEM_TEMPLATE.format(
name=name,
description=description
)
# 检查是否超过长度限制
if total_length + len(item) > MAX_KNOWLEDGE_BASE_LENGTH:
logger.warning(
f"知识库文本已达到最大长度限制 {MAX_KNOWLEDGE_BASE_LENGTH}"
f"停止添加更多知识点"
)
break
knowledge_items.append(item)
total_length += len(item)
knowledge_base = "\n".join(knowledge_items)
logger.info(
f"获取课程知识点成功 - course_id: {course_id}, "
f"count: {len(knowledge_items)}, length: {len(knowledge_base)}"
)
return knowledge_base
except Exception as e:
logger.error(f"获取课程知识点失败: {e}")
raise
async def _build_messages(
self,
knowledge_base: str,
query: str,
user_id: int,
conversation_id: str
) -> List[Dict[str, str]]:
"""
构建消息列表(包含历史对话)
Args:
knowledge_base: 知识库文本
query: 当前用户问题
user_id: 用户ID
conversation_id: 会话ID
Returns:
消息列表
"""
messages = []
# 1. 系统提示词
system_content = SYSTEM_PROMPT.format(knowledge_base=knowledge_base)
messages.append({"role": "system", "content": system_content})
# 2. 获取历史对话
history = await self._get_conversation_history(conversation_id)
# 限制历史窗口大小
if len(history) > CONVERSATION_WINDOW_SIZE * 2:
history = history[-(CONVERSATION_WINDOW_SIZE * 2):]
# 添加历史消息
messages.extend(history)
# 3. 当前用户问题
user_content = USER_PROMPT.format(query=query)
messages.append({"role": "user", "content": user_content})
logger.debug(
f"构建消息列表 - total: {len(messages)}, history: {len(history)}"
)
return messages
def _generate_conversation_id(self, user_id: int, course_id: int) -> str:
"""生成会话ID"""
unique_id = uuid.uuid4().hex[:8]
return f"conv_{user_id}_{course_id}_{unique_id}"
async def _get_conversation_history(
self,
conversation_id: str
) -> List[Dict[str, str]]:
"""
从 Redis 获取会话历史
Args:
conversation_id: 会话ID
Returns:
消息列表 [{"role": "user/assistant", "content": "..."}]
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
key = f"{self.CONVERSATION_KEY_PREFIX}{conversation_id}"
data = await redis.get(key)
if not data:
return []
history = json.loads(data)
return history
except RuntimeError:
# Redis 未初始化,返回空历史
logger.warning("Redis 未初始化,无法获取会话历史")
return []
except Exception as e:
logger.warning(f"获取会话历史失败: {e}")
return []
async def _save_conversation_history(
self,
conversation_id: str,
user_message: str,
assistant_message: str
) -> None:
"""
保存对话历史到 Redis
Args:
conversation_id: 会话ID
user_message: 用户消息
assistant_message: AI 回复
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
key = f"{self.CONVERSATION_KEY_PREFIX}{conversation_id}"
# 获取现有历史
history = await self._get_conversation_history(conversation_id)
# 添加新消息
history.append({"role": "user", "content": user_message})
history.append({"role": "assistant", "content": assistant_message})
# 限制历史长度
max_messages = CONVERSATION_WINDOW_SIZE * 2
if len(history) > max_messages:
history = history[-max_messages:]
# 保存到 Redis
await redis.setex(
key,
CONVERSATION_TTL,
json.dumps(history, ensure_ascii=False)
)
logger.debug(
f"保存会话历史成功 - conversation_id: {conversation_id}, "
f"messages: {len(history)}"
)
except RuntimeError:
# Redis 未初始化,跳过保存
logger.warning("Redis 未初始化,无法保存会话历史")
except Exception as e:
logger.warning(f"保存会话历史失败: {e}")
async def get_conversation_messages(
self,
conversation_id: str,
user_id: int
) -> List[Dict[str, Any]]:
"""
获取会话的历史消息
Args:
conversation_id: 会话ID
user_id: 用户ID用于权限验证
Returns:
消息列表
"""
# 验证会话ID是否属于该用户
if not conversation_id.startswith(f"conv_{user_id}_"):
logger.warning(
f"用户 {user_id} 尝试访问不属于自己的会话: {conversation_id}"
)
return []
history = await self._get_conversation_history(conversation_id)
# 格式化返回数据
messages = []
for i, msg in enumerate(history):
messages.append({
"id": i,
"role": msg["role"],
"content": msg["content"],
})
return messages
async def _add_to_conversation_index(
self,
user_id: int,
conversation_id: str,
course_id: int
) -> None:
"""
将会话添加到用户索引
Args:
user_id: 用户ID
conversation_id: 会话ID
course_id: 课程ID
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
# 1. 添加到用户的会话索引Sorted Setscore 为时间戳)
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
timestamp = time.time()
await redis.zadd(index_key, {conversation_id: timestamp})
await redis.expire(index_key, CONVERSATION_INDEX_TTL)
# 2. 保存会话元数据
meta_key = f"{CONVERSATION_META_PREFIX}{conversation_id}"
meta_data = {
"conversation_id": conversation_id,
"user_id": user_id,
"course_id": course_id,
"created_at": timestamp,
"updated_at": timestamp,
}
await redis.setex(
meta_key,
CONVERSATION_INDEX_TTL,
json.dumps(meta_data, ensure_ascii=False)
)
logger.debug(
f"会话已添加到索引 - user_id: {user_id}, conversation_id: {conversation_id}"
)
except RuntimeError:
logger.warning("Redis 未初始化,无法添加会话索引")
except Exception as e:
logger.warning(f"添加会话索引失败: {e}")
async def _update_conversation_index(
self,
user_id: int,
conversation_id: str
) -> None:
"""
更新会话的最后活跃时间
Args:
user_id: 用户ID
conversation_id: 会话ID
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
# 更新索引中的时间戳
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
timestamp = time.time()
await redis.zadd(index_key, {conversation_id: timestamp})
await redis.expire(index_key, CONVERSATION_INDEX_TTL)
# 更新元数据中的 updated_at
meta_key = f"{CONVERSATION_META_PREFIX}{conversation_id}"
meta_data = await redis.get(meta_key)
if meta_data:
meta = json.loads(meta_data)
meta["updated_at"] = timestamp
await redis.setex(
meta_key,
CONVERSATION_INDEX_TTL,
json.dumps(meta, ensure_ascii=False)
)
logger.debug(
f"会话索引已更新 - user_id: {user_id}, conversation_id: {conversation_id}"
)
except RuntimeError:
logger.warning("Redis 未初始化,无法更新会话索引")
except Exception as e:
logger.warning(f"更新会话索引失败: {e}")
async def list_user_conversations(
self,
user_id: int,
limit: int = 20
) -> List[Dict[str, Any]]:
"""
获取用户的会话列表
Args:
user_id: 用户ID
limit: 返回数量限制
Returns:
会话列表,按更新时间倒序
"""
try:
from app.core.redis import get_redis_client
redis = get_redis_client()
# 1. 从索引获取最近的会话ID列表倒序
index_key = f"{CONVERSATION_INDEX_PREFIX}{user_id}{CONVERSATION_INDEX_SUFFIX}"
conversation_ids = await redis.zrevrange(index_key, 0, limit - 1)
if not conversation_ids:
logger.debug(f"用户 {user_id} 没有会话记录")
return []
# 2. 获取每个会话的元数据和最后消息
conversations = []
for conv_id in conversation_ids:
# 确保是字符串
if isinstance(conv_id, bytes):
conv_id = conv_id.decode('utf-8')
# 获取元数据
meta_key = f"{CONVERSATION_META_PREFIX}{conv_id}"
meta_data = await redis.get(meta_key)
if meta_data:
if isinstance(meta_data, bytes):
meta_data = meta_data.decode('utf-8')
meta = json.loads(meta_data)
else:
# 从 conversation_id 解析 course_id
# 格式: conv_{user_id}_{course_id}_{uuid}
parts = conv_id.split('_')
course_id = int(parts[2]) if len(parts) >= 3 else 0
meta = {
"conversation_id": conv_id,
"user_id": user_id,
"course_id": course_id,
"created_at": time.time(),
"updated_at": time.time(),
}
# 获取最后一条消息作为预览
history = await self._get_conversation_history(conv_id)
last_message = ""
if history:
# 获取最后一条 assistant 消息
for msg in reversed(history):
if msg["role"] == "assistant":
last_message = msg["content"][:100] # 截取前100字符
if len(msg["content"]) > 100:
last_message += "..."
break
conversations.append({
"id": conv_id,
"course_id": meta.get("course_id"),
"created_at": meta.get("created_at"),
"updated_at": meta.get("updated_at"),
"last_message": last_message,
"message_count": len(history),
})
logger.info(f"获取用户会话列表 - user_id: {user_id}, count: {len(conversations)}")
return conversations
except RuntimeError:
logger.warning("Redis 未初始化,无法获取会话列表")
return []
except Exception as e:
logger.warning(f"获取会话列表失败: {e}")
return []
# 别名方法,供 API 层调用
async def get_conversations(
self,
user_id: int,
course_id: Optional[int] = None,
limit: int = 20
) -> List[Dict[str, Any]]:
"""
获取用户的会话列表(别名方法)
Args:
user_id: 用户ID
course_id: 课程ID可选用于过滤
limit: 返回数量限制
Returns:
会话列表
"""
conversations = await self.list_user_conversations(user_id, limit)
# 如果指定了 course_id进行过滤
if course_id is not None:
conversations = [
c for c in conversations
if c.get("course_id") == course_id
]
return conversations
async def get_messages(
self,
conversation_id: str,
user_id: int,
limit: int = 50
) -> List[Dict[str, Any]]:
"""
获取会话历史消息(别名方法)
Args:
conversation_id: 会话ID
user_id: 用户ID用于权限验证
limit: 返回数量限制
Returns:
消息列表
"""
messages = await self.get_conversation_messages(conversation_id, limit)
return messages
# 创建全局实例
course_chat_service_v2 = CourseChatServiceV2()

View File

@@ -0,0 +1,61 @@
"""
Coze AI 服务模块
"""
from .client import get_coze_client, get_auth_manager, get_bot_config, get_workspace_id
from .service import get_coze_service, CozeService
from .models import (
SessionType,
MessageRole,
ContentType,
StreamEventType,
CozeSession,
CozeMessage,
StreamEvent,
CreateSessionRequest,
CreateSessionResponse,
SendMessageRequest,
EndSessionRequest,
EndSessionResponse,
)
from .exceptions import (
CozeException,
CozeAuthError,
CozeAPIError,
CozeRateLimitError,
CozeTimeoutError,
CozeStreamError,
map_coze_error_to_exception,
)
__all__ = [
# Client
"get_coze_client",
"get_auth_manager",
"get_bot_config",
"get_workspace_id",
# Service
"get_coze_service",
"CozeService",
# Models
"SessionType",
"MessageRole",
"ContentType",
"StreamEventType",
"CozeSession",
"CozeMessage",
"StreamEvent",
"CreateSessionRequest",
"CreateSessionResponse",
"SendMessageRequest",
"EndSessionRequest",
"EndSessionResponse",
# Exceptions
"CozeException",
"CozeAuthError",
"CozeAPIError",
"CozeRateLimitError",
"CozeTimeoutError",
"CozeStreamError",
"map_coze_error_to_exception",
]

View File

@@ -0,0 +1,203 @@
"""
Coze AI 客户端管理
负责管理 Coze API 的认证和客户端实例
"""
from functools import lru_cache
from typing import Optional, Dict, Any
import logging
from pathlib import Path
from cozepy import Coze, TokenAuth, JWTAuth, COZE_CN_BASE_URL
from app.core.config import get_settings
logger = logging.getLogger(__name__)
class CozeAuthManager:
"""Coze 认证管理器"""
def __init__(self):
self.settings = get_settings()
self._client: Optional[Coze] = None
def _create_pat_auth(self) -> TokenAuth:
"""创建个人访问令牌认证"""
if not self.settings.COZE_API_TOKEN:
raise ValueError("COZE_API_TOKEN 未配置")
return TokenAuth(token=self.settings.COZE_API_TOKEN)
def _create_oauth_auth(self) -> JWTAuth:
"""创建 OAuth 认证"""
if not all(
[
self.settings.COZE_OAUTH_CLIENT_ID,
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
]
):
raise ValueError("OAuth 配置不完整")
# 读取私钥
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
if not private_key_path.exists():
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
with open(private_key_path, "r") as f:
private_key = f.read()
try:
return JWTAuth(
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
private_key=private_key,
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
)
except Exception as e:
logger.error(f"创建 OAuth 认证失败: {e}")
raise
def get_client(self, force_new: bool = False) -> Coze:
"""
获取 Coze 客户端实例
Args:
force_new: 是否强制创建新客户端用于长时间运行的请求避免token过期
认证优先级:
1. OAuth推荐配置完整时使用自动刷新token
2. PAT仅当OAuth未配置时使用注意PAT会过期
"""
if self._client is not None and not force_new:
return self._client
auth = None
auth_type = None
# 检查 OAuth 配置是否完整
oauth_configured = all([
self.settings.COZE_OAUTH_CLIENT_ID,
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
])
if oauth_configured:
# OAuth 配置完整,必须使用 OAuth不fallback到PAT
try:
auth = self._create_oauth_auth()
auth_type = "OAuth"
logger.info("使用 OAuth 认证")
except Exception as e:
# OAuth 配置完整但创建失败直接抛出异常不fallback到可能过期的PAT
logger.error(f"OAuth 认证创建失败: {e}")
raise ValueError(f"OAuth 认证失败,请检查私钥文件和配置: {e}")
else:
# OAuth 未配置,使用 PAT
if self.settings.COZE_API_TOKEN:
auth = self._create_pat_auth()
auth_type = "PAT"
logger.warning("使用 PAT 认证注意PAT会过期建议配置OAuth")
else:
raise ValueError("Coze 认证未配置:需要配置 OAuth 或 PAT Token")
# 创建客户端
client = Coze(
auth=auth, base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL
)
logger.debug(f"Coze客户端创建成功认证方式: {auth_type}, force_new: {force_new}")
# 只有非强制创建时才缓存
if not force_new:
self._client = client
return client
def reset(self):
"""重置客户端实例"""
self._client = None
def get_oauth_token(self) -> str:
"""
获取OAuth JWT Token用于前端直连
Returns:
JWT token字符串
"""
if not all([
self.settings.COZE_OAUTH_CLIENT_ID,
self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
self.settings.COZE_OAUTH_PRIVATE_KEY_PATH,
]):
raise ValueError("OAuth 配置不完整")
# 读取私钥
private_key_path = Path(self.settings.COZE_OAUTH_PRIVATE_KEY_PATH)
if not private_key_path.exists():
raise FileNotFoundError(f"私钥文件不存在: {private_key_path}")
with open(private_key_path, "r") as f:
private_key = f.read()
# 创建JWTAuth实例必须指定中国区base_url
jwt_auth = JWTAuth(
client_id=self.settings.COZE_OAUTH_CLIENT_ID,
private_key=private_key,
public_key_id=self.settings.COZE_OAUTH_PUBLIC_KEY_ID,
base_url=self.settings.COZE_API_BASE or COZE_CN_BASE_URL, # 使用中国区API
)
# 获取tokenJWTAuth内部会自动生成
# JWTAuth.token属性返回已签名的JWT
return jwt_auth.token
@lru_cache()
def get_auth_manager() -> CozeAuthManager:
"""获取认证管理器单例"""
return CozeAuthManager()
def get_coze_client(force_new: bool = False) -> Coze:
"""
获取 Coze 客户端
Args:
force_new: 是否强制创建新客户端(用于工作流等长时间运行的请求)
"""
return get_auth_manager().get_client(force_new=force_new)
def get_workspace_id() -> str:
"""获取工作空间 ID"""
settings = get_settings()
if not settings.COZE_WORKSPACE_ID:
raise ValueError("COZE_WORKSPACE_ID 未配置")
return settings.COZE_WORKSPACE_ID
def get_bot_config(session_type: str) -> Dict[str, Any]:
"""
根据会话类型获取 Bot 配置
Args:
session_type: 会话类型 (course_chat 或 training)
Returns:
包含 bot_id 等配置的字典
"""
settings = get_settings()
if session_type == "course_chat":
bot_id = settings.COZE_CHAT_BOT_ID
if not bot_id:
raise ValueError("COZE_CHAT_BOT_ID 未配置")
elif session_type == "training":
bot_id = settings.COZE_TRAINING_BOT_ID
if not bot_id:
raise ValueError("COZE_TRAINING_BOT_ID 未配置")
else:
raise ValueError(f"不支持的会话类型: {session_type}")
return {"bot_id": bot_id, "workspace_id": settings.COZE_WORKSPACE_ID}

View File

@@ -0,0 +1,44 @@
"""Coze客户端临时模拟等Agent-Coze实现后替换"""
import logging
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
class CozeClient:
"""
Coze客户端模拟类
TODO: 等Agent-Coze模块实现后这个类将被真实的Coze网关客户端替换
"""
async def create_conversation(
self, bot_id: str, user_id: str, meta_data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""创建会话(模拟)"""
logger.info(f"模拟创建Coze会话: bot_id={bot_id}, user_id={user_id}")
# 返回模拟的会话信息
return {
"conversation_id": f"mock_conversation_{user_id}_{bot_id[:8]}",
"bot_id": bot_id,
"status": "active",
}
async def send_message(
self, conversation_id: str, content: str, message_type: str = "text"
) -> Dict[str, Any]:
"""发送消息(模拟)"""
logger.info(f"模拟发送消息到会话 {conversation_id}: {content[:50]}...")
# 返回模拟的消息响应
return {
"message_id": f"mock_msg_{conversation_id[:8]}",
"content": f"这是对'{content[:30]}...'的模拟回复",
"role": "assistant",
}
async def end_conversation(self, conversation_id: str) -> Dict[str, Any]:
"""结束会话(模拟)"""
logger.info(f"模拟结束会话: {conversation_id}")
return {"status": "completed", "conversation_id": conversation_id}

View File

@@ -0,0 +1,101 @@
"""
Coze 服务异常定义
"""
from typing import Optional, Dict, Any
class CozeException(Exception):
"""Coze 服务基础异常"""
def __init__(
self,
message: str,
code: Optional[str] = None,
status_code: Optional[int] = None,
details: Optional[Dict[str, Any]] = None,
):
super().__init__(message)
self.message = message
self.code = code
self.status_code = status_code
self.details = details or {}
class CozeAuthError(CozeException):
"""认证异常"""
pass
class CozeAPIError(CozeException):
"""API 调用异常"""
pass
class CozeRateLimitError(CozeException):
"""速率限制异常"""
pass
class CozeTimeoutError(CozeException):
"""超时异常"""
pass
class CozeStreamError(CozeException):
"""流式响应异常"""
pass
def map_coze_error_to_exception(error: Exception) -> CozeException:
"""
将 Coze SDK 错误映射为统一异常
Args:
error: 原始异常
Returns:
CozeException: 映射后的异常
"""
error_message = str(error)
# 根据错误消息判断错误类型
if (
"authentication" in error_message.lower()
or "unauthorized" in error_message.lower()
):
return CozeAuthError(
message="Coze 认证失败",
code="COZE_AUTH_ERROR",
status_code=401,
details={"original_error": error_message},
)
if "rate limit" in error_message.lower():
return CozeRateLimitError(
message="Coze API 速率限制",
code="COZE_RATE_LIMIT",
status_code=429,
details={"original_error": error_message},
)
if "timeout" in error_message.lower():
return CozeTimeoutError(
message="Coze API 调用超时",
code="COZE_TIMEOUT",
status_code=504,
details={"original_error": error_message},
)
# 默认映射为 API 错误
return CozeAPIError(
message="Coze API 调用失败",
code="COZE_API_ERROR",
status_code=500,
details={"original_error": error_message},
)

Some files were not shown because too many files have changed in this diff Show More