fix: 修复任务服务SQLAlchemy异步加载错误
All checks were successful
continuous-integration/drone/push Build is passing

- create_task 和 get_tasks 现在使用 selectinload 预加载关联关系
- 避免懒加载导致的 MissingGreenlet 错误
This commit is contained in:
yuliang_guo
2026-01-31 18:31:07 +08:00
parent eca0ed8c9d
commit fc9775e61f

View File

@@ -5,7 +5,7 @@ from typing import List, Optional
from datetime import datetime from datetime import datetime
from sqlalchemy import select, func, and_, case from sqlalchemy import select, func, and_, case
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload, selectinload
from app.models.task import Task, TaskCourse, TaskAssignment, TaskStatus, AssignmentStatus from app.models.task import Task, TaskCourse, TaskAssignment, TaskStatus, AssignmentStatus
from app.models.course import Course from app.models.course import Course
from app.schemas.task import TaskCreate, TaskUpdate, TaskStatsResponse from app.schemas.task import TaskCreate, TaskUpdate, TaskStatsResponse
@@ -44,7 +44,14 @@ class TaskService(BaseService[Task]):
db.add(assignment) db.add(assignment)
await db.commit() await db.commit()
await db.refresh(task)
# 重新查询并加载关联关系(避免懒加载问题)
stmt = select(Task).where(Task.id == task.id).options(
selectinload(Task.course_links).selectinload(TaskCourse.course),
selectinload(Task.assignments)
)
result = await db.execute(stmt)
task = result.scalar_one()
return task return task
async def get_tasks( async def get_tasks(
@@ -61,6 +68,12 @@ class TaskService(BaseService[Task]):
stmt = stmt.where(Task.status == status) stmt = stmt.where(Task.status == status)
stmt = stmt.order_by(Task.created_at.desc()) stmt = stmt.order_by(Task.created_at.desc())
# 加载关联关系
stmt = stmt.options(
selectinload(Task.course_links).selectinload(TaskCourse.course),
selectinload(Task.assignments)
)
# 获取总数 # 获取总数
count_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False) count_stmt = select(func.count()).select_from(Task).where(Task.is_deleted == False)
@@ -71,7 +84,7 @@ class TaskService(BaseService[Task]):
# 分页 # 分页
stmt = stmt.offset((page - 1) * page_size).limit(page_size) stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await db.execute(stmt) result = await db.execute(stmt)
tasks = result.scalars().all() tasks = result.unique().scalars().all()
return tasks, total return tasks, total