feat: 新增告警、成本、配额、微信模块及缓存服务
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
- 新增告警模块 (alerts): 告警规则配置与触发 - 新增成本管理模块 (cost): 成本统计与分析 - 新增配额模块 (quota): 配额管理与限制 - 新增微信模块 (wechat): 微信相关功能接口 - 新增缓存服务 (cache): Redis 缓存封装 - 新增请求日志中间件 (request_logger) - 新增异常处理和链路追踪中间件 - 更新 dashboard 前端展示 - 更新 SDK stats_client 功能
This commit is contained in:
@@ -1,4 +1,22 @@
|
||||
"""业务服务"""
|
||||
from .crypto import encrypt_value, decrypt_value
|
||||
from .cache import CacheService, get_cache, get_redis_client
|
||||
from .wechat import WechatService, get_wechat_service_by_id
|
||||
from .alert import AlertService
|
||||
from .cost import CostCalculator, calculate_cost
|
||||
from .quota import QuotaService, check_quota_middleware
|
||||
|
||||
__all__ = ["encrypt_value", "decrypt_value"]
|
||||
__all__ = [
|
||||
"encrypt_value",
|
||||
"decrypt_value",
|
||||
"CacheService",
|
||||
"get_cache",
|
||||
"get_redis_client",
|
||||
"WechatService",
|
||||
"get_wechat_service_by_id",
|
||||
"AlertService",
|
||||
"CostCalculator",
|
||||
"calculate_cost",
|
||||
"QuotaService",
|
||||
"check_quota_middleware"
|
||||
]
|
||||
|
||||
455
backend/app/services/alert.py
Normal file
455
backend/app/services/alert.py
Normal file
@@ -0,0 +1,455 @@
|
||||
"""告警服务"""
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from ..models.alert import AlertRule, AlertRecord, NotificationChannel
|
||||
from ..models.stats import AICallEvent
|
||||
from ..models.logs import PlatformLog
|
||||
from .cache import get_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlertService:
|
||||
"""告警服务
|
||||
|
||||
提供告警规则检测、告警记录管理、通知发送等功能
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._cache = get_cache()
|
||||
|
||||
async def check_all_rules(self) -> List[AlertRecord]:
|
||||
"""检查所有启用的告警规则
|
||||
|
||||
Returns:
|
||||
触发的告警记录列表
|
||||
"""
|
||||
rules = self.db.query(AlertRule).filter(AlertRule.status == 1).all()
|
||||
triggered_alerts = []
|
||||
|
||||
for rule in rules:
|
||||
try:
|
||||
alert = await self.check_rule(rule)
|
||||
if alert:
|
||||
triggered_alerts.append(alert)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check rule {rule.id}: {e}")
|
||||
|
||||
return triggered_alerts
|
||||
|
||||
async def check_rule(self, rule: AlertRule) -> Optional[AlertRecord]:
|
||||
"""检查单个告警规则
|
||||
|
||||
Args:
|
||||
rule: 告警规则
|
||||
|
||||
Returns:
|
||||
触发的告警记录或None
|
||||
"""
|
||||
# 检查冷却期
|
||||
if self._is_in_cooldown(rule):
|
||||
logger.debug(f"Rule {rule.id} is in cooldown")
|
||||
return None
|
||||
|
||||
# 检查每日告警次数限制
|
||||
if self._exceeds_daily_limit(rule):
|
||||
logger.debug(f"Rule {rule.id} exceeds daily limit")
|
||||
return None
|
||||
|
||||
# 根据规则类型检查
|
||||
metric_value = None
|
||||
threshold_value = None
|
||||
triggered = False
|
||||
|
||||
condition = rule.condition or {}
|
||||
|
||||
if rule.rule_type == 'error_rate':
|
||||
triggered, metric_value, threshold_value = self._check_error_rate(rule, condition)
|
||||
elif rule.rule_type == 'call_count':
|
||||
triggered, metric_value, threshold_value = self._check_call_count(rule, condition)
|
||||
elif rule.rule_type == 'token_usage':
|
||||
triggered, metric_value, threshold_value = self._check_token_usage(rule, condition)
|
||||
elif rule.rule_type == 'cost_threshold':
|
||||
triggered, metric_value, threshold_value = self._check_cost_threshold(rule, condition)
|
||||
elif rule.rule_type == 'latency':
|
||||
triggered, metric_value, threshold_value = self._check_latency(rule, condition)
|
||||
|
||||
if triggered:
|
||||
alert = self._create_alert_record(rule, metric_value, threshold_value)
|
||||
return alert
|
||||
|
||||
return None
|
||||
|
||||
def _is_in_cooldown(self, rule: AlertRule) -> bool:
|
||||
"""检查规则是否在冷却期"""
|
||||
cache_key = f"alert:cooldown:{rule.id}"
|
||||
return self._cache.exists(cache_key)
|
||||
|
||||
def _set_cooldown(self, rule: AlertRule):
|
||||
"""设置规则冷却期"""
|
||||
cache_key = f"alert:cooldown:{rule.id}"
|
||||
self._cache.set(cache_key, "1", ttl=rule.cooldown_minutes * 60)
|
||||
|
||||
def _exceeds_daily_limit(self, rule: AlertRule) -> bool:
|
||||
"""检查是否超过每日告警次数限制"""
|
||||
today = datetime.now().date()
|
||||
count = self.db.query(func.count(AlertRecord.id)).filter(
|
||||
AlertRecord.rule_id == rule.id,
|
||||
func.date(AlertRecord.created_at) == today
|
||||
).scalar()
|
||||
return count >= rule.max_alerts_per_day
|
||||
|
||||
def _check_error_rate(self, rule: AlertRule, condition: dict) -> tuple:
|
||||
"""检查错误率"""
|
||||
window_minutes = self._parse_window(condition.get('window', '5m'))
|
||||
threshold = condition.get('threshold', 10) # 错误次数阈值
|
||||
operator = condition.get('operator', '>')
|
||||
|
||||
since = datetime.now() - timedelta(minutes=window_minutes)
|
||||
|
||||
query = self.db.query(func.count(AICallEvent.id)).filter(
|
||||
AICallEvent.created_at >= since,
|
||||
AICallEvent.status == 'error'
|
||||
)
|
||||
|
||||
# 应用作用范围
|
||||
if rule.scope_type == 'tenant' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
|
||||
elif rule.scope_type == 'app' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.app_code == rule.scope_value)
|
||||
|
||||
error_count = query.scalar() or 0
|
||||
triggered = self._compare(error_count, threshold, operator)
|
||||
|
||||
return triggered, str(error_count), str(threshold)
|
||||
|
||||
def _check_call_count(self, rule: AlertRule, condition: dict) -> tuple:
|
||||
"""检查调用次数"""
|
||||
window_minutes = self._parse_window(condition.get('window', '1h'))
|
||||
threshold = condition.get('threshold', 1000)
|
||||
operator = condition.get('operator', '>')
|
||||
|
||||
since = datetime.now() - timedelta(minutes=window_minutes)
|
||||
|
||||
query = self.db.query(func.count(AICallEvent.id)).filter(
|
||||
AICallEvent.created_at >= since
|
||||
)
|
||||
|
||||
if rule.scope_type == 'tenant' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
|
||||
elif rule.scope_type == 'app' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.app_code == rule.scope_value)
|
||||
|
||||
call_count = query.scalar() or 0
|
||||
triggered = self._compare(call_count, threshold, operator)
|
||||
|
||||
return triggered, str(call_count), str(threshold)
|
||||
|
||||
def _check_token_usage(self, rule: AlertRule, condition: dict) -> tuple:
|
||||
"""检查Token使用量"""
|
||||
window_minutes = self._parse_window(condition.get('window', '1d'))
|
||||
threshold = condition.get('threshold', 100000)
|
||||
operator = condition.get('operator', '>')
|
||||
|
||||
since = datetime.now() - timedelta(minutes=window_minutes)
|
||||
|
||||
query = self.db.query(
|
||||
func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0)
|
||||
).filter(
|
||||
AICallEvent.created_at >= since
|
||||
)
|
||||
|
||||
if rule.scope_type == 'tenant' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
|
||||
elif rule.scope_type == 'app' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.app_code == rule.scope_value)
|
||||
|
||||
token_usage = query.scalar() or 0
|
||||
triggered = self._compare(token_usage, threshold, operator)
|
||||
|
||||
return triggered, str(token_usage), str(threshold)
|
||||
|
||||
def _check_cost_threshold(self, rule: AlertRule, condition: dict) -> tuple:
|
||||
"""检查费用阈值"""
|
||||
window_minutes = self._parse_window(condition.get('window', '1d'))
|
||||
threshold = condition.get('threshold', 100) # 费用阈值(元)
|
||||
operator = condition.get('operator', '>')
|
||||
|
||||
since = datetime.now() - timedelta(minutes=window_minutes)
|
||||
|
||||
query = self.db.query(
|
||||
func.coalesce(func.sum(AICallEvent.cost), 0)
|
||||
).filter(
|
||||
AICallEvent.created_at >= since
|
||||
)
|
||||
|
||||
if rule.scope_type == 'tenant' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
|
||||
elif rule.scope_type == 'app' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.app_code == rule.scope_value)
|
||||
|
||||
total_cost = float(query.scalar() or 0)
|
||||
triggered = self._compare(total_cost, threshold, operator)
|
||||
|
||||
return triggered, f"¥{total_cost:.2f}", f"¥{threshold:.2f}"
|
||||
|
||||
def _check_latency(self, rule: AlertRule, condition: dict) -> tuple:
|
||||
"""检查延迟"""
|
||||
window_minutes = self._parse_window(condition.get('window', '5m'))
|
||||
threshold = condition.get('threshold', 5000) # 延迟阈值(ms)
|
||||
operator = condition.get('operator', '>')
|
||||
percentile = condition.get('percentile', 'avg') # avg, p95, p99, max
|
||||
|
||||
since = datetime.now() - timedelta(minutes=window_minutes)
|
||||
|
||||
query = self.db.query(AICallEvent.latency_ms).filter(
|
||||
AICallEvent.created_at >= since,
|
||||
AICallEvent.latency_ms.isnot(None)
|
||||
)
|
||||
|
||||
if rule.scope_type == 'tenant' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
|
||||
elif rule.scope_type == 'app' and rule.scope_value:
|
||||
query = query.filter(AICallEvent.app_code == rule.scope_value)
|
||||
|
||||
latencies = [r.latency_ms for r in query.all()]
|
||||
|
||||
if not latencies:
|
||||
return False, "0", str(threshold)
|
||||
|
||||
if percentile == 'avg':
|
||||
metric = sum(latencies) / len(latencies)
|
||||
elif percentile == 'max':
|
||||
metric = max(latencies)
|
||||
elif percentile == 'p95':
|
||||
latencies.sort()
|
||||
idx = int(len(latencies) * 0.95)
|
||||
metric = latencies[idx] if idx < len(latencies) else latencies[-1]
|
||||
elif percentile == 'p99':
|
||||
latencies.sort()
|
||||
idx = int(len(latencies) * 0.99)
|
||||
metric = latencies[idx] if idx < len(latencies) else latencies[-1]
|
||||
else:
|
||||
metric = sum(latencies) / len(latencies)
|
||||
|
||||
triggered = self._compare(metric, threshold, operator)
|
||||
|
||||
return triggered, f"{metric:.0f}ms", f"{threshold}ms"
|
||||
|
||||
def _parse_window(self, window: str) -> int:
|
||||
"""解析时间窗口字符串为分钟数"""
|
||||
if window.endswith('m'):
|
||||
return int(window[:-1])
|
||||
elif window.endswith('h'):
|
||||
return int(window[:-1]) * 60
|
||||
elif window.endswith('d'):
|
||||
return int(window[:-1]) * 60 * 24
|
||||
else:
|
||||
return int(window)
|
||||
|
||||
def _compare(self, value: float, threshold: float, operator: str) -> bool:
|
||||
"""比较值与阈值"""
|
||||
if operator == '>':
|
||||
return value > threshold
|
||||
elif operator == '>=':
|
||||
return value >= threshold
|
||||
elif operator == '<':
|
||||
return value < threshold
|
||||
elif operator == '<=':
|
||||
return value <= threshold
|
||||
elif operator == '==':
|
||||
return value == threshold
|
||||
elif operator == '!=':
|
||||
return value != threshold
|
||||
return False
|
||||
|
||||
def _create_alert_record(
|
||||
self,
|
||||
rule: AlertRule,
|
||||
metric_value: str,
|
||||
threshold_value: str
|
||||
) -> AlertRecord:
|
||||
"""创建告警记录"""
|
||||
title = f"[{rule.priority.upper()}] {rule.name}"
|
||||
message = f"规则 '{rule.name}' 触发告警\n当前值: {metric_value}\n阈值: {threshold_value}"
|
||||
|
||||
if rule.scope_type == 'tenant':
|
||||
message += f"\n租户: {rule.scope_value}"
|
||||
elif rule.scope_type == 'app':
|
||||
message += f"\n应用: {rule.scope_value}"
|
||||
|
||||
alert = AlertRecord(
|
||||
rule_id=rule.id,
|
||||
rule_name=rule.name,
|
||||
alert_type=rule.rule_type,
|
||||
severity=self._priority_to_severity(rule.priority),
|
||||
title=title,
|
||||
message=message,
|
||||
tenant_id=rule.scope_value if rule.scope_type == 'tenant' else None,
|
||||
app_code=rule.scope_value if rule.scope_type == 'app' else None,
|
||||
metric_value=metric_value,
|
||||
threshold_value=threshold_value,
|
||||
notification_status='pending'
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
self.db.commit()
|
||||
self.db.refresh(alert)
|
||||
|
||||
# 设置冷却期
|
||||
self._set_cooldown(rule)
|
||||
|
||||
logger.info(f"Alert triggered: {title}")
|
||||
|
||||
return alert
|
||||
|
||||
def _priority_to_severity(self, priority: str) -> str:
|
||||
"""将优先级转换为严重程度"""
|
||||
mapping = {
|
||||
'low': 'info',
|
||||
'medium': 'warning',
|
||||
'high': 'error',
|
||||
'critical': 'critical'
|
||||
}
|
||||
return mapping.get(priority, 'warning')
|
||||
|
||||
async def send_notification(self, alert: AlertRecord, rule: AlertRule) -> bool:
|
||||
"""发送告警通知
|
||||
|
||||
Args:
|
||||
alert: 告警记录
|
||||
rule: 告警规则
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
if not rule.notification_channels:
|
||||
alert.notification_status = 'skipped'
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
results = []
|
||||
success = True
|
||||
|
||||
for channel_config in rule.notification_channels:
|
||||
try:
|
||||
result = await self._send_to_channel(channel_config, alert)
|
||||
results.append(result)
|
||||
if not result.get('success'):
|
||||
success = False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send notification: {e}")
|
||||
results.append({'success': False, 'error': str(e)})
|
||||
success = False
|
||||
|
||||
alert.notification_status = 'sent' if success else 'failed'
|
||||
alert.notification_result = results
|
||||
alert.notified_at = datetime.now()
|
||||
self.db.commit()
|
||||
|
||||
return success
|
||||
|
||||
async def _send_to_channel(self, channel_config: dict, alert: AlertRecord) -> dict:
|
||||
"""发送到指定渠道"""
|
||||
channel_type = channel_config.get('type')
|
||||
|
||||
if channel_type == 'wechat_bot':
|
||||
return await self._send_wechat_bot(channel_config, alert)
|
||||
elif channel_type == 'webhook':
|
||||
return await self._send_webhook(channel_config, alert)
|
||||
else:
|
||||
return {'success': False, 'error': f'Unsupported channel type: {channel_type}'}
|
||||
|
||||
async def _send_wechat_bot(self, config: dict, alert: AlertRecord) -> dict:
|
||||
"""发送到企微机器人"""
|
||||
webhook = config.get('webhook')
|
||||
if not webhook:
|
||||
return {'success': False, 'error': 'Missing webhook URL'}
|
||||
|
||||
# 构建消息
|
||||
content = f"**{alert.title}**\n\n{alert.message}\n\n时间: {alert.created_at}"
|
||||
|
||||
payload = {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {
|
||||
"content": content
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.post(webhook, json=payload)
|
||||
result = response.json()
|
||||
|
||||
if result.get('errcode', 0) == 0:
|
||||
return {'success': True, 'channel': 'wechat_bot'}
|
||||
else:
|
||||
return {'success': False, 'error': result.get('errmsg')}
|
||||
except Exception as e:
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _send_webhook(self, config: dict, alert: AlertRecord) -> dict:
|
||||
"""发送到Webhook"""
|
||||
url = config.get('url')
|
||||
if not url:
|
||||
return {'success': False, 'error': 'Missing webhook URL'}
|
||||
|
||||
payload = {
|
||||
"alert_id": alert.id,
|
||||
"title": alert.title,
|
||||
"message": alert.message,
|
||||
"severity": alert.severity,
|
||||
"alert_type": alert.alert_type,
|
||||
"metric_value": alert.metric_value,
|
||||
"threshold_value": alert.threshold_value,
|
||||
"created_at": alert.created_at.isoformat()
|
||||
}
|
||||
|
||||
headers = config.get('headers', {})
|
||||
method = config.get('method', 'POST')
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
if method.upper() == 'POST':
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
else:
|
||||
response = await client.get(url, params=payload, headers=headers)
|
||||
|
||||
if response.status_code < 400:
|
||||
return {'success': True, 'channel': 'webhook', 'status': response.status_code}
|
||||
else:
|
||||
return {'success': False, 'error': f'HTTP {response.status_code}'}
|
||||
except Exception as e:
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def acknowledge_alert(self, alert_id: int, acknowledged_by: str) -> Optional[AlertRecord]:
|
||||
"""确认告警"""
|
||||
alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first()
|
||||
if not alert:
|
||||
return None
|
||||
|
||||
alert.status = 'acknowledged'
|
||||
alert.acknowledged_by = acknowledged_by
|
||||
alert.acknowledged_at = datetime.now()
|
||||
self.db.commit()
|
||||
|
||||
return alert
|
||||
|
||||
def resolve_alert(self, alert_id: int) -> Optional[AlertRecord]:
|
||||
"""解决告警"""
|
||||
alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first()
|
||||
if not alert:
|
||||
return None
|
||||
|
||||
alert.status = 'resolved'
|
||||
alert.resolved_at = datetime.now()
|
||||
self.db.commit()
|
||||
|
||||
return alert
|
||||
309
backend/app/services/cache.py
Normal file
309
backend/app/services/cache.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""Redis缓存服务"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Any, Union
|
||||
from functools import lru_cache
|
||||
|
||||
try:
|
||||
import redis
|
||||
from redis import Redis
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REDIS_AVAILABLE = False
|
||||
Redis = None
|
||||
|
||||
from ..config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局Redis连接池
|
||||
_redis_pool: Optional[Any] = None
|
||||
_redis_client: Optional[Any] = None
|
||||
|
||||
|
||||
def get_redis_client() -> Optional[Any]:
|
||||
"""获取Redis客户端单例"""
|
||||
global _redis_pool, _redis_client
|
||||
|
||||
if not REDIS_AVAILABLE:
|
||||
logger.warning("Redis module not installed, cache disabled")
|
||||
return None
|
||||
|
||||
if _redis_client is not None:
|
||||
return _redis_client
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
try:
|
||||
_redis_pool = redis.ConnectionPool.from_url(
|
||||
settings.REDIS_URL,
|
||||
max_connections=20,
|
||||
decode_responses=True
|
||||
)
|
||||
_redis_client = Redis(connection_pool=_redis_pool)
|
||||
# 测试连接
|
||||
_redis_client.ping()
|
||||
logger.info(f"Redis connected: {settings.REDIS_URL}")
|
||||
return _redis_client
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis connection failed: {e}, cache disabled")
|
||||
_redis_client = None
|
||||
return None
|
||||
|
||||
|
||||
class CacheService:
|
||||
"""缓存服务
|
||||
|
||||
提供统一的缓存接口,支持Redis和内存回退
|
||||
|
||||
使用示例:
|
||||
cache = CacheService()
|
||||
|
||||
# 设置缓存
|
||||
cache.set("user:123", {"name": "test"}, ttl=3600)
|
||||
|
||||
# 获取缓存
|
||||
user = cache.get("user:123")
|
||||
|
||||
# 删除缓存
|
||||
cache.delete("user:123")
|
||||
"""
|
||||
|
||||
def __init__(self, prefix: Optional[str] = None):
|
||||
"""初始化缓存服务
|
||||
|
||||
Args:
|
||||
prefix: 键前缀,默认使用配置中的REDIS_PREFIX
|
||||
"""
|
||||
settings = get_settings()
|
||||
self.prefix = prefix or settings.REDIS_PREFIX
|
||||
self._client = get_redis_client()
|
||||
|
||||
# 内存回退缓存(当Redis不可用时使用)
|
||||
self._memory_cache: dict = {}
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Redis是否可用"""
|
||||
return self._client is not None
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""生成完整的缓存键"""
|
||||
return f"{self.prefix}{key}"
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
"""获取缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
缓存值或默认值
|
||||
"""
|
||||
full_key = self._make_key(key)
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
value = self._client.get(full_key)
|
||||
if value is None:
|
||||
return default
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return value
|
||||
except Exception as e:
|
||||
logger.error(f"Cache get error: {e}")
|
||||
return default
|
||||
else:
|
||||
# 内存回退
|
||||
return self._memory_cache.get(full_key, default)
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[int] = None,
|
||||
nx: bool = False
|
||||
) -> bool:
|
||||
"""设置缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
ttl: 过期时间(秒)
|
||||
nx: 只在键不存在时设置
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
full_key = self._make_key(key)
|
||||
|
||||
# 序列化值
|
||||
if isinstance(value, (dict, list)):
|
||||
serialized = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
serialized = str(value) if value is not None else ""
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
if nx:
|
||||
result = self._client.set(full_key, serialized, ex=ttl, nx=True)
|
||||
else:
|
||||
result = self._client.set(full_key, serialized, ex=ttl)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Cache set error: {e}")
|
||||
return False
|
||||
else:
|
||||
# 内存回退(不支持TTL和NX)
|
||||
if nx and full_key in self._memory_cache:
|
||||
return False
|
||||
self._memory_cache[full_key] = value
|
||||
return True
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""删除缓存
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
full_key = self._make_key(key)
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
return bool(self._client.delete(full_key))
|
||||
except Exception as e:
|
||||
logger.error(f"Cache delete error: {e}")
|
||||
return False
|
||||
else:
|
||||
return self._memory_cache.pop(full_key, None) is not None
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
"""检查键是否存在
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
full_key = self._make_key(key)
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
return bool(self._client.exists(full_key))
|
||||
except Exception as e:
|
||||
logger.error(f"Cache exists error: {e}")
|
||||
return False
|
||||
else:
|
||||
return full_key in self._memory_cache
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
"""获取键的剩余过期时间
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示键不存在
|
||||
"""
|
||||
full_key = self._make_key(key)
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
return self._client.ttl(full_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Cache ttl error: {e}")
|
||||
return -2
|
||||
else:
|
||||
return -1 if full_key in self._memory_cache else -2
|
||||
|
||||
def incr(self, key: str, amount: int = 1) -> int:
|
||||
"""递增计数器
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
amount: 递增量
|
||||
|
||||
Returns:
|
||||
递增后的值
|
||||
"""
|
||||
full_key = self._make_key(key)
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
return self._client.incrby(full_key, amount)
|
||||
except Exception as e:
|
||||
logger.error(f"Cache incr error: {e}")
|
||||
return 0
|
||||
else:
|
||||
current = self._memory_cache.get(full_key, 0)
|
||||
new_value = int(current) + amount
|
||||
self._memory_cache[full_key] = new_value
|
||||
return new_value
|
||||
|
||||
def expire(self, key: str, ttl: int) -> bool:
|
||||
"""设置键的过期时间
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
ttl: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
full_key = self._make_key(key)
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
return bool(self._client.expire(full_key, ttl))
|
||||
except Exception as e:
|
||||
logger.error(f"Cache expire error: {e}")
|
||||
return False
|
||||
else:
|
||||
return full_key in self._memory_cache
|
||||
|
||||
def clear_prefix(self, prefix: str) -> int:
|
||||
"""删除指定前缀的所有键
|
||||
|
||||
Args:
|
||||
prefix: 键前缀
|
||||
|
||||
Returns:
|
||||
删除的键数量
|
||||
"""
|
||||
full_prefix = self._make_key(prefix)
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
keys = self._client.keys(f"{full_prefix}*")
|
||||
if keys:
|
||||
return self._client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Cache clear_prefix error: {e}")
|
||||
return 0
|
||||
else:
|
||||
count = 0
|
||||
keys_to_delete = [k for k in self._memory_cache if k.startswith(full_prefix)]
|
||||
for k in keys_to_delete:
|
||||
del self._memory_cache[k]
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
# 全局缓存实例
|
||||
_cache_instance: Optional[CacheService] = None
|
||||
|
||||
|
||||
def get_cache() -> CacheService:
|
||||
"""获取全局缓存实例"""
|
||||
global _cache_instance
|
||||
if _cache_instance is None:
|
||||
_cache_instance = CacheService()
|
||||
return _cache_instance
|
||||
420
backend/app/services/cost.py
Normal file
420
backend/app/services/cost.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""费用计算服务"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Optional, Dict, List
|
||||
from functools import lru_cache
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from ..models.pricing import ModelPricing, TenantBilling
|
||||
from ..models.stats import AICallEvent
|
||||
from .cache import get_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CostCalculator:
|
||||
"""费用计算器
|
||||
|
||||
使用示例:
|
||||
calculator = CostCalculator(db)
|
||||
|
||||
# 计算单次调用费用
|
||||
cost = calculator.calculate_cost("gpt-4", input_tokens=100, output_tokens=200)
|
||||
|
||||
# 生成月度账单
|
||||
billing = calculator.generate_monthly_billing("qiqi", "2026-01")
|
||||
"""
|
||||
|
||||
# 默认模型价格(当数据库中无配置时使用)
|
||||
DEFAULT_PRICING = {
|
||||
# OpenAI
|
||||
"gpt-4": {"input": 0.21, "output": 0.42}, # 元/1K tokens
|
||||
"gpt-4-turbo": {"input": 0.07, "output": 0.21},
|
||||
"gpt-4o": {"input": 0.035, "output": 0.105},
|
||||
"gpt-4o-mini": {"input": 0.00105, "output": 0.0042},
|
||||
"gpt-3.5-turbo": {"input": 0.0035, "output": 0.014},
|
||||
|
||||
# Anthropic
|
||||
"claude-3-opus": {"input": 0.105, "output": 0.525},
|
||||
"claude-3-sonnet": {"input": 0.021, "output": 0.105},
|
||||
"claude-3-haiku": {"input": 0.00175, "output": 0.00875},
|
||||
"claude-3.5-sonnet": {"input": 0.021, "output": 0.105},
|
||||
|
||||
# 国内模型
|
||||
"qwen-max": {"input": 0.02, "output": 0.06},
|
||||
"qwen-plus": {"input": 0.004, "output": 0.012},
|
||||
"qwen-turbo": {"input": 0.002, "output": 0.006},
|
||||
"glm-4": {"input": 0.01, "output": 0.01},
|
||||
"glm-4-flash": {"input": 0.0001, "output": 0.0001},
|
||||
"deepseek-chat": {"input": 0.001, "output": 0.002},
|
||||
"deepseek-coder": {"input": 0.001, "output": 0.002},
|
||||
|
||||
# 默认
|
||||
"default": {"input": 0.01, "output": 0.03}
|
||||
}
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._cache = get_cache()
|
||||
self._pricing_cache: Dict[str, ModelPricing] = {}
|
||||
|
||||
def get_model_pricing(self, model_name: str) -> Optional[ModelPricing]:
|
||||
"""获取模型价格配置
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
ModelPricing实例或None
|
||||
"""
|
||||
# 尝试从缓存获取
|
||||
cache_key = f"pricing:{model_name}"
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached:
|
||||
return self._dict_to_pricing(cached)
|
||||
|
||||
# 从数据库查询
|
||||
pricing = self.db.query(ModelPricing).filter(
|
||||
ModelPricing.model_name == model_name,
|
||||
ModelPricing.status == 1
|
||||
).first()
|
||||
|
||||
if pricing:
|
||||
# 缓存1小时
|
||||
self._cache.set(cache_key, self._pricing_to_dict(pricing), ttl=3600)
|
||||
return pricing
|
||||
|
||||
return None
|
||||
|
||||
def _pricing_to_dict(self, pricing: ModelPricing) -> dict:
|
||||
return {
|
||||
"model_name": pricing.model_name,
|
||||
"input_price_per_1k": str(pricing.input_price_per_1k),
|
||||
"output_price_per_1k": str(pricing.output_price_per_1k),
|
||||
"fixed_price_per_call": str(pricing.fixed_price_per_call),
|
||||
"pricing_type": pricing.pricing_type
|
||||
}
|
||||
|
||||
def _dict_to_pricing(self, d: dict) -> ModelPricing:
|
||||
pricing = ModelPricing()
|
||||
pricing.model_name = d.get("model_name")
|
||||
pricing.input_price_per_1k = Decimal(d.get("input_price_per_1k", "0"))
|
||||
pricing.output_price_per_1k = Decimal(d.get("output_price_per_1k", "0"))
|
||||
pricing.fixed_price_per_call = Decimal(d.get("fixed_price_per_call", "0"))
|
||||
pricing.pricing_type = d.get("pricing_type", "token")
|
||||
return pricing
|
||||
|
||||
def calculate_cost(
|
||||
self,
|
||||
model_name: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
call_count: int = 1
|
||||
) -> Decimal:
|
||||
"""计算调用费用
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
input_tokens: 输入token数
|
||||
output_tokens: 输出token数
|
||||
call_count: 调用次数
|
||||
|
||||
Returns:
|
||||
费用(元)
|
||||
"""
|
||||
# 尝试获取数据库配置
|
||||
pricing = self.get_model_pricing(model_name)
|
||||
|
||||
if pricing:
|
||||
if pricing.pricing_type == 'call':
|
||||
return pricing.fixed_price_per_call * call_count
|
||||
elif pricing.pricing_type == 'hybrid':
|
||||
token_cost = (
|
||||
pricing.input_price_per_1k * Decimal(input_tokens) / 1000 +
|
||||
pricing.output_price_per_1k * Decimal(output_tokens) / 1000
|
||||
)
|
||||
call_cost = pricing.fixed_price_per_call * call_count
|
||||
return token_cost + call_cost
|
||||
else: # token
|
||||
return (
|
||||
pricing.input_price_per_1k * Decimal(input_tokens) / 1000 +
|
||||
pricing.output_price_per_1k * Decimal(output_tokens) / 1000
|
||||
)
|
||||
|
||||
# 使用默认价格
|
||||
default_prices = self.DEFAULT_PRICING.get(model_name) or self.DEFAULT_PRICING.get("default")
|
||||
input_price = Decimal(str(default_prices["input"]))
|
||||
output_price = Decimal(str(default_prices["output"]))
|
||||
|
||||
return (
|
||||
input_price * Decimal(input_tokens) / 1000 +
|
||||
output_price * Decimal(output_tokens) / 1000
|
||||
)
|
||||
|
||||
def calculate_event_cost(self, event: AICallEvent) -> Decimal:
|
||||
"""计算单个事件的费用
|
||||
|
||||
Args:
|
||||
event: AI调用事件
|
||||
|
||||
Returns:
|
||||
费用(元)
|
||||
"""
|
||||
return self.calculate_cost(
|
||||
model_name=event.model or "default",
|
||||
input_tokens=event.input_tokens or 0,
|
||||
output_tokens=event.output_tokens or 0
|
||||
)
|
||||
|
||||
def update_event_costs(self, start_date: str = None, end_date: str = None) -> int:
|
||||
"""批量更新事件费用
|
||||
|
||||
对于cost为0或NULL的事件,重新计算费用
|
||||
|
||||
Args:
|
||||
start_date: 开始日期,格式 YYYY-MM-DD
|
||||
end_date: 结束日期,格式 YYYY-MM-DD
|
||||
|
||||
Returns:
|
||||
更新的记录数
|
||||
"""
|
||||
query = self.db.query(AICallEvent).filter(
|
||||
(AICallEvent.cost == None) | (AICallEvent.cost == 0)
|
||||
)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(AICallEvent.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
|
||||
|
||||
events = query.all()
|
||||
updated = 0
|
||||
|
||||
for event in events:
|
||||
try:
|
||||
cost = self.calculate_event_cost(event)
|
||||
event.cost = cost
|
||||
updated += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to calculate cost for event {event.id}: {e}")
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Updated {updated} event costs")
|
||||
|
||||
return updated
|
||||
|
||||
def generate_monthly_billing(
|
||||
self,
|
||||
tenant_id: str,
|
||||
billing_month: str
|
||||
) -> TenantBilling:
|
||||
"""生成月度账单
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
billing_month: 账单月份,格式 YYYY-MM
|
||||
|
||||
Returns:
|
||||
TenantBilling实例
|
||||
"""
|
||||
# 检查是否已存在
|
||||
existing = self.db.query(TenantBilling).filter(
|
||||
TenantBilling.tenant_id == tenant_id,
|
||||
TenantBilling.billing_month == billing_month
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
billing = existing
|
||||
else:
|
||||
billing = TenantBilling(
|
||||
tenant_id=tenant_id,
|
||||
billing_month=billing_month
|
||||
)
|
||||
self.db.add(billing)
|
||||
|
||||
# 计算统计数据
|
||||
start_date = f"{billing_month}-01"
|
||||
year, month = billing_month.split("-")
|
||||
if int(month) == 12:
|
||||
end_date = f"{int(year)+1}-01-01"
|
||||
else:
|
||||
end_date = f"{year}-{int(month)+1:02d}-01"
|
||||
|
||||
# 聚合查询
|
||||
stats = self.db.query(
|
||||
func.count(AICallEvent.id).label('total_calls'),
|
||||
func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'),
|
||||
func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'),
|
||||
func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost')
|
||||
).filter(
|
||||
AICallEvent.tenant_id == tenant_id,
|
||||
AICallEvent.created_at >= start_date,
|
||||
AICallEvent.created_at < end_date
|
||||
).first()
|
||||
|
||||
billing.total_calls = stats.total_calls or 0
|
||||
billing.total_input_tokens = int(stats.total_input or 0)
|
||||
billing.total_output_tokens = int(stats.total_output or 0)
|
||||
billing.total_cost = stats.total_cost or Decimal("0")
|
||||
|
||||
# 按模型统计
|
||||
model_stats = self.db.query(
|
||||
AICallEvent.model,
|
||||
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
|
||||
).filter(
|
||||
AICallEvent.tenant_id == tenant_id,
|
||||
AICallEvent.created_at >= start_date,
|
||||
AICallEvent.created_at < end_date
|
||||
).group_by(AICallEvent.model).all()
|
||||
|
||||
billing.cost_by_model = {
|
||||
m.model or "unknown": float(m.cost) for m in model_stats
|
||||
}
|
||||
|
||||
# 按应用统计
|
||||
app_stats = self.db.query(
|
||||
AICallEvent.app_code,
|
||||
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
|
||||
).filter(
|
||||
AICallEvent.tenant_id == tenant_id,
|
||||
AICallEvent.created_at >= start_date,
|
||||
AICallEvent.created_at < end_date
|
||||
).group_by(AICallEvent.app_code).all()
|
||||
|
||||
billing.cost_by_app = {
|
||||
a.app_code or "unknown": float(a.cost) for a in app_stats
|
||||
}
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(billing)
|
||||
|
||||
return billing
|
||||
|
||||
def get_cost_summary(
|
||||
self,
|
||||
tenant_id: str = None,
|
||||
start_date: str = None,
|
||||
end_date: str = None
|
||||
) -> Dict:
|
||||
"""获取费用汇总
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID(可选)
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
费用汇总字典
|
||||
"""
|
||||
query = self.db.query(
|
||||
func.count(AICallEvent.id).label('total_calls'),
|
||||
func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'),
|
||||
func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'),
|
||||
func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost')
|
||||
)
|
||||
|
||||
if tenant_id:
|
||||
query = query.filter(AICallEvent.tenant_id == tenant_id)
|
||||
if start_date:
|
||||
query = query.filter(AICallEvent.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
|
||||
|
||||
stats = query.first()
|
||||
|
||||
return {
|
||||
"total_calls": stats.total_calls or 0,
|
||||
"total_input_tokens": int(stats.total_input or 0),
|
||||
"total_output_tokens": int(stats.total_output or 0),
|
||||
"total_cost": float(stats.total_cost or 0)
|
||||
}
|
||||
|
||||
def get_cost_by_tenant(
|
||||
self,
|
||||
start_date: str = None,
|
||||
end_date: str = None
|
||||
) -> List[Dict]:
|
||||
"""按租户统计费用
|
||||
|
||||
Returns:
|
||||
租户费用列表
|
||||
"""
|
||||
query = self.db.query(
|
||||
AICallEvent.tenant_id,
|
||||
func.count(AICallEvent.id).label('calls'),
|
||||
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
|
||||
)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(AICallEvent.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
|
||||
|
||||
results = query.group_by(AICallEvent.tenant_id).order_by(
|
||||
func.sum(AICallEvent.cost).desc()
|
||||
).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"tenant_id": r.tenant_id,
|
||||
"calls": r.calls,
|
||||
"cost": float(r.cost)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
def get_cost_by_model(
|
||||
self,
|
||||
tenant_id: str = None,
|
||||
start_date: str = None,
|
||||
end_date: str = None
|
||||
) -> List[Dict]:
|
||||
"""按模型统计费用
|
||||
|
||||
Returns:
|
||||
模型费用列表
|
||||
"""
|
||||
query = self.db.query(
|
||||
AICallEvent.model,
|
||||
func.count(AICallEvent.id).label('calls'),
|
||||
func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('input_tokens'),
|
||||
func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('output_tokens'),
|
||||
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
|
||||
)
|
||||
|
||||
if tenant_id:
|
||||
query = query.filter(AICallEvent.tenant_id == tenant_id)
|
||||
if start_date:
|
||||
query = query.filter(AICallEvent.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
|
||||
|
||||
results = query.group_by(AICallEvent.model).order_by(
|
||||
func.sum(AICallEvent.cost).desc()
|
||||
).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"model": r.model or "unknown",
|
||||
"calls": r.calls,
|
||||
"input_tokens": int(r.input_tokens),
|
||||
"output_tokens": int(r.output_tokens),
|
||||
"cost": float(r.cost)
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def calculate_cost(
|
||||
db: Session,
|
||||
model_name: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0
|
||||
) -> Decimal:
|
||||
"""快速计算费用"""
|
||||
calculator = CostCalculator(db)
|
||||
return calculator.calculate_cost(model_name, input_tokens, output_tokens)
|
||||
346
backend/app/services/quota.py
Normal file
346
backend/app/services/quota.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""配额管理服务"""
|
||||
import logging
|
||||
from datetime import datetime, date, timedelta
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
|
||||
from ..models.tenant import Tenant, Subscription
|
||||
from ..models.stats import AICallEvent
|
||||
from .cache import get_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaConfig:
|
||||
"""配额配置"""
|
||||
daily_calls: int = 0 # 每日调用限制,0表示无限制
|
||||
daily_tokens: int = 0 # 每日Token限制
|
||||
monthly_calls: int = 0 # 每月调用限制
|
||||
monthly_tokens: int = 0 # 每月Token限制
|
||||
monthly_cost: float = 0 # 每月费用限制(元)
|
||||
concurrent_calls: int = 0 # 并发调用限制
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaUsage:
|
||||
"""配额使用情况"""
|
||||
daily_calls: int = 0
|
||||
daily_tokens: int = 0
|
||||
monthly_calls: int = 0
|
||||
monthly_tokens: int = 0
|
||||
monthly_cost: float = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCheckResult:
|
||||
"""配额检查结果"""
|
||||
allowed: bool
|
||||
reason: Optional[str] = None
|
||||
quota_type: Optional[str] = None
|
||||
limit: int = 0
|
||||
used: int = 0
|
||||
remaining: int = 0
|
||||
|
||||
|
||||
class QuotaService:
|
||||
"""配额管理服务
|
||||
|
||||
使用示例:
|
||||
quota_service = QuotaService(db)
|
||||
|
||||
# 检查配额
|
||||
result = quota_service.check_quota("qiqi", "tools")
|
||||
if not result.allowed:
|
||||
raise HTTPException(status_code=429, detail=result.reason)
|
||||
|
||||
# 获取使用情况
|
||||
usage = quota_service.get_usage("qiqi", "tools")
|
||||
"""
|
||||
|
||||
# 默认配额(当无订阅配置时使用)
|
||||
DEFAULT_QUOTA = QuotaConfig(
|
||||
daily_calls=1000,
|
||||
daily_tokens=100000,
|
||||
monthly_calls=30000,
|
||||
monthly_tokens=3000000,
|
||||
monthly_cost=100
|
||||
)
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._cache = get_cache()
|
||||
|
||||
def get_subscription(self, tenant_id: str, app_code: str) -> Optional[Subscription]:
|
||||
"""获取租户订阅配置"""
|
||||
return self.db.query(Subscription).filter(
|
||||
Subscription.tenant_id == tenant_id,
|
||||
Subscription.app_code == app_code,
|
||||
Subscription.status == 'active'
|
||||
).first()
|
||||
|
||||
def get_quota_config(self, tenant_id: str, app_code: str) -> QuotaConfig:
|
||||
"""获取配额配置
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
app_code: 应用代码
|
||||
|
||||
Returns:
|
||||
QuotaConfig实例
|
||||
"""
|
||||
# 尝试从缓存获取
|
||||
cache_key = f"quota:config:{tenant_id}:{app_code}"
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached:
|
||||
return QuotaConfig(**cached)
|
||||
|
||||
# 从订阅表获取
|
||||
subscription = self.get_subscription(tenant_id, app_code)
|
||||
|
||||
if subscription and subscription.quota:
|
||||
quota = subscription.quota
|
||||
config = QuotaConfig(
|
||||
daily_calls=quota.get('daily_calls', 0),
|
||||
daily_tokens=quota.get('daily_tokens', 0),
|
||||
monthly_calls=quota.get('monthly_calls', 0),
|
||||
monthly_tokens=quota.get('monthly_tokens', 0),
|
||||
monthly_cost=quota.get('monthly_cost', 0),
|
||||
concurrent_calls=quota.get('concurrent_calls', 0)
|
||||
)
|
||||
else:
|
||||
config = self.DEFAULT_QUOTA
|
||||
|
||||
# 缓存5分钟
|
||||
self._cache.set(cache_key, config.__dict__, ttl=300)
|
||||
|
||||
return config
|
||||
|
||||
def get_usage(self, tenant_id: str, app_code: str) -> QuotaUsage:
|
||||
"""获取配额使用情况
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
app_code: 应用代码
|
||||
|
||||
Returns:
|
||||
QuotaUsage实例
|
||||
"""
|
||||
today = date.today()
|
||||
month_start = today.replace(day=1)
|
||||
|
||||
# 今日使用量
|
||||
daily_stats = self.db.query(
|
||||
func.count(AICallEvent.id).label('calls'),
|
||||
func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens')
|
||||
).filter(
|
||||
AICallEvent.tenant_id == tenant_id,
|
||||
AICallEvent.app_code == app_code,
|
||||
func.date(AICallEvent.created_at) == today
|
||||
).first()
|
||||
|
||||
# 本月使用量
|
||||
monthly_stats = self.db.query(
|
||||
func.count(AICallEvent.id).label('calls'),
|
||||
func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens'),
|
||||
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
|
||||
).filter(
|
||||
AICallEvent.tenant_id == tenant_id,
|
||||
AICallEvent.app_code == app_code,
|
||||
func.date(AICallEvent.created_at) >= month_start
|
||||
).first()
|
||||
|
||||
return QuotaUsage(
|
||||
daily_calls=daily_stats.calls or 0,
|
||||
daily_tokens=int(daily_stats.tokens or 0),
|
||||
monthly_calls=monthly_stats.calls or 0,
|
||||
monthly_tokens=int(monthly_stats.tokens or 0),
|
||||
monthly_cost=float(monthly_stats.cost or 0)
|
||||
)
|
||||
|
||||
def check_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_code: str,
|
||||
estimated_tokens: int = 0
|
||||
) -> QuotaCheckResult:
|
||||
"""检查配额是否足够
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
app_code: 应用代码
|
||||
estimated_tokens: 预估Token消耗
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult实例
|
||||
"""
|
||||
config = self.get_quota_config(tenant_id, app_code)
|
||||
usage = self.get_usage(tenant_id, app_code)
|
||||
|
||||
# 检查日调用次数
|
||||
if config.daily_calls > 0:
|
||||
if usage.daily_calls >= config.daily_calls:
|
||||
return QuotaCheckResult(
|
||||
allowed=False,
|
||||
reason=f"已达到每日调用限制 ({config.daily_calls} 次)",
|
||||
quota_type="daily_calls",
|
||||
limit=config.daily_calls,
|
||||
used=usage.daily_calls,
|
||||
remaining=0
|
||||
)
|
||||
|
||||
# 检查日Token限制
|
||||
if config.daily_tokens > 0:
|
||||
if usage.daily_tokens + estimated_tokens > config.daily_tokens:
|
||||
return QuotaCheckResult(
|
||||
allowed=False,
|
||||
reason=f"已达到每日Token限制 ({config.daily_tokens:,})",
|
||||
quota_type="daily_tokens",
|
||||
limit=config.daily_tokens,
|
||||
used=usage.daily_tokens,
|
||||
remaining=max(0, config.daily_tokens - usage.daily_tokens)
|
||||
)
|
||||
|
||||
# 检查月调用次数
|
||||
if config.monthly_calls > 0:
|
||||
if usage.monthly_calls >= config.monthly_calls:
|
||||
return QuotaCheckResult(
|
||||
allowed=False,
|
||||
reason=f"已达到每月调用限制 ({config.monthly_calls} 次)",
|
||||
quota_type="monthly_calls",
|
||||
limit=config.monthly_calls,
|
||||
used=usage.monthly_calls,
|
||||
remaining=0
|
||||
)
|
||||
|
||||
# 检查月Token限制
|
||||
if config.monthly_tokens > 0:
|
||||
if usage.monthly_tokens + estimated_tokens > config.monthly_tokens:
|
||||
return QuotaCheckResult(
|
||||
allowed=False,
|
||||
reason=f"已达到每月Token限制 ({config.monthly_tokens:,})",
|
||||
quota_type="monthly_tokens",
|
||||
limit=config.monthly_tokens,
|
||||
used=usage.monthly_tokens,
|
||||
remaining=max(0, config.monthly_tokens - usage.monthly_tokens)
|
||||
)
|
||||
|
||||
# 检查月费用限制
|
||||
if config.monthly_cost > 0:
|
||||
if usage.monthly_cost >= config.monthly_cost:
|
||||
return QuotaCheckResult(
|
||||
allowed=False,
|
||||
reason=f"已达到每月费用限制 (¥{config.monthly_cost:.2f})",
|
||||
quota_type="monthly_cost",
|
||||
limit=int(config.monthly_cost * 100), # 转为分
|
||||
used=int(usage.monthly_cost * 100),
|
||||
remaining=max(0, int((config.monthly_cost - usage.monthly_cost) * 100))
|
||||
)
|
||||
|
||||
# 所有检查通过
|
||||
return QuotaCheckResult(
|
||||
allowed=True,
|
||||
quota_type="daily_calls",
|
||||
limit=config.daily_calls,
|
||||
used=usage.daily_calls,
|
||||
remaining=max(0, config.daily_calls - usage.daily_calls) if config.daily_calls > 0 else -1
|
||||
)
|
||||
|
||||
def get_quota_summary(self, tenant_id: str, app_code: str) -> Dict[str, Any]:
|
||||
"""获取配额汇总信息
|
||||
|
||||
Returns:
|
||||
包含配额配置和使用情况的字典
|
||||
"""
|
||||
config = self.get_quota_config(tenant_id, app_code)
|
||||
usage = self.get_usage(tenant_id, app_code)
|
||||
|
||||
def calc_percentage(used: int, limit: int) -> float:
|
||||
if limit <= 0:
|
||||
return 0
|
||||
return min(100, round(used / limit * 100, 1))
|
||||
|
||||
return {
|
||||
"config": {
|
||||
"daily_calls": config.daily_calls,
|
||||
"daily_tokens": config.daily_tokens,
|
||||
"monthly_calls": config.monthly_calls,
|
||||
"monthly_tokens": config.monthly_tokens,
|
||||
"monthly_cost": config.monthly_cost
|
||||
},
|
||||
"usage": {
|
||||
"daily_calls": usage.daily_calls,
|
||||
"daily_tokens": usage.daily_tokens,
|
||||
"monthly_calls": usage.monthly_calls,
|
||||
"monthly_tokens": usage.monthly_tokens,
|
||||
"monthly_cost": round(usage.monthly_cost, 2)
|
||||
},
|
||||
"percentage": {
|
||||
"daily_calls": calc_percentage(usage.daily_calls, config.daily_calls),
|
||||
"daily_tokens": calc_percentage(usage.daily_tokens, config.daily_tokens),
|
||||
"monthly_calls": calc_percentage(usage.monthly_calls, config.monthly_calls),
|
||||
"monthly_tokens": calc_percentage(usage.monthly_tokens, config.monthly_tokens),
|
||||
"monthly_cost": calc_percentage(int(usage.monthly_cost * 100), int(config.monthly_cost * 100))
|
||||
}
|
||||
}
|
||||
|
||||
def update_quota(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_code: str,
|
||||
quota_config: Dict[str, Any]
|
||||
) -> Subscription:
|
||||
"""更新配额配置
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
app_code: 应用代码
|
||||
quota_config: 配额配置字典
|
||||
|
||||
Returns:
|
||||
更新后的Subscription实例
|
||||
"""
|
||||
subscription = self.get_subscription(tenant_id, app_code)
|
||||
|
||||
if not subscription:
|
||||
# 创建新订阅
|
||||
subscription = Subscription(
|
||||
tenant_id=tenant_id,
|
||||
app_code=app_code,
|
||||
start_date=date.today(),
|
||||
quota=quota_config,
|
||||
status='active'
|
||||
)
|
||||
self.db.add(subscription)
|
||||
else:
|
||||
# 更新现有订阅
|
||||
subscription.quota = quota_config
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(subscription)
|
||||
|
||||
# 清除缓存
|
||||
cache_key = f"quota:config:{tenant_id}:{app_code}"
|
||||
self._cache.delete(cache_key)
|
||||
|
||||
return subscription
|
||||
|
||||
|
||||
def check_quota_middleware(
|
||||
db: Session,
|
||||
tenant_id: str,
|
||||
app_code: str,
|
||||
estimated_tokens: int = 0
|
||||
) -> QuotaCheckResult:
|
||||
"""配额检查中间件函数
|
||||
|
||||
可在路由中使用:
|
||||
result = check_quota_middleware(db, "qiqi", "tools")
|
||||
if not result.allowed:
|
||||
raise HTTPException(status_code=429, detail=result.reason)
|
||||
"""
|
||||
service = QuotaService(db)
|
||||
return service.check_quota(tenant_id, app_code, estimated_tokens)
|
||||
371
backend/app/services/wechat.py
Normal file
371
backend/app/services/wechat.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""企业微信服务"""
|
||||
import hashlib
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
from ..config import get_settings
|
||||
from .cache import get_cache
|
||||
from .crypto import decrypt_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@dataclass
|
||||
class WechatConfig:
|
||||
"""企业微信应用配置"""
|
||||
corp_id: str
|
||||
agent_id: str
|
||||
secret: str
|
||||
|
||||
|
||||
class WechatService:
|
||||
"""企业微信服务
|
||||
|
||||
提供access_token获取、JS-SDK签名、OAuth2等功能
|
||||
|
||||
使用示例:
|
||||
wechat = WechatService(corp_id="wwxxxx", agent_id="1000001", secret="xxx")
|
||||
|
||||
# 获取access_token
|
||||
token = await wechat.get_access_token()
|
||||
|
||||
# 获取JS-SDK签名
|
||||
signature = await wechat.get_jssdk_signature("https://example.com/page")
|
||||
"""
|
||||
|
||||
# 企业微信API基础URL
|
||||
BASE_URL = "https://qyapi.weixin.qq.com"
|
||||
|
||||
def __init__(self, corp_id: str, agent_id: str, secret: str):
|
||||
"""初始化企业微信服务
|
||||
|
||||
Args:
|
||||
corp_id: 企业ID
|
||||
agent_id: 应用AgentId
|
||||
secret: 应用Secret(明文)
|
||||
"""
|
||||
self.corp_id = corp_id
|
||||
self.agent_id = agent_id
|
||||
self.secret = secret
|
||||
self._cache = get_cache()
|
||||
|
||||
@classmethod
|
||||
def from_wechat_app(cls, wechat_app) -> "WechatService":
|
||||
"""从TenantWechatApp模型创建服务实例
|
||||
|
||||
Args:
|
||||
wechat_app: TenantWechatApp数据库模型
|
||||
|
||||
Returns:
|
||||
WechatService实例
|
||||
"""
|
||||
secret = ""
|
||||
if wechat_app.secret_encrypted:
|
||||
try:
|
||||
secret = decrypt_config(wechat_app.secret_encrypted)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt secret: {e}")
|
||||
|
||||
return cls(
|
||||
corp_id=wechat_app.corp_id,
|
||||
agent_id=wechat_app.agent_id,
|
||||
secret=secret
|
||||
)
|
||||
|
||||
def _cache_key(self, key_type: str) -> str:
|
||||
"""生成缓存键"""
|
||||
return f"wechat:{self.corp_id}:{self.agent_id}:{key_type}"
|
||||
|
||||
async def get_access_token(self, force_refresh: bool = False) -> Optional[str]:
|
||||
"""获取access_token
|
||||
|
||||
企业微信access_token有效期7200秒,需要缓存
|
||||
|
||||
Args:
|
||||
force_refresh: 是否强制刷新
|
||||
|
||||
Returns:
|
||||
access_token或None
|
||||
"""
|
||||
cache_key = self._cache_key("access_token")
|
||||
|
||||
# 尝试从缓存获取
|
||||
if not force_refresh:
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached:
|
||||
logger.debug(f"Access token from cache: {cached[:20]}...")
|
||||
return cached
|
||||
|
||||
# 从企业微信API获取
|
||||
url = f"{self.BASE_URL}/cgi-bin/gettoken"
|
||||
params = {
|
||||
"corpid": self.corp_id,
|
||||
"corpsecret": self.secret
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(url, params=params)
|
||||
result = response.json()
|
||||
|
||||
if result.get("errcode", 0) != 0:
|
||||
logger.error(f"Get access_token failed: {result}")
|
||||
return None
|
||||
|
||||
access_token = result.get("access_token")
|
||||
expires_in = result.get("expires_in", 7200)
|
||||
|
||||
# 缓存,提前200秒过期以确保安全
|
||||
self._cache.set(
|
||||
cache_key,
|
||||
access_token,
|
||||
ttl=min(expires_in - 200, settings.WECHAT_ACCESS_TOKEN_EXPIRE)
|
||||
)
|
||||
|
||||
logger.info(f"Got new access_token for {self.corp_id}")
|
||||
return access_token
|
||||
except Exception as e:
|
||||
logger.error(f"Get access_token error: {e}")
|
||||
return None
|
||||
|
||||
async def get_jsapi_ticket(self, force_refresh: bool = False) -> Optional[str]:
|
||||
"""获取jsapi_ticket
|
||||
|
||||
用于生成JS-SDK签名
|
||||
|
||||
Args:
|
||||
force_refresh: 是否强制刷新
|
||||
|
||||
Returns:
|
||||
jsapi_ticket或None
|
||||
"""
|
||||
cache_key = self._cache_key("jsapi_ticket")
|
||||
|
||||
# 尝试从缓存获取
|
||||
if not force_refresh:
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached:
|
||||
logger.debug(f"JSAPI ticket from cache: {cached[:20]}...")
|
||||
return cached
|
||||
|
||||
# 先获取access_token
|
||||
access_token = await self.get_access_token()
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
# 获取jsapi_ticket
|
||||
url = f"{self.BASE_URL}/cgi-bin/get_jsapi_ticket"
|
||||
params = {"access_token": access_token}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(url, params=params)
|
||||
result = response.json()
|
||||
|
||||
if result.get("errcode", 0) != 0:
|
||||
logger.error(f"Get jsapi_ticket failed: {result}")
|
||||
return None
|
||||
|
||||
ticket = result.get("ticket")
|
||||
expires_in = result.get("expires_in", 7200)
|
||||
|
||||
# 缓存
|
||||
self._cache.set(
|
||||
cache_key,
|
||||
ticket,
|
||||
ttl=min(expires_in - 200, settings.WECHAT_JSAPI_TICKET_EXPIRE)
|
||||
)
|
||||
|
||||
logger.info(f"Got new jsapi_ticket for {self.corp_id}")
|
||||
return ticket
|
||||
except Exception as e:
|
||||
logger.error(f"Get jsapi_ticket error: {e}")
|
||||
return None
|
||||
|
||||
async def get_jssdk_signature(
|
||||
self,
|
||||
url: str,
|
||||
noncestr: Optional[str] = None,
|
||||
timestamp: Optional[int] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""生成JS-SDK签名
|
||||
|
||||
Args:
|
||||
url: 当前页面URL(不含#及其后面部分)
|
||||
noncestr: 随机字符串,可选
|
||||
timestamp: 时间戳,可选
|
||||
|
||||
Returns:
|
||||
签名信息字典,包含signature, noncestr, timestamp, appId等
|
||||
"""
|
||||
ticket = await self.get_jsapi_ticket()
|
||||
if not ticket:
|
||||
return None
|
||||
|
||||
# 生成随机字符串和时间戳
|
||||
if noncestr is None:
|
||||
import secrets
|
||||
noncestr = secrets.token_hex(8)
|
||||
if timestamp is None:
|
||||
timestamp = int(time.time())
|
||||
|
||||
# 构建签名字符串
|
||||
sign_str = f"jsapi_ticket={ticket}&noncestr={noncestr}×tamp={timestamp}&url={url}"
|
||||
|
||||
# SHA1签名
|
||||
signature = hashlib.sha1(sign_str.encode()).hexdigest()
|
||||
|
||||
return {
|
||||
"appId": self.corp_id,
|
||||
"agentId": self.agent_id,
|
||||
"timestamp": timestamp,
|
||||
"nonceStr": noncestr,
|
||||
"signature": signature,
|
||||
"url": url
|
||||
}
|
||||
|
||||
def get_oauth2_url(
|
||||
self,
|
||||
redirect_uri: str,
|
||||
scope: str = "snsapi_base",
|
||||
state: str = ""
|
||||
) -> str:
|
||||
"""生成OAuth2授权URL
|
||||
|
||||
Args:
|
||||
redirect_uri: 授权后重定向的URL
|
||||
scope: 应用授权作用域
|
||||
- snsapi_base: 静默授权,只能获取成员基础信息
|
||||
- snsapi_privateinfo: 手动授权,可获取成员详细信息
|
||||
state: 重定向后会带上state参数
|
||||
|
||||
Returns:
|
||||
OAuth2授权URL
|
||||
"""
|
||||
import urllib.parse
|
||||
|
||||
encoded_uri = urllib.parse.quote(redirect_uri, safe='')
|
||||
|
||||
url = (
|
||||
f"https://open.weixin.qq.com/connect/oauth2/authorize"
|
||||
f"?appid={self.corp_id}"
|
||||
f"&redirect_uri={encoded_uri}"
|
||||
f"&response_type=code"
|
||||
f"&scope={scope}"
|
||||
f"&state={state}"
|
||||
f"&agentid={self.agent_id}"
|
||||
f"#wechat_redirect"
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
async def get_user_info_by_code(self, code: str) -> Optional[Dict[str, Any]]:
|
||||
"""通过OAuth2 code获取用户信息
|
||||
|
||||
Args:
|
||||
code: OAuth2回调返回的code
|
||||
|
||||
Returns:
|
||||
用户信息字典,包含UserId, DeviceId等
|
||||
"""
|
||||
access_token = await self.get_access_token()
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
url = f"{self.BASE_URL}/cgi-bin/auth/getuserinfo"
|
||||
params = {
|
||||
"access_token": access_token,
|
||||
"code": code
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(url, params=params)
|
||||
result = response.json()
|
||||
|
||||
if result.get("errcode", 0) != 0:
|
||||
logger.error(f"Get user info by code failed: {result}")
|
||||
return None
|
||||
|
||||
return {
|
||||
"user_id": result.get("userid") or result.get("UserId"),
|
||||
"device_id": result.get("deviceid") or result.get("DeviceId"),
|
||||
"open_id": result.get("openid") or result.get("OpenId"),
|
||||
"external_userid": result.get("external_userid"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Get user info by code error: {e}")
|
||||
return None
|
||||
|
||||
async def get_user_detail(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取成员详细信息
|
||||
|
||||
Args:
|
||||
user_id: 成员UserID
|
||||
|
||||
Returns:
|
||||
成员详细信息
|
||||
"""
|
||||
access_token = await self.get_access_token()
|
||||
if not access_token:
|
||||
return None
|
||||
|
||||
url = f"{self.BASE_URL}/cgi-bin/user/get"
|
||||
params = {
|
||||
"access_token": access_token,
|
||||
"userid": user_id
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
response = await client.get(url, params=params)
|
||||
result = response.json()
|
||||
|
||||
if result.get("errcode", 0) != 0:
|
||||
logger.error(f"Get user detail failed: {result}")
|
||||
return None
|
||||
|
||||
return {
|
||||
"userid": result.get("userid"),
|
||||
"name": result.get("name"),
|
||||
"department": result.get("department"),
|
||||
"position": result.get("position"),
|
||||
"mobile": result.get("mobile"),
|
||||
"email": result.get("email"),
|
||||
"avatar": result.get("avatar"),
|
||||
"status": result.get("status"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Get user detail error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_wechat_service_by_id(
|
||||
wechat_app_id: int,
|
||||
db_session
|
||||
) -> Optional[WechatService]:
|
||||
"""根据企微应用ID获取服务实例
|
||||
|
||||
Args:
|
||||
wechat_app_id: platform_tenant_wechat_apps表的ID
|
||||
db_session: 数据库session
|
||||
|
||||
Returns:
|
||||
WechatService实例或None
|
||||
"""
|
||||
from ..models.tenant_wechat_app import TenantWechatApp
|
||||
|
||||
wechat_app = db_session.query(TenantWechatApp).filter(
|
||||
TenantWechatApp.id == wechat_app_id,
|
||||
TenantWechatApp.status == 1
|
||||
).first()
|
||||
|
||||
if not wechat_app:
|
||||
return None
|
||||
|
||||
return WechatService.from_wechat_app(wechat_app)
|
||||
Reference in New Issue
Block a user