- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
331 lines
9.6 KiB
Python
331 lines
9.6 KiB
Python
"""
|
||
站内消息通知服务
|
||
提供通知的CRUD操作和业务逻辑
|
||
"""
|
||
from typing import List, Optional, Tuple
|
||
from sqlalchemy import select, and_, desc, func, update
|
||
from sqlalchemy.orm import selectinload
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.logger import get_logger
|
||
from app.models.notification import Notification
|
||
from app.models.user import User
|
||
from app.schemas.notification import (
|
||
NotificationCreate,
|
||
NotificationBatchCreate,
|
||
NotificationResponse,
|
||
NotificationType,
|
||
)
|
||
from app.services.base_service import BaseService
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class NotificationService(BaseService[Notification]):
|
||
"""
|
||
站内消息通知服务
|
||
|
||
提供通知的创建、查询、标记已读等功能
|
||
"""
|
||
|
||
def __init__(self):
|
||
super().__init__(Notification)
|
||
|
||
async def create_notification(
|
||
self,
|
||
db: AsyncSession,
|
||
notification_in: NotificationCreate
|
||
) -> Notification:
|
||
"""
|
||
创建单个通知
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
notification_in: 通知创建数据
|
||
|
||
Returns:
|
||
创建的通知对象
|
||
"""
|
||
notification = Notification(
|
||
user_id=notification_in.user_id,
|
||
title=notification_in.title,
|
||
content=notification_in.content,
|
||
type=notification_in.type.value if isinstance(notification_in.type, NotificationType) else notification_in.type,
|
||
related_id=notification_in.related_id,
|
||
related_type=notification_in.related_type,
|
||
sender_id=notification_in.sender_id,
|
||
is_read=False
|
||
)
|
||
|
||
db.add(notification)
|
||
await db.commit()
|
||
await db.refresh(notification)
|
||
|
||
logger.info(
|
||
"创建通知成功",
|
||
notification_id=notification.id,
|
||
user_id=notification_in.user_id,
|
||
type=notification_in.type
|
||
)
|
||
|
||
return notification
|
||
|
||
async def batch_create_notifications(
|
||
self,
|
||
db: AsyncSession,
|
||
batch_in: NotificationBatchCreate
|
||
) -> List[Notification]:
|
||
"""
|
||
批量创建通知(发送给多个用户)
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
batch_in: 批量通知创建数据
|
||
|
||
Returns:
|
||
创建的通知列表
|
||
"""
|
||
notifications = []
|
||
notification_type = batch_in.type.value if isinstance(batch_in.type, NotificationType) else batch_in.type
|
||
|
||
for user_id in batch_in.user_ids:
|
||
notification = Notification(
|
||
user_id=user_id,
|
||
title=batch_in.title,
|
||
content=batch_in.content,
|
||
type=notification_type,
|
||
related_id=batch_in.related_id,
|
||
related_type=batch_in.related_type,
|
||
sender_id=batch_in.sender_id,
|
||
is_read=False
|
||
)
|
||
notifications.append(notification)
|
||
db.add(notification)
|
||
|
||
await db.commit()
|
||
|
||
# 刷新所有对象
|
||
for notification in notifications:
|
||
await db.refresh(notification)
|
||
|
||
logger.info(
|
||
"批量创建通知成功",
|
||
count=len(notifications),
|
||
user_ids=batch_in.user_ids,
|
||
type=batch_in.type
|
||
)
|
||
|
||
return notifications
|
||
|
||
async def get_user_notifications(
|
||
self,
|
||
db: AsyncSession,
|
||
user_id: int,
|
||
skip: int = 0,
|
||
limit: int = 20,
|
||
is_read: Optional[bool] = None,
|
||
notification_type: Optional[str] = None
|
||
) -> Tuple[List[NotificationResponse], int, int]:
|
||
"""
|
||
获取用户的通知列表
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
user_id: 用户ID
|
||
skip: 跳过数量
|
||
limit: 返回数量
|
||
is_read: 是否已读筛选
|
||
notification_type: 通知类型筛选
|
||
|
||
Returns:
|
||
(通知列表, 总数, 未读数)
|
||
"""
|
||
# 构建基础查询条件
|
||
conditions = [Notification.user_id == user_id]
|
||
|
||
if is_read is not None:
|
||
conditions.append(Notification.is_read == is_read)
|
||
|
||
if notification_type:
|
||
conditions.append(Notification.type == notification_type)
|
||
|
||
# 查询通知列表(带发送者信息)
|
||
stmt = (
|
||
select(Notification)
|
||
.where(and_(*conditions))
|
||
.order_by(desc(Notification.created_at))
|
||
.offset(skip)
|
||
.limit(limit)
|
||
)
|
||
|
||
result = await db.execute(stmt)
|
||
notifications = result.scalars().all()
|
||
|
||
# 统计总数
|
||
count_stmt = select(func.count()).select_from(Notification).where(and_(*conditions))
|
||
total_result = await db.execute(count_stmt)
|
||
total = total_result.scalar_one()
|
||
|
||
# 统计未读数
|
||
unread_stmt = (
|
||
select(func.count())
|
||
.select_from(Notification)
|
||
.where(and_(Notification.user_id == user_id, Notification.is_read == False))
|
||
)
|
||
unread_result = await db.execute(unread_stmt)
|
||
unread_count = unread_result.scalar_one()
|
||
|
||
# 获取发送者信息
|
||
sender_ids = [n.sender_id for n in notifications if n.sender_id]
|
||
sender_names = {}
|
||
if sender_ids:
|
||
sender_stmt = select(User.id, User.full_name).where(User.id.in_(sender_ids))
|
||
sender_result = await db.execute(sender_stmt)
|
||
sender_names = {row[0]: row[1] for row in sender_result.fetchall()}
|
||
|
||
# 构建响应
|
||
responses = []
|
||
for notification in notifications:
|
||
response = NotificationResponse(
|
||
id=notification.id,
|
||
user_id=notification.user_id,
|
||
title=notification.title,
|
||
content=notification.content,
|
||
type=notification.type,
|
||
is_read=notification.is_read,
|
||
related_id=notification.related_id,
|
||
related_type=notification.related_type,
|
||
sender_id=notification.sender_id,
|
||
sender_name=sender_names.get(notification.sender_id) if notification.sender_id else None,
|
||
created_at=notification.created_at,
|
||
updated_at=notification.updated_at
|
||
)
|
||
responses.append(response)
|
||
|
||
return responses, total, unread_count
|
||
|
||
async def get_unread_count(
|
||
self,
|
||
db: AsyncSession,
|
||
user_id: int
|
||
) -> Tuple[int, int]:
|
||
"""
|
||
获取用户未读通知数量
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
(未读数, 总数)
|
||
"""
|
||
# 统计未读数
|
||
unread_stmt = (
|
||
select(func.count())
|
||
.select_from(Notification)
|
||
.where(and_(Notification.user_id == user_id, Notification.is_read == False))
|
||
)
|
||
unread_result = await db.execute(unread_stmt)
|
||
unread_count = unread_result.scalar_one()
|
||
|
||
# 统计总数
|
||
total_stmt = (
|
||
select(func.count())
|
||
.select_from(Notification)
|
||
.where(Notification.user_id == user_id)
|
||
)
|
||
total_result = await db.execute(total_stmt)
|
||
total = total_result.scalar_one()
|
||
|
||
return unread_count, total
|
||
|
||
async def mark_as_read(
|
||
self,
|
||
db: AsyncSession,
|
||
user_id: int,
|
||
notification_ids: Optional[List[int]] = None
|
||
) -> int:
|
||
"""
|
||
标记通知为已读
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
user_id: 用户ID
|
||
notification_ids: 通知ID列表,为空则标记全部
|
||
|
||
Returns:
|
||
更新的数量
|
||
"""
|
||
conditions = [
|
||
Notification.user_id == user_id,
|
||
Notification.is_read == False
|
||
]
|
||
|
||
if notification_ids:
|
||
conditions.append(Notification.id.in_(notification_ids))
|
||
|
||
stmt = (
|
||
update(Notification)
|
||
.where(and_(*conditions))
|
||
.values(is_read=True)
|
||
)
|
||
|
||
result = await db.execute(stmt)
|
||
await db.commit()
|
||
|
||
updated_count = result.rowcount
|
||
|
||
logger.info(
|
||
"标记通知已读",
|
||
user_id=user_id,
|
||
notification_ids=notification_ids,
|
||
updated_count=updated_count
|
||
)
|
||
|
||
return updated_count
|
||
|
||
async def delete_notification(
|
||
self,
|
||
db: AsyncSession,
|
||
user_id: int,
|
||
notification_id: int
|
||
) -> bool:
|
||
"""
|
||
删除通知
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
user_id: 用户ID
|
||
notification_id: 通知ID
|
||
|
||
Returns:
|
||
是否删除成功
|
||
"""
|
||
stmt = select(Notification).where(
|
||
and_(
|
||
Notification.id == notification_id,
|
||
Notification.user_id == user_id
|
||
)
|
||
)
|
||
|
||
result = await db.execute(stmt)
|
||
notification = result.scalar_one_or_none()
|
||
|
||
if notification:
|
||
await db.delete(notification)
|
||
await db.commit()
|
||
|
||
logger.info(
|
||
"删除通知成功",
|
||
notification_id=notification_id,
|
||
user_id=user_id
|
||
)
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
# 创建服务实例
|
||
notification_service = NotificationService()
|
||
|