"""基础服务类""" from typing import TypeVar, Generic, Type, Optional, List, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func from sqlalchemy.sql import Select from app.models.base import BaseModel ModelType = TypeVar("ModelType", bound=BaseModel) class BaseService(Generic[ModelType]): """ 基础服务类,提供通用的CRUD操作 """ def __init__(self, model: Type[ModelType]): self.model = model async def get(self, db: AsyncSession, id: int) -> Optional[ModelType]: """根据ID获取单个对象""" result = await db.execute(select(self.model).where(self.model.id == id)) return result.scalar_one_or_none() async def get_by_id(self, db: AsyncSession, id: int) -> Optional[ModelType]: """别名:按ID获取对象(兼容旧代码)""" return await self.get(db, id) async def get_multi( self, db: AsyncSession, *, skip: int = 0, limit: int = 100, query: Optional[Select] = None, ) -> List[ModelType]: """获取多个对象""" if query is None: query = select(self.model) result = await db.execute(query.offset(skip).limit(limit)) return result.scalars().all() async def count(self, db: AsyncSession, *, query: Optional[Select] = None) -> int: """统计数量""" if query is None: query = select(func.count()).select_from(self.model) else: query = select(func.count()).select_from(query.subquery()) result = await db.execute(query) return result.scalar_one() async def create(self, db: AsyncSession, *, obj_in: Any, **kwargs) -> ModelType: """创建对象""" if hasattr(obj_in, "model_dump"): create_data = obj_in.model_dump() else: create_data = obj_in # 合并额外参数 create_data.update(kwargs) db_obj = self.model(**create_data) db.add(db_obj) await db.commit() await db.refresh(db_obj) return db_obj async def update( self, db: AsyncSession, *, db_obj: ModelType, obj_in: Any, **kwargs ) -> ModelType: """更新对象""" if hasattr(obj_in, "model_dump"): update_data = obj_in.model_dump(exclude_unset=True) else: update_data = obj_in # 合并额外参数(如 updated_by 等审计字段) if kwargs: update_data.update(kwargs) for field, value in update_data.items(): setattr(db_obj, field, value) db.add(db_obj) await db.commit() await db.refresh(db_obj) return db_obj async def delete(self, db: AsyncSession, *, id: int) -> bool: """删除对象""" obj = await self.get(db, id) if obj: await db.delete(obj) await db.commit() return True return False async def soft_delete(self, db: AsyncSession, *, id: int) -> bool: """软删除对象""" from datetime import datetime obj = await self.get(db, id) if obj and hasattr(obj, "is_deleted"): obj.is_deleted = True if hasattr(obj, "deleted_at"): obj.deleted_at = datetime.now() db.add(obj) await db.commit() return True return False