All checks were successful
continuous-integration/drone/push Build is passing
- create_task 和 get_tasks 现在使用 selectinload 预加载关联关系 - 避免懒加载导致的 MissingGreenlet 错误
228 lines
7.5 KiB
Python
228 lines
7.5 KiB
Python
"""
|
|
任务服务
|
|
"""
|
|
from typing import List, Optional
|
|
from datetime import datetime
|
|
from sqlalchemy import select, func, and_, case
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import joinedload, selectinload
|
|
from app.models.task import Task, TaskCourse, TaskAssignment, TaskStatus, AssignmentStatus
|
|
from app.models.course import Course
|
|
from app.schemas.task import TaskCreate, TaskUpdate, TaskStatsResponse
|
|
from app.services.base_service import BaseService
|
|
|
|
|
|
class TaskService(BaseService[Task]):
|
|
"""任务服务"""
|
|
|
|
def __init__(self):
|
|
super().__init__(Task)
|
|
|
|
async def create_task(self, db: AsyncSession, task_in: TaskCreate, creator_id: int) -> Task:
|
|
"""创建任务"""
|
|
# 创建任务
|
|
task = Task(
|
|
title=task_in.title,
|
|
description=task_in.description,
|
|
priority=task_in.priority,
|
|
deadline=task_in.deadline,
|
|
requirements=task_in.requirements,
|
|
creator_id=creator_id,
|
|
status=TaskStatus.PENDING
|
|
)
|
|
db.add(task)
|
|
await db.flush()
|
|
|
|
# 关联课程
|
|
for course_id in task_in.course_ids:
|
|
task_course = TaskCourse(task_id=task.id, course_id=course_id)
|
|
db.add(task_course)
|
|
|
|
# 分配用户
|
|
for user_id in task_in.user_ids:
|
|
assignment = TaskAssignment(task_id=task.id, user_id=user_id)
|
|
db.add(assignment)
|
|
|
|
await db.commit()
|
|
|
|
# 重新查询并加载关联关系(避免懒加载问题)
|
|
stmt = select(Task).where(Task.id == task.id).options(
|
|
selectinload(Task.course_links).selectinload(TaskCourse.course),
|
|
selectinload(Task.assignments)
|
|
)
|
|
result = await db.execute(stmt)
|
|
task = result.scalar_one()
|
|
return task
|
|
|
|
async def get_tasks(
|
|
self,
|
|
db: AsyncSession,
|
|
status: Optional[str] = None,
|
|
page: int = 1,
|
|
page_size: int = 20
|
|
) -> (List[Task], int):
|
|
"""获取任务列表"""
|
|
stmt = select(Task).where(Task.is_deleted == False)
|
|
|
|
if status:
|
|
stmt = stmt.where(Task.status == status)
|
|
|
|
stmt = stmt.order_by(Task.created_at.desc())
|
|
|
|
# 加载关联关系
|
|
stmt = stmt.options(
|
|
selectinload(Task.course_links).selectinload(TaskCourse.course),
|
|
selectinload(Task.assignments)
|
|
)
|
|
|
|
# 获取总数
|
|
count_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False)
|
|
if status:
|
|
count_stmt = count_stmt.where(Task.status == status)
|
|
total = (await db.execute(count_stmt)).scalar_one()
|
|
|
|
# 分页
|
|
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
|
result = await db.execute(stmt)
|
|
tasks = result.unique().scalars().all()
|
|
|
|
return tasks, total
|
|
|
|
async def get_task_detail(self, db: AsyncSession, task_id: int) -> Optional[Task]:
|
|
"""获取任务详情"""
|
|
stmt = select(Task).where(
|
|
and_(Task.id == task_id, Task.is_deleted == False)
|
|
).options(
|
|
joinedload(Task.course_links).joinedload(TaskCourse.course),
|
|
joinedload(Task.assignments)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.unique().scalar_one_or_none()
|
|
|
|
async def update_task(self, db: AsyncSession, task_id: int, task_in: TaskUpdate) -> Optional[Task]:
|
|
"""更新任务"""
|
|
stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False))
|
|
result = await db.execute(stmt)
|
|
task = result.scalar_one_or_none()
|
|
|
|
if not task:
|
|
return None
|
|
|
|
update_data = task_in.model_dump(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(task, field, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(task)
|
|
return task
|
|
|
|
async def delete_task(self, db: AsyncSession, task_id: int) -> bool:
|
|
"""删除任务(软删除)"""
|
|
stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False))
|
|
result = await db.execute(stmt)
|
|
task = result.scalar_one_or_none()
|
|
|
|
if not task:
|
|
return False
|
|
|
|
task.is_deleted = True
|
|
await db.commit()
|
|
return True
|
|
|
|
async def get_task_stats(self, db: AsyncSession) -> TaskStatsResponse:
|
|
"""获取任务统计"""
|
|
# 总任务数
|
|
total_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False)
|
|
total = (await db.execute(total_stmt)).scalar_one()
|
|
|
|
# 各状态任务数
|
|
status_stmt = select(
|
|
Task.status,
|
|
func.count(Task.id)
|
|
).where(Task.is_deleted == False).group_by(Task.status)
|
|
status_result = await db.execute(status_stmt)
|
|
status_counts = dict(status_result.all())
|
|
|
|
# 平均完成率
|
|
avg_stmt = select(func.avg(Task.progress)).where(
|
|
and_(Task.is_deleted == False, Task.status != TaskStatus.EXPIRED)
|
|
)
|
|
avg_completion = (await db.execute(avg_stmt)).scalar_one() or 0.0
|
|
|
|
return TaskStatsResponse(
|
|
total=total,
|
|
ongoing=status_counts.get(TaskStatus.ONGOING.value, 0),
|
|
completed=status_counts.get(TaskStatus.COMPLETED.value, 0),
|
|
expired=status_counts.get(TaskStatus.EXPIRED.value, 0),
|
|
avg_completion_rate=round(avg_completion, 1)
|
|
)
|
|
|
|
async def update_task_progress(self, db: AsyncSession, task_id: int) -> int:
|
|
"""
|
|
更新任务进度
|
|
|
|
计算已完成的分配数占总分配数的百分比
|
|
"""
|
|
# 统计总分配数和完成数
|
|
stmt = select(
|
|
func.count(TaskAssignment.id).label('total'),
|
|
func.sum(
|
|
case(
|
|
(TaskAssignment.status == AssignmentStatus.COMPLETED, 1),
|
|
else_=0
|
|
)
|
|
).label('completed')
|
|
).where(TaskAssignment.task_id == task_id)
|
|
|
|
result = (await db.execute(stmt)).first()
|
|
total = result.total or 0
|
|
completed = result.completed or 0
|
|
|
|
if total == 0:
|
|
progress = 0
|
|
else:
|
|
progress = int((completed / total) * 100)
|
|
|
|
# 更新任务进度
|
|
task_stmt = select(Task).where(and_(Task.id == task_id, Task.is_deleted == False))
|
|
task_result = await db.execute(task_stmt)
|
|
task = task_result.scalar_one_or_none()
|
|
|
|
if task:
|
|
task.progress = progress
|
|
await db.commit()
|
|
|
|
return progress
|
|
|
|
async def update_task_status(self, db: AsyncSession, task_id: int):
|
|
"""
|
|
更新任务状态
|
|
|
|
根据进度和截止时间自动更新任务状态
|
|
"""
|
|
task = await self.get_task_detail(db, task_id)
|
|
if not task:
|
|
return
|
|
|
|
# 计算并更新进度
|
|
progress = await self.update_task_progress(db, task_id)
|
|
|
|
# 自动更新状态
|
|
now = datetime.now()
|
|
|
|
if progress == 100:
|
|
# 完全完成
|
|
task.status = TaskStatus.COMPLETED
|
|
elif task.deadline and now > task.deadline and task.status != TaskStatus.COMPLETED:
|
|
# 已过期且未完成
|
|
task.status = TaskStatus.EXPIRED
|
|
elif progress > 0 and task.status == TaskStatus.PENDING:
|
|
# 已开始但未完成
|
|
task.status = TaskStatus.ONGOING
|
|
|
|
await db.commit()
|
|
|
|
|
|
task_service = TaskService()
|
|
|