"""项目分类路由 实现分类的 CRUD 操作,支持树形结构 """ from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.database import get_db from app.models.category import Category from app.schemas.common import ResponseModel, PaginatedResponse, PaginatedData, ErrorCode from app.schemas.category import ( CategoryCreate, CategoryUpdate, CategoryResponse, CategoryTreeResponse, ) router = APIRouter() @router.get("", response_model=PaginatedResponse[CategoryResponse]) async def get_categories( page: int = Query(1, ge=1, description="页码"), page_size: int = Query(20, ge=1, le=100, description="每页数量"), parent_id: Optional[int] = Query(None, description="父分类ID筛选"), is_active: Optional[bool] = Query(None, description="是否启用筛选"), db: AsyncSession = Depends(get_db), ): """获取项目分类列表""" # 构建查询 query = select(Category) if parent_id is not None: query = query.where(Category.parent_id == parent_id) if is_active is not None: query = query.where(Category.is_active == is_active) query = query.order_by(Category.sort_order, Category.id) # 分页 offset = (page - 1) * page_size query = query.offset(offset).limit(page_size) result = await db.execute(query) categories = result.scalars().all() # 统计总数 count_query = select(func.count(Category.id)) if parent_id is not None: count_query = count_query.where(Category.parent_id == parent_id) if is_active is not None: count_query = count_query.where(Category.is_active == is_active) total_result = await db.execute(count_query) total = total_result.scalar() or 0 return PaginatedResponse( data=PaginatedData( items=[CategoryResponse.model_validate(c) for c in categories], total=total, page=page, page_size=page_size, total_pages=(total + page_size - 1) // page_size, ) ) @router.get("/tree", response_model=ResponseModel[List[CategoryTreeResponse]]) async def get_category_tree( is_active: Optional[bool] = Query(True, description="是否只返回启用的分类"), db: AsyncSession = Depends(get_db), ): """获取分类树形结构""" query = select(Category).options(selectinload(Category.children)) if is_active is not None: query = query.where(Category.is_active == is_active) # 只获取顶级分类 query = query.where(Category.parent_id.is_(None)) query = query.order_by(Category.sort_order, Category.id) result = await db.execute(query) categories = result.scalars().all() return ResponseModel(data=[CategoryTreeResponse.model_validate(c) for c in categories]) @router.get("/{category_id}", response_model=ResponseModel[CategoryResponse]) async def get_category( category_id: int, db: AsyncSession = Depends(get_db), ): """获取单个分类详情""" result = await db.execute(select(Category).where(Category.id == category_id)) category = result.scalar_one_or_none() if not category: raise HTTPException( status_code=404, detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"} ) return ResponseModel(data=CategoryResponse.model_validate(category)) @router.post("", response_model=ResponseModel[CategoryResponse]) async def create_category( data: CategoryCreate, db: AsyncSession = Depends(get_db), ): """创建项目分类""" # 检查父分类是否存在 if data.parent_id: parent_result = await db.execute( select(Category).where(Category.id == data.parent_id) ) if not parent_result.scalar_one_or_none(): raise HTTPException( status_code=400, detail={"code": ErrorCode.PARAM_ERROR, "message": "父分类不存在"} ) # 创建分类 category = Category(**data.model_dump()) db.add(category) await db.flush() await db.refresh(category) return ResponseModel(message="创建成功", data=CategoryResponse.model_validate(category)) @router.put("/{category_id}", response_model=ResponseModel[CategoryResponse]) async def update_category( category_id: int, data: CategoryUpdate, db: AsyncSession = Depends(get_db), ): """更新项目分类""" result = await db.execute(select(Category).where(Category.id == category_id)) category = result.scalar_one_or_none() if not category: raise HTTPException( status_code=404, detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"} ) # 检查父分类 if data.parent_id is not None and data.parent_id != category.parent_id: if data.parent_id == category_id: raise HTTPException( status_code=400, detail={"code": ErrorCode.PARAM_ERROR, "message": "不能将自己设为父分类"} ) parent_result = await db.execute( select(Category).where(Category.id == data.parent_id) ) if not parent_result.scalar_one_or_none(): raise HTTPException( status_code=400, detail={"code": ErrorCode.PARAM_ERROR, "message": "父分类不存在"} ) # 更新字段 update_data = data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(category, field, value) await db.flush() await db.refresh(category) return ResponseModel(message="更新成功", data=CategoryResponse.model_validate(category)) @router.delete("/{category_id}", response_model=ResponseModel) async def delete_category( category_id: int, db: AsyncSession = Depends(get_db), ): """删除项目分类""" result = await db.execute(select(Category).where(Category.id == category_id)) category = result.scalar_one_or_none() if not category: raise HTTPException( status_code=404, detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"} ) # 检查是否有子分类 children_result = await db.execute( select(func.count(Category.id)).where(Category.parent_id == category_id) ) children_count = children_result.scalar() or 0 if children_count > 0: raise HTTPException( status_code=400, detail={"code": ErrorCode.NOT_ALLOWED, "message": "该分类下有子分类,无法删除"} ) await db.delete(category) return ResponseModel(message="删除成功")