""" 任务服务 """ from typing import List, Optional from datetime import datetime from sqlalchemy import select, func, and_ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from app.models.task import Task, TaskCourse, TaskAssignment, TaskStatus, AssignmentStatus from app.models.course import Course from app.schemas.task import TaskCreate, TaskUpdate, TaskStatsResponse from app.services.base_service import BaseService class TaskService(BaseService[Task]): """任务服务""" def __init__(self): super().__init__(Task) async def create_task(self, db: AsyncSession, task_in: TaskCreate, creator_id: int) -> Task: """创建任务""" # 创建任务 task = Task( title=task_in.title, description=task_in.description, priority=task_in.priority, deadline=task_in.deadline, requirements=task_in.requirements, creator_id=creator_id, status=TaskStatus.PENDING ) db.add(task) await db.flush() # 关联课程 for course_id in task_in.course_ids: task_course = TaskCourse(task_id=task.id, course_id=course_id) db.add(task_course) # 分配用户 for user_id in task_in.user_ids: assignment = TaskAssignment(task_id=task.id, user_id=user_id) db.add(assignment) await db.commit() await db.refresh(task) return task async def get_tasks( self, db: AsyncSession, status: Optional[str] = None, page: int = 1, page_size: int = 20 ) -> (List[Task], int): """获取任务列表""" stmt = select(Task).where(Task.is_deleted == False) if status: stmt = stmt.where(Task.status == status) stmt = stmt.order_by(Task.created_at.desc()) # 获取总数 count_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False) if status: count_stmt = count_stmt.where(Task.status == status) total = (await db.execute(count_stmt)).scalar_one() # 分页 stmt = stmt.offset((page - 1) * page_size).limit(page_size) result = await db.execute(stmt) tasks = result.scalars().all() return tasks, total async def get_task_detail(self, db: AsyncSession, task_id: int) -> Optional[Task]: """获取任务详情""" stmt = select(Task).where( and_(Task.id == task_id, Task.is_deleted == False) ).options( joinedload(Task.course_links).joinedload(TaskCourse.course), joinedload(Task.assignments) ) result = await db.execute(stmt) return result.unique().scalar_one_or_none() async def update_task(self, db: AsyncSession, task_id: int, task_in: TaskUpdate) -> Optional[Task]: """更新任务""" stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False)) result = await db.execute(stmt) task = result.scalar_one_or_none() if not task: return None update_data = task_in.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(task, field, value) await db.commit() await db.refresh(task) return task async def delete_task(self, db: AsyncSession, task_id: int) -> bool: """删除任务(软删除)""" stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False)) result = await db.execute(stmt) task = result.scalar_one_or_none() if not task: return False task.is_deleted = True await db.commit() return True async def get_task_stats(self, db: AsyncSession) -> TaskStatsResponse: """获取任务统计""" # 总任务数 total_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False) total = (await db.execute(total_stmt)).scalar_one() # 各状态任务数 status_stmt = select( Task.status, func.count(Task.id) ).where(Task.is_deleted == False).group_by(Task.status) status_result = await db.execute(status_stmt) status_counts = dict(status_result.all()) # 平均完成率 avg_stmt = select(func.avg(Task.progress)).where( and_(Task.is_deleted == False, Task.status != TaskStatus.EXPIRED) ) avg_completion = (await db.execute(avg_stmt)).scalar_one() or 0.0 return TaskStatsResponse( total=total, ongoing=status_counts.get(TaskStatus.ONGOING.value, 0), completed=status_counts.get(TaskStatus.COMPLETED.value, 0), expired=status_counts.get(TaskStatus.EXPIRED.value, 0), avg_completion_rate=round(avg_completion, 1) ) async def update_task_progress(self, db: AsyncSession, task_id: int) -> int: """ 更新任务进度 计算已完成的分配数占总分配数的百分比 """ # 统计总分配数和完成数 stmt = select( func.count(TaskAssignment.id).label('total'), func.sum( func.case( (TaskAssignment.status == AssignmentStatus.COMPLETED, 1), else_=0 ) ).label('completed') ).where(TaskAssignment.task_id == task_id) result = (await db.execute(stmt)).first() total = result.total or 0 completed = result.completed or 0 if total == 0: progress = 0 else: progress = int((completed / total) * 100) # 更新任务进度 task_stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False)) task_result = await db.execute(task_stmt) task = task_result.scalar_one_or_none() if task: task.progress = progress await db.commit() return progress async def update_task_status(self, db: AsyncSession, task_id: int): """ 更新任务状态 根据进度和截止时间自动更新任务状态 """ task = await self.get_task_detail(db, task_id) if not task: return # 计算并更新进度 progress = await self.update_task_progress(db, task_id) # 自动更新状态 now = datetime.now() if progress == 100: # 完全完成 task.status = TaskStatus.COMPLETED elif task.deadline and now > task.deadline and task.status != TaskStatus.COMPLETED: # 已过期且未完成 task.status = TaskStatus.EXPIRED elif progress > 0 and task.status == TaskStatus.PENDING: # 已开始但未完成 task.status = TaskStatus.ONGOING await db.commit() task_service = TaskService()