feat: 新增告警、成本、配额、微信模块及缓存服务
All checks were successful
continuous-integration/drone/push Build is passing

- 新增告警模块 (alerts): 告警规则配置与触发
- 新增成本管理模块 (cost): 成本统计与分析
- 新增配额模块 (quota): 配额管理与限制
- 新增微信模块 (wechat): 微信相关功能接口
- 新增缓存服务 (cache): Redis 缓存封装
- 新增请求日志中间件 (request_logger)
- 新增异常处理和链路追踪中间件
- 更新 dashboard 前端展示
- 更新 SDK stats_client 功能
This commit is contained in:
111
2026-01-24 16:53:47 +08:00
parent eab2533c36
commit 6c6c48cf71
29 changed files with 4607 additions and 41 deletions

View File

@@ -2,6 +2,7 @@
import os
from functools import lru_cache
from pydantic_settings import BaseSettings
from typing import Optional
class Settings(BaseSettings):
@@ -14,6 +15,10 @@ class Settings(BaseSettings):
# 数据库
DATABASE_URL: str = "mysql+pymysql://scrm_reader:ScrmReader2024Pass@47.107.71.55:3306/new_qiqi"
# Redis
REDIS_URL: str = "redis://localhost:6379/0"
REDIS_PREFIX: str = "platform:"
# API Key内部服务调用
API_KEY: str = "platform_api_key_2026"
@@ -29,6 +34,10 @@ class Settings(BaseSettings):
# 配置加密密钥
CONFIG_ENCRYPT_KEY: str = "platform_config_key_32bytes!!"
# 企业微信配置
WECHAT_ACCESS_TOKEN_EXPIRE: int = 7000 # access_token缓存时间(秒)企微有效期7200秒
WECHAT_JSAPI_TICKET_EXPIRE: int = 7000 # jsapi_ticket缓存时间(秒)
class Config:
env_file = ".env"
extra = "ignore"

View File

@@ -1,4 +1,5 @@
"""平台服务入口"""
import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@@ -9,6 +10,15 @@ from .routers.tenants import router as tenants_router
from .routers.tenant_apps import router as tenant_apps_router
from .routers.tenant_wechat_apps import router as tenant_wechat_apps_router
from .routers.apps import router as apps_router
from .routers.wechat import router as wechat_router
from .routers.alerts import router as alerts_router
from .routers.cost import router as cost_router
from .routers.quota import router as quota_router
from .middleware import TraceMiddleware, setup_exception_handlers, RequestLoggerMiddleware
from .middleware.trace import setup_logging
# 配置日志(包含 TraceID
setup_logging(level=logging.INFO, include_trace=True)
settings = get_settings()
@@ -18,6 +28,20 @@ app = FastAPI(
description="平台基础设施服务 - 统计/日志/配置管理"
)
# 配置统一异常处理
setup_exception_handlers(app)
# 中间件按添加的反序执行,所以:
# 1. CORS 最后添加,最先执行
# 2. TraceMiddleware 在 RequestLoggerMiddleware 之后添加,这样先执行
# 3. RequestLoggerMiddleware 最先添加,最后执行(此时 trace_id 已设置)
# 请求日志中间件(自动记录到数据库)
app.add_middleware(RequestLoggerMiddleware, app_code="000-platform")
# TraceID 追踪中间件
app.add_middleware(TraceMiddleware, log_requests=True)
# CORS
app.add_middleware(
CORSMiddleware,
@@ -25,6 +49,7 @@ app.add_middleware(
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["X-Trace-ID", "X-Response-Time"]
)
# 注册路由
@@ -37,6 +62,10 @@ app.include_router(apps_router, prefix="/api")
app.include_router(stats_router, prefix="/api")
app.include_router(logs_router, prefix="/api")
app.include_router(config_router, prefix="/api")
app.include_router(wechat_router, prefix="/api")
app.include_router(alerts_router, prefix="/api")
app.include_router(cost_router, prefix="/api")
app.include_router(quota_router, prefix="/api")
@app.get("/")

View File

@@ -0,0 +1,19 @@
"""
中间件模块
提供:
- TraceID 追踪
- 统一异常处理
- 请求日志记录
"""
from .trace import TraceMiddleware, get_trace_id, set_trace_id
from .exception_handler import setup_exception_handlers
from .request_logger import RequestLoggerMiddleware
__all__ = [
"TraceMiddleware",
"get_trace_id",
"set_trace_id",
"setup_exception_handlers",
"RequestLoggerMiddleware"
]

View File

@@ -0,0 +1,128 @@
"""
统一异常处理
捕获所有异常,返回统一格式的错误响应,包含 TraceID。
"""
import logging
import traceback
from typing import Union
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
from .trace import get_trace_id
logger = logging.getLogger(__name__)
class ErrorCode:
"""错误码常量"""
BAD_REQUEST = "BAD_REQUEST"
UNAUTHORIZED = "UNAUTHORIZED"
FORBIDDEN = "FORBIDDEN"
NOT_FOUND = "NOT_FOUND"
VALIDATION_ERROR = "VALIDATION_ERROR"
RATE_LIMITED = "RATE_LIMITED"
INTERNAL_ERROR = "INTERNAL_ERROR"
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
GATEWAY_ERROR = "GATEWAY_ERROR"
STATUS_TO_ERROR_CODE = {
400: ErrorCode.BAD_REQUEST,
401: ErrorCode.UNAUTHORIZED,
403: ErrorCode.FORBIDDEN,
404: ErrorCode.NOT_FOUND,
422: ErrorCode.VALIDATION_ERROR,
429: ErrorCode.RATE_LIMITED,
500: ErrorCode.INTERNAL_ERROR,
502: ErrorCode.GATEWAY_ERROR,
503: ErrorCode.SERVICE_UNAVAILABLE,
}
def create_error_response(
status_code: int,
code: str,
message: str,
trace_id: str = None,
details: dict = None
) -> JSONResponse:
"""创建统一格式的错误响应"""
if trace_id is None:
trace_id = get_trace_id()
error_body = {
"code": code,
"message": message,
"trace_id": trace_id
}
if details:
error_body["details"] = details
return JSONResponse(
status_code=status_code,
content={"success": False, "error": error_body},
headers={"X-Trace-ID": trace_id}
)
async def http_exception_handler(request: Request, exc: Union[HTTPException, StarletteHTTPException]):
"""处理 HTTP 异常"""
trace_id = get_trace_id()
status_code = exc.status_code
error_code = STATUS_TO_ERROR_CODE.get(status_code, ErrorCode.INTERNAL_ERROR)
message = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
logger.warning(f"[{trace_id}] HTTP {status_code}: {message}")
return create_error_response(
status_code=status_code,
code=error_code,
message=message,
trace_id=trace_id
)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""处理请求验证错误"""
trace_id = get_trace_id()
errors = exc.errors()
error_messages = [f"{'.'.join(str(l) for l in e['loc'])}: {e['msg']}" for e in errors]
logger.warning(f"[{trace_id}] 验证错误: {error_messages}")
return create_error_response(
status_code=422,
code=ErrorCode.VALIDATION_ERROR,
message="请求参数验证失败",
trace_id=trace_id,
details={"validation_errors": error_messages}
)
async def generic_exception_handler(request: Request, exc: Exception):
"""处理所有未捕获的异常"""
trace_id = get_trace_id()
logger.error(f"[{trace_id}] 未捕获异常: {type(exc).__name__}: {exc}")
logger.error(f"[{trace_id}] 堆栈:\n{traceback.format_exc()}")
return create_error_response(
status_code=500,
code=ErrorCode.INTERNAL_ERROR,
message="服务器内部错误,请稍后重试",
trace_id=trace_id
)
def setup_exception_handlers(app: FastAPI):
"""配置 FastAPI 应用的异常处理器"""
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(StarletteHTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(Exception, generic_exception_handler)
logger.info("异常处理器已配置")

View File

@@ -0,0 +1,190 @@
"""
请求日志中间件
自动将所有请求记录到数据库 platform_logs 表
"""
import time
import logging
from typing import Optional, Set
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from .trace import get_trace_id
from ..database import SessionLocal
from ..models.logs import PlatformLog
logger = logging.getLogger(__name__)
class RequestLoggerMiddleware(BaseHTTPMiddleware):
"""请求日志中间件
自动记录所有请求到数据库,便于后续查询和分析
使用示例:
app.add_middleware(RequestLoggerMiddleware, app_code="000-platform")
"""
# 默认排除的路径(不记录这些请求)
DEFAULT_EXCLUDE_PATHS: Set[str] = {
"/",
"/docs",
"/redoc",
"/openapi.json",
"/api/health",
"/api/health/",
"/favicon.ico",
}
def __init__(
self,
app,
app_code: str = "platform",
exclude_paths: Optional[Set[str]] = None,
log_request_body: bool = False,
log_response_body: bool = False,
max_body_length: int = 1000
):
"""初始化中间件
Args:
app: FastAPI应用
app_code: 应用代码,记录到日志中
exclude_paths: 排除的路径集合,这些路径不记录日志
log_request_body: 是否记录请求体
log_response_body: 是否记录响应体
max_body_length: 记录体的最大长度
"""
super().__init__(app)
self.app_code = app_code
self.exclude_paths = exclude_paths or self.DEFAULT_EXCLUDE_PATHS
self.log_request_body = log_request_body
self.log_response_body = log_response_body
self.max_body_length = max_body_length
async def dispatch(self, request: Request, call_next) -> Response:
path = request.url.path
# 检查是否排除
if self._should_exclude(path):
return await call_next(request)
trace_id = get_trace_id()
method = request.method
start_time = time.time()
# 获取客户端IP
client_ip = self._get_client_ip(request)
# 获取租户ID从查询参数
tenant_id = request.query_params.get("tid") or request.query_params.get("tenant_id")
# 执行请求
response = None
error_message = None
status_code = 500
try:
response = await call_next(request)
status_code = response.status_code
except Exception as e:
error_message = str(e)
raise
finally:
duration_ms = int((time.time() - start_time) * 1000)
# 异步写入数据库(不阻塞响应)
try:
self._save_log(
trace_id=trace_id,
method=method,
path=path,
status_code=status_code,
duration_ms=duration_ms,
ip_address=client_ip,
tenant_id=tenant_id,
error_message=error_message
)
except Exception as e:
logger.error(f"Failed to save request log: {e}")
return response
def _should_exclude(self, path: str) -> bool:
"""检查路径是否应排除"""
# 精确匹配
if path in self.exclude_paths:
return True
# 前缀匹配(静态文件等)
exclude_prefixes = ["/static/", "/assets/", "/_next/"]
for prefix in exclude_prefixes:
if path.startswith(prefix):
return True
return False
def _get_client_ip(self, request: Request) -> str:
"""获取客户端真实IP"""
# 优先从代理头获取
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直连IP
if request.client:
return request.client.host
return "unknown"
def _save_log(
self,
trace_id: str,
method: str,
path: str,
status_code: int,
duration_ms: int,
ip_address: str,
tenant_id: Optional[str] = None,
error_message: Optional[str] = None
):
"""保存日志到数据库"""
from datetime import datetime
# 使用独立的数据库会话
db = SessionLocal()
try:
# 转换 tenant_id 为整数(如果是数字字符串)
tenant_id_int = None
if tenant_id:
try:
tenant_id_int = int(tenant_id)
except (ValueError, TypeError):
tenant_id_int = None
log_entry = PlatformLog(
log_type="request",
level="error" if status_code >= 500 else ("warn" if status_code >= 400 else "info"),
app_code=self.app_code,
tenant_id=tenant_id_int,
trace_id=trace_id,
message=f"{method} {path}" + (f" - {error_message}" if error_message else ""),
path=path,
method=method,
status_code=status_code,
duration_ms=duration_ms,
log_time=datetime.now(), # 必须设置 log_time
context={"ip": ip_address} # ip_address 放到 context 中
)
db.add(log_entry)
db.commit()
except Exception as e:
logger.error(f"Database error saving log: {e}")
db.rollback()
finally:
db.close()

View File

@@ -0,0 +1,114 @@
"""
TraceID 追踪中间件
为每个请求生成唯一的 TraceID用于日志追踪和问题排查。
功能:
- 自动生成 TraceID或从请求头获取
- 注入到响应头 X-Trace-ID
- 提供上下文变量供日志使用
- 支持请求耗时统计
"""
import time
import uuid
import logging
from contextvars import ContextVar
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
logger = logging.getLogger(__name__)
# 上下文变量存储当前请求的 TraceID
_trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None)
# 请求头名称
TRACE_ID_HEADER = "X-Trace-ID"
REQUEST_ID_HEADER = "X-Request-ID"
def get_trace_id() -> str:
"""获取当前请求的 TraceID"""
trace_id = _trace_id_var.get()
return trace_id if trace_id else "no-trace"
def set_trace_id(trace_id: str) -> None:
"""设置当前请求的 TraceID"""
_trace_id_var.set(trace_id)
def generate_trace_id() -> str:
"""生成新的 TraceID格式: 时间戳-随机8位"""
timestamp = int(time.time())
random_part = uuid.uuid4().hex[:8]
return f"{timestamp}-{random_part}"
class TraceMiddleware(BaseHTTPMiddleware):
"""TraceID 追踪中间件"""
def __init__(self, app, log_requests: bool = True):
super().__init__(app)
self.log_requests = log_requests
async def dispatch(self, request: Request, call_next) -> Response:
# 从请求头获取 TraceID或生成新的
trace_id = (
request.headers.get(TRACE_ID_HEADER) or
request.headers.get(REQUEST_ID_HEADER) or
generate_trace_id()
)
set_trace_id(trace_id)
start_time = time.time()
method = request.method
path = request.url.path
if self.log_requests:
logger.info(f"[{trace_id}] --> {method} {path}")
try:
response = await call_next(request)
duration_ms = int((time.time() - start_time) * 1000)
response.headers[TRACE_ID_HEADER] = trace_id
response.headers["X-Response-Time"] = f"{duration_ms}ms"
if self.log_requests:
logger.info(f"[{trace_id}] <-- {response.status_code} ({duration_ms}ms)")
return response
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(f"[{trace_id}] !!! 请求异常: {e} ({duration_ms}ms)")
raise
class TraceLogFilter(logging.Filter):
"""日志过滤器:自动添加 TraceID"""
def filter(self, record):
record.trace_id = get_trace_id()
return True
def setup_logging(level: int = logging.INFO, include_trace: bool = True):
"""配置日志格式"""
if include_trace:
format_str = "%(asctime)s [%(trace_id)s] %(levelname)s %(name)s: %(message)s"
else:
format_str = "%(asctime)s %(levelname)s %(name)s: %(message)s"
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(format_str, datefmt="%Y-%m-%d %H:%M:%S"))
if include_trace:
handler.addFilter(TraceLogFilter())
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.handlers = [handler]

View File

@@ -5,6 +5,8 @@ from .tenant_wechat_app import TenantWechatApp
from .app import App
from .stats import AICallEvent, TenantUsageDaily
from .logs import PlatformLog
from .alert import AlertRule, AlertRecord, NotificationChannel
from .pricing import ModelPricing, TenantBilling
__all__ = [
"Tenant",
@@ -15,5 +17,10 @@ __all__ = [
"App",
"AICallEvent",
"TenantUsageDaily",
"PlatformLog"
"PlatformLog",
"AlertRule",
"AlertRecord",
"NotificationChannel",
"ModelPricing",
"TenantBilling"
]

108
backend/app/models/alert.py Normal file
View File

@@ -0,0 +1,108 @@
"""告警相关模型"""
from datetime import datetime
from sqlalchemy import Column, Integer, BigInteger, String, Text, Enum, SmallInteger, JSON, TIMESTAMP
from ..database import Base
class AlertRule(Base):
"""告警规则表"""
__tablename__ = "platform_alert_rules"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False) # 规则名称
description = Column(Text) # 规则描述
# 规则类型
rule_type = Column(Enum(
'error_rate', # 错误率告警
'call_count', # 调用次数告警
'token_usage', # Token使用量告警
'cost_threshold', # 费用阈值告警
'latency', # 延迟告警
'custom' # 自定义告警
), nullable=False)
# 作用范围
scope_type = Column(Enum('global', 'tenant', 'app'), default='global') # 作用范围类型
scope_value = Column(String(100)) # 作用范围值如租户ID或应用代码
# 告警条件
condition = Column(JSON, nullable=False) # 告警条件配置
# 示例: {"metric": "error_count", "operator": ">", "threshold": 10, "window": "5m"}
# 通知配置
notification_channels = Column(JSON) # 通知渠道列表
# 示例: [{"type": "wechat_bot", "webhook": "https://..."}, {"type": "email", "to": ["a@b.com"]}]
# 告警限制
cooldown_minutes = Column(Integer, default=30) # 冷却时间(分钟),避免重复告警
max_alerts_per_day = Column(Integer, default=10) # 每天最大告警次数
# 状态
status = Column(SmallInteger, default=1) # 0-禁用 1-启用
priority = Column(Enum('low', 'medium', 'high', 'critical'), default='medium') # 优先级
created_at = Column(TIMESTAMP, default=datetime.now)
updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now)
class AlertRecord(Base):
"""告警记录表"""
__tablename__ = "platform_alert_records"
id = Column(BigInteger, primary_key=True, autoincrement=True)
rule_id = Column(Integer, nullable=False, index=True) # 关联的规则ID
rule_name = Column(String(100)) # 规则名称(冗余,便于查询)
# 告警信息
alert_type = Column(String(50), nullable=False) # 告警类型
severity = Column(Enum('info', 'warning', 'error', 'critical'), default='warning') # 严重程度
title = Column(String(200), nullable=False) # 告警标题
message = Column(Text) # 告警详情
# 上下文
tenant_id = Column(String(50), index=True) # 相关租户
app_code = Column(String(50)) # 相关应用
metric_value = Column(String(100)) # 触发告警的指标值
threshold_value = Column(String(100)) # 阈值
# 通知状态
notification_status = Column(Enum('pending', 'sent', 'failed', 'skipped'), default='pending')
notification_result = Column(JSON) # 通知结果
notified_at = Column(TIMESTAMP) # 通知时间
# 处理状态
status = Column(Enum('active', 'acknowledged', 'resolved', 'ignored'), default='active')
acknowledged_by = Column(String(100)) # 确认人
acknowledged_at = Column(TIMESTAMP) # 确认时间
resolved_at = Column(TIMESTAMP) # 解决时间
created_at = Column(TIMESTAMP, default=datetime.now, index=True)
class NotificationChannel(Base):
"""通知渠道配置表"""
__tablename__ = "platform_notification_channels"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False) # 渠道名称
channel_type = Column(Enum(
'wechat_bot', # 企微机器人
'email', # 邮件
'sms', # 短信
'webhook', # Webhook
'dingtalk' # 钉钉
), nullable=False)
# 渠道配置
config = Column(JSON, nullable=False)
# wechat_bot: {"webhook": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx"}
# email: {"smtp_host": "...", "smtp_port": 465, "username": "...", "password_encrypted": "..."}
# webhook: {"url": "https://...", "method": "POST", "headers": {...}}
# 状态
status = Column(SmallInteger, default=1) # 0-禁用 1-启用
created_at = Column(TIMESTAMP, default=datetime.now)
updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now)

View File

@@ -0,0 +1,70 @@
"""费用计算相关模型"""
from datetime import datetime
from decimal import Decimal
from sqlalchemy import Column, Integer, String, Text, DECIMAL, SmallInteger, JSON, TIMESTAMP
from ..database import Base
class ModelPricing(Base):
"""模型价格配置表"""
__tablename__ = "platform_model_pricing"
id = Column(Integer, primary_key=True, autoincrement=True)
# 模型标识
model_name = Column(String(100), nullable=False, unique=True) # 模型名称,如 gpt-4, claude-3-opus
provider = Column(String(50)) # 提供商,如 openai, anthropic, 4sapi
display_name = Column(String(100)) # 显示名称
# 价格配置(单位:元/1K tokens
input_price_per_1k = Column(DECIMAL(10, 6), default=0) # 输入价格
output_price_per_1k = Column(DECIMAL(10, 6), default=0) # 输出价格
# 或固定价格(每次调用)
fixed_price_per_call = Column(DECIMAL(10, 6), default=0)
# 计费方式
pricing_type = Column(String(20), default='token') # token / call / hybrid
# 备注
description = Column(Text)
status = Column(SmallInteger, default=1) # 0-禁用 1-启用
created_at = Column(TIMESTAMP, default=datetime.now)
updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now)
class TenantBilling(Base):
"""租户账单表(月度汇总)"""
__tablename__ = "platform_tenant_billing"
id = Column(Integer, primary_key=True, autoincrement=True)
tenant_id = Column(String(50), nullable=False, index=True)
billing_month = Column(String(7), nullable=False) # 格式: YYYY-MM
# 使用量统计
total_calls = Column(Integer, default=0) # 总调用次数
total_input_tokens = Column(Integer, default=0) # 总输入token
total_output_tokens = Column(Integer, default=0) # 总输出token
# 费用统计
total_cost = Column(DECIMAL(12, 4), default=0) # 总费用
# 按模型分类的费用明细
cost_by_model = Column(JSON) # {"gpt-4": 10.5, "claude-3": 5.2}
# 按应用分类的费用明细
cost_by_app = Column(JSON) # {"tools": 8.0, "interview": 7.7}
# 状态
status = Column(String(20), default='pending') # pending / confirmed / paid
created_at = Column(TIMESTAMP, default=datetime.now)
updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now)
class Config:
# 联合唯一索引
__table_args__ = (
{'mysql_charset': 'utf8mb4'}
)

View File

@@ -0,0 +1,430 @@
"""告警管理路由"""
from typing import Optional, List
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy import desc, func
from ..database import get_db
from ..models.alert import AlertRule, AlertRecord, NotificationChannel
from ..services.alert import AlertService
from .auth import get_current_user, require_operator
from ..models.user import User
router = APIRouter(prefix="/alerts", tags=["告警管理"])
# ============= Schemas =============
class AlertRuleCreate(BaseModel):
name: str
description: Optional[str] = None
rule_type: str
scope_type: str = "global"
scope_value: Optional[str] = None
condition: dict
notification_channels: Optional[List[dict]] = None
cooldown_minutes: int = 30
max_alerts_per_day: int = 10
priority: str = "medium"
class AlertRuleUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
condition: Optional[dict] = None
notification_channels: Optional[List[dict]] = None
cooldown_minutes: Optional[int] = None
max_alerts_per_day: Optional[int] = None
priority: Optional[str] = None
status: Optional[int] = None
class NotificationChannelCreate(BaseModel):
name: str
channel_type: str
config: dict
class NotificationChannelUpdate(BaseModel):
name: Optional[str] = None
config: Optional[dict] = None
status: Optional[int] = None
# ============= Alert Rules API =============
@router.get("/rules")
async def list_alert_rules(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
rule_type: Optional[str] = None,
status: Optional[int] = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取告警规则列表"""
query = db.query(AlertRule)
if rule_type:
query = query.filter(AlertRule.rule_type == rule_type)
if status is not None:
query = query.filter(AlertRule.status == status)
total = query.count()
rules = query.order_by(desc(AlertRule.created_at)).offset((page - 1) * size).limit(size).all()
return {
"total": total,
"page": page,
"size": size,
"items": [format_rule(r) for r in rules]
}
@router.get("/rules/{rule_id}")
async def get_alert_rule(
rule_id: int,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取告警规则详情"""
rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first()
if not rule:
raise HTTPException(status_code=404, detail="告警规则不存在")
return format_rule(rule)
@router.post("/rules")
async def create_alert_rule(
data: AlertRuleCreate,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""创建告警规则"""
rule = AlertRule(
name=data.name,
description=data.description,
rule_type=data.rule_type,
scope_type=data.scope_type,
scope_value=data.scope_value,
condition=data.condition,
notification_channels=data.notification_channels,
cooldown_minutes=data.cooldown_minutes,
max_alerts_per_day=data.max_alerts_per_day,
priority=data.priority,
status=1
)
db.add(rule)
db.commit()
db.refresh(rule)
return {"success": True, "id": rule.id}
@router.put("/rules/{rule_id}")
async def update_alert_rule(
rule_id: int,
data: AlertRuleUpdate,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""更新告警规则"""
rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first()
if not rule:
raise HTTPException(status_code=404, detail="告警规则不存在")
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(rule, key, value)
db.commit()
return {"success": True}
@router.delete("/rules/{rule_id}")
async def delete_alert_rule(
rule_id: int,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""删除告警规则"""
rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first()
if not rule:
raise HTTPException(status_code=404, detail="告警规则不存在")
db.delete(rule)
db.commit()
return {"success": True}
# ============= Alert Records API =============
@router.get("/records")
async def list_alert_records(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
status: Optional[str] = None,
severity: Optional[str] = None,
alert_type: Optional[str] = None,
tenant_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取告警记录列表"""
query = db.query(AlertRecord)
if status:
query = query.filter(AlertRecord.status == status)
if severity:
query = query.filter(AlertRecord.severity == severity)
if alert_type:
query = query.filter(AlertRecord.alert_type == alert_type)
if tenant_id:
query = query.filter(AlertRecord.tenant_id == tenant_id)
if start_date:
query = query.filter(AlertRecord.created_at >= start_date)
if end_date:
query = query.filter(AlertRecord.created_at <= end_date + " 23:59:59")
total = query.count()
records = query.order_by(desc(AlertRecord.created_at)).offset((page - 1) * size).limit(size).all()
return {
"total": total,
"page": page,
"size": size,
"items": [format_record(r) for r in records]
}
@router.get("/records/summary")
async def get_alert_summary(
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取告警摘要统计"""
today = datetime.now().date()
week_start = today - timedelta(days=7)
# 今日告警数
today_count = db.query(func.count(AlertRecord.id)).filter(
func.date(AlertRecord.created_at) == today
).scalar()
# 本周告警数
week_count = db.query(func.count(AlertRecord.id)).filter(
func.date(AlertRecord.created_at) >= week_start
).scalar()
# 活跃告警数
active_count = db.query(func.count(AlertRecord.id)).filter(
AlertRecord.status == 'active'
).scalar()
# 按严重程度统计
severity_stats = db.query(
AlertRecord.severity,
func.count(AlertRecord.id)
).filter(
func.date(AlertRecord.created_at) >= week_start
).group_by(AlertRecord.severity).all()
return {
"today_count": today_count,
"week_count": week_count,
"active_count": active_count,
"by_severity": {s: c for s, c in severity_stats}
}
@router.get("/records/{record_id}")
async def get_alert_record(
record_id: int,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取告警记录详情"""
record = db.query(AlertRecord).filter(AlertRecord.id == record_id).first()
if not record:
raise HTTPException(status_code=404, detail="告警记录不存在")
return format_record(record)
@router.post("/records/{record_id}/acknowledge")
async def acknowledge_alert(
record_id: int,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""确认告警"""
service = AlertService(db)
record = service.acknowledge_alert(record_id, user.username)
if not record:
raise HTTPException(status_code=404, detail="告警记录不存在")
return {"success": True}
@router.post("/records/{record_id}/resolve")
async def resolve_alert(
record_id: int,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""解决告警"""
service = AlertService(db)
record = service.resolve_alert(record_id)
if not record:
raise HTTPException(status_code=404, detail="告警记录不存在")
return {"success": True}
# ============= Check Alerts API =============
@router.post("/check")
async def trigger_alert_check(
background_tasks: BackgroundTasks,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""手动触发告警检查"""
service = AlertService(db)
alerts = await service.check_all_rules()
# 异步发送通知
for alert in alerts:
rule = db.query(AlertRule).filter(AlertRule.id == alert.rule_id).first()
if rule:
background_tasks.add_task(service.send_notification, alert, rule)
return {
"success": True,
"triggered_count": len(alerts),
"alerts": [format_record(a) for a in alerts]
}
# ============= Notification Channels API =============
@router.get("/channels")
async def list_notification_channels(
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取通知渠道列表"""
channels = db.query(NotificationChannel).order_by(desc(NotificationChannel.created_at)).all()
return [format_channel(c) for c in channels]
@router.post("/channels")
async def create_notification_channel(
data: NotificationChannelCreate,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""创建通知渠道"""
channel = NotificationChannel(
name=data.name,
channel_type=data.channel_type,
config=data.config,
status=1
)
db.add(channel)
db.commit()
db.refresh(channel)
return {"success": True, "id": channel.id}
@router.put("/channels/{channel_id}")
async def update_notification_channel(
channel_id: int,
data: NotificationChannelUpdate,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""更新通知渠道"""
channel = db.query(NotificationChannel).filter(NotificationChannel.id == channel_id).first()
if not channel:
raise HTTPException(status_code=404, detail="通知渠道不存在")
update_data = data.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(channel, key, value)
db.commit()
return {"success": True}
@router.delete("/channels/{channel_id}")
async def delete_notification_channel(
channel_id: int,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""删除通知渠道"""
channel = db.query(NotificationChannel).filter(NotificationChannel.id == channel_id).first()
if not channel:
raise HTTPException(status_code=404, detail="通知渠道不存在")
db.delete(channel)
db.commit()
return {"success": True}
# ============= Helper Functions =============
def format_rule(rule: AlertRule) -> dict:
return {
"id": rule.id,
"name": rule.name,
"description": rule.description,
"rule_type": rule.rule_type,
"scope_type": rule.scope_type,
"scope_value": rule.scope_value,
"condition": rule.condition,
"notification_channels": rule.notification_channels,
"cooldown_minutes": rule.cooldown_minutes,
"max_alerts_per_day": rule.max_alerts_per_day,
"priority": rule.priority,
"status": rule.status,
"created_at": rule.created_at,
"updated_at": rule.updated_at
}
def format_record(record: AlertRecord) -> dict:
return {
"id": record.id,
"rule_id": record.rule_id,
"rule_name": record.rule_name,
"alert_type": record.alert_type,
"severity": record.severity,
"title": record.title,
"message": record.message,
"tenant_id": record.tenant_id,
"app_code": record.app_code,
"metric_value": record.metric_value,
"threshold_value": record.threshold_value,
"notification_status": record.notification_status,
"status": record.status,
"acknowledged_by": record.acknowledged_by,
"acknowledged_at": record.acknowledged_at,
"resolved_at": record.resolved_at,
"created_at": record.created_at
}
def format_channel(channel: NotificationChannel) -> dict:
return {
"id": channel.id,
"name": channel.name,
"channel_type": channel.channel_type,
"config": channel.config,
"status": channel.status,
"created_at": channel.created_at,
"updated_at": channel.updated_at
}

333
backend/app/routers/cost.py Normal file
View File

@@ -0,0 +1,333 @@
"""费用管理路由"""
from typing import Optional, List
from decimal import Decimal
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy import desc
from ..database import get_db
from ..models.pricing import ModelPricing, TenantBilling
from ..services.cost import CostCalculator
from .auth import get_current_user, require_operator
from ..models.user import User
router = APIRouter(prefix="/cost", tags=["费用管理"])
# ============= Schemas =============
class ModelPricingCreate(BaseModel):
model_name: str
provider: Optional[str] = None
display_name: Optional[str] = None
input_price_per_1k: float = 0
output_price_per_1k: float = 0
fixed_price_per_call: float = 0
pricing_type: str = "token"
description: Optional[str] = None
class ModelPricingUpdate(BaseModel):
provider: Optional[str] = None
display_name: Optional[str] = None
input_price_per_1k: Optional[float] = None
output_price_per_1k: Optional[float] = None
fixed_price_per_call: Optional[float] = None
pricing_type: Optional[str] = None
description: Optional[str] = None
status: Optional[int] = None
class CostCalculateRequest(BaseModel):
model_name: str
input_tokens: int = 0
output_tokens: int = 0
# ============= Model Pricing API =============
@router.get("/pricing")
async def list_model_pricing(
page: int = Query(1, ge=1),
size: int = Query(50, ge=1, le=100),
provider: Optional[str] = None,
status: Optional[int] = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取模型价格配置列表"""
query = db.query(ModelPricing)
if provider:
query = query.filter(ModelPricing.provider == provider)
if status is not None:
query = query.filter(ModelPricing.status == status)
total = query.count()
items = query.order_by(ModelPricing.model_name).offset((page - 1) * size).limit(size).all()
return {
"total": total,
"page": page,
"size": size,
"items": [format_pricing(p) for p in items]
}
@router.get("/pricing/{pricing_id}")
async def get_model_pricing(
pricing_id: int,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取模型价格详情"""
pricing = db.query(ModelPricing).filter(ModelPricing.id == pricing_id).first()
if not pricing:
raise HTTPException(status_code=404, detail="模型价格配置不存在")
return format_pricing(pricing)
@router.post("/pricing")
async def create_model_pricing(
data: ModelPricingCreate,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""创建模型价格配置"""
# 检查是否已存在
existing = db.query(ModelPricing).filter(ModelPricing.model_name == data.model_name).first()
if existing:
raise HTTPException(status_code=400, detail="该模型价格配置已存在")
pricing = ModelPricing(
model_name=data.model_name,
provider=data.provider,
display_name=data.display_name,
input_price_per_1k=Decimal(str(data.input_price_per_1k)),
output_price_per_1k=Decimal(str(data.output_price_per_1k)),
fixed_price_per_call=Decimal(str(data.fixed_price_per_call)),
pricing_type=data.pricing_type,
description=data.description,
status=1
)
db.add(pricing)
db.commit()
db.refresh(pricing)
return {"success": True, "id": pricing.id}
@router.put("/pricing/{pricing_id}")
async def update_model_pricing(
pricing_id: int,
data: ModelPricingUpdate,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""更新模型价格配置"""
pricing = db.query(ModelPricing).filter(ModelPricing.id == pricing_id).first()
if not pricing:
raise HTTPException(status_code=404, detail="模型价格配置不存在")
update_data = data.model_dump(exclude_unset=True)
# 转换价格字段
for field in ['input_price_per_1k', 'output_price_per_1k', 'fixed_price_per_call']:
if field in update_data and update_data[field] is not None:
update_data[field] = Decimal(str(update_data[field]))
for key, value in update_data.items():
setattr(pricing, key, value)
db.commit()
return {"success": True}
@router.delete("/pricing/{pricing_id}")
async def delete_model_pricing(
pricing_id: int,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""删除模型价格配置"""
pricing = db.query(ModelPricing).filter(ModelPricing.id == pricing_id).first()
if not pricing:
raise HTTPException(status_code=404, detail="模型价格配置不存在")
db.delete(pricing)
db.commit()
return {"success": True}
# ============= Cost Calculation API =============
@router.post("/calculate")
async def calculate_cost(
request: CostCalculateRequest,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""计算调用费用"""
calculator = CostCalculator(db)
cost = calculator.calculate_cost(
model_name=request.model_name,
input_tokens=request.input_tokens,
output_tokens=request.output_tokens
)
return {
"model": request.model_name,
"input_tokens": request.input_tokens,
"output_tokens": request.output_tokens,
"cost": float(cost)
}
@router.get("/summary")
async def get_cost_summary(
tenant_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取费用汇总"""
calculator = CostCalculator(db)
return calculator.get_cost_summary(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
@router.get("/by-tenant")
async def get_cost_by_tenant(
start_date: Optional[str] = None,
end_date: Optional[str] = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""按租户统计费用"""
calculator = CostCalculator(db)
return calculator.get_cost_by_tenant(
start_date=start_date,
end_date=end_date
)
@router.get("/by-model")
async def get_cost_by_model(
tenant_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""按模型统计费用"""
calculator = CostCalculator(db)
return calculator.get_cost_by_model(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
# ============= Billing API =============
@router.get("/billing")
async def list_billing(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
tenant_id: Optional[str] = None,
billing_month: Optional[str] = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取账单列表"""
query = db.query(TenantBilling)
if tenant_id:
query = query.filter(TenantBilling.tenant_id == tenant_id)
if billing_month:
query = query.filter(TenantBilling.billing_month == billing_month)
total = query.count()
items = query.order_by(desc(TenantBilling.billing_month)).offset((page - 1) * size).limit(size).all()
return {
"total": total,
"page": page,
"size": size,
"items": [format_billing(b) for b in items]
}
@router.post("/billing/generate")
async def generate_billing(
tenant_id: str = Query(...),
billing_month: str = Query(..., description="格式: YYYY-MM"),
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""生成月度账单"""
calculator = CostCalculator(db)
billing = calculator.generate_monthly_billing(tenant_id, billing_month)
return {
"success": True,
"billing": format_billing(billing)
}
@router.post("/recalculate")
async def recalculate_costs(
start_date: Optional[str] = None,
end_date: Optional[str] = None,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""重新计算事件费用"""
calculator = CostCalculator(db)
updated = calculator.update_event_costs(start_date, end_date)
return {
"success": True,
"updated_count": updated
}
# ============= Helper Functions =============
def format_pricing(pricing: ModelPricing) -> dict:
return {
"id": pricing.id,
"model_name": pricing.model_name,
"provider": pricing.provider,
"display_name": pricing.display_name,
"input_price_per_1k": float(pricing.input_price_per_1k or 0),
"output_price_per_1k": float(pricing.output_price_per_1k or 0),
"fixed_price_per_call": float(pricing.fixed_price_per_call or 0),
"pricing_type": pricing.pricing_type,
"description": pricing.description,
"status": pricing.status,
"created_at": pricing.created_at,
"updated_at": pricing.updated_at
}
def format_billing(billing: TenantBilling) -> dict:
return {
"id": billing.id,
"tenant_id": billing.tenant_id,
"billing_month": billing.billing_month,
"total_calls": billing.total_calls,
"total_input_tokens": billing.total_input_tokens,
"total_output_tokens": billing.total_output_tokens,
"total_cost": float(billing.total_cost or 0),
"cost_by_model": billing.cost_by_model,
"cost_by_app": billing.cost_by_app,
"status": billing.status,
"created_at": billing.created_at,
"updated_at": billing.updated_at
}

View File

@@ -1,6 +1,10 @@
"""日志路由"""
import csv
import io
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from sqlalchemy import desc
@@ -13,6 +17,14 @@ from ..services.auth import decode_token
router = APIRouter(prefix="/logs", tags=["logs"])
settings = get_settings()
# 尝试导入openpyxl
try:
from openpyxl import Workbook
from openpyxl.styles import Font, Alignment, PatternFill
OPENPYXL_AVAILABLE = True
except ImportError:
OPENPYXL_AVAILABLE = False
def get_current_user_optional(authorization: Optional[str] = Header(None)):
"""可选的用户认证"""
@@ -113,3 +125,154 @@ async def query_logs(
for item in items
]
}
@router.get("/export")
async def export_logs(
format: str = Query("csv", description="导出格式: csv 或 excel"),
log_type: Optional[str] = None,
level: Optional[str] = None,
app_code: Optional[str] = None,
tenant_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
limit: int = Query(10000, ge=1, le=100000, description="最大导出记录数"),
db: Session = Depends(get_db),
user = Depends(get_current_user_optional)
):
"""导出日志
支持CSV和Excel格式最多导出10万条记录
"""
query = db.query(PlatformLog)
if log_type:
query = query.filter(PlatformLog.log_type == log_type)
if level:
query = query.filter(PlatformLog.level == level)
if app_code:
query = query.filter(PlatformLog.app_code == app_code)
if tenant_id:
query = query.filter(PlatformLog.tenant_id == tenant_id)
if start_date:
query = query.filter(PlatformLog.log_time >= start_date)
if end_date:
query = query.filter(PlatformLog.log_time <= end_date + " 23:59:59")
items = query.order_by(desc(PlatformLog.log_time)).limit(limit).all()
if format.lower() == "excel":
return export_excel(items)
else:
return export_csv(items)
def export_csv(logs: list) -> StreamingResponse:
"""导出为CSV格式"""
output = io.StringIO()
writer = csv.writer(output)
# 写入表头
headers = [
"ID", "类型", "级别", "应用", "租户", "Trace ID",
"消息", "路径", "方法", "状态码", "耗时(ms)",
"IP地址", "时间"
]
writer.writerow(headers)
# 写入数据
for log in logs:
writer.writerow([
log.id,
log.log_type,
log.level,
log.app_code or "",
log.tenant_id or "",
log.trace_id or "",
log.message or "",
log.path or "",
log.method or "",
log.status_code or "",
log.duration_ms or "",
log.ip_address or "",
str(log.log_time) if log.log_time else ""
])
output.seek(0)
# 生成文件名
filename = f"logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
return StreamingResponse(
iter([output.getvalue()]),
media_type="text/csv",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Type": "text/csv; charset=utf-8-sig"
}
)
def export_excel(logs: list) -> StreamingResponse:
"""导出为Excel格式"""
if not OPENPYXL_AVAILABLE:
raise HTTPException(status_code=400, detail="Excel导出功能不可用请安装openpyxl")
wb = Workbook()
ws = wb.active
ws.title = "日志导出"
# 表头样式
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
# 写入表头
headers = [
"ID", "类型", "级别", "应用", "租户", "Trace ID",
"消息", "路径", "方法", "状态码", "耗时(ms)",
"IP地址", "时间"
]
for col, header in enumerate(headers, 1):
cell = ws.cell(row=1, column=col, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
# 写入数据
for row, log in enumerate(logs, 2):
ws.cell(row=row, column=1, value=log.id)
ws.cell(row=row, column=2, value=log.log_type)
ws.cell(row=row, column=3, value=log.level)
ws.cell(row=row, column=4, value=log.app_code or "")
ws.cell(row=row, column=5, value=log.tenant_id or "")
ws.cell(row=row, column=6, value=log.trace_id or "")
ws.cell(row=row, column=7, value=log.message or "")
ws.cell(row=row, column=8, value=log.path or "")
ws.cell(row=row, column=9, value=log.method or "")
ws.cell(row=row, column=10, value=log.status_code or "")
ws.cell(row=row, column=11, value=log.duration_ms or "")
ws.cell(row=row, column=12, value=log.ip_address or "")
ws.cell(row=row, column=13, value=str(log.log_time) if log.log_time else "")
# 调整列宽
column_widths = [8, 10, 10, 12, 12, 36, 50, 30, 8, 10, 10, 15, 20]
for col, width in enumerate(column_widths, 1):
ws.column_dimensions[chr(64 + col)].width = width
# 保存到内存
output = io.BytesIO()
wb.save(output)
output.seek(0)
# 生成文件名
filename = f"logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
return StreamingResponse(
iter([output.getvalue()]),
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={
"Content-Disposition": f'attachment; filename="{filename}"'
}
)

View File

@@ -0,0 +1,264 @@
"""配额管理路由"""
from typing import Optional, Dict, Any
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy import desc
from ..database import get_db
from ..models.tenant import Subscription
from ..services.quota import QuotaService
from .auth import get_current_user, require_operator
from ..models.user import User
router = APIRouter(prefix="/quota", tags=["配额管理"])
# ============= Schemas =============
class QuotaConfigUpdate(BaseModel):
daily_calls: int = 0
daily_tokens: int = 0
monthly_calls: int = 0
monthly_tokens: int = 0
monthly_cost: float = 0
concurrent_calls: int = 0
class SubscriptionCreate(BaseModel):
tenant_id: str
app_code: str
start_date: Optional[str] = None
end_date: Optional[str] = None
quota: QuotaConfigUpdate
class SubscriptionUpdate(BaseModel):
start_date: Optional[str] = None
end_date: Optional[str] = None
quota: Optional[QuotaConfigUpdate] = None
status: Optional[str] = None
# ============= Quota Check API =============
@router.get("/check")
async def check_quota(
tenant_id: str = Query(..., alias="tid"),
app_code: str = Query(..., alias="aid"),
estimated_tokens: int = Query(0),
db: Session = Depends(get_db)
):
"""检查配额是否足够
用于调用前检查,返回是否允许继续调用
"""
service = QuotaService(db)
result = service.check_quota(tenant_id, app_code, estimated_tokens)
return {
"allowed": result.allowed,
"reason": result.reason,
"quota_type": result.quota_type,
"limit": result.limit,
"used": result.used,
"remaining": result.remaining
}
@router.get("/summary")
async def get_quota_summary(
tenant_id: str = Query(...),
app_code: str = Query(...),
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取配额使用汇总"""
service = QuotaService(db)
return service.get_quota_summary(tenant_id, app_code)
@router.get("/usage")
async def get_quota_usage(
tenant_id: str = Query(...),
app_code: str = Query(...),
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取配额使用情况"""
service = QuotaService(db)
usage = service.get_usage(tenant_id, app_code)
return {
"daily_calls": usage.daily_calls,
"daily_tokens": usage.daily_tokens,
"monthly_calls": usage.monthly_calls,
"monthly_tokens": usage.monthly_tokens,
"monthly_cost": round(usage.monthly_cost, 2)
}
# ============= Subscription API =============
@router.get("/subscriptions")
async def list_subscriptions(
page: int = Query(1, ge=1),
size: int = Query(20, ge=1, le=100),
tenant_id: Optional[str] = None,
app_code: Optional[str] = None,
status: Optional[str] = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取订阅列表"""
query = db.query(Subscription)
if tenant_id:
query = query.filter(Subscription.tenant_id == tenant_id)
if app_code:
query = query.filter(Subscription.app_code == app_code)
if status:
query = query.filter(Subscription.status == status)
total = query.count()
items = query.order_by(desc(Subscription.created_at)).offset((page - 1) * size).limit(size).all()
return {
"total": total,
"page": page,
"size": size,
"items": [format_subscription(s) for s in items]
}
@router.get("/subscriptions/{subscription_id}")
async def get_subscription(
subscription_id: int,
user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取订阅详情"""
subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first()
if not subscription:
raise HTTPException(status_code=404, detail="订阅不存在")
return format_subscription(subscription)
@router.post("/subscriptions")
async def create_subscription(
data: SubscriptionCreate,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""创建订阅"""
# 检查是否已存在
existing = db.query(Subscription).filter(
Subscription.tenant_id == data.tenant_id,
Subscription.app_code == data.app_code,
Subscription.status == 'active'
).first()
if existing:
raise HTTPException(status_code=400, detail="该租户应用已有活跃订阅")
subscription = Subscription(
tenant_id=data.tenant_id,
app_code=data.app_code,
start_date=data.start_date or date.today(),
end_date=data.end_date,
quota=data.quota.model_dump() if data.quota else {},
status='active'
)
db.add(subscription)
db.commit()
db.refresh(subscription)
return {"success": True, "id": subscription.id}
@router.put("/subscriptions/{subscription_id}")
async def update_subscription(
subscription_id: int,
data: SubscriptionUpdate,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""更新订阅"""
subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first()
if not subscription:
raise HTTPException(status_code=404, detail="订阅不存在")
if data.start_date:
subscription.start_date = data.start_date
if data.end_date:
subscription.end_date = data.end_date
if data.quota:
subscription.quota = data.quota.model_dump()
if data.status:
subscription.status = data.status
db.commit()
# 清除缓存
service = QuotaService(db)
cache_key = f"quota:config:{subscription.tenant_id}:{subscription.app_code}"
service._cache.delete(cache_key)
return {"success": True}
@router.delete("/subscriptions/{subscription_id}")
async def delete_subscription(
subscription_id: int,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""删除订阅"""
subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first()
if not subscription:
raise HTTPException(status_code=404, detail="订阅不存在")
db.delete(subscription)
db.commit()
return {"success": True}
@router.put("/subscriptions/{subscription_id}/quota")
async def update_quota(
subscription_id: int,
data: QuotaConfigUpdate,
user: User = Depends(require_operator),
db: Session = Depends(get_db)
):
"""更新订阅配额"""
subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first()
if not subscription:
raise HTTPException(status_code=404, detail="订阅不存在")
subscription.quota = data.model_dump()
db.commit()
# 清除缓存
service = QuotaService(db)
cache_key = f"quota:config:{subscription.tenant_id}:{subscription.app_code}"
service._cache.delete(cache_key)
return {"success": True}
# ============= Helper Functions =============
def format_subscription(subscription: Subscription) -> dict:
return {
"id": subscription.id,
"tenant_id": subscription.tenant_id,
"app_code": subscription.app_code,
"start_date": str(subscription.start_date) if subscription.start_date else None,
"end_date": str(subscription.end_date) if subscription.end_date else None,
"quota": subscription.quota or {},
"status": subscription.status,
"created_at": subscription.created_at,
"updated_at": subscription.updated_at
}

View File

@@ -0,0 +1,264 @@
"""企业微信JS-SDK路由"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ..database import get_db
from ..models.tenant_app import TenantApp
from ..models.tenant_wechat_app import TenantWechatApp
from ..services.wechat import WechatService, get_wechat_service_by_id
router = APIRouter(prefix="/wechat", tags=["企业微信"])
class JssdkSignatureRequest(BaseModel):
"""JS-SDK签名请求"""
url: str # 当前页面URL不含#及其后面部分)
class JssdkSignatureResponse(BaseModel):
"""JS-SDK签名响应"""
appId: str
agentId: str
timestamp: int
nonceStr: str
signature: str
class OAuth2UrlRequest(BaseModel):
"""OAuth2授权URL请求"""
redirect_uri: str
scope: str = "snsapi_base"
state: str = ""
class UserInfoRequest(BaseModel):
"""用户信息请求"""
code: str
@router.post("/jssdk/signature")
async def get_jssdk_signature(
request: JssdkSignatureRequest,
tenant_id: str = Query(..., alias="tid"),
app_code: str = Query(..., alias="aid"),
db: Session = Depends(get_db)
):
"""获取JS-SDK签名
用于前端初始化企业微信JS-SDK
Args:
request: 包含当前页面URL
tenant_id: 租户ID
app_code: 应用代码
Returns:
JS-SDK签名信息
"""
# 查找租户应用配置
tenant_app = db.query(TenantApp).filter(
TenantApp.tenant_id == tenant_id,
TenantApp.app_code == app_code,
TenantApp.status == 1
).first()
if not tenant_app:
raise HTTPException(status_code=404, detail="租户应用配置不存在")
if not tenant_app.wechat_app_id:
raise HTTPException(status_code=400, detail="该应用未配置企业微信")
# 获取企微服务
wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
if not wechat_service:
raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
# 生成签名
signature_data = await wechat_service.get_jssdk_signature(request.url)
if not signature_data:
raise HTTPException(status_code=500, detail="获取JS-SDK签名失败")
return signature_data
@router.get("/jssdk/signature")
async def get_jssdk_signature_get(
url: str = Query(..., description="当前页面URL"),
tenant_id: str = Query(..., alias="tid"),
app_code: str = Query(..., alias="aid"),
db: Session = Depends(get_db)
):
"""获取JS-SDK签名GET方式
方便前端JSONP调用
"""
# 查找租户应用配置
tenant_app = db.query(TenantApp).filter(
TenantApp.tenant_id == tenant_id,
TenantApp.app_code == app_code,
TenantApp.status == 1
).first()
if not tenant_app:
raise HTTPException(status_code=404, detail="租户应用配置不存在")
if not tenant_app.wechat_app_id:
raise HTTPException(status_code=400, detail="该应用未配置企业微信")
# 获取企微服务
wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
if not wechat_service:
raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
# 生成签名
signature_data = await wechat_service.get_jssdk_signature(url)
if not signature_data:
raise HTTPException(status_code=500, detail="获取JS-SDK签名失败")
return signature_data
@router.post("/oauth2/url")
async def get_oauth2_url(
request: OAuth2UrlRequest,
tenant_id: str = Query(..., alias="tid"),
app_code: str = Query(..., alias="aid"),
db: Session = Depends(get_db)
):
"""获取OAuth2授权URL
用于企业微信内网页获取用户身份
"""
# 查找租户应用配置
tenant_app = db.query(TenantApp).filter(
TenantApp.tenant_id == tenant_id,
TenantApp.app_code == app_code,
TenantApp.status == 1
).first()
if not tenant_app:
raise HTTPException(status_code=404, detail="租户应用配置不存在")
if not tenant_app.wechat_app_id:
raise HTTPException(status_code=400, detail="该应用未配置企业微信")
# 获取企微服务
wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
if not wechat_service:
raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
# 生成OAuth2 URL
oauth_url = wechat_service.get_oauth2_url(
redirect_uri=request.redirect_uri,
scope=request.scope,
state=request.state
)
return {"url": oauth_url}
@router.post("/oauth2/userinfo")
async def get_user_info(
request: UserInfoRequest,
tenant_id: str = Query(..., alias="tid"),
app_code: str = Query(..., alias="aid"),
db: Session = Depends(get_db)
):
"""通过OAuth2 code获取用户信息
在OAuth2回调后用code换取用户信息
"""
# 查找租户应用配置
tenant_app = db.query(TenantApp).filter(
TenantApp.tenant_id == tenant_id,
TenantApp.app_code == app_code,
TenantApp.status == 1
).first()
if not tenant_app:
raise HTTPException(status_code=404, detail="租户应用配置不存在")
if not tenant_app.wechat_app_id:
raise HTTPException(status_code=400, detail="该应用未配置企业微信")
# 获取企微服务
wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
if not wechat_service:
raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
# 获取用户信息
user_info = await wechat_service.get_user_info_by_code(request.code)
if not user_info:
raise HTTPException(status_code=400, detail="获取用户信息失败code可能已过期")
return user_info
@router.get("/oauth2/userinfo")
async def get_user_info_get(
code: str = Query(..., description="OAuth2回调的code"),
tenant_id: str = Query(..., alias="tid"),
app_code: str = Query(..., alias="aid"),
db: Session = Depends(get_db)
):
"""通过OAuth2 code获取用户信息GET方式"""
# 查找租户应用配置
tenant_app = db.query(TenantApp).filter(
TenantApp.tenant_id == tenant_id,
TenantApp.app_code == app_code,
TenantApp.status == 1
).first()
if not tenant_app:
raise HTTPException(status_code=404, detail="租户应用配置不存在")
if not tenant_app.wechat_app_id:
raise HTTPException(status_code=400, detail="该应用未配置企业微信")
# 获取企微服务
wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
if not wechat_service:
raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
# 获取用户信息
user_info = await wechat_service.get_user_info_by_code(code)
if not user_info:
raise HTTPException(status_code=400, detail="获取用户信息失败code可能已过期")
return user_info
@router.get("/user/{user_id}")
async def get_user_detail(
user_id: str,
tenant_id: str = Query(..., alias="tid"),
app_code: str = Query(..., alias="aid"),
db: Session = Depends(get_db)
):
"""获取企业微信成员详细信息"""
# 查找租户应用配置
tenant_app = db.query(TenantApp).filter(
TenantApp.tenant_id == tenant_id,
TenantApp.app_code == app_code,
TenantApp.status == 1
).first()
if not tenant_app:
raise HTTPException(status_code=404, detail="租户应用配置不存在")
if not tenant_app.wechat_app_id:
raise HTTPException(status_code=400, detail="该应用未配置企业微信")
# 获取企微服务
wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
if not wechat_service:
raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
# 获取用户详情
user_detail = await wechat_service.get_user_detail(user_id)
if not user_detail:
raise HTTPException(status_code=404, detail="用户不存在")
return user_detail

View File

@@ -1,4 +1,22 @@
"""业务服务"""
from .crypto import encrypt_value, decrypt_value
from .cache import CacheService, get_cache, get_redis_client
from .wechat import WechatService, get_wechat_service_by_id
from .alert import AlertService
from .cost import CostCalculator, calculate_cost
from .quota import QuotaService, check_quota_middleware
__all__ = ["encrypt_value", "decrypt_value"]
__all__ = [
"encrypt_value",
"decrypt_value",
"CacheService",
"get_cache",
"get_redis_client",
"WechatService",
"get_wechat_service_by_id",
"AlertService",
"CostCalculator",
"calculate_cost",
"QuotaService",
"check_quota_middleware"
]

View File

@@ -0,0 +1,455 @@
"""告警服务"""
import logging
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
import httpx
from sqlalchemy.orm import Session
from sqlalchemy import func
from ..models.alert import AlertRule, AlertRecord, NotificationChannel
from ..models.stats import AICallEvent
from ..models.logs import PlatformLog
from .cache import get_cache
logger = logging.getLogger(__name__)
class AlertService:
"""告警服务
提供告警规则检测、告警记录管理、通知发送等功能
"""
def __init__(self, db: Session):
self.db = db
self._cache = get_cache()
async def check_all_rules(self) -> List[AlertRecord]:
"""检查所有启用的告警规则
Returns:
触发的告警记录列表
"""
rules = self.db.query(AlertRule).filter(AlertRule.status == 1).all()
triggered_alerts = []
for rule in rules:
try:
alert = await self.check_rule(rule)
if alert:
triggered_alerts.append(alert)
except Exception as e:
logger.error(f"Failed to check rule {rule.id}: {e}")
return triggered_alerts
async def check_rule(self, rule: AlertRule) -> Optional[AlertRecord]:
"""检查单个告警规则
Args:
rule: 告警规则
Returns:
触发的告警记录或None
"""
# 检查冷却期
if self._is_in_cooldown(rule):
logger.debug(f"Rule {rule.id} is in cooldown")
return None
# 检查每日告警次数限制
if self._exceeds_daily_limit(rule):
logger.debug(f"Rule {rule.id} exceeds daily limit")
return None
# 根据规则类型检查
metric_value = None
threshold_value = None
triggered = False
condition = rule.condition or {}
if rule.rule_type == 'error_rate':
triggered, metric_value, threshold_value = self._check_error_rate(rule, condition)
elif rule.rule_type == 'call_count':
triggered, metric_value, threshold_value = self._check_call_count(rule, condition)
elif rule.rule_type == 'token_usage':
triggered, metric_value, threshold_value = self._check_token_usage(rule, condition)
elif rule.rule_type == 'cost_threshold':
triggered, metric_value, threshold_value = self._check_cost_threshold(rule, condition)
elif rule.rule_type == 'latency':
triggered, metric_value, threshold_value = self._check_latency(rule, condition)
if triggered:
alert = self._create_alert_record(rule, metric_value, threshold_value)
return alert
return None
def _is_in_cooldown(self, rule: AlertRule) -> bool:
"""检查规则是否在冷却期"""
cache_key = f"alert:cooldown:{rule.id}"
return self._cache.exists(cache_key)
def _set_cooldown(self, rule: AlertRule):
"""设置规则冷却期"""
cache_key = f"alert:cooldown:{rule.id}"
self._cache.set(cache_key, "1", ttl=rule.cooldown_minutes * 60)
def _exceeds_daily_limit(self, rule: AlertRule) -> bool:
"""检查是否超过每日告警次数限制"""
today = datetime.now().date()
count = self.db.query(func.count(AlertRecord.id)).filter(
AlertRecord.rule_id == rule.id,
func.date(AlertRecord.created_at) == today
).scalar()
return count >= rule.max_alerts_per_day
def _check_error_rate(self, rule: AlertRule, condition: dict) -> tuple:
"""检查错误率"""
window_minutes = self._parse_window(condition.get('window', '5m'))
threshold = condition.get('threshold', 10) # 错误次数阈值
operator = condition.get('operator', '>')
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(func.count(AICallEvent.id)).filter(
AICallEvent.created_at >= since,
AICallEvent.status == 'error'
)
# 应用作用范围
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
error_count = query.scalar() or 0
triggered = self._compare(error_count, threshold, operator)
return triggered, str(error_count), str(threshold)
def _check_call_count(self, rule: AlertRule, condition: dict) -> tuple:
"""检查调用次数"""
window_minutes = self._parse_window(condition.get('window', '1h'))
threshold = condition.get('threshold', 1000)
operator = condition.get('operator', '>')
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(func.count(AICallEvent.id)).filter(
AICallEvent.created_at >= since
)
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
call_count = query.scalar() or 0
triggered = self._compare(call_count, threshold, operator)
return triggered, str(call_count), str(threshold)
def _check_token_usage(self, rule: AlertRule, condition: dict) -> tuple:
"""检查Token使用量"""
window_minutes = self._parse_window(condition.get('window', '1d'))
threshold = condition.get('threshold', 100000)
operator = condition.get('operator', '>')
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(
func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0)
).filter(
AICallEvent.created_at >= since
)
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
token_usage = query.scalar() or 0
triggered = self._compare(token_usage, threshold, operator)
return triggered, str(token_usage), str(threshold)
def _check_cost_threshold(self, rule: AlertRule, condition: dict) -> tuple:
"""检查费用阈值"""
window_minutes = self._parse_window(condition.get('window', '1d'))
threshold = condition.get('threshold', 100) # 费用阈值(元)
operator = condition.get('operator', '>')
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(
func.coalesce(func.sum(AICallEvent.cost), 0)
).filter(
AICallEvent.created_at >= since
)
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
total_cost = float(query.scalar() or 0)
triggered = self._compare(total_cost, threshold, operator)
return triggered, f"¥{total_cost:.2f}", f"¥{threshold:.2f}"
def _check_latency(self, rule: AlertRule, condition: dict) -> tuple:
"""检查延迟"""
window_minutes = self._parse_window(condition.get('window', '5m'))
threshold = condition.get('threshold', 5000) # 延迟阈值(ms)
operator = condition.get('operator', '>')
percentile = condition.get('percentile', 'avg') # avg, p95, p99, max
since = datetime.now() - timedelta(minutes=window_minutes)
query = self.db.query(AICallEvent.latency_ms).filter(
AICallEvent.created_at >= since,
AICallEvent.latency_ms.isnot(None)
)
if rule.scope_type == 'tenant' and rule.scope_value:
query = query.filter(AICallEvent.tenant_id == rule.scope_value)
elif rule.scope_type == 'app' and rule.scope_value:
query = query.filter(AICallEvent.app_code == rule.scope_value)
latencies = [r.latency_ms for r in query.all()]
if not latencies:
return False, "0", str(threshold)
if percentile == 'avg':
metric = sum(latencies) / len(latencies)
elif percentile == 'max':
metric = max(latencies)
elif percentile == 'p95':
latencies.sort()
idx = int(len(latencies) * 0.95)
metric = latencies[idx] if idx < len(latencies) else latencies[-1]
elif percentile == 'p99':
latencies.sort()
idx = int(len(latencies) * 0.99)
metric = latencies[idx] if idx < len(latencies) else latencies[-1]
else:
metric = sum(latencies) / len(latencies)
triggered = self._compare(metric, threshold, operator)
return triggered, f"{metric:.0f}ms", f"{threshold}ms"
def _parse_window(self, window: str) -> int:
"""解析时间窗口字符串为分钟数"""
if window.endswith('m'):
return int(window[:-1])
elif window.endswith('h'):
return int(window[:-1]) * 60
elif window.endswith('d'):
return int(window[:-1]) * 60 * 24
else:
return int(window)
def _compare(self, value: float, threshold: float, operator: str) -> bool:
"""比较值与阈值"""
if operator == '>':
return value > threshold
elif operator == '>=':
return value >= threshold
elif operator == '<':
return value < threshold
elif operator == '<=':
return value <= threshold
elif operator == '==':
return value == threshold
elif operator == '!=':
return value != threshold
return False
def _create_alert_record(
self,
rule: AlertRule,
metric_value: str,
threshold_value: str
) -> AlertRecord:
"""创建告警记录"""
title = f"[{rule.priority.upper()}] {rule.name}"
message = f"规则 '{rule.name}' 触发告警\n当前值: {metric_value}\n阈值: {threshold_value}"
if rule.scope_type == 'tenant':
message += f"\n租户: {rule.scope_value}"
elif rule.scope_type == 'app':
message += f"\n应用: {rule.scope_value}"
alert = AlertRecord(
rule_id=rule.id,
rule_name=rule.name,
alert_type=rule.rule_type,
severity=self._priority_to_severity(rule.priority),
title=title,
message=message,
tenant_id=rule.scope_value if rule.scope_type == 'tenant' else None,
app_code=rule.scope_value if rule.scope_type == 'app' else None,
metric_value=metric_value,
threshold_value=threshold_value,
notification_status='pending'
)
self.db.add(alert)
self.db.commit()
self.db.refresh(alert)
# 设置冷却期
self._set_cooldown(rule)
logger.info(f"Alert triggered: {title}")
return alert
def _priority_to_severity(self, priority: str) -> str:
"""将优先级转换为严重程度"""
mapping = {
'low': 'info',
'medium': 'warning',
'high': 'error',
'critical': 'critical'
}
return mapping.get(priority, 'warning')
async def send_notification(self, alert: AlertRecord, rule: AlertRule) -> bool:
"""发送告警通知
Args:
alert: 告警记录
rule: 告警规则
Returns:
是否发送成功
"""
if not rule.notification_channels:
alert.notification_status = 'skipped'
self.db.commit()
return True
results = []
success = True
for channel_config in rule.notification_channels:
try:
result = await self._send_to_channel(channel_config, alert)
results.append(result)
if not result.get('success'):
success = False
except Exception as e:
logger.error(f"Failed to send notification: {e}")
results.append({'success': False, 'error': str(e)})
success = False
alert.notification_status = 'sent' if success else 'failed'
alert.notification_result = results
alert.notified_at = datetime.now()
self.db.commit()
return success
async def _send_to_channel(self, channel_config: dict, alert: AlertRecord) -> dict:
"""发送到指定渠道"""
channel_type = channel_config.get('type')
if channel_type == 'wechat_bot':
return await self._send_wechat_bot(channel_config, alert)
elif channel_type == 'webhook':
return await self._send_webhook(channel_config, alert)
else:
return {'success': False, 'error': f'Unsupported channel type: {channel_type}'}
async def _send_wechat_bot(self, config: dict, alert: AlertRecord) -> dict:
"""发送到企微机器人"""
webhook = config.get('webhook')
if not webhook:
return {'success': False, 'error': 'Missing webhook URL'}
# 构建消息
content = f"**{alert.title}**\n\n{alert.message}\n\n时间: {alert.created_at}"
payload = {
"msgtype": "markdown",
"markdown": {
"content": content
}
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(webhook, json=payload)
result = response.json()
if result.get('errcode', 0) == 0:
return {'success': True, 'channel': 'wechat_bot'}
else:
return {'success': False, 'error': result.get('errmsg')}
except Exception as e:
return {'success': False, 'error': str(e)}
async def _send_webhook(self, config: dict, alert: AlertRecord) -> dict:
"""发送到Webhook"""
url = config.get('url')
if not url:
return {'success': False, 'error': 'Missing webhook URL'}
payload = {
"alert_id": alert.id,
"title": alert.title,
"message": alert.message,
"severity": alert.severity,
"alert_type": alert.alert_type,
"metric_value": alert.metric_value,
"threshold_value": alert.threshold_value,
"created_at": alert.created_at.isoformat()
}
headers = config.get('headers', {})
method = config.get('method', 'POST')
try:
async with httpx.AsyncClient(timeout=10) as client:
if method.upper() == 'POST':
response = await client.post(url, json=payload, headers=headers)
else:
response = await client.get(url, params=payload, headers=headers)
if response.status_code < 400:
return {'success': True, 'channel': 'webhook', 'status': response.status_code}
else:
return {'success': False, 'error': f'HTTP {response.status_code}'}
except Exception as e:
return {'success': False, 'error': str(e)}
def acknowledge_alert(self, alert_id: int, acknowledged_by: str) -> Optional[AlertRecord]:
"""确认告警"""
alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first()
if not alert:
return None
alert.status = 'acknowledged'
alert.acknowledged_by = acknowledged_by
alert.acknowledged_at = datetime.now()
self.db.commit()
return alert
def resolve_alert(self, alert_id: int) -> Optional[AlertRecord]:
"""解决告警"""
alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first()
if not alert:
return None
alert.status = 'resolved'
alert.resolved_at = datetime.now()
self.db.commit()
return alert

View File

@@ -0,0 +1,309 @@
"""Redis缓存服务"""
import json
import logging
from typing import Optional, Any, Union
from functools import lru_cache
try:
import redis
from redis import Redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
Redis = None
from ..config import get_settings
logger = logging.getLogger(__name__)
# 全局Redis连接池
_redis_pool: Optional[Any] = None
_redis_client: Optional[Any] = None
def get_redis_client() -> Optional[Any]:
"""获取Redis客户端单例"""
global _redis_pool, _redis_client
if not REDIS_AVAILABLE:
logger.warning("Redis module not installed, cache disabled")
return None
if _redis_client is not None:
return _redis_client
settings = get_settings()
try:
_redis_pool = redis.ConnectionPool.from_url(
settings.REDIS_URL,
max_connections=20,
decode_responses=True
)
_redis_client = Redis(connection_pool=_redis_pool)
# 测试连接
_redis_client.ping()
logger.info(f"Redis connected: {settings.REDIS_URL}")
return _redis_client
except Exception as e:
logger.warning(f"Redis connection failed: {e}, cache disabled")
_redis_client = None
return None
class CacheService:
"""缓存服务
提供统一的缓存接口支持Redis和内存回退
使用示例:
cache = CacheService()
# 设置缓存
cache.set("user:123", {"name": "test"}, ttl=3600)
# 获取缓存
user = cache.get("user:123")
# 删除缓存
cache.delete("user:123")
"""
def __init__(self, prefix: Optional[str] = None):
"""初始化缓存服务
Args:
prefix: 键前缀默认使用配置中的REDIS_PREFIX
"""
settings = get_settings()
self.prefix = prefix or settings.REDIS_PREFIX
self._client = get_redis_client()
# 内存回退缓存当Redis不可用时使用
self._memory_cache: dict = {}
@property
def is_available(self) -> bool:
"""Redis是否可用"""
return self._client is not None
def _make_key(self, key: str) -> str:
"""生成完整的缓存键"""
return f"{self.prefix}{key}"
def get(self, key: str, default: Any = None) -> Any:
"""获取缓存值
Args:
key: 缓存键
default: 默认值
Returns:
缓存值或默认值
"""
full_key = self._make_key(key)
if self._client:
try:
value = self._client.get(full_key)
if value is None:
return default
# 尝试解析JSON
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
except Exception as e:
logger.error(f"Cache get error: {e}")
return default
else:
# 内存回退
return self._memory_cache.get(full_key, default)
def set(
self,
key: str,
value: Any,
ttl: Optional[int] = None,
nx: bool = False
) -> bool:
"""设置缓存值
Args:
key: 缓存键
value: 缓存值
ttl: 过期时间(秒)
nx: 只在键不存在时设置
Returns:
是否设置成功
"""
full_key = self._make_key(key)
# 序列化值
if isinstance(value, (dict, list)):
serialized = json.dumps(value, ensure_ascii=False)
else:
serialized = str(value) if value is not None else ""
if self._client:
try:
if nx:
result = self._client.set(full_key, serialized, ex=ttl, nx=True)
else:
result = self._client.set(full_key, serialized, ex=ttl)
return bool(result)
except Exception as e:
logger.error(f"Cache set error: {e}")
return False
else:
# 内存回退不支持TTL和NX
if nx and full_key in self._memory_cache:
return False
self._memory_cache[full_key] = value
return True
def delete(self, key: str) -> bool:
"""删除缓存
Args:
key: 缓存键
Returns:
是否删除成功
"""
full_key = self._make_key(key)
if self._client:
try:
return bool(self._client.delete(full_key))
except Exception as e:
logger.error(f"Cache delete error: {e}")
return False
else:
return self._memory_cache.pop(full_key, None) is not None
def exists(self, key: str) -> bool:
"""检查键是否存在
Args:
key: 缓存键
Returns:
是否存在
"""
full_key = self._make_key(key)
if self._client:
try:
return bool(self._client.exists(full_key))
except Exception as e:
logger.error(f"Cache exists error: {e}")
return False
else:
return full_key in self._memory_cache
def ttl(self, key: str) -> int:
"""获取键的剩余过期时间
Args:
key: 缓存键
Returns:
剩余秒数,-1表示永不过期-2表示键不存在
"""
full_key = self._make_key(key)
if self._client:
try:
return self._client.ttl(full_key)
except Exception as e:
logger.error(f"Cache ttl error: {e}")
return -2
else:
return -1 if full_key in self._memory_cache else -2
def incr(self, key: str, amount: int = 1) -> int:
"""递增计数器
Args:
key: 缓存键
amount: 递增量
Returns:
递增后的值
"""
full_key = self._make_key(key)
if self._client:
try:
return self._client.incrby(full_key, amount)
except Exception as e:
logger.error(f"Cache incr error: {e}")
return 0
else:
current = self._memory_cache.get(full_key, 0)
new_value = int(current) + amount
self._memory_cache[full_key] = new_value
return new_value
def expire(self, key: str, ttl: int) -> bool:
"""设置键的过期时间
Args:
key: 缓存键
ttl: 过期时间(秒)
Returns:
是否设置成功
"""
full_key = self._make_key(key)
if self._client:
try:
return bool(self._client.expire(full_key, ttl))
except Exception as e:
logger.error(f"Cache expire error: {e}")
return False
else:
return full_key in self._memory_cache
def clear_prefix(self, prefix: str) -> int:
"""删除指定前缀的所有键
Args:
prefix: 键前缀
Returns:
删除的键数量
"""
full_prefix = self._make_key(prefix)
if self._client:
try:
keys = self._client.keys(f"{full_prefix}*")
if keys:
return self._client.delete(*keys)
return 0
except Exception as e:
logger.error(f"Cache clear_prefix error: {e}")
return 0
else:
count = 0
keys_to_delete = [k for k in self._memory_cache if k.startswith(full_prefix)]
for k in keys_to_delete:
del self._memory_cache[k]
count += 1
return count
# 全局缓存实例
_cache_instance: Optional[CacheService] = None
def get_cache() -> CacheService:
"""获取全局缓存实例"""
global _cache_instance
if _cache_instance is None:
_cache_instance = CacheService()
return _cache_instance

View File

@@ -0,0 +1,420 @@
"""费用计算服务"""
import logging
from datetime import datetime
from decimal import Decimal
from typing import Optional, Dict, List
from functools import lru_cache
from sqlalchemy.orm import Session
from sqlalchemy import func
from ..models.pricing import ModelPricing, TenantBilling
from ..models.stats import AICallEvent
from .cache import get_cache
logger = logging.getLogger(__name__)
class CostCalculator:
"""费用计算器
使用示例:
calculator = CostCalculator(db)
# 计算单次调用费用
cost = calculator.calculate_cost("gpt-4", input_tokens=100, output_tokens=200)
# 生成月度账单
billing = calculator.generate_monthly_billing("qiqi", "2026-01")
"""
# 默认模型价格(当数据库中无配置时使用)
DEFAULT_PRICING = {
# OpenAI
"gpt-4": {"input": 0.21, "output": 0.42}, # 元/1K tokens
"gpt-4-turbo": {"input": 0.07, "output": 0.21},
"gpt-4o": {"input": 0.035, "output": 0.105},
"gpt-4o-mini": {"input": 0.00105, "output": 0.0042},
"gpt-3.5-turbo": {"input": 0.0035, "output": 0.014},
# Anthropic
"claude-3-opus": {"input": 0.105, "output": 0.525},
"claude-3-sonnet": {"input": 0.021, "output": 0.105},
"claude-3-haiku": {"input": 0.00175, "output": 0.00875},
"claude-3.5-sonnet": {"input": 0.021, "output": 0.105},
# 国内模型
"qwen-max": {"input": 0.02, "output": 0.06},
"qwen-plus": {"input": 0.004, "output": 0.012},
"qwen-turbo": {"input": 0.002, "output": 0.006},
"glm-4": {"input": 0.01, "output": 0.01},
"glm-4-flash": {"input": 0.0001, "output": 0.0001},
"deepseek-chat": {"input": 0.001, "output": 0.002},
"deepseek-coder": {"input": 0.001, "output": 0.002},
# 默认
"default": {"input": 0.01, "output": 0.03}
}
def __init__(self, db: Session):
self.db = db
self._cache = get_cache()
self._pricing_cache: Dict[str, ModelPricing] = {}
def get_model_pricing(self, model_name: str) -> Optional[ModelPricing]:
"""获取模型价格配置
Args:
model_name: 模型名称
Returns:
ModelPricing实例或None
"""
# 尝试从缓存获取
cache_key = f"pricing:{model_name}"
cached = self._cache.get(cache_key)
if cached:
return self._dict_to_pricing(cached)
# 从数据库查询
pricing = self.db.query(ModelPricing).filter(
ModelPricing.model_name == model_name,
ModelPricing.status == 1
).first()
if pricing:
# 缓存1小时
self._cache.set(cache_key, self._pricing_to_dict(pricing), ttl=3600)
return pricing
return None
def _pricing_to_dict(self, pricing: ModelPricing) -> dict:
return {
"model_name": pricing.model_name,
"input_price_per_1k": str(pricing.input_price_per_1k),
"output_price_per_1k": str(pricing.output_price_per_1k),
"fixed_price_per_call": str(pricing.fixed_price_per_call),
"pricing_type": pricing.pricing_type
}
def _dict_to_pricing(self, d: dict) -> ModelPricing:
pricing = ModelPricing()
pricing.model_name = d.get("model_name")
pricing.input_price_per_1k = Decimal(d.get("input_price_per_1k", "0"))
pricing.output_price_per_1k = Decimal(d.get("output_price_per_1k", "0"))
pricing.fixed_price_per_call = Decimal(d.get("fixed_price_per_call", "0"))
pricing.pricing_type = d.get("pricing_type", "token")
return pricing
def calculate_cost(
self,
model_name: str,
input_tokens: int = 0,
output_tokens: int = 0,
call_count: int = 1
) -> Decimal:
"""计算调用费用
Args:
model_name: 模型名称
input_tokens: 输入token数
output_tokens: 输出token数
call_count: 调用次数
Returns:
费用(元)
"""
# 尝试获取数据库配置
pricing = self.get_model_pricing(model_name)
if pricing:
if pricing.pricing_type == 'call':
return pricing.fixed_price_per_call * call_count
elif pricing.pricing_type == 'hybrid':
token_cost = (
pricing.input_price_per_1k * Decimal(input_tokens) / 1000 +
pricing.output_price_per_1k * Decimal(output_tokens) / 1000
)
call_cost = pricing.fixed_price_per_call * call_count
return token_cost + call_cost
else: # token
return (
pricing.input_price_per_1k * Decimal(input_tokens) / 1000 +
pricing.output_price_per_1k * Decimal(output_tokens) / 1000
)
# 使用默认价格
default_prices = self.DEFAULT_PRICING.get(model_name) or self.DEFAULT_PRICING.get("default")
input_price = Decimal(str(default_prices["input"]))
output_price = Decimal(str(default_prices["output"]))
return (
input_price * Decimal(input_tokens) / 1000 +
output_price * Decimal(output_tokens) / 1000
)
def calculate_event_cost(self, event: AICallEvent) -> Decimal:
"""计算单个事件的费用
Args:
event: AI调用事件
Returns:
费用(元)
"""
return self.calculate_cost(
model_name=event.model or "default",
input_tokens=event.input_tokens or 0,
output_tokens=event.output_tokens or 0
)
def update_event_costs(self, start_date: str = None, end_date: str = None) -> int:
"""批量更新事件费用
对于cost为0或NULL的事件重新计算费用
Args:
start_date: 开始日期,格式 YYYY-MM-DD
end_date: 结束日期,格式 YYYY-MM-DD
Returns:
更新的记录数
"""
query = self.db.query(AICallEvent).filter(
(AICallEvent.cost == None) | (AICallEvent.cost == 0)
)
if start_date:
query = query.filter(AICallEvent.created_at >= start_date)
if end_date:
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
events = query.all()
updated = 0
for event in events:
try:
cost = self.calculate_event_cost(event)
event.cost = cost
updated += 1
except Exception as e:
logger.error(f"Failed to calculate cost for event {event.id}: {e}")
self.db.commit()
logger.info(f"Updated {updated} event costs")
return updated
def generate_monthly_billing(
self,
tenant_id: str,
billing_month: str
) -> TenantBilling:
"""生成月度账单
Args:
tenant_id: 租户ID
billing_month: 账单月份,格式 YYYY-MM
Returns:
TenantBilling实例
"""
# 检查是否已存在
existing = self.db.query(TenantBilling).filter(
TenantBilling.tenant_id == tenant_id,
TenantBilling.billing_month == billing_month
).first()
if existing:
billing = existing
else:
billing = TenantBilling(
tenant_id=tenant_id,
billing_month=billing_month
)
self.db.add(billing)
# 计算统计数据
start_date = f"{billing_month}-01"
year, month = billing_month.split("-")
if int(month) == 12:
end_date = f"{int(year)+1}-01-01"
else:
end_date = f"{year}-{int(month)+1:02d}-01"
# 聚合查询
stats = self.db.query(
func.count(AICallEvent.id).label('total_calls'),
func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'),
func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.created_at >= start_date,
AICallEvent.created_at < end_date
).first()
billing.total_calls = stats.total_calls or 0
billing.total_input_tokens = int(stats.total_input or 0)
billing.total_output_tokens = int(stats.total_output or 0)
billing.total_cost = stats.total_cost or Decimal("0")
# 按模型统计
model_stats = self.db.query(
AICallEvent.model,
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.created_at >= start_date,
AICallEvent.created_at < end_date
).group_by(AICallEvent.model).all()
billing.cost_by_model = {
m.model or "unknown": float(m.cost) for m in model_stats
}
# 按应用统计
app_stats = self.db.query(
AICallEvent.app_code,
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.created_at >= start_date,
AICallEvent.created_at < end_date
).group_by(AICallEvent.app_code).all()
billing.cost_by_app = {
a.app_code or "unknown": float(a.cost) for a in app_stats
}
self.db.commit()
self.db.refresh(billing)
return billing
def get_cost_summary(
self,
tenant_id: str = None,
start_date: str = None,
end_date: str = None
) -> Dict:
"""获取费用汇总
Args:
tenant_id: 租户ID可选
start_date: 开始日期
end_date: 结束日期
Returns:
费用汇总字典
"""
query = self.db.query(
func.count(AICallEvent.id).label('total_calls'),
func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'),
func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost')
)
if tenant_id:
query = query.filter(AICallEvent.tenant_id == tenant_id)
if start_date:
query = query.filter(AICallEvent.created_at >= start_date)
if end_date:
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
stats = query.first()
return {
"total_calls": stats.total_calls or 0,
"total_input_tokens": int(stats.total_input or 0),
"total_output_tokens": int(stats.total_output or 0),
"total_cost": float(stats.total_cost or 0)
}
def get_cost_by_tenant(
self,
start_date: str = None,
end_date: str = None
) -> List[Dict]:
"""按租户统计费用
Returns:
租户费用列表
"""
query = self.db.query(
AICallEvent.tenant_id,
func.count(AICallEvent.id).label('calls'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
)
if start_date:
query = query.filter(AICallEvent.created_at >= start_date)
if end_date:
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
results = query.group_by(AICallEvent.tenant_id).order_by(
func.sum(AICallEvent.cost).desc()
).all()
return [
{
"tenant_id": r.tenant_id,
"calls": r.calls,
"cost": float(r.cost)
}
for r in results
]
def get_cost_by_model(
self,
tenant_id: str = None,
start_date: str = None,
end_date: str = None
) -> List[Dict]:
"""按模型统计费用
Returns:
模型费用列表
"""
query = self.db.query(
AICallEvent.model,
func.count(AICallEvent.id).label('calls'),
func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('input_tokens'),
func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('output_tokens'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
)
if tenant_id:
query = query.filter(AICallEvent.tenant_id == tenant_id)
if start_date:
query = query.filter(AICallEvent.created_at >= start_date)
if end_date:
query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
results = query.group_by(AICallEvent.model).order_by(
func.sum(AICallEvent.cost).desc()
).all()
return [
{
"model": r.model or "unknown",
"calls": r.calls,
"input_tokens": int(r.input_tokens),
"output_tokens": int(r.output_tokens),
"cost": float(r.cost)
}
for r in results
]
# 便捷函数
def calculate_cost(
db: Session,
model_name: str,
input_tokens: int = 0,
output_tokens: int = 0
) -> Decimal:
"""快速计算费用"""
calculator = CostCalculator(db)
return calculator.calculate_cost(model_name, input_tokens, output_tokens)

View File

@@ -0,0 +1,346 @@
"""配额管理服务"""
import logging
from datetime import datetime, date, timedelta
from typing import Optional, Dict, Any, Tuple
from dataclasses import dataclass
from sqlalchemy.orm import Session
from sqlalchemy import func
from ..models.tenant import Tenant, Subscription
from ..models.stats import AICallEvent
from .cache import get_cache
logger = logging.getLogger(__name__)
@dataclass
class QuotaConfig:
"""配额配置"""
daily_calls: int = 0 # 每日调用限制0表示无限制
daily_tokens: int = 0 # 每日Token限制
monthly_calls: int = 0 # 每月调用限制
monthly_tokens: int = 0 # 每月Token限制
monthly_cost: float = 0 # 每月费用限制(元)
concurrent_calls: int = 0 # 并发调用限制
@dataclass
class QuotaUsage:
"""配额使用情况"""
daily_calls: int = 0
daily_tokens: int = 0
monthly_calls: int = 0
monthly_tokens: int = 0
monthly_cost: float = 0
@dataclass
class QuotaCheckResult:
"""配额检查结果"""
allowed: bool
reason: Optional[str] = None
quota_type: Optional[str] = None
limit: int = 0
used: int = 0
remaining: int = 0
class QuotaService:
"""配额管理服务
使用示例:
quota_service = QuotaService(db)
# 检查配额
result = quota_service.check_quota("qiqi", "tools")
if not result.allowed:
raise HTTPException(status_code=429, detail=result.reason)
# 获取使用情况
usage = quota_service.get_usage("qiqi", "tools")
"""
# 默认配额(当无订阅配置时使用)
DEFAULT_QUOTA = QuotaConfig(
daily_calls=1000,
daily_tokens=100000,
monthly_calls=30000,
monthly_tokens=3000000,
monthly_cost=100
)
def __init__(self, db: Session):
self.db = db
self._cache = get_cache()
def get_subscription(self, tenant_id: str, app_code: str) -> Optional[Subscription]:
"""获取租户订阅配置"""
return self.db.query(Subscription).filter(
Subscription.tenant_id == tenant_id,
Subscription.app_code == app_code,
Subscription.status == 'active'
).first()
def get_quota_config(self, tenant_id: str, app_code: str) -> QuotaConfig:
"""获取配额配置
Args:
tenant_id: 租户ID
app_code: 应用代码
Returns:
QuotaConfig实例
"""
# 尝试从缓存获取
cache_key = f"quota:config:{tenant_id}:{app_code}"
cached = self._cache.get(cache_key)
if cached:
return QuotaConfig(**cached)
# 从订阅表获取
subscription = self.get_subscription(tenant_id, app_code)
if subscription and subscription.quota:
quota = subscription.quota
config = QuotaConfig(
daily_calls=quota.get('daily_calls', 0),
daily_tokens=quota.get('daily_tokens', 0),
monthly_calls=quota.get('monthly_calls', 0),
monthly_tokens=quota.get('monthly_tokens', 0),
monthly_cost=quota.get('monthly_cost', 0),
concurrent_calls=quota.get('concurrent_calls', 0)
)
else:
config = self.DEFAULT_QUOTA
# 缓存5分钟
self._cache.set(cache_key, config.__dict__, ttl=300)
return config
def get_usage(self, tenant_id: str, app_code: str) -> QuotaUsage:
"""获取配额使用情况
Args:
tenant_id: 租户ID
app_code: 应用代码
Returns:
QuotaUsage实例
"""
today = date.today()
month_start = today.replace(day=1)
# 今日使用量
daily_stats = self.db.query(
func.count(AICallEvent.id).label('calls'),
func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.app_code == app_code,
func.date(AICallEvent.created_at) == today
).first()
# 本月使用量
monthly_stats = self.db.query(
func.count(AICallEvent.id).label('calls'),
func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens'),
func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
).filter(
AICallEvent.tenant_id == tenant_id,
AICallEvent.app_code == app_code,
func.date(AICallEvent.created_at) >= month_start
).first()
return QuotaUsage(
daily_calls=daily_stats.calls or 0,
daily_tokens=int(daily_stats.tokens or 0),
monthly_calls=monthly_stats.calls or 0,
monthly_tokens=int(monthly_stats.tokens or 0),
monthly_cost=float(monthly_stats.cost or 0)
)
def check_quota(
self,
tenant_id: str,
app_code: str,
estimated_tokens: int = 0
) -> QuotaCheckResult:
"""检查配额是否足够
Args:
tenant_id: 租户ID
app_code: 应用代码
estimated_tokens: 预估Token消耗
Returns:
QuotaCheckResult实例
"""
config = self.get_quota_config(tenant_id, app_code)
usage = self.get_usage(tenant_id, app_code)
# 检查日调用次数
if config.daily_calls > 0:
if usage.daily_calls >= config.daily_calls:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每日调用限制 ({config.daily_calls} 次)",
quota_type="daily_calls",
limit=config.daily_calls,
used=usage.daily_calls,
remaining=0
)
# 检查日Token限制
if config.daily_tokens > 0:
if usage.daily_tokens + estimated_tokens > config.daily_tokens:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每日Token限制 ({config.daily_tokens:,})",
quota_type="daily_tokens",
limit=config.daily_tokens,
used=usage.daily_tokens,
remaining=max(0, config.daily_tokens - usage.daily_tokens)
)
# 检查月调用次数
if config.monthly_calls > 0:
if usage.monthly_calls >= config.monthly_calls:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每月调用限制 ({config.monthly_calls} 次)",
quota_type="monthly_calls",
limit=config.monthly_calls,
used=usage.monthly_calls,
remaining=0
)
# 检查月Token限制
if config.monthly_tokens > 0:
if usage.monthly_tokens + estimated_tokens > config.monthly_tokens:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每月Token限制 ({config.monthly_tokens:,})",
quota_type="monthly_tokens",
limit=config.monthly_tokens,
used=usage.monthly_tokens,
remaining=max(0, config.monthly_tokens - usage.monthly_tokens)
)
# 检查月费用限制
if config.monthly_cost > 0:
if usage.monthly_cost >= config.monthly_cost:
return QuotaCheckResult(
allowed=False,
reason=f"已达到每月费用限制 (¥{config.monthly_cost:.2f})",
quota_type="monthly_cost",
limit=int(config.monthly_cost * 100), # 转为分
used=int(usage.monthly_cost * 100),
remaining=max(0, int((config.monthly_cost - usage.monthly_cost) * 100))
)
# 所有检查通过
return QuotaCheckResult(
allowed=True,
quota_type="daily_calls",
limit=config.daily_calls,
used=usage.daily_calls,
remaining=max(0, config.daily_calls - usage.daily_calls) if config.daily_calls > 0 else -1
)
def get_quota_summary(self, tenant_id: str, app_code: str) -> Dict[str, Any]:
"""获取配额汇总信息
Returns:
包含配额配置和使用情况的字典
"""
config = self.get_quota_config(tenant_id, app_code)
usage = self.get_usage(tenant_id, app_code)
def calc_percentage(used: int, limit: int) -> float:
if limit <= 0:
return 0
return min(100, round(used / limit * 100, 1))
return {
"config": {
"daily_calls": config.daily_calls,
"daily_tokens": config.daily_tokens,
"monthly_calls": config.monthly_calls,
"monthly_tokens": config.monthly_tokens,
"monthly_cost": config.monthly_cost
},
"usage": {
"daily_calls": usage.daily_calls,
"daily_tokens": usage.daily_tokens,
"monthly_calls": usage.monthly_calls,
"monthly_tokens": usage.monthly_tokens,
"monthly_cost": round(usage.monthly_cost, 2)
},
"percentage": {
"daily_calls": calc_percentage(usage.daily_calls, config.daily_calls),
"daily_tokens": calc_percentage(usage.daily_tokens, config.daily_tokens),
"monthly_calls": calc_percentage(usage.monthly_calls, config.monthly_calls),
"monthly_tokens": calc_percentage(usage.monthly_tokens, config.monthly_tokens),
"monthly_cost": calc_percentage(int(usage.monthly_cost * 100), int(config.monthly_cost * 100))
}
}
def update_quota(
self,
tenant_id: str,
app_code: str,
quota_config: Dict[str, Any]
) -> Subscription:
"""更新配额配置
Args:
tenant_id: 租户ID
app_code: 应用代码
quota_config: 配额配置字典
Returns:
更新后的Subscription实例
"""
subscription = self.get_subscription(tenant_id, app_code)
if not subscription:
# 创建新订阅
subscription = Subscription(
tenant_id=tenant_id,
app_code=app_code,
start_date=date.today(),
quota=quota_config,
status='active'
)
self.db.add(subscription)
else:
# 更新现有订阅
subscription.quota = quota_config
self.db.commit()
self.db.refresh(subscription)
# 清除缓存
cache_key = f"quota:config:{tenant_id}:{app_code}"
self._cache.delete(cache_key)
return subscription
def check_quota_middleware(
db: Session,
tenant_id: str,
app_code: str,
estimated_tokens: int = 0
) -> QuotaCheckResult:
"""配额检查中间件函数
可在路由中使用:
result = check_quota_middleware(db, "qiqi", "tools")
if not result.allowed:
raise HTTPException(status_code=429, detail=result.reason)
"""
service = QuotaService(db)
return service.check_quota(tenant_id, app_code, estimated_tokens)

View File

@@ -0,0 +1,371 @@
"""企业微信服务"""
import hashlib
import time
import logging
from typing import Optional, Dict, Any
from dataclasses import dataclass
import httpx
from ..config import get_settings
from .cache import get_cache
from .crypto import decrypt_config
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class WechatConfig:
"""企业微信应用配置"""
corp_id: str
agent_id: str
secret: str
class WechatService:
"""企业微信服务
提供access_token获取、JS-SDK签名、OAuth2等功能
使用示例:
wechat = WechatService(corp_id="wwxxxx", agent_id="1000001", secret="xxx")
# 获取access_token
token = await wechat.get_access_token()
# 获取JS-SDK签名
signature = await wechat.get_jssdk_signature("https://example.com/page")
"""
# 企业微信API基础URL
BASE_URL = "https://qyapi.weixin.qq.com"
def __init__(self, corp_id: str, agent_id: str, secret: str):
"""初始化企业微信服务
Args:
corp_id: 企业ID
agent_id: 应用AgentId
secret: 应用Secret明文
"""
self.corp_id = corp_id
self.agent_id = agent_id
self.secret = secret
self._cache = get_cache()
@classmethod
def from_wechat_app(cls, wechat_app) -> "WechatService":
"""从TenantWechatApp模型创建服务实例
Args:
wechat_app: TenantWechatApp数据库模型
Returns:
WechatService实例
"""
secret = ""
if wechat_app.secret_encrypted:
try:
secret = decrypt_config(wechat_app.secret_encrypted)
except Exception as e:
logger.error(f"Failed to decrypt secret: {e}")
return cls(
corp_id=wechat_app.corp_id,
agent_id=wechat_app.agent_id,
secret=secret
)
def _cache_key(self, key_type: str) -> str:
"""生成缓存键"""
return f"wechat:{self.corp_id}:{self.agent_id}:{key_type}"
async def get_access_token(self, force_refresh: bool = False) -> Optional[str]:
"""获取access_token
企业微信access_token有效期7200秒需要缓存
Args:
force_refresh: 是否强制刷新
Returns:
access_token或None
"""
cache_key = self._cache_key("access_token")
# 尝试从缓存获取
if not force_refresh:
cached = self._cache.get(cache_key)
if cached:
logger.debug(f"Access token from cache: {cached[:20]}...")
return cached
# 从企业微信API获取
url = f"{self.BASE_URL}/cgi-bin/gettoken"
params = {
"corpid": self.corp_id,
"corpsecret": self.secret
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url, params=params)
result = response.json()
if result.get("errcode", 0) != 0:
logger.error(f"Get access_token failed: {result}")
return None
access_token = result.get("access_token")
expires_in = result.get("expires_in", 7200)
# 缓存提前200秒过期以确保安全
self._cache.set(
cache_key,
access_token,
ttl=min(expires_in - 200, settings.WECHAT_ACCESS_TOKEN_EXPIRE)
)
logger.info(f"Got new access_token for {self.corp_id}")
return access_token
except Exception as e:
logger.error(f"Get access_token error: {e}")
return None
async def get_jsapi_ticket(self, force_refresh: bool = False) -> Optional[str]:
"""获取jsapi_ticket
用于生成JS-SDK签名
Args:
force_refresh: 是否强制刷新
Returns:
jsapi_ticket或None
"""
cache_key = self._cache_key("jsapi_ticket")
# 尝试从缓存获取
if not force_refresh:
cached = self._cache.get(cache_key)
if cached:
logger.debug(f"JSAPI ticket from cache: {cached[:20]}...")
return cached
# 先获取access_token
access_token = await self.get_access_token()
if not access_token:
return None
# 获取jsapi_ticket
url = f"{self.BASE_URL}/cgi-bin/get_jsapi_ticket"
params = {"access_token": access_token}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url, params=params)
result = response.json()
if result.get("errcode", 0) != 0:
logger.error(f"Get jsapi_ticket failed: {result}")
return None
ticket = result.get("ticket")
expires_in = result.get("expires_in", 7200)
# 缓存
self._cache.set(
cache_key,
ticket,
ttl=min(expires_in - 200, settings.WECHAT_JSAPI_TICKET_EXPIRE)
)
logger.info(f"Got new jsapi_ticket for {self.corp_id}")
return ticket
except Exception as e:
logger.error(f"Get jsapi_ticket error: {e}")
return None
async def get_jssdk_signature(
self,
url: str,
noncestr: Optional[str] = None,
timestamp: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""生成JS-SDK签名
Args:
url: 当前页面URL不含#及其后面部分)
noncestr: 随机字符串,可选
timestamp: 时间戳,可选
Returns:
签名信息字典包含signature, noncestr, timestamp, appId等
"""
ticket = await self.get_jsapi_ticket()
if not ticket:
return None
# 生成随机字符串和时间戳
if noncestr is None:
import secrets
noncestr = secrets.token_hex(8)
if timestamp is None:
timestamp = int(time.time())
# 构建签名字符串
sign_str = f"jsapi_ticket={ticket}&noncestr={noncestr}&timestamp={timestamp}&url={url}"
# SHA1签名
signature = hashlib.sha1(sign_str.encode()).hexdigest()
return {
"appId": self.corp_id,
"agentId": self.agent_id,
"timestamp": timestamp,
"nonceStr": noncestr,
"signature": signature,
"url": url
}
def get_oauth2_url(
self,
redirect_uri: str,
scope: str = "snsapi_base",
state: str = ""
) -> str:
"""生成OAuth2授权URL
Args:
redirect_uri: 授权后重定向的URL
scope: 应用授权作用域
- snsapi_base: 静默授权,只能获取成员基础信息
- snsapi_privateinfo: 手动授权,可获取成员详细信息
state: 重定向后会带上state参数
Returns:
OAuth2授权URL
"""
import urllib.parse
encoded_uri = urllib.parse.quote(redirect_uri, safe='')
url = (
f"https://open.weixin.qq.com/connect/oauth2/authorize"
f"?appid={self.corp_id}"
f"&redirect_uri={encoded_uri}"
f"&response_type=code"
f"&scope={scope}"
f"&state={state}"
f"&agentid={self.agent_id}"
f"#wechat_redirect"
)
return url
async def get_user_info_by_code(self, code: str) -> Optional[Dict[str, Any]]:
"""通过OAuth2 code获取用户信息
Args:
code: OAuth2回调返回的code
Returns:
用户信息字典包含UserId, DeviceId等
"""
access_token = await self.get_access_token()
if not access_token:
return None
url = f"{self.BASE_URL}/cgi-bin/auth/getuserinfo"
params = {
"access_token": access_token,
"code": code
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url, params=params)
result = response.json()
if result.get("errcode", 0) != 0:
logger.error(f"Get user info by code failed: {result}")
return None
return {
"user_id": result.get("userid") or result.get("UserId"),
"device_id": result.get("deviceid") or result.get("DeviceId"),
"open_id": result.get("openid") or result.get("OpenId"),
"external_userid": result.get("external_userid"),
}
except Exception as e:
logger.error(f"Get user info by code error: {e}")
return None
async def get_user_detail(self, user_id: str) -> Optional[Dict[str, Any]]:
"""获取成员详细信息
Args:
user_id: 成员UserID
Returns:
成员详细信息
"""
access_token = await self.get_access_token()
if not access_token:
return None
url = f"{self.BASE_URL}/cgi-bin/user/get"
params = {
"access_token": access_token,
"userid": user_id
}
try:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.get(url, params=params)
result = response.json()
if result.get("errcode", 0) != 0:
logger.error(f"Get user detail failed: {result}")
return None
return {
"userid": result.get("userid"),
"name": result.get("name"),
"department": result.get("department"),
"position": result.get("position"),
"mobile": result.get("mobile"),
"email": result.get("email"),
"avatar": result.get("avatar"),
"status": result.get("status"),
}
except Exception as e:
logger.error(f"Get user detail error: {e}")
return None
async def get_wechat_service_by_id(
wechat_app_id: int,
db_session
) -> Optional[WechatService]:
"""根据企微应用ID获取服务实例
Args:
wechat_app_id: platform_tenant_wechat_apps表的ID
db_session: 数据库session
Returns:
WechatService实例或None
"""
from ..models.tenant_wechat_app import TenantWechatApp
wechat_app = db_session.query(TenantWechatApp).filter(
TenantWechatApp.id == wechat_app_id,
TenantWechatApp.status == 1
).first()
if not wechat_app:
return None
return WechatService.from_wechat_app(wechat_app)