fix: 修复 notification_service 导入错误
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
添加 NotificationServiceAdapter 适配器类,兼容 API 层调用方式 导出 notification_service 单例实例
This commit is contained in:
@@ -417,3 +417,186 @@ class NotificationService:
|
||||
def get_notification_service(db: AsyncSession) -> NotificationService:
|
||||
"""获取通知服务实例"""
|
||||
return NotificationService(db)
|
||||
|
||||
|
||||
class NotificationServiceAdapter:
|
||||
"""
|
||||
通知服务适配器
|
||||
|
||||
提供静态方法接口,兼容 API 层的调用方式
|
||||
每个方法接受 db 参数,内部创建 NotificationService 实例
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def get_user_notifications(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
is_read: Optional[bool] = None,
|
||||
notification_type: Optional[str] = None,
|
||||
) -> tuple[List[Notification], int, int]:
|
||||
"""获取用户通知列表"""
|
||||
from sqlalchemy import func
|
||||
|
||||
# 构建查询条件
|
||||
conditions = [Notification.user_id == user_id]
|
||||
if is_read is not None:
|
||||
conditions.append(Notification.is_read == is_read)
|
||||
if notification_type:
|
||||
conditions.append(Notification.type == notification_type)
|
||||
|
||||
# 查询总数
|
||||
total_stmt = select(func.count(Notification.id)).where(and_(*conditions))
|
||||
total = await db.scalar(total_stmt) or 0
|
||||
|
||||
# 查询未读数
|
||||
unread_stmt = select(func.count(Notification.id)).where(
|
||||
and_(Notification.user_id == user_id, Notification.is_read == False)
|
||||
)
|
||||
unread_count = await db.scalar(unread_stmt) or 0
|
||||
|
||||
# 查询列表
|
||||
list_stmt = (
|
||||
select(Notification)
|
||||
.where(and_(*conditions))
|
||||
.order_by(Notification.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(list_stmt)
|
||||
notifications = list(result.scalars().all())
|
||||
|
||||
return notifications, total, unread_count
|
||||
|
||||
@staticmethod
|
||||
async def get_unread_count(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
) -> tuple[int, int]:
|
||||
"""获取未读通知数量"""
|
||||
from sqlalchemy import func
|
||||
|
||||
# 未读数
|
||||
unread_stmt = select(func.count(Notification.id)).where(
|
||||
and_(Notification.user_id == user_id, Notification.is_read == False)
|
||||
)
|
||||
unread_count = await db.scalar(unread_stmt) or 0
|
||||
|
||||
# 总数
|
||||
total_stmt = select(func.count(Notification.id)).where(
|
||||
Notification.user_id == user_id
|
||||
)
|
||||
total = await db.scalar(total_stmt) or 0
|
||||
|
||||
return unread_count, total
|
||||
|
||||
@staticmethod
|
||||
async def mark_as_read(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
notification_ids: Optional[List[int]] = None,
|
||||
) -> int:
|
||||
"""标记通知为已读"""
|
||||
from sqlalchemy import update
|
||||
|
||||
if notification_ids:
|
||||
# 标记指定通知
|
||||
stmt = (
|
||||
update(Notification)
|
||||
.where(
|
||||
and_(
|
||||
Notification.user_id == user_id,
|
||||
Notification.id.in_(notification_ids),
|
||||
Notification.is_read == False
|
||||
)
|
||||
)
|
||||
.values(is_read=True, updated_at=datetime.now())
|
||||
)
|
||||
else:
|
||||
# 标记所有未读
|
||||
stmt = (
|
||||
update(Notification)
|
||||
.where(
|
||||
and_(
|
||||
Notification.user_id == user_id,
|
||||
Notification.is_read == False
|
||||
)
|
||||
)
|
||||
.values(is_read=True, updated_at=datetime.now())
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
return result.rowcount
|
||||
|
||||
@staticmethod
|
||||
async def delete_notification(
|
||||
db: AsyncSession,
|
||||
user_id: int,
|
||||
notification_id: int,
|
||||
) -> bool:
|
||||
"""删除通知"""
|
||||
from sqlalchemy import delete
|
||||
|
||||
stmt = delete(Notification).where(
|
||||
and_(
|
||||
Notification.id == notification_id,
|
||||
Notification.user_id == user_id
|
||||
)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
@staticmethod
|
||||
async def create_notification(
|
||||
db: AsyncSession,
|
||||
notification_in: Any,
|
||||
) -> Notification:
|
||||
"""创建单条通知"""
|
||||
notification = Notification(
|
||||
user_id=notification_in.user_id,
|
||||
title=notification_in.title,
|
||||
content=notification_in.content,
|
||||
type=getattr(notification_in, 'type', 'system'),
|
||||
related_id=getattr(notification_in, 'related_id', None),
|
||||
related_type=getattr(notification_in, 'related_type', None),
|
||||
sender_id=getattr(notification_in, 'sender_id', None),
|
||||
is_read=False,
|
||||
)
|
||||
db.add(notification)
|
||||
await db.commit()
|
||||
await db.refresh(notification)
|
||||
return notification
|
||||
|
||||
@staticmethod
|
||||
async def batch_create_notifications(
|
||||
db: AsyncSession,
|
||||
batch_in: Any,
|
||||
) -> List[Notification]:
|
||||
"""批量创建通知"""
|
||||
notifications = []
|
||||
for user_id in batch_in.user_ids:
|
||||
notification = Notification(
|
||||
user_id=user_id,
|
||||
title=batch_in.title,
|
||||
content=batch_in.content,
|
||||
type=getattr(batch_in, 'type', 'system'),
|
||||
related_id=getattr(batch_in, 'related_id', None),
|
||||
related_type=getattr(batch_in, 'related_type', None),
|
||||
sender_id=getattr(batch_in, 'sender_id', None),
|
||||
is_read=False,
|
||||
)
|
||||
db.add(notification)
|
||||
notifications.append(notification)
|
||||
|
||||
await db.commit()
|
||||
for n in notifications:
|
||||
await db.refresh(n)
|
||||
|
||||
return notifications
|
||||
|
||||
|
||||
# 全局单例适配器
|
||||
notification_service = NotificationServiceAdapter()
|
||||
|
||||
Reference in New Issue
Block a user