- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
"""基础服务类"""
|
||
from typing import TypeVar, Generic, Type, Optional, List, Any
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, func
|
||
from sqlalchemy.sql import Select
|
||
|
||
from app.models.base import BaseModel
|
||
|
||
ModelType = TypeVar("ModelType", bound=BaseModel)
|
||
|
||
|
||
class BaseService(Generic[ModelType]):
|
||
"""
|
||
基础服务类,提供通用的CRUD操作
|
||
"""
|
||
|
||
def __init__(self, model: Type[ModelType]):
|
||
self.model = model
|
||
|
||
async def get(self, db: AsyncSession, id: int) -> Optional[ModelType]:
|
||
"""根据ID获取单个对象"""
|
||
result = await db.execute(select(self.model).where(self.model.id == id))
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_by_id(self, db: AsyncSession, id: int) -> Optional[ModelType]:
|
||
"""别名:按ID获取对象(兼容旧代码)"""
|
||
return await self.get(db, id)
|
||
|
||
async def get_multi(
|
||
self,
|
||
db: AsyncSession,
|
||
*,
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
query: Optional[Select] = None,
|
||
) -> List[ModelType]:
|
||
"""获取多个对象"""
|
||
if query is None:
|
||
query = select(self.model)
|
||
|
||
result = await db.execute(query.offset(skip).limit(limit))
|
||
return result.scalars().all()
|
||
|
||
async def count(self, db: AsyncSession, *, query: Optional[Select] = None) -> int:
|
||
"""统计数量"""
|
||
if query is None:
|
||
query = select(func.count()).select_from(self.model)
|
||
else:
|
||
query = select(func.count()).select_from(query.subquery())
|
||
|
||
result = await db.execute(query)
|
||
return result.scalar_one()
|
||
|
||
async def create(self, db: AsyncSession, *, obj_in: Any, **kwargs) -> ModelType:
|
||
"""创建对象"""
|
||
if hasattr(obj_in, "model_dump"):
|
||
create_data = obj_in.model_dump()
|
||
else:
|
||
create_data = obj_in
|
||
|
||
# 合并额外参数
|
||
create_data.update(kwargs)
|
||
|
||
db_obj = self.model(**create_data)
|
||
db.add(db_obj)
|
||
await db.commit()
|
||
await db.refresh(db_obj)
|
||
return db_obj
|
||
|
||
async def update(
|
||
self, db: AsyncSession, *, db_obj: ModelType, obj_in: Any, **kwargs
|
||
) -> ModelType:
|
||
"""更新对象"""
|
||
if hasattr(obj_in, "model_dump"):
|
||
update_data = obj_in.model_dump(exclude_unset=True)
|
||
else:
|
||
update_data = obj_in
|
||
|
||
# 合并额外参数(如 updated_by 等审计字段)
|
||
if kwargs:
|
||
update_data.update(kwargs)
|
||
|
||
for field, value in update_data.items():
|
||
setattr(db_obj, field, value)
|
||
|
||
db.add(db_obj)
|
||
await db.commit()
|
||
await db.refresh(db_obj)
|
||
return db_obj
|
||
|
||
async def delete(self, db: AsyncSession, *, id: int) -> bool:
|
||
"""删除对象"""
|
||
obj = await self.get(db, id)
|
||
if obj:
|
||
await db.delete(obj)
|
||
await db.commit()
|
||
return True
|
||
return False
|
||
|
||
async def soft_delete(self, db: AsyncSession, *, id: int) -> bool:
|
||
"""软删除对象"""
|
||
from datetime import datetime
|
||
|
||
obj = await self.get(db, id)
|
||
if obj and hasattr(obj, "is_deleted"):
|
||
obj.is_deleted = True
|
||
if hasattr(obj, "deleted_at"):
|
||
obj.deleted_at = datetime.now()
|
||
db.add(obj)
|
||
await db.commit()
|
||
return True
|
||
return False
|