""" 站内消息通知服务 提供通知的CRUD操作和业务逻辑 """ from typing import List, Optional, Tuple from sqlalchemy import select, and_, desc, func, update from sqlalchemy.orm import selectinload from sqlalchemy.ext.asyncio import AsyncSession from app.core.logger import get_logger from app.models.notification import Notification from app.models.user import User from app.schemas.notification import ( NotificationCreate, NotificationBatchCreate, NotificationResponse, NotificationType, ) from app.services.base_service import BaseService logger = get_logger(__name__) class NotificationService(BaseService[Notification]): """ 站内消息通知服务 提供通知的创建、查询、标记已读等功能 """ def __init__(self): super().__init__(Notification) async def create_notification( self, db: AsyncSession, notification_in: NotificationCreate ) -> Notification: """ 创建单个通知 Args: db: 数据库会话 notification_in: 通知创建数据 Returns: 创建的通知对象 """ notification = Notification( user_id=notification_in.user_id, title=notification_in.title, content=notification_in.content, type=notification_in.type.value if isinstance(notification_in.type, NotificationType) else notification_in.type, related_id=notification_in.related_id, related_type=notification_in.related_type, sender_id=notification_in.sender_id, is_read=False ) db.add(notification) await db.commit() await db.refresh(notification) logger.info( "创建通知成功", notification_id=notification.id, user_id=notification_in.user_id, type=notification_in.type ) return notification async def batch_create_notifications( self, db: AsyncSession, batch_in: NotificationBatchCreate ) -> List[Notification]: """ 批量创建通知(发送给多个用户) Args: db: 数据库会话 batch_in: 批量通知创建数据 Returns: 创建的通知列表 """ notifications = [] notification_type = batch_in.type.value if isinstance(batch_in.type, NotificationType) else batch_in.type for user_id in batch_in.user_ids: notification = Notification( user_id=user_id, title=batch_in.title, content=batch_in.content, type=notification_type, related_id=batch_in.related_id, related_type=batch_in.related_type, sender_id=batch_in.sender_id, is_read=False ) notifications.append(notification) db.add(notification) await db.commit() # 刷新所有对象 for notification in notifications: await db.refresh(notification) logger.info( "批量创建通知成功", count=len(notifications), user_ids=batch_in.user_ids, type=batch_in.type ) return notifications async def get_user_notifications( self, db: AsyncSession, user_id: int, skip: int = 0, limit: int = 20, is_read: Optional[bool] = None, notification_type: Optional[str] = None ) -> Tuple[List[NotificationResponse], int, int]: """ 获取用户的通知列表 Args: db: 数据库会话 user_id: 用户ID skip: 跳过数量 limit: 返回数量 is_read: 是否已读筛选 notification_type: 通知类型筛选 Returns: (通知列表, 总数, 未读数) """ # 构建基础查询条件 conditions = [Notification.user_id == user_id] if is_read is not None: conditions.append(Notification.is_read == is_read) if notification_type: conditions.append(Notification.type == notification_type) # 查询通知列表(带发送者信息) stmt = ( select(Notification) .where(and_(*conditions)) .order_by(desc(Notification.created_at)) .offset(skip) .limit(limit) ) result = await db.execute(stmt) notifications = result.scalars().all() # 统计总数 count_stmt = select(func.count()).select_from(Notification).where(and_(*conditions)) total_result = await db.execute(count_stmt) total = total_result.scalar_one() # 统计未读数 unread_stmt = ( select(func.count()) .select_from(Notification) .where(and_(Notification.user_id == user_id, Notification.is_read == False)) ) unread_result = await db.execute(unread_stmt) unread_count = unread_result.scalar_one() # 获取发送者信息 sender_ids = [n.sender_id for n in notifications if n.sender_id] sender_names = {} if sender_ids: sender_stmt = select(User.id, User.full_name).where(User.id.in_(sender_ids)) sender_result = await db.execute(sender_stmt) sender_names = {row[0]: row[1] for row in sender_result.fetchall()} # 构建响应 responses = [] for notification in notifications: response = NotificationResponse( id=notification.id, user_id=notification.user_id, title=notification.title, content=notification.content, type=notification.type, is_read=notification.is_read, related_id=notification.related_id, related_type=notification.related_type, sender_id=notification.sender_id, sender_name=sender_names.get(notification.sender_id) if notification.sender_id else None, created_at=notification.created_at, updated_at=notification.updated_at ) responses.append(response) return responses, total, unread_count async def get_unread_count( self, db: AsyncSession, user_id: int ) -> Tuple[int, int]: """ 获取用户未读通知数量 Args: db: 数据库会话 user_id: 用户ID Returns: (未读数, 总数) """ # 统计未读数 unread_stmt = ( select(func.count()) .select_from(Notification) .where(and_(Notification.user_id == user_id, Notification.is_read == False)) ) unread_result = await db.execute(unread_stmt) unread_count = unread_result.scalar_one() # 统计总数 total_stmt = ( select(func.count()) .select_from(Notification) .where(Notification.user_id == user_id) ) total_result = await db.execute(total_stmt) total = total_result.scalar_one() return unread_count, total async def mark_as_read( self, db: AsyncSession, user_id: int, notification_ids: Optional[List[int]] = None ) -> int: """ 标记通知为已读 Args: db: 数据库会话 user_id: 用户ID notification_ids: 通知ID列表,为空则标记全部 Returns: 更新的数量 """ conditions = [ Notification.user_id == user_id, Notification.is_read == False ] if notification_ids: conditions.append(Notification.id.in_(notification_ids)) stmt = ( update(Notification) .where(and_(*conditions)) .values(is_read=True) ) result = await db.execute(stmt) await db.commit() updated_count = result.rowcount logger.info( "标记通知已读", user_id=user_id, notification_ids=notification_ids, updated_count=updated_count ) return updated_count async def delete_notification( self, db: AsyncSession, user_id: int, notification_id: int ) -> bool: """ 删除通知 Args: db: 数据库会话 user_id: 用户ID notification_id: 通知ID Returns: 是否删除成功 """ stmt = select(Notification).where( and_( Notification.id == notification_id, Notification.user_id == user_id ) ) result = await db.execute(stmt) notification = result.scalar_one_or_none() if notification: await db.delete(notification) await db.commit() logger.info( "删除通知成功", notification_id=notification_id, user_id=user_id ) return True return False # 创建服务实例 notification_service = NotificationService()