""" 通知推送服务 支持钉钉、企业微信、站内消息等多种渠道 """ import os import json import logging from datetime import datetime, timedelta from typing import Optional, List, Dict, Any import httpx from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_ from app.models.user import User from app.models.notification import Notification logger = logging.getLogger(__name__) class NotificationChannel: """通知渠道基类""" async def send( self, user_id: int, title: str, content: str, **kwargs ) -> bool: """ 发送通知 Args: user_id: 用户ID title: 通知标题 content: 通知内容 Returns: 是否发送成功 """ raise NotImplementedError class DingtalkChannel(NotificationChannel): """ 钉钉通知渠道 使用钉钉工作通知 API 发送消息 文档: https://open.dingtalk.com/document/orgapp/asynchronous-sending-of-enterprise-session-messages """ def __init__( self, app_key: Optional[str] = None, app_secret: Optional[str] = None, agent_id: Optional[str] = None, ): self.app_key = app_key or os.getenv("DINGTALK_APP_KEY") self.app_secret = app_secret or os.getenv("DINGTALK_APP_SECRET") self.agent_id = agent_id or os.getenv("DINGTALK_AGENT_ID") self._access_token = None self._token_expires_at = None async def _get_access_token(self) -> str: """获取钉钉访问令牌""" if ( self._access_token and self._token_expires_at and datetime.now() < self._token_expires_at ): return self._access_token url = "https://oapi.dingtalk.com/gettoken" params = { "appkey": self.app_key, "appsecret": self.app_secret, } async with httpx.AsyncClient() as client: response = await client.get(url, params=params, timeout=10.0) result = response.json() if result.get("errcode") == 0: self._access_token = result["access_token"] self._token_expires_at = datetime.now() + timedelta(seconds=7000) return self._access_token else: raise Exception(f"获取钉钉Token失败: {result.get('errmsg')}") async def send( self, user_id: int, title: str, content: str, dingtalk_user_id: Optional[str] = None, **kwargs ) -> bool: """发送钉钉工作通知""" if not all([self.app_key, self.app_secret, self.agent_id]): logger.warning("钉钉配置不完整,跳过发送") return False if not dingtalk_user_id: logger.warning(f"用户 {user_id} 没有绑定钉钉ID") return False try: access_token = await self._get_access_token() url = f"https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2?access_token={access_token}" # 构建消息体 msg = { "agent_id": self.agent_id, "userid_list": dingtalk_user_id, "msg": { "msgtype": "text", "text": { "content": f"{title}\n\n{content}" } } } async with httpx.AsyncClient() as client: response = await client.post(url, json=msg, timeout=10.0) result = response.json() if result.get("errcode") == 0: logger.info(f"钉钉消息发送成功: user_id={user_id}") return True else: logger.error(f"钉钉消息发送失败: {result.get('errmsg')}") return False except Exception as e: logger.error(f"钉钉消息发送异常: {str(e)}") return False class WeworkChannel(NotificationChannel): """ 企业微信通知渠道 使用企业微信应用消息 API 文档: https://developer.work.weixin.qq.com/document/path/90236 """ def __init__( self, corp_id: Optional[str] = None, corp_secret: Optional[str] = None, agent_id: Optional[str] = None, ): self.corp_id = corp_id or os.getenv("WEWORK_CORP_ID") self.corp_secret = corp_secret or os.getenv("WEWORK_CORP_SECRET") self.agent_id = agent_id or os.getenv("WEWORK_AGENT_ID") self._access_token = None self._token_expires_at = None async def _get_access_token(self) -> str: """获取企业微信访问令牌""" if ( self._access_token and self._token_expires_at and datetime.now() < self._token_expires_at ): return self._access_token url = "https://qyapi.weixin.qq.com/cgi-bin/gettoken" params = { "corpid": self.corp_id, "corpsecret": self.corp_secret, } async with httpx.AsyncClient() as client: response = await client.get(url, params=params, timeout=10.0) result = response.json() if result.get("errcode") == 0: self._access_token = result["access_token"] self._token_expires_at = datetime.now() + timedelta(seconds=7000) return self._access_token else: raise Exception(f"获取企微Token失败: {result.get('errmsg')}") async def send( self, user_id: int, title: str, content: str, wework_user_id: Optional[str] = None, **kwargs ) -> bool: """发送企业微信应用消息""" if not all([self.corp_id, self.corp_secret, self.agent_id]): logger.warning("企业微信配置不完整,跳过发送") return False if not wework_user_id: logger.warning(f"用户 {user_id} 没有绑定企业微信ID") return False try: access_token = await self._get_access_token() url = f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={access_token}" # 构建消息体 msg = { "touser": wework_user_id, "msgtype": "text", "agentid": int(self.agent_id), "text": { "content": f"{title}\n\n{content}" } } async with httpx.AsyncClient() as client: response = await client.post(url, json=msg, timeout=10.0) result = response.json() if result.get("errcode") == 0: logger.info(f"企微消息发送成功: user_id={user_id}") return True else: logger.error(f"企微消息发送失败: {result.get('errmsg')}") return False except Exception as e: logger.error(f"企微消息发送异常: {str(e)}") return False class InAppChannel(NotificationChannel): """站内消息通道""" def __init__(self, db: AsyncSession): self.db = db async def send( self, user_id: int, title: str, content: str, notification_type: str = "system", **kwargs ) -> bool: """创建站内消息""" try: notification = Notification( user_id=user_id, title=title, content=content, type=notification_type, is_read=False, ) self.db.add(notification) await self.db.commit() logger.info(f"站内消息创建成功: user_id={user_id}") return True except Exception as e: logger.error(f"站内消息创建失败: {str(e)}") return False class NotificationService: """ 通知服务 统一管理多渠道通知发送 """ def __init__(self, db: AsyncSession): self.db = db self.channels = { "dingtalk": DingtalkChannel(), "wework": WeworkChannel(), "inapp": InAppChannel(db), } async def send_notification( self, user_id: int, title: str, content: str, channels: Optional[List[str]] = None, **kwargs ) -> Dict[str, bool]: """ 发送通知 Args: user_id: 用户ID title: 通知标题 content: 通知内容 channels: 发送渠道列表,默认全部发送 Returns: 各渠道发送结果 """ # 获取用户信息 user = await self._get_user(user_id) if not user: return {"error": "用户不存在"} # 准备用户渠道标识 user_channels = { "dingtalk_user_id": getattr(user, "dingtalk_id", None), "wework_user_id": getattr(user, "wework_userid", None), } # 确定发送渠道 target_channels = channels or ["inapp"] # 默认只发站内消息 results = {} for channel_name in target_channels: if channel_name in self.channels: channel = self.channels[channel_name] success = await channel.send( user_id=user_id, title=title, content=content, **user_channels, **kwargs ) results[channel_name] = success return results async def send_learning_reminder( self, user_id: int, course_name: str, days_inactive: int = 3, ) -> Dict[str, bool]: """发送学习提醒""" title = "📚 学习提醒" content = f"您已有 {days_inactive} 天没有学习《{course_name}》课程了,快来继续学习吧!" return await self.send_notification( user_id=user_id, title=title, content=content, channels=["inapp", "dingtalk", "wework"], notification_type="learning_reminder", ) async def send_task_deadline_reminder( self, user_id: int, task_name: str, deadline: datetime, ) -> Dict[str, bool]: """发送任务截止提醒""" days_left = (deadline - datetime.now()).days title = "⏰ 任务截止提醒" content = f"任务《{task_name}》将于 {deadline.strftime('%Y-%m-%d %H:%M')} 截止,还有 {days_left} 天,请尽快完成!" return await self.send_notification( user_id=user_id, title=title, content=content, channels=["inapp", "dingtalk", "wework"], notification_type="task_deadline", ) async def send_exam_reminder( self, user_id: int, exam_name: str, exam_time: datetime, ) -> Dict[str, bool]: """发送考试提醒""" title = "📝 考试提醒" content = f"考试《{exam_name}》将于 {exam_time.strftime('%Y-%m-%d %H:%M')} 开始,请提前做好准备!" return await self.send_notification( user_id=user_id, title=title, content=content, channels=["inapp", "dingtalk", "wework"], notification_type="exam_reminder", ) async def send_weekly_report( self, user_id: int, study_time: int, courses_completed: int, exams_passed: int, ) -> Dict[str, bool]: """发送周学习报告""" title = "📊 本周学习报告" content = ( f"本周学习总结:\n" f"• 学习时长:{study_time // 60} 分钟\n" f"• 完成课程:{courses_completed} 门\n" f"• 通过考试:{exams_passed} 次\n\n" f"继续加油!💪" ) return await self.send_notification( user_id=user_id, title=title, content=content, channels=["inapp", "dingtalk", "wework"], notification_type="weekly_report", ) async def _get_user(self, user_id: int) -> Optional[User]: """获取用户信息""" result = await self.db.execute( select(User).where(User.id == user_id) ) return result.scalar_one_or_none() # 便捷函数 def get_notification_service(db: AsyncSession) -> NotificationService: """获取通知服务实例""" return NotificationService(db) class NotificationServiceAdapter: """ 通知服务适配器 提供静态方法接口,兼容 API 层的调用方式 每个方法接受 db 参数,内部创建 NotificationService 实例 """ @staticmethod async def get_user_notifications( db: AsyncSession, user_id: int, skip: int = 0, limit: int = 20, is_read: Optional[bool] = None, notification_type: Optional[str] = None, ) -> tuple[List[Notification], int, int]: """获取用户通知列表""" from sqlalchemy import func # 构建查询条件 conditions = [Notification.user_id == user_id] if is_read is not None: conditions.append(Notification.is_read == is_read) if notification_type: conditions.append(Notification.type == notification_type) # 查询总数 total_stmt = select(func.count(Notification.id)).where(and_(*conditions)) total = await db.scalar(total_stmt) or 0 # 查询未读数 unread_stmt = select(func.count(Notification.id)).where( and_(Notification.user_id == user_id, Notification.is_read == False) ) unread_count = await db.scalar(unread_stmt) or 0 # 查询列表 list_stmt = ( select(Notification) .where(and_(*conditions)) .order_by(Notification.created_at.desc()) .offset(skip) .limit(limit) ) result = await db.execute(list_stmt) notifications = list(result.scalars().all()) return notifications, total, unread_count @staticmethod async def get_unread_count( db: AsyncSession, user_id: int, ) -> tuple[int, int]: """获取未读通知数量""" from sqlalchemy import func # 未读数 unread_stmt = select(func.count(Notification.id)).where( and_(Notification.user_id == user_id, Notification.is_read == False) ) unread_count = await db.scalar(unread_stmt) or 0 # 总数 total_stmt = select(func.count(Notification.id)).where( Notification.user_id == user_id ) total = await db.scalar(total_stmt) or 0 return unread_count, total @staticmethod async def mark_as_read( db: AsyncSession, user_id: int, notification_ids: Optional[List[int]] = None, ) -> int: """标记通知为已读""" from sqlalchemy import update if notification_ids: # 标记指定通知 stmt = ( update(Notification) .where( and_( Notification.user_id == user_id, Notification.id.in_(notification_ids), Notification.is_read == False ) ) .values(is_read=True, updated_at=datetime.now()) ) else: # 标记所有未读 stmt = ( update(Notification) .where( and_( Notification.user_id == user_id, Notification.is_read == False ) ) .values(is_read=True, updated_at=datetime.now()) ) result = await db.execute(stmt) await db.commit() return result.rowcount @staticmethod async def delete_notification( db: AsyncSession, user_id: int, notification_id: int, ) -> bool: """删除通知""" from sqlalchemy import delete stmt = delete(Notification).where( and_( Notification.id == notification_id, Notification.user_id == user_id ) ) result = await db.execute(stmt) await db.commit() return result.rowcount > 0 @staticmethod async def create_notification( db: AsyncSession, notification_in: Any, ) -> Notification: """创建单条通知""" notification = Notification( user_id=notification_in.user_id, title=notification_in.title, content=notification_in.content, type=getattr(notification_in, 'type', 'system'), related_id=getattr(notification_in, 'related_id', None), related_type=getattr(notification_in, 'related_type', None), sender_id=getattr(notification_in, 'sender_id', None), is_read=False, ) db.add(notification) await db.commit() await db.refresh(notification) return notification @staticmethod async def batch_create_notifications( db: AsyncSession, batch_in: Any, ) -> List[Notification]: """批量创建通知""" notifications = [] for user_id in batch_in.user_ids: notification = Notification( user_id=user_id, title=batch_in.title, content=batch_in.content, type=getattr(batch_in, 'type', 'system'), related_id=getattr(batch_in, 'related_id', None), related_type=getattr(batch_in, 'related_type', None), sender_id=getattr(batch_in, 'sender_id', None), is_read=False, ) db.add(notification) notifications.append(notification) await db.commit() for n in notifications: await db.refresh(n) return notifications # 全局单例适配器 notification_service = NotificationServiceAdapter()