Files
012-kaopeilian/backend/app/services/base_service.py
111 998211c483 feat: 初始化考培练系统项目
- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
2026-01-24 19:33:28 +08:00

113 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""基础服务类"""
from typing import TypeVar, Generic, Type, Optional, List, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from sqlalchemy.sql import Select
from app.models.base import BaseModel
ModelType = TypeVar("ModelType", bound=BaseModel)
class BaseService(Generic[ModelType]):
"""
基础服务类提供通用的CRUD操作
"""
def __init__(self, model: Type[ModelType]):
self.model = model
async def get(self, db: AsyncSession, id: int) -> Optional[ModelType]:
"""根据ID获取单个对象"""
result = await db.execute(select(self.model).where(self.model.id == id))
return result.scalar_one_or_none()
async def get_by_id(self, db: AsyncSession, id: int) -> Optional[ModelType]:
"""别名按ID获取对象兼容旧代码"""
return await self.get(db, id)
async def get_multi(
self,
db: AsyncSession,
*,
skip: int = 0,
limit: int = 100,
query: Optional[Select] = None,
) -> List[ModelType]:
"""获取多个对象"""
if query is None:
query = select(self.model)
result = await db.execute(query.offset(skip).limit(limit))
return result.scalars().all()
async def count(self, db: AsyncSession, *, query: Optional[Select] = None) -> int:
"""统计数量"""
if query is None:
query = select(func.count()).select_from(self.model)
else:
query = select(func.count()).select_from(query.subquery())
result = await db.execute(query)
return result.scalar_one()
async def create(self, db: AsyncSession, *, obj_in: Any, **kwargs) -> ModelType:
"""创建对象"""
if hasattr(obj_in, "model_dump"):
create_data = obj_in.model_dump()
else:
create_data = obj_in
# 合并额外参数
create_data.update(kwargs)
db_obj = self.model(**create_data)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def update(
self, db: AsyncSession, *, db_obj: ModelType, obj_in: Any, **kwargs
) -> ModelType:
"""更新对象"""
if hasattr(obj_in, "model_dump"):
update_data = obj_in.model_dump(exclude_unset=True)
else:
update_data = obj_in
# 合并额外参数(如 updated_by 等审计字段)
if kwargs:
update_data.update(kwargs)
for field, value in update_data.items():
setattr(db_obj, field, value)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
async def delete(self, db: AsyncSession, *, id: int) -> bool:
"""删除对象"""
obj = await self.get(db, id)
if obj:
await db.delete(obj)
await db.commit()
return True
return False
async def soft_delete(self, db: AsyncSession, *, id: int) -> bool:
"""软删除对象"""
from datetime import datetime
obj = await self.get(db, id)
if obj and hasattr(obj, "is_deleted"):
obj.is_deleted = True
if hasattr(obj, "deleted_at"):
obj.deleted_at = datetime.now()
db.add(obj)
await db.commit()
return True
return False