fix: 修复 notification_service 导入错误
All checks were successful
continuous-integration/drone/push Build is passing

添加 NotificationServiceAdapter 适配器类,兼容 API 层调用方式
导出 notification_service 单例实例
This commit is contained in:
yuliang_guo
2026-01-30 15:02:09 +08:00
parent b1e6ca20fd
commit a2429329df

View File

@@ -417,3 +417,186 @@ class NotificationService:
def get_notification_service(db: AsyncSession) -> NotificationService: def get_notification_service(db: AsyncSession) -> NotificationService:
"""获取通知服务实例""" """获取通知服务实例"""
return NotificationService(db) return NotificationService(db)
class NotificationServiceAdapter:
"""
通知服务适配器
提供静态方法接口,兼容 API 层的调用方式
每个方法接受 db 参数,内部创建 NotificationService 实例
"""
@staticmethod
async def get_user_notifications(
db: AsyncSession,
user_id: int,
skip: int = 0,
limit: int = 20,
is_read: Optional[bool] = None,
notification_type: Optional[str] = None,
) -> tuple[List[Notification], int, int]:
"""获取用户通知列表"""
from sqlalchemy import func
# 构建查询条件
conditions = [Notification.user_id == user_id]
if is_read is not None:
conditions.append(Notification.is_read == is_read)
if notification_type:
conditions.append(Notification.type == notification_type)
# 查询总数
total_stmt = select(func.count(Notification.id)).where(and_(*conditions))
total = await db.scalar(total_stmt) or 0
# 查询未读数
unread_stmt = select(func.count(Notification.id)).where(
and_(Notification.user_id == user_id, Notification.is_read == False)
)
unread_count = await db.scalar(unread_stmt) or 0
# 查询列表
list_stmt = (
select(Notification)
.where(and_(*conditions))
.order_by(Notification.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(list_stmt)
notifications = list(result.scalars().all())
return notifications, total, unread_count
@staticmethod
async def get_unread_count(
db: AsyncSession,
user_id: int,
) -> tuple[int, int]:
"""获取未读通知数量"""
from sqlalchemy import func
# 未读数
unread_stmt = select(func.count(Notification.id)).where(
and_(Notification.user_id == user_id, Notification.is_read == False)
)
unread_count = await db.scalar(unread_stmt) or 0
# 总数
total_stmt = select(func.count(Notification.id)).where(
Notification.user_id == user_id
)
total = await db.scalar(total_stmt) or 0
return unread_count, total
@staticmethod
async def mark_as_read(
db: AsyncSession,
user_id: int,
notification_ids: Optional[List[int]] = None,
) -> int:
"""标记通知为已读"""
from sqlalchemy import update
if notification_ids:
# 标记指定通知
stmt = (
update(Notification)
.where(
and_(
Notification.user_id == user_id,
Notification.id.in_(notification_ids),
Notification.is_read == False
)
)
.values(is_read=True, updated_at=datetime.now())
)
else:
# 标记所有未读
stmt = (
update(Notification)
.where(
and_(
Notification.user_id == user_id,
Notification.is_read == False
)
)
.values(is_read=True, updated_at=datetime.now())
)
result = await db.execute(stmt)
await db.commit()
return result.rowcount
@staticmethod
async def delete_notification(
db: AsyncSession,
user_id: int,
notification_id: int,
) -> bool:
"""删除通知"""
from sqlalchemy import delete
stmt = delete(Notification).where(
and_(
Notification.id == notification_id,
Notification.user_id == user_id
)
)
result = await db.execute(stmt)
await db.commit()
return result.rowcount > 0
@staticmethod
async def create_notification(
db: AsyncSession,
notification_in: Any,
) -> Notification:
"""创建单条通知"""
notification = Notification(
user_id=notification_in.user_id,
title=notification_in.title,
content=notification_in.content,
type=getattr(notification_in, 'type', 'system'),
related_id=getattr(notification_in, 'related_id', None),
related_type=getattr(notification_in, 'related_type', None),
sender_id=getattr(notification_in, 'sender_id', None),
is_read=False,
)
db.add(notification)
await db.commit()
await db.refresh(notification)
return notification
@staticmethod
async def batch_create_notifications(
db: AsyncSession,
batch_in: Any,
) -> List[Notification]:
"""批量创建通知"""
notifications = []
for user_id in batch_in.user_ids:
notification = Notification(
user_id=user_id,
title=batch_in.title,
content=batch_in.content,
type=getattr(batch_in, 'type', 'system'),
related_id=getattr(batch_in, 'related_id', None),
related_type=getattr(batch_in, 'related_type', None),
sender_id=getattr(batch_in, 'sender_id', None),
is_read=False,
)
db.add(notification)
notifications.append(notification)
await db.commit()
for n in notifications:
await db.refresh(n)
return notifications
# 全局单例适配器
notification_service = NotificationServiceAdapter()