All checks were successful
continuous-integration/drone/push Build is passing
添加 NotificationServiceAdapter 适配器类,兼容 API 层调用方式 导出 notification_service 单例实例
603 lines
19 KiB
Python
603 lines
19 KiB
Python
"""
|
|
通知推送服务
|
|
支持钉钉、企业微信、站内消息等多种渠道
|
|
"""
|
|
import os
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, List, Dict, Any
|
|
import httpx
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_
|
|
|
|
from app.models.user import User
|
|
from app.models.notification import Notification
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NotificationChannel:
|
|
"""通知渠道基类"""
|
|
|
|
async def send(
|
|
self,
|
|
user_id: int,
|
|
title: str,
|
|
content: str,
|
|
**kwargs
|
|
) -> bool:
|
|
"""
|
|
发送通知
|
|
|
|
Args:
|
|
user_id: 用户ID
|
|
title: 通知标题
|
|
content: 通知内容
|
|
|
|
Returns:
|
|
是否发送成功
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class DingtalkChannel(NotificationChannel):
|
|
"""
|
|
钉钉通知渠道
|
|
|
|
使用钉钉工作通知 API 发送消息
|
|
文档: https://open.dingtalk.com/document/orgapp/asynchronous-sending-of-enterprise-session-messages
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
app_key: Optional[str] = None,
|
|
app_secret: Optional[str] = None,
|
|
agent_id: Optional[str] = None,
|
|
):
|
|
self.app_key = app_key or os.getenv("DINGTALK_APP_KEY")
|
|
self.app_secret = app_secret or os.getenv("DINGTALK_APP_SECRET")
|
|
self.agent_id = agent_id or os.getenv("DINGTALK_AGENT_ID")
|
|
self._access_token = None
|
|
self._token_expires_at = None
|
|
|
|
async def _get_access_token(self) -> str:
|
|
"""获取钉钉访问令牌"""
|
|
if (
|
|
self._access_token
|
|
and self._token_expires_at
|
|
and datetime.now() < self._token_expires_at
|
|
):
|
|
return self._access_token
|
|
|
|
url = "https://oapi.dingtalk.com/gettoken"
|
|
params = {
|
|
"appkey": self.app_key,
|
|
"appsecret": self.app_secret,
|
|
}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(url, params=params, timeout=10.0)
|
|
result = response.json()
|
|
|
|
if result.get("errcode") == 0:
|
|
self._access_token = result["access_token"]
|
|
self._token_expires_at = datetime.now() + timedelta(seconds=7000)
|
|
return self._access_token
|
|
else:
|
|
raise Exception(f"获取钉钉Token失败: {result.get('errmsg')}")
|
|
|
|
async def send(
|
|
self,
|
|
user_id: int,
|
|
title: str,
|
|
content: str,
|
|
dingtalk_user_id: Optional[str] = None,
|
|
**kwargs
|
|
) -> bool:
|
|
"""发送钉钉工作通知"""
|
|
if not all([self.app_key, self.app_secret, self.agent_id]):
|
|
logger.warning("钉钉配置不完整,跳过发送")
|
|
return False
|
|
|
|
if not dingtalk_user_id:
|
|
logger.warning(f"用户 {user_id} 没有绑定钉钉ID")
|
|
return False
|
|
|
|
try:
|
|
access_token = await self._get_access_token()
|
|
|
|
url = f"https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2?access_token={access_token}"
|
|
|
|
# 构建消息体
|
|
msg = {
|
|
"agent_id": self.agent_id,
|
|
"userid_list": dingtalk_user_id,
|
|
"msg": {
|
|
"msgtype": "text",
|
|
"text": {
|
|
"content": f"{title}\n\n{content}"
|
|
}
|
|
}
|
|
}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(url, json=msg, timeout=10.0)
|
|
result = response.json()
|
|
|
|
if result.get("errcode") == 0:
|
|
logger.info(f"钉钉消息发送成功: user_id={user_id}")
|
|
return True
|
|
else:
|
|
logger.error(f"钉钉消息发送失败: {result.get('errmsg')}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"钉钉消息发送异常: {str(e)}")
|
|
return False
|
|
|
|
|
|
class WeworkChannel(NotificationChannel):
|
|
"""
|
|
企业微信通知渠道
|
|
|
|
使用企业微信应用消息 API
|
|
文档: https://developer.work.weixin.qq.com/document/path/90236
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
corp_id: Optional[str] = None,
|
|
corp_secret: Optional[str] = None,
|
|
agent_id: Optional[str] = None,
|
|
):
|
|
self.corp_id = corp_id or os.getenv("WEWORK_CORP_ID")
|
|
self.corp_secret = corp_secret or os.getenv("WEWORK_CORP_SECRET")
|
|
self.agent_id = agent_id or os.getenv("WEWORK_AGENT_ID")
|
|
self._access_token = None
|
|
self._token_expires_at = None
|
|
|
|
async def _get_access_token(self) -> str:
|
|
"""获取企业微信访问令牌"""
|
|
if (
|
|
self._access_token
|
|
and self._token_expires_at
|
|
and datetime.now() < self._token_expires_at
|
|
):
|
|
return self._access_token
|
|
|
|
url = "https://qyapi.weixin.qq.com/cgi-bin/gettoken"
|
|
params = {
|
|
"corpid": self.corp_id,
|
|
"corpsecret": self.corp_secret,
|
|
}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(url, params=params, timeout=10.0)
|
|
result = response.json()
|
|
|
|
if result.get("errcode") == 0:
|
|
self._access_token = result["access_token"]
|
|
self._token_expires_at = datetime.now() + timedelta(seconds=7000)
|
|
return self._access_token
|
|
else:
|
|
raise Exception(f"获取企微Token失败: {result.get('errmsg')}")
|
|
|
|
async def send(
|
|
self,
|
|
user_id: int,
|
|
title: str,
|
|
content: str,
|
|
wework_user_id: Optional[str] = None,
|
|
**kwargs
|
|
) -> bool:
|
|
"""发送企业微信应用消息"""
|
|
if not all([self.corp_id, self.corp_secret, self.agent_id]):
|
|
logger.warning("企业微信配置不完整,跳过发送")
|
|
return False
|
|
|
|
if not wework_user_id:
|
|
logger.warning(f"用户 {user_id} 没有绑定企业微信ID")
|
|
return False
|
|
|
|
try:
|
|
access_token = await self._get_access_token()
|
|
|
|
url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}"
|
|
|
|
# 构建消息体
|
|
msg = {
|
|
"touser": wework_user_id,
|
|
"msgtype": "text",
|
|
"agentid": int(self.agent_id),
|
|
"text": {
|
|
"content": f"{title}\n\n{content}"
|
|
}
|
|
}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(url, json=msg, timeout=10.0)
|
|
result = response.json()
|
|
|
|
if result.get("errcode") == 0:
|
|
logger.info(f"企微消息发送成功: user_id={user_id}")
|
|
return True
|
|
else:
|
|
logger.error(f"企微消息发送失败: {result.get('errmsg')}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"企微消息发送异常: {str(e)}")
|
|
return False
|
|
|
|
|
|
class InAppChannel(NotificationChannel):
|
|
"""站内消息通道"""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
|
|
async def send(
|
|
self,
|
|
user_id: int,
|
|
title: str,
|
|
content: str,
|
|
notification_type: str = "system",
|
|
**kwargs
|
|
) -> bool:
|
|
"""创建站内消息"""
|
|
try:
|
|
notification = Notification(
|
|
user_id=user_id,
|
|
title=title,
|
|
content=content,
|
|
type=notification_type,
|
|
is_read=False,
|
|
)
|
|
self.db.add(notification)
|
|
await self.db.commit()
|
|
logger.info(f"站内消息创建成功: user_id={user_id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"站内消息创建失败: {str(e)}")
|
|
return False
|
|
|
|
|
|
class NotificationService:
|
|
"""
|
|
通知服务
|
|
|
|
统一管理多渠道通知发送
|
|
"""
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
self.db = db
|
|
self.channels = {
|
|
"dingtalk": DingtalkChannel(),
|
|
"wework": WeworkChannel(),
|
|
"inapp": InAppChannel(db),
|
|
}
|
|
|
|
async def send_notification(
|
|
self,
|
|
user_id: int,
|
|
title: str,
|
|
content: str,
|
|
channels: Optional[List[str]] = None,
|
|
**kwargs
|
|
) -> Dict[str, bool]:
|
|
"""
|
|
发送通知
|
|
|
|
Args:
|
|
user_id: 用户ID
|
|
title: 通知标题
|
|
content: 通知内容
|
|
channels: 发送渠道列表,默认全部发送
|
|
|
|
Returns:
|
|
各渠道发送结果
|
|
"""
|
|
# 获取用户信息
|
|
user = await self._get_user(user_id)
|
|
if not user:
|
|
return {"error": "用户不存在"}
|
|
|
|
# 准备用户渠道标识
|
|
user_channels = {
|
|
"dingtalk_user_id": getattr(user, "dingtalk_id", None),
|
|
"wework_user_id": getattr(user, "wework_userid", None),
|
|
}
|
|
|
|
# 确定发送渠道
|
|
target_channels = channels or ["inapp"] # 默认只发站内消息
|
|
|
|
results = {}
|
|
for channel_name in target_channels:
|
|
if channel_name in self.channels:
|
|
channel = self.channels[channel_name]
|
|
success = await channel.send(
|
|
user_id=user_id,
|
|
title=title,
|
|
content=content,
|
|
**user_channels,
|
|
**kwargs
|
|
)
|
|
results[channel_name] = success
|
|
|
|
return results
|
|
|
|
async def send_learning_reminder(
|
|
self,
|
|
user_id: int,
|
|
course_name: str,
|
|
days_inactive: int = 3,
|
|
) -> Dict[str, bool]:
|
|
"""发送学习提醒"""
|
|
title = "📚 学习提醒"
|
|
content = f"您已有 {days_inactive} 天没有学习《{course_name}》课程了,快来继续学习吧!"
|
|
|
|
return await self.send_notification(
|
|
user_id=user_id,
|
|
title=title,
|
|
content=content,
|
|
channels=["inapp", "dingtalk", "wework"],
|
|
notification_type="learning_reminder",
|
|
)
|
|
|
|
async def send_task_deadline_reminder(
|
|
self,
|
|
user_id: int,
|
|
task_name: str,
|
|
deadline: datetime,
|
|
) -> Dict[str, bool]:
|
|
"""发送任务截止提醒"""
|
|
days_left = (deadline - datetime.now()).days
|
|
title = "⏰ 任务截止提醒"
|
|
content = f"任务《{task_name}》将于 {deadline.strftime('%Y-%m-%d %H:%M')} 截止,还有 {days_left} 天,请尽快完成!"
|
|
|
|
return await self.send_notification(
|
|
user_id=user_id,
|
|
title=title,
|
|
content=content,
|
|
channels=["inapp", "dingtalk", "wework"],
|
|
notification_type="task_deadline",
|
|
)
|
|
|
|
async def send_exam_reminder(
|
|
self,
|
|
user_id: int,
|
|
exam_name: str,
|
|
exam_time: datetime,
|
|
) -> Dict[str, bool]:
|
|
"""发送考试提醒"""
|
|
title = "📝 考试提醒"
|
|
content = f"考试《{exam_name}》将于 {exam_time.strftime('%Y-%m-%d %H:%M')} 开始,请提前做好准备!"
|
|
|
|
return await self.send_notification(
|
|
user_id=user_id,
|
|
title=title,
|
|
content=content,
|
|
channels=["inapp", "dingtalk", "wework"],
|
|
notification_type="exam_reminder",
|
|
)
|
|
|
|
async def send_weekly_report(
|
|
self,
|
|
user_id: int,
|
|
study_time: int,
|
|
courses_completed: int,
|
|
exams_passed: int,
|
|
) -> Dict[str, bool]:
|
|
"""发送周学习报告"""
|
|
title = "📊 本周学习报告"
|
|
content = (
|
|
f"本周学习总结:\n"
|
|
f"• 学习时长:{study_time // 60} 分钟\n"
|
|
f"• 完成课程:{courses_completed} 门\n"
|
|
f"• 通过考试:{exams_passed} 次\n\n"
|
|
f"继续加油!💪"
|
|
)
|
|
|
|
return await self.send_notification(
|
|
user_id=user_id,
|
|
title=title,
|
|
content=content,
|
|
channels=["inapp", "dingtalk", "wework"],
|
|
notification_type="weekly_report",
|
|
)
|
|
|
|
async def _get_user(self, user_id: int) -> Optional[User]:
|
|
"""获取用户信息"""
|
|
result = await self.db.execute(
|
|
select(User).where(User.id == user_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
# 便捷函数
|
|
def get_notification_service(db: AsyncSession) -> NotificationService:
|
|
"""获取通知服务实例"""
|
|
return NotificationService(db)
|
|
|
|
|
|
class NotificationServiceAdapter:
|
|
"""
|
|
通知服务适配器
|
|
|
|
提供静态方法接口,兼容 API 层的调用方式
|
|
每个方法接受 db 参数,内部创建 NotificationService 实例
|
|
"""
|
|
|
|
@staticmethod
|
|
async def get_user_notifications(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
skip: int = 0,
|
|
limit: int = 20,
|
|
is_read: Optional[bool] = None,
|
|
notification_type: Optional[str] = None,
|
|
) -> tuple[List[Notification], int, int]:
|
|
"""获取用户通知列表"""
|
|
from sqlalchemy import func
|
|
|
|
# 构建查询条件
|
|
conditions = [Notification.user_id == user_id]
|
|
if is_read is not None:
|
|
conditions.append(Notification.is_read == is_read)
|
|
if notification_type:
|
|
conditions.append(Notification.type == notification_type)
|
|
|
|
# 查询总数
|
|
total_stmt = select(func.count(Notification.id)).where(and_(*conditions))
|
|
total = await db.scalar(total_stmt) or 0
|
|
|
|
# 查询未读数
|
|
unread_stmt = select(func.count(Notification.id)).where(
|
|
and_(Notification.user_id == user_id, Notification.is_read == False)
|
|
)
|
|
unread_count = await db.scalar(unread_stmt) or 0
|
|
|
|
# 查询列表
|
|
list_stmt = (
|
|
select(Notification)
|
|
.where(and_(*conditions))
|
|
.order_by(Notification.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
result = await db.execute(list_stmt)
|
|
notifications = list(result.scalars().all())
|
|
|
|
return notifications, total, unread_count
|
|
|
|
@staticmethod
|
|
async def get_unread_count(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
) -> tuple[int, int]:
|
|
"""获取未读通知数量"""
|
|
from sqlalchemy import func
|
|
|
|
# 未读数
|
|
unread_stmt = select(func.count(Notification.id)).where(
|
|
and_(Notification.user_id == user_id, Notification.is_read == False)
|
|
)
|
|
unread_count = await db.scalar(unread_stmt) or 0
|
|
|
|
# 总数
|
|
total_stmt = select(func.count(Notification.id)).where(
|
|
Notification.user_id == user_id
|
|
)
|
|
total = await db.scalar(total_stmt) or 0
|
|
|
|
return unread_count, total
|
|
|
|
@staticmethod
|
|
async def mark_as_read(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
notification_ids: Optional[List[int]] = None,
|
|
) -> int:
|
|
"""标记通知为已读"""
|
|
from sqlalchemy import update
|
|
|
|
if notification_ids:
|
|
# 标记指定通知
|
|
stmt = (
|
|
update(Notification)
|
|
.where(
|
|
and_(
|
|
Notification.user_id == user_id,
|
|
Notification.id.in_(notification_ids),
|
|
Notification.is_read == False
|
|
)
|
|
)
|
|
.values(is_read=True, updated_at=datetime.now())
|
|
)
|
|
else:
|
|
# 标记所有未读
|
|
stmt = (
|
|
update(Notification)
|
|
.where(
|
|
and_(
|
|
Notification.user_id == user_id,
|
|
Notification.is_read == False
|
|
)
|
|
)
|
|
.values(is_read=True, updated_at=datetime.now())
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
return result.rowcount
|
|
|
|
@staticmethod
|
|
async def delete_notification(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
notification_id: int,
|
|
) -> bool:
|
|
"""删除通知"""
|
|
from sqlalchemy import delete
|
|
|
|
stmt = delete(Notification).where(
|
|
and_(
|
|
Notification.id == notification_id,
|
|
Notification.user_id == user_id
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
return result.rowcount > 0
|
|
|
|
@staticmethod
|
|
async def create_notification(
|
|
db: AsyncSession,
|
|
notification_in: Any,
|
|
) -> Notification:
|
|
"""创建单条通知"""
|
|
notification = Notification(
|
|
user_id=notification_in.user_id,
|
|
title=notification_in.title,
|
|
content=notification_in.content,
|
|
type=getattr(notification_in, 'type', 'system'),
|
|
related_id=getattr(notification_in, 'related_id', None),
|
|
related_type=getattr(notification_in, 'related_type', None),
|
|
sender_id=getattr(notification_in, 'sender_id', None),
|
|
is_read=False,
|
|
)
|
|
db.add(notification)
|
|
await db.commit()
|
|
await db.refresh(notification)
|
|
return notification
|
|
|
|
@staticmethod
|
|
async def batch_create_notifications(
|
|
db: AsyncSession,
|
|
batch_in: Any,
|
|
) -> List[Notification]:
|
|
"""批量创建通知"""
|
|
notifications = []
|
|
for user_id in batch_in.user_ids:
|
|
notification = Notification(
|
|
user_id=user_id,
|
|
title=batch_in.title,
|
|
content=batch_in.content,
|
|
type=getattr(batch_in, 'type', 'system'),
|
|
related_id=getattr(batch_in, 'related_id', None),
|
|
related_type=getattr(batch_in, 'related_type', None),
|
|
sender_id=getattr(batch_in, 'sender_id', None),
|
|
is_read=False,
|
|
)
|
|
db.add(notification)
|
|
notifications.append(notification)
|
|
|
|
await db.commit()
|
|
for n in notifications:
|
|
await db.refresh(n)
|
|
|
|
return notifications
|
|
|
|
|
|
# 全局单例适配器
|
|
notification_service = NotificationServiceAdapter()
|