diff --git a/backend/app/services/task_service.py b/backend/app/services/task_service.py index 6e6aa61..450ac1f 100644 --- a/backend/app/services/task_service.py +++ b/backend/app/services/task_service.py @@ -5,7 +5,7 @@ from typing import List, Optional from datetime import datetime from sqlalchemy import select, func, and_, case from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, selectinload from app.models.task import Task, TaskCourse, TaskAssignment, TaskStatus, AssignmentStatus from app.models.course import Course from app.schemas.task import TaskCreate, TaskUpdate, TaskStatsResponse @@ -44,7 +44,14 @@ class TaskService(BaseService[Task]): db.add(assignment) await db.commit() - await db.refresh(task) + + # 重新查询并加载关联关系(避免懒加载问题) + stmt = select(Task).where(Task.id == task.id).options( + selectinload(Task.course_links).selectinload(TaskCourse.course), + selectinload(Task.assignments) + ) + result = await db.execute(stmt) + task = result.scalar_one() return task async def get_tasks( @@ -61,6 +68,12 @@ class TaskService(BaseService[Task]): stmt = stmt.where(Task.status == status) stmt = stmt.order_by(Task.created_at.desc()) + + # 加载关联关系 + stmt = stmt.options( + selectinload(Task.course_links).selectinload(TaskCourse.course), + selectinload(Task.assignments) + ) # 获取总数 count_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False) @@ -71,7 +84,7 @@ class TaskService(BaseService[Task]): # 分页 stmt = stmt.offset((page - 1) * page_size).limit(page_size) result = await db.execute(stmt) - tasks = result.scalars().all() + tasks = result.unique().scalars().all() return tasks, total