From a2429329df3c463c50a692c54ffbc1f92d1b63f0 Mon Sep 17 00:00:00 2001 From: yuliang_guo Date: Fri, 30 Jan 2026 15:02:09 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20notification=5Fserv?= =?UTF-8?q?ice=20=E5=AF=BC=E5=85=A5=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加 NotificationServiceAdapter 适配器类,兼容 API 层调用方式 导出 notification_service 单例实例 --- backend/app/services/notification_service.py | 183 +++++++++++++++++++ 1 file changed, 183 insertions(+) diff --git a/backend/app/services/notification_service.py b/backend/app/services/notification_service.py index 3cec0b6..1d6cf41 100644 --- a/backend/app/services/notification_service.py +++ b/backend/app/services/notification_service.py @@ -417,3 +417,186 @@ class NotificationService: def get_notification_service(db: AsyncSession) -> NotificationService: """获取通知服务实例""" return NotificationService(db) + + +class NotificationServiceAdapter: + """ + 通知服务适配器 + + 提供静态方法接口,兼容 API 层的调用方式 + 每个方法接受 db 参数,内部创建 NotificationService 实例 + """ + + @staticmethod + async def get_user_notifications( + db: AsyncSession, + user_id: int, + skip: int = 0, + limit: int = 20, + is_read: Optional[bool] = None, + notification_type: Optional[str] = None, + ) -> tuple[List[Notification], int, int]: + """获取用户通知列表""" + from sqlalchemy import func + + # 构建查询条件 + conditions = [Notification.user_id == user_id] + if is_read is not None: + conditions.append(Notification.is_read == is_read) + if notification_type: + conditions.append(Notification.type == notification_type) + + # 查询总数 + total_stmt = select(func.count(Notification.id)).where(and_(*conditions)) + total = await db.scalar(total_stmt) or 0 + + # 查询未读数 + unread_stmt = select(func.count(Notification.id)).where( + and_(Notification.user_id == user_id, Notification.is_read == False) + ) + unread_count = await db.scalar(unread_stmt) or 0 + + # 查询列表 + list_stmt = ( + select(Notification) + .where(and_(*conditions)) + .order_by(Notification.created_at.desc()) + .offset(skip) + .limit(limit) + ) + result = await db.execute(list_stmt) + notifications = list(result.scalars().all()) + + return notifications, total, unread_count + + @staticmethod + async def get_unread_count( + db: AsyncSession, + user_id: int, + ) -> tuple[int, int]: + """获取未读通知数量""" + from sqlalchemy import func + + # 未读数 + unread_stmt = select(func.count(Notification.id)).where( + and_(Notification.user_id == user_id, Notification.is_read == False) + ) + unread_count = await db.scalar(unread_stmt) or 0 + + # 总数 + total_stmt = select(func.count(Notification.id)).where( + Notification.user_id == user_id + ) + total = await db.scalar(total_stmt) or 0 + + return unread_count, total + + @staticmethod + async def mark_as_read( + db: AsyncSession, + user_id: int, + notification_ids: Optional[List[int]] = None, + ) -> int: + """标记通知为已读""" + from sqlalchemy import update + + if notification_ids: + # 标记指定通知 + stmt = ( + update(Notification) + .where( + and_( + Notification.user_id == user_id, + Notification.id.in_(notification_ids), + Notification.is_read == False + ) + ) + .values(is_read=True, updated_at=datetime.now()) + ) + else: + # 标记所有未读 + stmt = ( + update(Notification) + .where( + and_( + Notification.user_id == user_id, + Notification.is_read == False + ) + ) + .values(is_read=True, updated_at=datetime.now()) + ) + + result = await db.execute(stmt) + await db.commit() + return result.rowcount + + @staticmethod + async def delete_notification( + db: AsyncSession, + user_id: int, + notification_id: int, + ) -> bool: + """删除通知""" + from sqlalchemy import delete + + stmt = delete(Notification).where( + and_( + Notification.id == notification_id, + Notification.user_id == user_id + ) + ) + result = await db.execute(stmt) + await db.commit() + return result.rowcount > 0 + + @staticmethod + async def create_notification( + db: AsyncSession, + notification_in: Any, + ) -> Notification: + """创建单条通知""" + notification = Notification( + user_id=notification_in.user_id, + title=notification_in.title, + content=notification_in.content, + type=getattr(notification_in, 'type', 'system'), + related_id=getattr(notification_in, 'related_id', None), + related_type=getattr(notification_in, 'related_type', None), + sender_id=getattr(notification_in, 'sender_id', None), + is_read=False, + ) + db.add(notification) + await db.commit() + await db.refresh(notification) + return notification + + @staticmethod + async def batch_create_notifications( + db: AsyncSession, + batch_in: Any, + ) -> List[Notification]: + """批量创建通知""" + notifications = [] + for user_id in batch_in.user_ids: + notification = Notification( + user_id=user_id, + title=batch_in.title, + content=batch_in.content, + type=getattr(batch_in, 'type', 'system'), + related_id=getattr(batch_in, 'related_id', None), + related_type=getattr(batch_in, 'related_type', None), + sender_id=getattr(batch_in, 'sender_id', None), + is_read=False, + ) + db.add(notification) + notifications.append(notification) + + await db.commit() + for n in notifications: + await db.refresh(n) + + return notifications + + +# 全局单例适配器 +notification_service = NotificationServiceAdapter()