""" 课程岗位分配服务 """ from typing import List, Optional from sqlalchemy import select, and_, delete, func from sqlalchemy.orm import selectinload from sqlalchemy.ext.asyncio import AsyncSession from app.core.logger import get_logger from app.models.position_course import PositionCourse from app.models.position import Position from app.models.position_member import PositionMember from app.schemas.course import CoursePositionAssignment, CoursePositionAssignmentInDB from app.services.base_service import BaseService logger = get_logger(__name__) class CoursePositionService(BaseService[PositionCourse]): """课程岗位分配服务""" def __init__(self): super().__init__(PositionCourse) async def get_course_positions( self, db: AsyncSession, course_id: int, course_type: Optional[str] = None ) -> List[CoursePositionAssignmentInDB]: """ 获取课程的岗位分配列表 Args: db: 数据库会话 course_id: 课程ID course_type: 课程类型筛选 Returns: 岗位分配列表 """ # 构建查询 conditions = [ PositionCourse.course_id == course_id, PositionCourse.is_deleted == False ] if course_type: conditions.append(PositionCourse.course_type == course_type) stmt = ( select(PositionCourse) .options(selectinload(PositionCourse.position)) .where(and_(*conditions)) .order_by(PositionCourse.priority, PositionCourse.id) ) result = await db.execute(stmt) assignments = result.scalars().all() # 转换为返回格式,并查询每个岗位的成员数量 result_list = [] for assignment in assignments: # 查询岗位成员数量 member_count = 0 if assignment.position_id: member_count_result = await db.execute( select(func.count(PositionMember.id)).where( and_( PositionMember.position_id == assignment.position_id, PositionMember.is_deleted == False ) ) ) member_count = member_count_result.scalar() or 0 result_list.append( CoursePositionAssignmentInDB( id=assignment.id, course_id=assignment.course_id, position_id=assignment.position_id, course_type=assignment.course_type, priority=assignment.priority, position_name=assignment.position.name if assignment.position else None, position_description=assignment.position.description if assignment.position else None, member_count=member_count ) ) return result_list async def batch_assign_positions( self, db: AsyncSession, course_id: int, assignments: List[CoursePositionAssignment], user_id: int ) -> List[CoursePositionAssignmentInDB]: """ 批量分配课程到岗位 Args: db: 数据库会话 course_id: 课程ID assignments: 岗位分配列表 user_id: 操作用户ID Returns: 分配结果列表 """ created_assignments = [] for assignment in assignments: # 检查是否已存在(注意:Result 只能消费一次,需保存结果) result = await db.execute( select(PositionCourse).where( PositionCourse.course_id == course_id, PositionCourse.position_id == assignment.position_id, PositionCourse.is_deleted == False, ) ) existing_assignment = result.scalar_one_or_none() if existing_assignment: # 已存在则更新类型与优先级 existing_assignment.course_type = assignment.course_type existing_assignment.priority = assignment.priority # PositionCourse 未继承 AuditMixin,不强制写入审计字段 created_assignments.append(existing_assignment) else: # 新建分配关系 new_assignment = PositionCourse( course_id=course_id, position_id=assignment.position_id, course_type=assignment.course_type, priority=assignment.priority, ) db.add(new_assignment) created_assignments.append(new_assignment) await db.commit() # 重新加载关联数据 for obj in created_assignments: await db.refresh(obj) logger.info("批量分配课程到岗位成功", course_id=course_id, count=len(assignments), user_id=user_id) # 返回分配结果 return await self.get_course_positions(db, course_id) async def remove_position_assignment( self, db: AsyncSession, course_id: int, position_id: int, user_id: int ) -> bool: """ 移除课程的岗位分配 Args: db: 数据库会话 course_id: 课程ID position_id: 岗位ID user_id: 操作用户ID Returns: 是否成功 """ # 查找分配记录 stmt = select(PositionCourse).where( PositionCourse.course_id == course_id, PositionCourse.position_id == position_id, PositionCourse.is_deleted == False ) result = await db.execute(stmt) assignment = result.scalar_one_or_none() if assignment: # 软删除 assignment.is_deleted = True assignment.deleted_by = user_id await db.commit() logger.info(f"移除课程岗位分配成功", course_id=course_id, position_id=position_id, user_id=user_id) return True return False # 创建服务实例 course_position_service = CoursePositionService()