feat: 新增告警、成本、配额、微信模块及缓存服务
All checks were successful
continuous-integration/drone/push Build is passing

- 新增告警模块 (alerts): 告警规则配置与触发
- 新增成本管理模块 (cost): 成本统计与分析
- 新增配额模块 (quota): 配额管理与限制
- 新增微信模块 (wechat): 微信相关功能接口
- 新增缓存服务 (cache): Redis 缓存封装
- 新增请求日志中间件 (request_logger)
- 新增异常处理和链路追踪中间件
- 更新 dashboard 前端展示
- 更新 SDK stats_client 功能
This commit is contained in:
111
2026-01-24 16:53:47 +08:00
parent eab2533c36
commit 6c6c48cf71
29 changed files with 4607 additions and 41 deletions

View File

@@ -1,4 +1,22 @@
"""业务服务"""
from .crypto import encrypt_value, decrypt_value
from .cache import CacheService, get_cache, get_redis_client
from .wechat import WechatService, get_wechat_service_by_id
from .alert import AlertService
from .cost import CostCalculator, calculate_cost
from .quota import QuotaService, check_quota_middleware
__all__ = ["encrypt_value", "decrypt_value"]
__all__ = [
"encrypt_value",
"decrypt_value",
"CacheService",
"get_cache",
"get_redis_client",
"WechatService",
"get_wechat_service_by_id",
"AlertService",
"CostCalculator",
"calculate_cost",
"QuotaService",
"check_quota_middleware"
]

View File

@@ -0,0 +1,455 @@
"""告警服务"""
import logging
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
import httpx
from sqlalchemy.orm import Session
from sqlalchemy import func
from ..models.alert import AlertRule, AlertRecord, NotificationChannel
from ..models.stats import AICallEvent
from ..models.logs import PlatformLog
from .cache import get_cache
logger = logging.getLogger(__name__)
class AlertService:
"""告警服务
提供告警规则检测、告警记录管理、通知发送等功能
"""
def __init__(self, db: Session):
self.db = db
self._cache = get_cache()
async def check_all_rules(self) -> List[AlertRecord]:
"""检查所有启用的告警规则
Returns:
触发的告警记录列表
"""
rules = self.db.query(AlertRule).filter(AlertRule.status == 1).all()
triggered_alerts = []
for rule in rules:
try:
alert = await self.check_rule(rule)
if alert:
triggered_alerts.append(alert)
except Exception as e:
logger.error(f"Failed to check rule {rule.id}: {e}")
return triggered_alerts
async def check_rule(self, rule: AlertRule) -> Optional[AlertRecord]:
"""检查单个告警规则
Args:
rule: 告警规则
Returns:
触发的告警记录或None
"""
# 检查冷却期
if self._is_in_cooldown(rule):
logger.debug(f"Rule {rule.id} is in cooldown")
return None
# 检查每日告警次数限制
if self._exceeds_daily_limit(rule):
logger.debug(f"Rule {rule.id} exceeds daily limit")
return None
# 根据规则类型检查
metric_value = None
threshold_value = None
triggered = False
condition = rule.condition or {}
if rule.rule_type == 'error_rate':
triggered, metric_value, threshold_value = self._check_error_rate(rule, condition)
elif rule.rule_type == 'call_count':
triggered, metric_value, threshold_value = self._check_call_count(rule, condition)
elif rule.rule_type == 'token_usage':
triggered, metric_value, threshold_value = self._check_token_usage(rule, condition)
elif rule.rule_type == 'cost_threshold':
triggered, metric_value, threshold_value = self._check_cost_threshold(rule, condition)
elif rule.rule_type == 'latency':
triggered, metric_value, threshold_value = self._check_latency(rule, condition)
if triggered:
alert = self._create_alert_record(rule, metric_value, threshold_value)
return alert
return None
def _is_in_cooldown(self, rule: AlertRule) -> bool:
"""检查规则是否在冷却期"""
cache_key = f"alert:cooldown:{rule.id}"
return self._cache.exists(cache_key)
def _set_cooldown(self, rule: AlertRule):
"""设置规则冷却期"""
cache_key = f"alert:cooldown:{rule.id}"
self._cache.set(cache_key, "1", ttl=rule.cooldown_minutes * 60)
def _exceeds_daily_limit(self, rule: AlertRule) -> bool:
"""检查是否超过每日告警次数限制"""
today = datetime.now().date()
count = self.db.query(func.count(AlertRecord.id)).filter(
AlertRecord.rule_id == rule.id,
func.date(AlertRecord.created_at) == today
).scalar()
return count >= rule.max_alerts_per_day
def _check_error_rate(self, rule: AlertRule, condition: dict) -> tuple:
"""检查错误率"""
window_minutes = self._parse_window(condition.get('window', '5m'))
threshold = condition.get('threshold', 10) # 错误次数阈值
operator = condition.get('operator', '>')
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(func.count(AICallEvent.id)).filter(
AICallEvent.created_at >= since,
AICallEvent.status == 'error'
)
# 应用作用范围
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
error_count = query.scalar() or 0
triggered = self._compare(error_count, threshold, operator)
return triggered, str(error_count), str(threshold)
def _check_call_count(self, rule: AlertRule, condition: dict) -> tuple:
"""检查调用次数"""
window_minutes = self._parse_window(condition.get('window', '1h'))
threshold = condition.get('threshold', 1000)
operator = condition.get('operator', '>')
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(func.count(AICallEvent.id)).filter(
AICallEvent.created_at >= since
)
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
call_count = query.scalar() or 0
triggered = self._compare(call_count, threshold, operator)
return triggered, str(call_count), str(threshold)
def _check_token_usage(self, rule: AlertRule, condition: dict) -> tuple:
"""检查Token使用量"""
window_minutes = self._parse_window(condition.get('window', '1d'))
threshold = condition.get('threshold', 100000)
operator = condition.get('operator', '>')
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(
func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0)
).filter(
AICallEvent.created_at >= since
)
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
token_usage = query.scalar() or 0
triggered = self._compare(token_usage, threshold, operator)
return triggered, str(token_usage), str(threshold)
def _check_cost_threshold(self, rule: AlertRule, condition: dict) -> tuple:
"""检查费用阈值"""
window_minutes = self._parse_window(condition.get('window', '1d'))
threshold = condition.get('threshold', 100) # 费用阈值(元)
operator = condition.get('operator', '>')
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(
func.coalesce(func.sum(AICallEvent.cost), 0)
).filter(
AICallEvent.created_at >= since
)
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
total_cost = float(query.scalar() or 0)
triggered = self._compare(total_cost, threshold, operator)
return triggered, f"¥{total_cost:.2f}", f"¥{threshold:.2f}"
def _check_latency(self, rule: AlertRule, condition: dict) -> tuple:
"""检查延迟"""
window_minutes = self._parse_window(condition.get('window', '5m'))
threshold = condition.get('threshold', 5000) # 延迟阈值(ms)
operator = condition.get('operator', '>')
percentile = condition.get('percentile', 'avg') # avg, p95, p99, max
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(AICallEvent.latency_ms).filter(
AICallEvent.created_at >= since,
AICallEvent.latency_ms.isnot(None)
)
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
latencies = [r.latency_ms for r in query.all()]
if not latencies:
return False, "0", str(threshold)
if percentile == 'avg':
metric = sum(latencies) / len(latencies)
elif percentile == 'max':
metric = max(latencies)
elif percentile == 'p95':
latencies.sort()
idx = int(len(latencies) * 0.95)
metric = latencies[idx] if idx < len(latencies) else latencies[-1]
elif percentile == 'p99':
latencies.sort()
idx = int(len(latencies) * 0.99)
metric = latencies[idx] if idx < len(latencies) else latencies[-1]
else:
metric = sum(latencies) / len(latencies)
triggered = self._compare(metric, threshold, operator)
return triggered, f"{metric:.0f}ms", f"{threshold}ms"
def _parse_window(self, window: str) -> int:
"""解析时间窗口字符串为分钟数"""
if window.endswith('m'):
return int(window[:-1])
elif window.endswith('h'):
return int(window[:-1]) * 60
elif window.endswith('d'):
return int(window[:-1]) * 60 * 24
else:
return int(window)
def _compare(self, value: float, threshold: float, operator: str) -> bool:
"""比较值与阈值"""
if operator == '>':
return value > threshold
elif operator == '>=':
return value >= threshold
elif operator == '<':
return value < threshold
elif operator == '<=':
return value <= threshold
elif operator == '==':
return value == threshold
elif operator == '!=':
return value != threshold
return False
def _create_alert_record(
self,
rule: AlertRule,
metric_value: str,
threshold_value: str
) -> AlertRecord:
"""创建告警记录"""
title = f"[{rule.priority.upper()}] {rule.name}"
message = f"规则 '{rule.name}' 触发告警\n当前值: {metric_value}\n阈值: {threshold_value}"
if rule.scope_type == 'tenant':
message += f"\n租户: {rule.scope_value}"
elif rule.scope_type == 'app':
message += f"\n应用: {rule.scope_value}"
alert = AlertRecord(
rule_id=rule.id,
rule_name=rule.name,
alert_type=rule.rule_type,
severity=self._priority_to_severity(rule.priority),
title=title,
message=message,
tenant_id=rule.scope_value if rule.scope_type == 'tenant' else None,
app_code=rule.scope_value if rule.scope_type == 'app' else None,
metric_value=metric_value,
threshold_value=threshold_value,
notification_status='pending'
)
self.db.add(alert)
self.db.commit()
self.db.refresh(alert)
# 设置冷却期
self._set_cooldown(rule)
logger.info(f"Alert triggered: {title}")
return alert
def _priority_to_severity(self, priority: str) -> str:
"""将优先级转换为严重程度"""
mapping = {
'low': 'info',
'medium': 'warning',
'high': 'error',
'critical': 'critical'
}
return mapping.get(priority, 'warning')
async def send_notification(self, alert: AlertRecord, rule: AlertRule) -> bool:
"""发送告警通知
Args:
alert: 告警记录
rule: 告警规则
Returns:
是否发送成功
"""
if not rule.notification_channels:
alert.notification_status = 'skipped'
self.db.commit()
return True
results = []
success = True
for channel_config in rule.notification_channels:
try:
result = await self._send_to_channel(channel_config, alert)
results.append(result)
if not result.get('success'):
success = False
except Exception as e:
logger.error(f"Failed to send notification: {e}")
results.append({'success': False, 'error': str(e)})
success = False
alert.notification_status = 'sent' if success else 'failed'
alert.notification_result = results
alert.notified_at = datetime.now()
self.db.commit()
return success
async def _send_to_channel(self, channel_config: dict, alert: AlertRecord) -> dict:
"""发送到指定渠道"""
channel_type = channel_config.get('type')
if channel_type == 'wechat_bot':
return await self._send_wechat_bot(channel_config, alert)
elif channel_type == 'webhook':
return await self._send_webhook(channel_config, alert)
else:
return {'success': False, 'error': f'Unsupported channel type: {channel_type}'}
async def _send_wechat_bot(self, config: dict, alert: AlertRecord) -> dict:
"""发送到企微机器人"""
webhook = config.get('webhook')
if not webhook:
return {'success': False, 'error': 'Missing webhook URL'}
# 构建消息
content = f"**{alert.title}**\n\n{alert.message}\n\n时间: {alert.created_at}"
payload = {
"msgtype": "markdown",
"markdown": {
"content": content
}
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(webhook, json=payload)
result = response.json()
if result.get('errcode', 0) == 0:
return {'success': True, 'channel': 'wechat_bot'}
else:
return {'success': False, 'error': result.get('errmsg')}
except Exception as e:
return {'success': False, 'error': str(e)}
async def _send_webhook(self, config: dict, alert: AlertRecord) -> dict:
"""发送到Webhook"""
url = config.get('url')
if not url:
return {'success': False, 'error': 'Missing webhook URL'}
payload = {
"alert_id": alert.id,
"title": alert.title,
"message": alert.message,
"severity": alert.severity,
"alert_type": alert.alert_type,
"metric_value": alert.metric_value,
"threshold_value": alert.threshold_value,
"created_at": alert.created_at.isoformat()
}
headers = config.get('headers', {})
method = config.get('method', 'POST')
try:
async with httpx.AsyncClient(timeout=10) as client:
if method.upper() == 'POST':
response = await client.post(url, json=payload, headers=headers)
else:
response = await client.get(url, params=payload, headers=headers)
if response.status_code < 400:
return {'success': True, 'channel': 'webhook', 'status': response.status_code}
else:
return {'success': False, 'error': f'HTTP {response.status_code}'}
except Exception as e:
return {'success': False, 'error': str(e)}
def acknowledge_alert(self, alert_id: int, acknowledged_by: str) -> Optional[AlertRecord]:
"""确认告警"""
alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first()
if not alert:
return None
alert.status = 'acknowledged'
alert.acknowledged_by = acknowledged_by
alert.acknowledged_at = datetime.now()
self.db.commit()
return alert
def resolve_alert(self, alert_id: int) -> Optional[AlertRecord]:
"""解决告警"""
alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first()
if not alert:
return None
alert.status = 'resolved'
alert.resolved_at = datetime.now()
self.db.commit()
return alert

View File

@@ -0,0 +1,309 @@
"""Redis缓存服务"""
import json
import logging
from typing import Optional, Any, Union
from functools import lru_cache
try:
import redis
from redis import Redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
Redis = None
from ..config import get_settings
logger = logging.getLogger(__name__)
# 全局Redis连接池
_redis_pool: Optional[Any] = None
_redis_client: Optional[Any] = None
def get_redis_client() -> Optional[Any]:
"""获取Redis客户端单例"""
global _redis_pool, _redis_client
if not REDIS_AVAILABLE:
logger.warning("Redis module not installed, cache disabled")
return None
if _redis_client is not None:
return _redis_client
settings = get_settings()
try:
_redis_pool = redis.ConnectionPool.from_url(
settings.REDIS_URL,
max_connections=20,
decode_responses=True
)
_redis_client = Redis(connection_pool=_redis_pool)
# 测试连接
_redis_client.ping()
logger.info(f"Redis connected: {settings.REDIS_URL}")
return _redis_client
except Exception as e:
logger.warning(f"Redis connection failed: {e}, cache disabled")
_redis_client = None
return None
class CacheService:
"""缓存服务
提供统一的缓存接口支持Redis和内存回退
使用示例:
cache = CacheService()
# 设置缓存
cache.set("user:123", {"name": "test"}, ttl=3600)
# 获取缓存
user = cache.get("user:123")
# 删除缓存
cache.delete("user:123")
"""
def __init__(self, prefix: Optional[str] = None):
"""初始化缓存服务
Args:
prefix: 键前缀默认使用配置中的REDIS_PREFIX
"""
settings = get_settings()
self.prefix = prefix or settings.REDIS_PREFIX
self._client = get_redis_client()
# 内存回退缓存当Redis不可用时使用
self._memory_cache: dict = {}
@property
def is_available(self) -> bool:
"""Redis是否可用"""
return self._client is not None
def _make_key(self, key: str) -> str:
"""生成完整的缓存键"""
return f"{self.prefix}{key}"
def get(self, key: str, default: Any = None) -> Any:
"""获取缓存值
Args:
key: 缓存键
default: 默认值
Returns:
缓存值或默认值
"""
full_key = self._make_key(key)
if self._client:
try:
value = self._client.get(full_key)
if value is None:
return default
# 尝试解析JSON
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
except Exception as e:
logger.error(f"Cache get error: {e}")
return default
else:
# 内存回退
return self._memory_cache.get(full_key, default)
def set(
self,
key: str,
value: Any,
ttl: Optional[int] = None,
nx: bool = False
) -> bool:
"""设置缓存值
Args:
key: 缓存键
value: 缓存值
ttl: 过期时间(秒)
nx: 只在键不存在时设置
Returns:
是否设置成功
"""
full_key = self._make_key(key)
# 序列化值
if isinstance(value, (dict, list)):
serialized = json.dumps(value, ensure_ascii=False)
else:
serialized = str(value) if value is not None else ""
if self._client:
try:
if nx:
result = self._client.set(full_key, serialized, ex=ttl, nx=True)
else:
result = self._client.set(full_key, serialized, ex=ttl)
return bool(result)
except Exception as e:
logger.error(f"Cache set error: {e}")
return False
else:
# 内存回退不支持TTL和NX
if nx and full_key in self._memory_cache:
return False
self._memory_cache[full_key] = value
return True
def delete(self, key: str) -> bool:
"""删除缓存
Args:
key: 缓存键
Returns:
是否删除成功
"""
full_key = self._make_key(key)
if self._client:
try:
return bool(self._client.delete(full_key))
except Exception as e:
logger.error(f"Cache delete error: {e}")
return False
else:
return self._memory_cache.pop(full_key, None) is not None
def exists(self, key: str) -> bool:
"""检查键是否存在
Args:
key: 缓存键
Returns:
是否存在
"""
full_key = self._make_key(key)
if self._client:
try:
return bool(self._client.exists(full_key))
except Exception as e:
logger.error(f"Cache exists error: {e}")
return False
else:
return full_key in self._memory_cache
def ttl(self, key: str) -> int:
"""获取键的剩余过期时间
Args:
key: 缓存键
Returns:
剩余秒数,-1表示永不过期-2表示键不存在
"""
full_key = self._make_key(key)
if self._client:
try:
return self._client.ttl(full_key)
except Exception as e:
logger.error(f"Cache ttl error: {e}")
return -2
else:
return -1 if full_key in self._memory_cache else -2
def incr(self, key: str, amount: int = 1) -> int:
"""递增计数器
Args:
key: 缓存键
amount: 递增量
Returns:
递增后的值
"""
full_key = self._make_key(key)
if self._client:
try:
return self._client.incrby(full_key, amount)
except Exception as e:
logger.error(f"Cache incr error: {e}")
return 0
else:
current = self._memory_cache.get(full_key, 0)
new_value = int(current) + amount
self._memory_cache[full_key] = new_value
return new_value
def expire(self, key: str, ttl: int) -> bool:
"""设置键的过期时间
Args:
key: 缓存键
ttl: 过期时间(秒)
Returns:
是否设置成功
"""
full_key = self._make_key(key)
if self._client:
try:
return bool(self._client.expire(full_key, ttl))
except Exception as e:
logger.error(f"Cache expire error: {e}")
return False
else:
return full_key in self._memory_cache
def clear_prefix(self, prefix: str) -> int:
"""删除指定前缀的所有键
Args:
prefix: 键前缀
Returns:
删除的键数量
"""
full_prefix = self._make_key(prefix)
if self._client:
try:
keys = self._client.keys(f"{full_prefix}*")
if keys:
return self._client.delete(*keys)
return 0
except Exception as e:
logger.error(f"Cache clear_prefix error: {e}")
return 0
else:
count = 0
keys_to_delete = [k for k in self._memory_cache if k.startswith(full_prefix)]
for k in keys_to_delete:
del self._memory_cache[k]
count += 1
return count
# 全局缓存实例
_cache_instance: Optional[CacheService] = None
def get_cache() -> CacheService:
"""获取全局缓存实例"""
global _cache_instance
if _cache_instance is None:
_cache_instance = CacheService()
return _cache_instance

View File

@@ -0,0 +1,420 @@
"""费用计算服务"""
import logging
from datetime import datetime
from decimal import Decimal
from typing import Optional, Dict, List
from functools import lru_cache
from sqlalchemy.orm import Session
from sqlalchemy import func
from ..models.pricing import ModelPricing, TenantBilling
from ..models.stats import AICallEvent
from .cache import get_cache
logger = logging.getLogger(__name__)
class CostCalculator:
"""费用计算器
使用示例:
calculator = CostCalculator(db)
# 计算单次调用费用
cost = calculator.calculate_cost("gpt-4", input_tokens=100, output_tokens=200)
# 生成月度账单
billing = calculator.generate_monthly_billing("qiqi", "2026-01")
"""
# 默认模型价格(当数据库中无配置时使用)
DEFAULT_PRICING = {
# OpenAI
"gpt-4": {"input": 0.21, "output": 0.42}, # 元/1K tokens
"gpt-4-turbo": {"input": 0.07, "output": 0.21},
"gpt-4o": {"input": 0.035, "output": 0.105},
"gpt-4o-mini": {"input": 0.00105, "output": 0.0042},
"gpt-3.5-turbo": {"input": 0.0035, "output": 0.014},
# Anthropic
"claude-3-opus": {"input": 0.105, "output": 0.525},
"claude-3-sonnet": {"input": 0.021, "output": 0.105},
"claude-3-haiku": {"input": 0.00175, "output": 0.00875},
"claude-3.5-sonnet": {"input": 0.021, "output": 0.105},
# 国内模型
"qwen-max": {"input": 0.02, "output": 0.06},
"qwen-plus": {"input": 0.004, "output": 0.012},
"qwen-turbo": {"input": 0.002, "output": 0.006},
"glm-4": {"input": 0.01, "output": 0.01},
"glm-4-flash": {"input": 0.0001, "output": 0.0001},
"deepseek-chat": {"input": 0.001, "output": 0.002},
"deepseek-coder": {"input": 0.001, "output": 0.002},
# 默认
"default": {"input": 0.01, "output": 0.03}
}
def __init__(self, db: Session):
self.db = db
self._cache = get_cache()
self._pricing_cache: Dict[str, ModelPricing] = {}
def get_model_pricing(self, model_name: str) -> Optional[ModelPricing]:
"""获取模型价格配置
Args:
model_name: 模型名称
Returns:
ModelPricing实例或None
"""
# 尝试从缓存获取
cache_key = f"pricing:{model_name}"
cached = self._cache.get(cache_key)
if cached:
return self._dict_to_pricing(cached)
# 从数据库查询
pricing = self.db.query(ModelPricing).filter(
ModelPricing.model_name == model_name,
ModelPricing.status == 1
).first()
if pricing:
# 缓存1小时
self._cache.set(cache_key, self._pricing_to_dict(pricing), ttl=3600)
return pricing
return None
def _pricing_to_dict(self, pricing: ModelPricing) -> dict:
return {
"model_name": pricing.model_name,
"input_price_per_1k": str(pricing.input_price_per_1k),
"output_price_per_1k": str(pricing.output_price_per_1k),
"fixed_price_per_call": str(pricing.fixed_price_per_call),
"pricing_type": pricing.pricing_type
}
def _dict_to_pricing(self, d: dict) -> ModelPricing:
pricing = ModelPricing()
pricing.model_name = d.get("model_name")
pricing.input_price_per_1k = Decimal(d.get("input_price_per_1k", "0"))
pricing.output_price_per_1k = Decimal(d.get("output_price_per_1k", "0"))
pricing.fixed_price_per_call = Decimal(d.get("fixed_price_per_call", "0"))
pricing.pricing_type = d.get("pricing_type", "token")
return pricing
def calculate_cost(
self,
model_name: str,
input_tokens: int = 0,
output_tokens: int = 0,
call_count: int = 1
) -> Decimal:
"""计算调用费用
Args:
model_name: 模型名称
input_tokens: 输入token数
output_tokens: 输出token数
call_count: 调用次数
Returns:
费用(元)
"""
# 尝试获取数据库配置
pricing = self.get_model_pricing(model_name)
if pricing:
if pricing.pricing_type == 'call':
return pricing.fixed_price_per_call * call_count
elif pricing.pricing_type == 'hybrid':
token_cost = (
pricing.input_price_per_1k * Decimal(input_tokens) / 1000 +
pricing.output_price_per_1k * Decimal(output_tokens) / 1000
)
call_cost = pricing.fixed_price_per_call * call_count
return token_cost + call_cost
else: # token
return (
pricing.input_price_per_1k * Decimal(input_tokens) / 1000 +
pricing.output_price_per_1k * Decimal(output_tokens) / 1000
)
# 使用默认价格
default_prices = self.DEFAULT_PRICING.get(model_name) or self.DEFAULT_PRICING.get("default")
input_price = Decimal(str(default_prices["input"]))
output_price = Decimal(str(default_prices["output"]))
return (
input_price * Decimal(input_tokens) / 1000 +
output_price * Decimal(output_tokens) / 1000
)
def calculate_event_cost(self, event: AICallEvent) -> Decimal:
"""计算单个事件的费用
Args:
event: AI调用事件
Returns:
费用(元)
"""
return self.calculate_cost(
model_name=event.model or "default",
input_tokens=event.input_tokens or 0,
output_tokens=event.output_tokens or 0
)
def update_event_costs(self, start_date: str = None, end_date: str = None) -> int:
"""批量更新事件费用
对于cost为0或NULL的事件重新计算费用
Args:
start_date: 开始日期,格式 YYYY-MM-DD
end_date: 结束日期,格式 YYYY-MM-DD
Returns:
更新的记录数
"""
query = self.db.query(AICallEvent).filter(
(AICallEvent.cost == None) | (AICallEvent.cost == 0)
)
if start_date:
query = query.filter(AICallEvent.created_at >= start_date)
if end_date:
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
events = query.all()
updated = 0
for event in events:
try:
cost = self.calculate_event_cost(event)
event.cost = cost
updated += 1
except Exception as e:
logger.error(f"Failed to calculate cost for event {event.id}: {e}")
self.db.commit()
logger.info(f"Updated {updated} event costs")
return updated
def generate_monthly_billing(
self,
tenant_id: str,
billing_month: str
) -> TenantBilling:
"""生成月度账单
Args:
tenant_id: 租户ID
billing_month: 账单月份,格式 YYYY-MM
Returns:
TenantBilling实例
"""
# 检查是否已存在
existing = self.db.query(TenantBilling).filter(
TenantBilling.tenant_id == tenant_id,
TenantBilling.billing_month == billing_month
).first()
if existing:
billing = existing
else:
billing = TenantBilling(
tenant_id=tenant_id,
billing_month=billing_month
)
self.db.add(billing)
# 计算统计数据
start_date = f"{billing_month}-01"
year, month = billing_month.split("-")
if int(month) == 12:
end_date = f"{int(year)+1}-01-01"
else:
end_date = f"{year}-{int(month)+1:02d}-01"
# 聚合查询
stats = self.db.query(
func.count(AICallEvent.id).label('total_calls'),
func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'),
func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.created_at >= start_date,
AICallEvent.created_at < end_date
).first()
billing.total_calls = stats.total_calls or 0
billing.total_input_tokens = int(stats.total_input or 0)
billing.total_output_tokens = int(stats.total_output or 0)
billing.total_cost = stats.total_cost or Decimal("0")
# 按模型统计
model_stats = self.db.query(
AICallEvent.model,
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.created_at >= start_date,
AICallEvent.created_at < end_date
).group_by(AICallEvent.model).all()
billing.cost_by_model = {
m.model or "unknown": float(m.cost) for m in model_stats
}
# 按应用统计
app_stats = self.db.query(
AICallEvent.app_code,
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.created_at >= start_date,
AICallEvent.created_at < end_date
).group_by(AICallEvent.app_code).all()
billing.cost_by_app = {
a.app_code or "unknown": float(a.cost) for a in app_stats
}
self.db.commit()
self.db.refresh(billing)
return billing
def get_cost_summary(
self,
tenant_id: str = None,
start_date: str = None,
end_date: str = None
) -> Dict:
"""获取费用汇总
Args:
tenant_id: 租户ID可选
start_date: 开始日期
end_date: 结束日期
Returns:
费用汇总字典
"""
query = self.db.query(
func.count(AICallEvent.id).label('total_calls'),
func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'),
func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost')
)
if tenant_id:
query = query.filter(AICallEvent.tenant_id == tenant_id)
if start_date:
query = query.filter(AICallEvent.created_at >= start_date)
if end_date:
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
stats = query.first()
return {
"total_calls": stats.total_calls or 0,
"total_input_tokens": int(stats.total_input or 0),
"total_output_tokens": int(stats.total_output or 0),
"total_cost": float(stats.total_cost or 0)
}
def get_cost_by_tenant(
self,
start_date: str = None,
end_date: str = None
) -> List[Dict]:
"""按租户统计费用
Returns:
租户费用列表
"""
query = self.db.query(
AICallEvent.tenant_id,
func.count(AICallEvent.id).label('calls'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
)
if start_date:
query = query.filter(AICallEvent.created_at >= start_date)
if end_date:
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
results = query.group_by(AICallEvent.tenant_id).order_by(
func.sum(AICallEvent.cost).desc()
).all()
return [
{
"tenant_id": r.tenant_id,
"calls": r.calls,
"cost": float(r.cost)
}
for r in results
]
def get_cost_by_model(
self,
tenant_id: str = None,
start_date: str = None,
end_date: str = None
) -> List[Dict]:
"""按模型统计费用
Returns:
模型费用列表
"""
query = self.db.query(
AICallEvent.model,
func.count(AICallEvent.id).label('calls'),
func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('input_tokens'),
func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('output_tokens'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
)
if tenant_id:
query = query.filter(AICallEvent.tenant_id == tenant_id)
if start_date:
query = query.filter(AICallEvent.created_at >= start_date)
if end_date:
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
results = query.group_by(AICallEvent.model).order_by(
func.sum(AICallEvent.cost).desc()
).all()
return [
{
"model": r.model or "unknown",
"calls": r.calls,
"input_tokens": int(r.input_tokens),
"output_tokens": int(r.output_tokens),
"cost": float(r.cost)
}
for r in results
]
# 便捷函数
def calculate_cost(
db: Session,
model_name: str,
input_tokens: int = 0,
output_tokens: int = 0
) -> Decimal:
"""快速计算费用"""
calculator = CostCalculator(db)
return calculator.calculate_cost(model_name, input_tokens, output_tokens)

View File

@@ -0,0 +1,346 @@
"""配额管理服务"""
import logging
from datetime import datetime, date, timedelta
from typing import Optional, Dict, Any, Tuple
from dataclasses import dataclass
from sqlalchemy.orm import Session
from sqlalchemy import func
from ..models.tenant import Tenant, Subscription
from ..models.stats import AICallEvent
from .cache import get_cache
logger = logging.getLogger(__name__)
@dataclass
class QuotaConfig:
"""配额配置"""
daily_calls: int = 0 # 每日调用限制0表示无限制
daily_tokens: int = 0 # 每日Token限制
monthly_calls: int = 0 # 每月调用限制
monthly_tokens: int = 0 # 每月Token限制
monthly_cost: float = 0 # 每月费用限制(元)
concurrent_calls: int = 0 # 并发调用限制
@dataclass
class QuotaUsage:
"""配额使用情况"""
daily_calls: int = 0
daily_tokens: int = 0
monthly_calls: int = 0
monthly_tokens: int = 0
monthly_cost: float = 0
@dataclass
class QuotaCheckResult:
"""配额检查结果"""
allowed: bool
reason: Optional[str] = None
quota_type: Optional[str] = None
limit: int = 0
used: int = 0
remaining: int = 0
class QuotaService:
"""配额管理服务
使用示例:
quota_service = QuotaService(db)
# 检查配额
result = quota_service.check_quota("qiqi", "tools")
if not result.allowed:
raise HTTPException(status_code=429, detail=result.reason)
# 获取使用情况
usage = quota_service.get_usage("qiqi", "tools")
"""
# 默认配额(当无订阅配置时使用)
DEFAULT_QUOTA = QuotaConfig(
daily_calls=1000,
daily_tokens=100000,
monthly_calls=30000,
monthly_tokens=3000000,
monthly_cost=100
)
def __init__(self, db: Session):
self.db = db
self._cache = get_cache()
def get_subscription(self, tenant_id: str, app_code: str) -> Optional[Subscription]:
"""获取租户订阅配置"""
return self.db.query(Subscription).filter(
Subscription.tenant_id == tenant_id,
Subscription.app_code == app_code,
Subscription.status == 'active'
).first()
def get_quota_config(self, tenant_id: str, app_code: str) -> QuotaConfig:
"""获取配额配置
Args:
tenant_id: 租户ID
app_code: 应用代码
Returns:
QuotaConfig实例
"""
# 尝试从缓存获取
cache_key = f"quota:config:{tenant_id}:{app_code}"
cached = self._cache.get(cache_key)
if cached:
return QuotaConfig(**cached)
# 从订阅表获取
subscription = self.get_subscription(tenant_id, app_code)
if subscription and subscription.quota:
quota = subscription.quota
config = QuotaConfig(
daily_calls=quota.get('daily_calls', 0),
daily_tokens=quota.get('daily_tokens', 0),
monthly_calls=quota.get('monthly_calls', 0),
monthly_tokens=quota.get('monthly_tokens', 0),
monthly_cost=quota.get('monthly_cost', 0),
concurrent_calls=quota.get('concurrent_calls', 0)
)
else:
config = self.DEFAULT_QUOTA
# 缓存5分钟
self._cache.set(cache_key, config.__dict__, ttl=300)
return config
def get_usage(self, tenant_id: str, app_code: str) -> QuotaUsage:
"""获取配额使用情况
Args:
tenant_id: 租户ID
app_code: 应用代码
Returns:
QuotaUsage实例
"""
today = date.today()
month_start = today.replace(day=1)
# 今日使用量
daily_stats = self.db.query(
func.count(AICallEvent.id).label('calls'),
func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.app_code == app_code,
func.date(AICallEvent.created_at) == today
).first()
# 本月使用量
monthly_stats = self.db.query(
func.count(AICallEvent.id).label('calls'),
func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.app_code == app_code,
func.date(AICallEvent.created_at) >= month_start
).first()
return QuotaUsage(
daily_calls=daily_stats.calls or 0,
daily_tokens=int(daily_stats.tokens or 0),
monthly_calls=monthly_stats.calls or 0,
monthly_tokens=int(monthly_stats.tokens or 0),
monthly_cost=float(monthly_stats.cost or 0)
)
def check_quota(
self,
tenant_id: str,
app_code: str,
estimated_tokens: int = 0
) -> QuotaCheckResult:
"""检查配额是否足够
Args:
tenant_id: 租户ID
app_code: 应用代码
estimated_tokens: 预估Token消耗
Returns:
QuotaCheckResult实例
"""
config = self.get_quota_config(tenant_id, app_code)
usage = self.get_usage(tenant_id, app_code)
# 检查日调用次数
if config.daily_calls > 0:
if usage.daily_calls >= config.daily_calls:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每日调用限制 ({config.daily_calls} 次)",
quota_type="daily_calls",
limit=config.daily_calls,
used=usage.daily_calls,
remaining=0
)
# 检查日Token限制
if config.daily_tokens > 0:
if usage.daily_tokens + estimated_tokens > config.daily_tokens:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每日Token限制 ({config.daily_tokens:,})",
quota_type="daily_tokens",
limit=config.daily_tokens,
used=usage.daily_tokens,
remaining=max(0, config.daily_tokens - usage.daily_tokens)
)
# 检查月调用次数
if config.monthly_calls > 0:
if usage.monthly_calls >= config.monthly_calls:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每月调用限制 ({config.monthly_calls} 次)",
quota_type="monthly_calls",
limit=config.monthly_calls,
used=usage.monthly_calls,
remaining=0
)
# 检查月Token限制
if config.monthly_tokens > 0:
if usage.monthly_tokens + estimated_tokens > config.monthly_tokens:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每月Token限制 ({config.monthly_tokens:,})",
quota_type="monthly_tokens",
limit=config.monthly_tokens,
used=usage.monthly_tokens,
remaining=max(0, config.monthly_tokens - usage.monthly_tokens)
)
# 检查月费用限制
if config.monthly_cost > 0:
if usage.monthly_cost >= config.monthly_cost:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每月费用限制 (¥{config.monthly_cost:.2f})",
quota_type="monthly_cost",
limit=int(config.monthly_cost * 100), # 转为分
used=int(usage.monthly_cost * 100),
remaining=max(0, int((config.monthly_cost - usage.monthly_cost) * 100))
)
# 所有检查通过
return QuotaCheckResult(
allowed=True,
quota_type="daily_calls",
limit=config.daily_calls,
used=usage.daily_calls,
remaining=max(0, config.daily_calls - usage.daily_calls) if config.daily_calls > 0 else -1
)
def get_quota_summary(self, tenant_id: str, app_code: str) -> Dict[str, Any]:
"""获取配额汇总信息
Returns:
包含配额配置和使用情况的字典
"""
config = self.get_quota_config(tenant_id, app_code)
usage = self.get_usage(tenant_id, app_code)
def calc_percentage(used: int, limit: int) -> float:
if limit <= 0:
return 0
return min(100, round(used / limit * 100, 1))
return {
"config": {
"daily_calls": config.daily_calls,
"daily_tokens": config.daily_tokens,
"monthly_calls": config.monthly_calls,
"monthly_tokens": config.monthly_tokens,
"monthly_cost": config.monthly_cost
},
"usage": {
"daily_calls": usage.daily_calls,
"daily_tokens": usage.daily_tokens,
"monthly_calls": usage.monthly_calls,
"monthly_tokens": usage.monthly_tokens,
"monthly_cost": round(usage.monthly_cost, 2)
},
"percentage": {
"daily_calls": calc_percentage(usage.daily_calls, config.daily_calls),
"daily_tokens": calc_percentage(usage.daily_tokens, config.daily_tokens),
"monthly_calls": calc_percentage(usage.monthly_calls, config.monthly_calls),
"monthly_tokens": calc_percentage(usage.monthly_tokens, config.monthly_tokens),
"monthly_cost": calc_percentage(int(usage.monthly_cost * 100), int(config.monthly_cost * 100))
}
}
def update_quota(
self,
tenant_id: str,
app_code: str,
quota_config: Dict[str, Any]
) -> Subscription:
"""更新配额配置
Args:
tenant_id: 租户ID
app_code: 应用代码
quota_config: 配额配置字典
Returns:
更新后的Subscription实例
"""
subscription = self.get_subscription(tenant_id, app_code)
if not subscription:
# 创建新订阅
subscription = Subscription(
tenant_id=tenant_id,
app_code=app_code,
start_date=date.today(),
quota=quota_config,
status='active'
)
self.db.add(subscription)
else:
# 更新现有订阅
subscription.quota = quota_config
self.db.commit()
self.db.refresh(subscription)
# 清除缓存
cache_key = f"quota:config:{tenant_id}:{app_code}"
self._cache.delete(cache_key)
return subscription
def check_quota_middleware(
db: Session,
tenant_id: str,
app_code: str,
estimated_tokens: int = 0
) -> QuotaCheckResult:
"""配额检查中间件函数
可在路由中使用:
result = check_quota_middleware(db, "qiqi", "tools")
if not result.allowed:
raise HTTPException(status_code=429, detail=result.reason)
"""
service = QuotaService(db)
return service.check_quota(tenant_id, app_code, estimated_tokens)

View File

@@ -0,0 +1,371 @@
"""企业微信服务"""
import hashlib
import time
import logging
from typing import Optional, Dict, Any
from dataclasses import dataclass
import httpx
from ..config import get_settings
from .cache import get_cache
from .crypto import decrypt_config
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class WechatConfig:
"""企业微信应用配置"""
corp_id: str
agent_id: str
secret: str
class WechatService:
"""企业微信服务
提供access_token获取、JS-SDK签名、OAuth2等功能
使用示例:
wechat = WechatService(corp_id="wwxxxx", agent_id="1000001", secret="xxx")
# 获取access_token
token = await wechat.get_access_token()
# 获取JS-SDK签名
signature = await wechat.get_jssdk_signature("https://example.com/page")
"""
# 企业微信API基础URL
BASE_URL = "https://qyapi.weixin.qq.com"
def __init__(self, corp_id: str, agent_id: str, secret: str):
"""初始化企业微信服务
Args:
corp_id: 企业ID
agent_id: 应用AgentId
secret: 应用Secret明文
"""
self.corp_id = corp_id
self.agent_id = agent_id
self.secret = secret
self._cache = get_cache()
@classmethod
def from_wechat_app(cls, wechat_app) -> "WechatService":
"""从TenantWechatApp模型创建服务实例
Args:
wechat_app: TenantWechatApp数据库模型
Returns:
WechatService实例
"""
secret = ""
if wechat_app.secret_encrypted:
try:
secret = decrypt_config(wechat_app.secret_encrypted)
except Exception as e:
logger.error(f"Failed to decrypt secret: {e}")
return cls(
corp_id=wechat_app.corp_id,
agent_id=wechat_app.agent_id,
secret=secret
)
def _cache_key(self, key_type: str) -> str:
"""生成缓存键"""
return f"wechat:{self.corp_id}:{self.agent_id}:{key_type}"
async def get_access_token(self, force_refresh: bool = False) -> Optional[str]:
"""获取access_token
企业微信access_token有效期7200秒需要缓存
Args:
force_refresh: 是否强制刷新
Returns:
access_token或None
"""
cache_key = self._cache_key("access_token")
# 尝试从缓存获取
if not force_refresh:
cached = self._cache.get(cache_key)
if cached:
logger.debug(f"Access token from cache: {cached[:20]}...")
return cached
# 从企业微信API获取
url = f"{self.BASE_URL}/cgi-bin/gettoken"
params = {
"corpid": self.corp_id,
"corpsecret": self.secret
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url, params=params)
result = response.json()
if result.get("errcode", 0) != 0:
logger.error(f"Get access_token failed: {result}")
return None
access_token = result.get("access_token")
expires_in = result.get("expires_in", 7200)
# 缓存提前200秒过期以确保安全
self._cache.set(
cache_key,
access_token,
ttl=min(expires_in - 200, settings.WECHAT_ACCESS_TOKEN_EXPIRE)
)
logger.info(f"Got new access_token for {self.corp_id}")
return access_token
except Exception as e:
logger.error(f"Get access_token error: {e}")
return None
async def get_jsapi_ticket(self, force_refresh: bool = False) -> Optional[str]:
"""获取jsapi_ticket
用于生成JS-SDK签名
Args:
force_refresh: 是否强制刷新
Returns:
jsapi_ticket或None
"""
cache_key = self._cache_key("jsapi_ticket")
# 尝试从缓存获取
if not force_refresh:
cached = self._cache.get(cache_key)
if cached:
logger.debug(f"JSAPI ticket from cache: {cached[:20]}...")
return cached
# 先获取access_token
access_token = await self.get_access_token()
if not access_token:
return None
# 获取jsapi_ticket
url = f"{self.BASE_URL}/cgi-bin/get_jsapi_ticket"
params = {"access_token": access_token}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url, params=params)
result = response.json()
if result.get("errcode", 0) != 0:
logger.error(f"Get jsapi_ticket failed: {result}")
return None
ticket = result.get("ticket")
expires_in = result.get("expires_in", 7200)
# 缓存
self._cache.set(
cache_key,
ticket,
ttl=min(expires_in - 200, settings.WECHAT_JSAPI_TICKET_EXPIRE)
)
logger.info(f"Got new jsapi_ticket for {self.corp_id}")
return ticket
except Exception as e:
logger.error(f"Get jsapi_ticket error: {e}")
return None
async def get_jssdk_signature(
self,
url: str,
noncestr: Optional[str] = None,
timestamp: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""生成JS-SDK签名
Args:
url: 当前页面URL不含#及其后面部分)
noncestr: 随机字符串,可选
timestamp: 时间戳,可选
Returns:
签名信息字典包含signature, noncestr, timestamp, appId等
"""
ticket = await self.get_jsapi_ticket()
if not ticket:
return None
# 生成随机字符串和时间戳
if noncestr is None:
import secrets
noncestr = secrets.token_hex(8)
if timestamp is None:
timestamp = int(time.time())
# 构建签名字符串
sign_str = f"jsapi_ticket={ticket}&noncestr={noncestr}&timestamp={timestamp}&url={url}"
# SHA1签名
signature = hashlib.sha1(sign_str.encode()).hexdigest()
return {
"appId": self.corp_id,
"agentId": self.agent_id,
"timestamp": timestamp,
"nonceStr": noncestr,
"signature": signature,
"url": url
}
def get_oauth2_url(
self,
redirect_uri: str,
scope: str = "snsapi_base",
state: str = ""
) -> str:
"""生成OAuth2授权URL
Args:
redirect_uri: 授权后重定向的URL
scope: 应用授权作用域
- snsapi_base: 静默授权,只能获取成员基础信息
- snsapi_privateinfo: 手动授权,可获取成员详细信息
state: 重定向后会带上state参数
Returns:
OAuth2授权URL
"""
import urllib.parse
encoded_uri = urllib.parse.quote(redirect_uri, safe='')
url = (
f"https://open.weixin.qq.com/connect/oauth2/authorize"
f"?appid={self.corp_id}"
f"&redirect_uri={encoded_uri}"
f"&response_type=code"
f"&scope={scope}"
f"&state={state}"
f"&agentid={self.agent_id}"
f"#wechat_redirect"
)
return url
async def get_user_info_by_code(self, code: str) -> Optional[Dict[str, Any]]:
"""通过OAuth2 code获取用户信息
Args:
code: OAuth2回调返回的code
Returns:
用户信息字典包含UserId, DeviceId等
"""
access_token = await self.get_access_token()
if not access_token:
return None
url = f"{self.BASE_URL}/cgi-bin/auth/getuserinfo"
params = {
"access_token": access_token,
"code": code
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url, params=params)
result = response.json()
if result.get("errcode", 0) != 0:
logger.error(f"Get user info by code failed: {result}")
return None
return {
"user_id": result.get("userid") or result.get("UserId"),
"device_id": result.get("deviceid") or result.get("DeviceId"),
"open_id": result.get("openid") or result.get("OpenId"),
"external_userid": result.get("external_userid"),
}
except Exception as e:
logger.error(f"Get user info by code error: {e}")
return None
async def get_user_detail(self, user_id: str) -> Optional[Dict[str, Any]]:
"""获取成员详细信息
Args:
user_id: 成员UserID
Returns:
成员详细信息
"""
access_token = await self.get_access_token()
if not access_token:
return None
url = f"{self.BASE_URL}/cgi-bin/user/get"
params = {
"access_token": access_token,
"userid": user_id
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url, params=params)
result = response.json()
if result.get("errcode", 0) != 0:
logger.error(f"Get user detail failed: {result}")
return None
return {
"userid": result.get("userid"),
"name": result.get("name"),
"department": result.get("department"),
"position": result.get("position"),
"mobile": result.get("mobile"),
"email": result.get("email"),
"avatar": result.get("avatar"),
"status": result.get("status"),
}
except Exception as e:
logger.error(f"Get user detail error: {e}")
return None
async def get_wechat_service_by_id(
wechat_app_id: int,
db_session
) -> Optional[WechatService]:
"""根据企微应用ID获取服务实例
Args:
wechat_app_id: platform_tenant_wechat_apps表的ID
db_session: 数据库session
Returns:
WechatService实例或None
"""
from ..models.tenant_wechat_app import TenantWechatApp
wechat_app = db_session.query(TenantWechatApp).filter(
TenantWechatApp.id == wechat_app_id,
TenantWechatApp.status == 1
).first()
if not wechat_app:
return None
return WechatService.from_wechat_app(wechat_app)