All checks were successful
continuous-integration/drone/push Build is passing
- 新增告警模块 (alerts): 告警规则配置与触发 - 新增成本管理模块 (cost): 成本统计与分析 - 新增配额模块 (quota): 配额管理与限制 - 新增微信模块 (wechat): 微信相关功能接口 - 新增缓存服务 (cache): Redis 缓存封装 - 新增请求日志中间件 (request_logger) - 新增异常处理和链路追踪中间件 - 更新 dashboard 前端展示 - 更新 SDK stats_client 功能
330 lines
12 KiB
Python
330 lines
12 KiB
Python
"""AI统计上报客户端"""
|
||
import os
|
||
import json
|
||
import asyncio
|
||
import logging
|
||
import threading
|
||
from datetime import datetime
|
||
from decimal import Decimal
|
||
from typing import Optional, List
|
||
from dataclasses import dataclass, asdict
|
||
from pathlib import Path
|
||
|
||
try:
|
||
import httpx
|
||
HTTPX_AVAILABLE = True
|
||
except ImportError:
|
||
HTTPX_AVAILABLE = False
|
||
|
||
from .trace import get_trace_id, get_tenant_id, get_user_id
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class AICallEvent:
|
||
"""AI调用事件"""
|
||
tenant_id: int
|
||
app_code: str
|
||
module_code: str
|
||
prompt_name: str
|
||
model: str
|
||
input_tokens: int = 0
|
||
output_tokens: int = 0
|
||
cost: Decimal = Decimal("0")
|
||
latency_ms: int = 0
|
||
status: str = "success"
|
||
user_id: Optional[int] = None
|
||
trace_id: Optional[str] = None
|
||
event_time: datetime = None
|
||
|
||
def __post_init__(self):
|
||
if self.event_time is None:
|
||
self.event_time = datetime.now()
|
||
if self.trace_id is None:
|
||
self.trace_id = get_trace_id()
|
||
if self.user_id is None:
|
||
self.user_id = get_user_id()
|
||
|
||
def to_dict(self) -> dict:
|
||
"""转换为可序列化的字典"""
|
||
return {
|
||
"tenant_id": self.tenant_id,
|
||
"app_code": self.app_code,
|
||
"module_code": self.module_code,
|
||
"prompt_name": self.prompt_name,
|
||
"model": self.model,
|
||
"input_tokens": self.input_tokens,
|
||
"output_tokens": self.output_tokens,
|
||
"cost": str(self.cost),
|
||
"latency_ms": self.latency_ms,
|
||
"status": self.status,
|
||
"user_id": self.user_id,
|
||
"trace_id": self.trace_id,
|
||
"event_time": self.event_time.isoformat() if self.event_time else None
|
||
}
|
||
|
||
|
||
class StatsClient:
|
||
"""统计上报客户端
|
||
|
||
使用示例:
|
||
stats = StatsClient(tenant_id=1, app_code="011-ai-interview")
|
||
|
||
# 上报AI调用
|
||
stats.report_ai_call(
|
||
module_code="interview",
|
||
prompt_name="generate_question",
|
||
model="gpt-4",
|
||
input_tokens=100,
|
||
output_tokens=200,
|
||
latency_ms=1500
|
||
)
|
||
"""
|
||
|
||
# 失败事件持久化文件
|
||
FAILED_EVENTS_FILE = ".platform_failed_events.json"
|
||
|
||
def __init__(
|
||
self,
|
||
tenant_id: int,
|
||
app_code: str,
|
||
platform_url: Optional[str] = None,
|
||
api_key: Optional[str] = None,
|
||
local_only: bool = False,
|
||
max_retries: int = 3,
|
||
retry_delay: float = 1.0,
|
||
timeout: float = 10.0
|
||
):
|
||
self.tenant_id = tenant_id
|
||
self.app_code = app_code
|
||
self.platform_url = platform_url or os.getenv("PLATFORM_URL", "")
|
||
self.api_key = api_key or os.getenv("PLATFORM_API_KEY", "")
|
||
self.local_only = local_only or not self.platform_url or not HTTPX_AVAILABLE
|
||
self.max_retries = max_retries
|
||
self.retry_delay = retry_delay
|
||
self.timeout = timeout
|
||
|
||
# 批量上报缓冲区
|
||
self._buffer: List[AICallEvent] = []
|
||
self._buffer_size = 10 # 达到此数量时自动上报
|
||
self._lock = threading.Lock()
|
||
|
||
# 在启动时尝试发送之前失败的事件
|
||
if not self.local_only:
|
||
self._retry_failed_events()
|
||
|
||
def report_ai_call(
|
||
self,
|
||
module_code: str,
|
||
prompt_name: str,
|
||
model: str,
|
||
input_tokens: int = 0,
|
||
output_tokens: int = 0,
|
||
cost: Decimal = Decimal("0"),
|
||
latency_ms: int = 0,
|
||
status: str = "success",
|
||
user_id: Optional[int] = None,
|
||
flush: bool = False
|
||
) -> AICallEvent:
|
||
"""上报AI调用事件
|
||
|
||
Args:
|
||
module_code: 模块编码
|
||
prompt_name: Prompt名称
|
||
model: 模型名称
|
||
input_tokens: 输入token数
|
||
output_tokens: 输出token数
|
||
cost: 成本
|
||
latency_ms: 延迟毫秒
|
||
status: 状态 (success/error)
|
||
user_id: 用户ID(可选,默认从上下文获取)
|
||
flush: 是否立即发送
|
||
|
||
Returns:
|
||
创建的事件对象
|
||
"""
|
||
event = AICallEvent(
|
||
tenant_id=self.tenant_id,
|
||
app_code=self.app_code,
|
||
module_code=module_code,
|
||
prompt_name=prompt_name,
|
||
model=model,
|
||
input_tokens=input_tokens,
|
||
output_tokens=output_tokens,
|
||
cost=cost,
|
||
latency_ms=latency_ms,
|
||
status=status,
|
||
user_id=user_id
|
||
)
|
||
|
||
with self._lock:
|
||
self._buffer.append(event)
|
||
should_flush = flush or len(self._buffer) >= self._buffer_size
|
||
|
||
if should_flush:
|
||
self.flush()
|
||
|
||
return event
|
||
|
||
def flush(self):
|
||
"""发送缓冲区中的所有事件"""
|
||
with self._lock:
|
||
if not self._buffer:
|
||
return
|
||
events = self._buffer.copy()
|
||
self._buffer.clear()
|
||
|
||
if self.local_only:
|
||
# 本地模式:仅打印
|
||
for event in events:
|
||
logger.info(f"[STATS] {event.app_code}/{event.module_code}: "
|
||
f"{event.prompt_name} - {event.input_tokens}+{event.output_tokens} tokens")
|
||
else:
|
||
# 远程上报
|
||
self._send_to_platform(events)
|
||
|
||
def _send_to_platform(self, events: List[AICallEvent]):
|
||
"""发送事件到平台"""
|
||
if not HTTPX_AVAILABLE:
|
||
logger.warning("httpx not installed, falling back to local mode")
|
||
return
|
||
|
||
# 转换事件为可序列化格式
|
||
payload = {"events": [e.to_dict() for e in events]}
|
||
|
||
# 尝试在事件循环中运行
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
# 已在异步上下文中,创建任务
|
||
asyncio.create_task(self._send_async(payload, events))
|
||
except RuntimeError:
|
||
# 没有运行中的事件循环,使用同步方式
|
||
self._send_sync(payload, events)
|
||
|
||
def _send_sync(self, payload: dict, events: List[AICallEvent]):
|
||
"""同步发送事件"""
|
||
url = f"{self.platform_url.rstrip('/')}/api/stats/report/batch"
|
||
headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"}
|
||
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
with httpx.Client(timeout=self.timeout) as client:
|
||
response = client.post(url, json=payload, headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
logger.debug(f"Stats reported successfully: {result.get('count', len(events))} events")
|
||
return
|
||
else:
|
||
logger.warning(f"Stats report failed with status {response.status_code}: {response.text}")
|
||
except httpx.TimeoutException:
|
||
logger.warning(f"Stats report timeout (attempt {attempt + 1}/{self.max_retries})")
|
||
except httpx.RequestError as e:
|
||
logger.warning(f"Stats report request error (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||
except Exception as e:
|
||
logger.error(f"Stats report unexpected error: {e}")
|
||
break
|
||
|
||
# 重试延迟
|
||
if attempt < self.max_retries - 1:
|
||
import time
|
||
time.sleep(self.retry_delay * (attempt + 1))
|
||
|
||
# 所有重试都失败,持久化到文件
|
||
self._persist_failed_events(events)
|
||
|
||
async def _send_async(self, payload: dict, events: List[AICallEvent]):
|
||
"""异步发送事件"""
|
||
url = f"{self.platform_url.rstrip('/')}/api/stats/report/batch"
|
||
headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"}
|
||
|
||
for attempt in range(self.max_retries):
|
||
try:
|
||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||
response = await client.post(url, json=payload, headers=headers)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
logger.debug(f"Stats reported successfully: {result.get('count', len(events))} events")
|
||
return
|
||
else:
|
||
logger.warning(f"Stats report failed with status {response.status_code}: {response.text}")
|
||
except httpx.TimeoutException:
|
||
logger.warning(f"Stats report timeout (attempt {attempt + 1}/{self.max_retries})")
|
||
except httpx.RequestError as e:
|
||
logger.warning(f"Stats report request error (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||
except Exception as e:
|
||
logger.error(f"Stats report unexpected error: {e}")
|
||
break
|
||
|
||
# 重试延迟
|
||
if attempt < self.max_retries - 1:
|
||
await asyncio.sleep(self.retry_delay * (attempt + 1))
|
||
|
||
# 所有重试都失败,持久化到文件
|
||
self._persist_failed_events(events)
|
||
|
||
def _persist_failed_events(self, events: List[AICallEvent]):
|
||
"""持久化失败的事件到文件"""
|
||
try:
|
||
failed_file = Path(self.FAILED_EVENTS_FILE)
|
||
existing = []
|
||
|
||
if failed_file.exists():
|
||
try:
|
||
existing = json.loads(failed_file.read_text())
|
||
except (json.JSONDecodeError, IOError):
|
||
existing = []
|
||
|
||
# 添加新的失败事件
|
||
for event in events:
|
||
existing.append(event.to_dict())
|
||
|
||
# 限制最多保存1000条
|
||
if len(existing) > 1000:
|
||
existing = existing[-1000:]
|
||
|
||
failed_file.write_text(json.dumps(existing, ensure_ascii=False, indent=2))
|
||
logger.info(f"Persisted {len(events)} failed events to {self.FAILED_EVENTS_FILE}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to persist events: {e}")
|
||
|
||
def _retry_failed_events(self):
|
||
"""重试之前失败的事件"""
|
||
try:
|
||
failed_file = Path(self.FAILED_EVENTS_FILE)
|
||
if not failed_file.exists():
|
||
return
|
||
|
||
events_data = json.loads(failed_file.read_text())
|
||
if not events_data:
|
||
return
|
||
|
||
logger.info(f"Retrying {len(events_data)} previously failed events")
|
||
|
||
# 尝试发送
|
||
payload = {"events": events_data}
|
||
url = f"{self.platform_url.rstrip('/')}/api/stats/report/batch"
|
||
headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"}
|
||
|
||
try:
|
||
with httpx.Client(timeout=self.timeout) as client:
|
||
response = client.post(url, json=payload, headers=headers)
|
||
if response.status_code == 200:
|
||
# 成功后删除文件
|
||
failed_file.unlink()
|
||
logger.info(f"Successfully sent {len(events_data)} previously failed events")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to retry events: {e}")
|
||
except Exception as e:
|
||
logger.error(f"Error loading failed events: {e}")
|
||
|
||
def __del__(self):
|
||
"""析构时发送剩余事件"""
|
||
try:
|
||
if self._buffer:
|
||
self.flush()
|
||
except Exception:
|
||
pass # 忽略析构时的错误
|