diff --git a/.drone.yml b/.drone.yml
index 9082340..3ec04f3 100644
--- a/.drone.yml
+++ b/.drone.yml
@@ -46,10 +46,11 @@ steps:
CONFIG_ENCRYPT_KEY:
from_secret: config_encrypt_key
commands:
+ - docker network create platform-network 2>/dev/null || true
- docker stop platform-backend-test platform-frontend-test || true
- docker rm platform-backend-test platform-frontend-test || true
- - docker run -d --name platform-backend-test -p 8001:8000 --restart unless-stopped -e DATABASE_URL=$DATABASE_URL -e API_KEY=$API_KEY -e JWT_SECRET=$JWT_SECRET -e CONFIG_ENCRYPT_KEY=$CONFIG_ENCRYPT_KEY platform-backend:latest
- - docker run -d --name platform-frontend-test -p 3003:80 --restart unless-stopped platform-frontend:latest
+ - docker run -d --name platform-backend-test --network platform-network -p 8001:8000 --restart unless-stopped -e DATABASE_URL=$DATABASE_URL -e API_KEY=$API_KEY -e JWT_SECRET=$JWT_SECRET -e CONFIG_ENCRYPT_KEY=$CONFIG_ENCRYPT_KEY platform-backend:latest
+ - docker run -d --name platform-frontend-test --network platform-network -p 3003:80 --restart unless-stopped platform-frontend:latest
when:
branch:
- develop
diff --git a/backend/app/config.py b/backend/app/config.py
index e0a91b4..6822eb8 100644
--- a/backend/app/config.py
+++ b/backend/app/config.py
@@ -2,6 +2,7 @@
import os
from functools import lru_cache
from pydantic_settings import BaseSettings
+from typing import Optional
class Settings(BaseSettings):
@@ -14,6 +15,10 @@ class Settings(BaseSettings):
# 数据库
DATABASE_URL: str = "mysql+pymysql://scrm_reader:ScrmReader2024Pass@47.107.71.55:3306/new_qiqi"
+ # Redis
+ REDIS_URL: str = "redis://localhost:6379/0"
+ REDIS_PREFIX: str = "platform:"
+
# API Key(内部服务调用)
API_KEY: str = "platform_api_key_2026"
@@ -29,6 +34,10 @@ class Settings(BaseSettings):
# 配置加密密钥
CONFIG_ENCRYPT_KEY: str = "platform_config_key_32bytes!!"
+ # 企业微信配置
+ WECHAT_ACCESS_TOKEN_EXPIRE: int = 7000 # access_token缓存时间(秒),企微有效期7200秒
+ WECHAT_JSAPI_TICKET_EXPIRE: int = 7000 # jsapi_ticket缓存时间(秒)
+
class Config:
env_file = ".env"
extra = "ignore"
diff --git a/backend/app/main.py b/backend/app/main.py
index e024ba1..f7b094c 100644
--- a/backend/app/main.py
+++ b/backend/app/main.py
@@ -1,4 +1,5 @@
"""平台服务入口"""
+import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@@ -9,6 +10,15 @@ from .routers.tenants import router as tenants_router
from .routers.tenant_apps import router as tenant_apps_router
from .routers.tenant_wechat_apps import router as tenant_wechat_apps_router
from .routers.apps import router as apps_router
+from .routers.wechat import router as wechat_router
+from .routers.alerts import router as alerts_router
+from .routers.cost import router as cost_router
+from .routers.quota import router as quota_router
+from .middleware import TraceMiddleware, setup_exception_handlers, RequestLoggerMiddleware
+from .middleware.trace import setup_logging
+
+# 配置日志(包含 TraceID)
+setup_logging(level=logging.INFO, include_trace=True)
settings = get_settings()
@@ -18,6 +28,20 @@ app = FastAPI(
description="平台基础设施服务 - 统计/日志/配置管理"
)
+# 配置统一异常处理
+setup_exception_handlers(app)
+
+# 中间件按添加的反序执行,所以:
+# 1. CORS 最后添加,最先执行
+# 2. TraceMiddleware 在 RequestLoggerMiddleware 之后添加,这样先执行
+# 3. RequestLoggerMiddleware 最先添加,最后执行(此时 trace_id 已设置)
+
+# 请求日志中间件(自动记录到数据库)
+app.add_middleware(RequestLoggerMiddleware, app_code="000-platform")
+
+# TraceID 追踪中间件
+app.add_middleware(TraceMiddleware, log_requests=True)
+
# CORS
app.add_middleware(
CORSMiddleware,
@@ -25,6 +49,7 @@ app.add_middleware(
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
+ expose_headers=["X-Trace-ID", "X-Response-Time"]
)
# 注册路由
@@ -37,6 +62,10 @@ app.include_router(apps_router, prefix="/api")
app.include_router(stats_router, prefix="/api")
app.include_router(logs_router, prefix="/api")
app.include_router(config_router, prefix="/api")
+app.include_router(wechat_router, prefix="/api")
+app.include_router(alerts_router, prefix="/api")
+app.include_router(cost_router, prefix="/api")
+app.include_router(quota_router, prefix="/api")
@app.get("/")
diff --git a/backend/app/middleware/__init__.py b/backend/app/middleware/__init__.py
new file mode 100644
index 0000000..2a04008
--- /dev/null
+++ b/backend/app/middleware/__init__.py
@@ -0,0 +1,19 @@
+"""
+中间件模块
+
+提供:
+- TraceID 追踪
+- 统一异常处理
+- 请求日志记录
+"""
+from .trace import TraceMiddleware, get_trace_id, set_trace_id
+from .exception_handler import setup_exception_handlers
+from .request_logger import RequestLoggerMiddleware
+
+__all__ = [
+ "TraceMiddleware",
+ "get_trace_id",
+ "set_trace_id",
+ "setup_exception_handlers",
+ "RequestLoggerMiddleware"
+]
diff --git a/backend/app/middleware/exception_handler.py b/backend/app/middleware/exception_handler.py
new file mode 100644
index 0000000..946efb5
--- /dev/null
+++ b/backend/app/middleware/exception_handler.py
@@ -0,0 +1,128 @@
+"""
+统一异常处理
+
+捕获所有异常,返回统一格式的错误响应,包含 TraceID。
+"""
+import logging
+import traceback
+from typing import Union
+from fastapi import FastAPI, Request, HTTPException
+from fastapi.responses import JSONResponse
+from fastapi.exceptions import RequestValidationError
+from starlette.exceptions import HTTPException as StarletteHTTPException
+
+from .trace import get_trace_id
+
+logger = logging.getLogger(__name__)
+
+
+class ErrorCode:
+ """错误码常量"""
+ BAD_REQUEST = "BAD_REQUEST"
+ UNAUTHORIZED = "UNAUTHORIZED"
+ FORBIDDEN = "FORBIDDEN"
+ NOT_FOUND = "NOT_FOUND"
+ VALIDATION_ERROR = "VALIDATION_ERROR"
+ RATE_LIMITED = "RATE_LIMITED"
+ INTERNAL_ERROR = "INTERNAL_ERROR"
+ SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
+ GATEWAY_ERROR = "GATEWAY_ERROR"
+
+
+STATUS_TO_ERROR_CODE = {
+ 400: ErrorCode.BAD_REQUEST,
+ 401: ErrorCode.UNAUTHORIZED,
+ 403: ErrorCode.FORBIDDEN,
+ 404: ErrorCode.NOT_FOUND,
+ 422: ErrorCode.VALIDATION_ERROR,
+ 429: ErrorCode.RATE_LIMITED,
+ 500: ErrorCode.INTERNAL_ERROR,
+ 502: ErrorCode.GATEWAY_ERROR,
+ 503: ErrorCode.SERVICE_UNAVAILABLE,
+}
+
+
+def create_error_response(
+ status_code: int,
+ code: str,
+ message: str,
+ trace_id: str = None,
+ details: dict = None
+) -> JSONResponse:
+ """创建统一格式的错误响应"""
+ if trace_id is None:
+ trace_id = get_trace_id()
+
+ error_body = {
+ "code": code,
+ "message": message,
+ "trace_id": trace_id
+ }
+
+ if details:
+ error_body["details"] = details
+
+ return JSONResponse(
+ status_code=status_code,
+ content={"success": False, "error": error_body},
+ headers={"X-Trace-ID": trace_id}
+ )
+
+
+async def http_exception_handler(request: Request, exc: Union[HTTPException, StarletteHTTPException]):
+ """处理 HTTP 异常"""
+ trace_id = get_trace_id()
+ status_code = exc.status_code
+ error_code = STATUS_TO_ERROR_CODE.get(status_code, ErrorCode.INTERNAL_ERROR)
+ message = exc.detail if isinstance(exc.detail, str) else str(exc.detail)
+
+ logger.warning(f"[{trace_id}] HTTP {status_code}: {message}")
+
+ return create_error_response(
+ status_code=status_code,
+ code=error_code,
+ message=message,
+ trace_id=trace_id
+ )
+
+
+async def validation_exception_handler(request: Request, exc: RequestValidationError):
+ """处理请求验证错误"""
+ trace_id = get_trace_id()
+ errors = exc.errors()
+ error_messages = [f"{'.'.join(str(l) for l in e['loc'])}: {e['msg']}" for e in errors]
+
+ logger.warning(f"[{trace_id}] 验证错误: {error_messages}")
+
+ return create_error_response(
+ status_code=422,
+ code=ErrorCode.VALIDATION_ERROR,
+ message="请求参数验证失败",
+ trace_id=trace_id,
+ details={"validation_errors": error_messages}
+ )
+
+
+async def generic_exception_handler(request: Request, exc: Exception):
+ """处理所有未捕获的异常"""
+ trace_id = get_trace_id()
+
+ logger.error(f"[{trace_id}] 未捕获异常: {type(exc).__name__}: {exc}")
+ logger.error(f"[{trace_id}] 堆栈:\n{traceback.format_exc()}")
+
+ return create_error_response(
+ status_code=500,
+ code=ErrorCode.INTERNAL_ERROR,
+ message="服务器内部错误,请稍后重试",
+ trace_id=trace_id
+ )
+
+
+def setup_exception_handlers(app: FastAPI):
+ """配置 FastAPI 应用的异常处理器"""
+ app.add_exception_handler(HTTPException, http_exception_handler)
+ app.add_exception_handler(StarletteHTTPException, http_exception_handler)
+ app.add_exception_handler(RequestValidationError, validation_exception_handler)
+ app.add_exception_handler(Exception, generic_exception_handler)
+
+ logger.info("异常处理器已配置")
diff --git a/backend/app/middleware/request_logger.py b/backend/app/middleware/request_logger.py
new file mode 100644
index 0000000..e50bdc7
--- /dev/null
+++ b/backend/app/middleware/request_logger.py
@@ -0,0 +1,190 @@
+"""
+请求日志中间件
+
+自动将所有请求记录到数据库 platform_logs 表
+"""
+import time
+import logging
+from typing import Optional, Set
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.requests import Request
+from starlette.responses import Response
+
+from .trace import get_trace_id
+from ..database import SessionLocal
+from ..models.logs import PlatformLog
+
+logger = logging.getLogger(__name__)
+
+
+class RequestLoggerMiddleware(BaseHTTPMiddleware):
+ """请求日志中间件
+
+ 自动记录所有请求到数据库,便于后续查询和分析
+
+ 使用示例:
+ app.add_middleware(RequestLoggerMiddleware, app_code="000-platform")
+ """
+
+ # 默认排除的路径(不记录这些请求)
+ DEFAULT_EXCLUDE_PATHS: Set[str] = {
+ "/",
+ "/docs",
+ "/redoc",
+ "/openapi.json",
+ "/api/health",
+ "/api/health/",
+ "/favicon.ico",
+ }
+
+ def __init__(
+ self,
+ app,
+ app_code: str = "platform",
+ exclude_paths: Optional[Set[str]] = None,
+ log_request_body: bool = False,
+ log_response_body: bool = False,
+ max_body_length: int = 1000
+ ):
+ """初始化中间件
+
+ Args:
+ app: FastAPI应用
+ app_code: 应用代码,记录到日志中
+ exclude_paths: 排除的路径集合,这些路径不记录日志
+ log_request_body: 是否记录请求体
+ log_response_body: 是否记录响应体
+ max_body_length: 记录体的最大长度
+ """
+ super().__init__(app)
+ self.app_code = app_code
+ self.exclude_paths = exclude_paths or self.DEFAULT_EXCLUDE_PATHS
+ self.log_request_body = log_request_body
+ self.log_response_body = log_response_body
+ self.max_body_length = max_body_length
+
+ async def dispatch(self, request: Request, call_next) -> Response:
+ path = request.url.path
+
+ # 检查是否排除
+ if self._should_exclude(path):
+ return await call_next(request)
+
+ trace_id = get_trace_id()
+ method = request.method
+ start_time = time.time()
+
+ # 获取客户端IP
+ client_ip = self._get_client_ip(request)
+
+ # 获取租户ID(从查询参数)
+ tenant_id = request.query_params.get("tid") or request.query_params.get("tenant_id")
+
+ # 执行请求
+ response = None
+ error_message = None
+ status_code = 500
+
+ try:
+ response = await call_next(request)
+ status_code = response.status_code
+ except Exception as e:
+ error_message = str(e)
+ raise
+ finally:
+ duration_ms = int((time.time() - start_time) * 1000)
+
+ # 异步写入数据库(不阻塞响应)
+ try:
+ self._save_log(
+ trace_id=trace_id,
+ method=method,
+ path=path,
+ status_code=status_code,
+ duration_ms=duration_ms,
+ ip_address=client_ip,
+ tenant_id=tenant_id,
+ error_message=error_message
+ )
+ except Exception as e:
+ logger.error(f"Failed to save request log: {e}")
+
+ return response
+
+ def _should_exclude(self, path: str) -> bool:
+ """检查路径是否应排除"""
+ # 精确匹配
+ if path in self.exclude_paths:
+ return True
+
+ # 前缀匹配(静态文件等)
+ exclude_prefixes = ["/static/", "/assets/", "/_next/"]
+ for prefix in exclude_prefixes:
+ if path.startswith(prefix):
+ return True
+
+ return False
+
+ def _get_client_ip(self, request: Request) -> str:
+ """获取客户端真实IP"""
+ # 优先从代理头获取
+ forwarded = request.headers.get("X-Forwarded-For")
+ if forwarded:
+ return forwarded.split(",")[0].strip()
+
+ real_ip = request.headers.get("X-Real-IP")
+ if real_ip:
+ return real_ip
+
+ # 直连IP
+ if request.client:
+ return request.client.host
+
+ return "unknown"
+
+ def _save_log(
+ self,
+ trace_id: str,
+ method: str,
+ path: str,
+ status_code: int,
+ duration_ms: int,
+ ip_address: str,
+ tenant_id: Optional[str] = None,
+ error_message: Optional[str] = None
+ ):
+ """保存日志到数据库"""
+ from datetime import datetime
+
+ # 使用独立的数据库会话
+ db = SessionLocal()
+ try:
+ # 转换 tenant_id 为整数(如果是数字字符串)
+ tenant_id_int = None
+ if tenant_id:
+ try:
+ tenant_id_int = int(tenant_id)
+ except (ValueError, TypeError):
+ tenant_id_int = None
+
+ log_entry = PlatformLog(
+ log_type="request",
+ level="error" if status_code >= 500 else ("warn" if status_code >= 400 else "info"),
+ app_code=self.app_code,
+ tenant_id=tenant_id_int,
+ trace_id=trace_id,
+ message=f"{method} {path}" + (f" - {error_message}" if error_message else ""),
+ path=path,
+ method=method,
+ status_code=status_code,
+ duration_ms=duration_ms,
+ log_time=datetime.now(), # 必须设置 log_time
+ context={"ip": ip_address} # ip_address 放到 context 中
+ )
+ db.add(log_entry)
+ db.commit()
+ except Exception as e:
+ logger.error(f"Database error saving log: {e}")
+ db.rollback()
+ finally:
+ db.close()
diff --git a/backend/app/middleware/trace.py b/backend/app/middleware/trace.py
new file mode 100644
index 0000000..4bd02fa
--- /dev/null
+++ b/backend/app/middleware/trace.py
@@ -0,0 +1,114 @@
+"""
+TraceID 追踪中间件
+
+为每个请求生成唯一的 TraceID,用于日志追踪和问题排查。
+
+功能:
+- 自动生成 TraceID(或从请求头获取)
+- 注入到响应头 X-Trace-ID
+- 提供上下文变量供日志使用
+- 支持请求耗时统计
+"""
+import time
+import uuid
+import logging
+from contextvars import ContextVar
+from typing import Optional
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.requests import Request
+from starlette.responses import Response
+
+logger = logging.getLogger(__name__)
+
+# 上下文变量存储当前请求的 TraceID
+_trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None)
+
+# 请求头名称
+TRACE_ID_HEADER = "X-Trace-ID"
+REQUEST_ID_HEADER = "X-Request-ID"
+
+
+def get_trace_id() -> str:
+ """获取当前请求的 TraceID"""
+ trace_id = _trace_id_var.get()
+ return trace_id if trace_id else "no-trace"
+
+
+def set_trace_id(trace_id: str) -> None:
+ """设置当前请求的 TraceID"""
+ _trace_id_var.set(trace_id)
+
+
+def generate_trace_id() -> str:
+ """生成新的 TraceID,格式: 时间戳-随机8位"""
+ timestamp = int(time.time())
+ random_part = uuid.uuid4().hex[:8]
+ return f"{timestamp}-{random_part}"
+
+
+class TraceMiddleware(BaseHTTPMiddleware):
+ """TraceID 追踪中间件"""
+
+ def __init__(self, app, log_requests: bool = True):
+ super().__init__(app)
+ self.log_requests = log_requests
+
+ async def dispatch(self, request: Request, call_next) -> Response:
+ # 从请求头获取 TraceID,或生成新的
+ trace_id = (
+ request.headers.get(TRACE_ID_HEADER) or
+ request.headers.get(REQUEST_ID_HEADER) or
+ generate_trace_id()
+ )
+
+ set_trace_id(trace_id)
+
+ start_time = time.time()
+ method = request.method
+ path = request.url.path
+
+ if self.log_requests:
+ logger.info(f"[{trace_id}] --> {method} {path}")
+
+ try:
+ response = await call_next(request)
+ duration_ms = int((time.time() - start_time) * 1000)
+
+ response.headers[TRACE_ID_HEADER] = trace_id
+ response.headers["X-Response-Time"] = f"{duration_ms}ms"
+
+ if self.log_requests:
+ logger.info(f"[{trace_id}] <-- {response.status_code} ({duration_ms}ms)")
+
+ return response
+
+ except Exception as e:
+ duration_ms = int((time.time() - start_time) * 1000)
+ logger.error(f"[{trace_id}] !!! 请求异常: {e} ({duration_ms}ms)")
+ raise
+
+
+class TraceLogFilter(logging.Filter):
+ """日志过滤器:自动添加 TraceID"""
+
+ def filter(self, record):
+ record.trace_id = get_trace_id()
+ return True
+
+
+def setup_logging(level: int = logging.INFO, include_trace: bool = True):
+ """配置日志格式"""
+ if include_trace:
+ format_str = "%(asctime)s [%(trace_id)s] %(levelname)s %(name)s: %(message)s"
+ else:
+ format_str = "%(asctime)s %(levelname)s %(name)s: %(message)s"
+
+ handler = logging.StreamHandler()
+ handler.setFormatter(logging.Formatter(format_str, datefmt="%Y-%m-%d %H:%M:%S"))
+
+ if include_trace:
+ handler.addFilter(TraceLogFilter())
+
+ root_logger = logging.getLogger()
+ root_logger.setLevel(level)
+ root_logger.handlers = [handler]
diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py
index f610327..63c7c17 100644
--- a/backend/app/models/__init__.py
+++ b/backend/app/models/__init__.py
@@ -5,6 +5,8 @@ from .tenant_wechat_app import TenantWechatApp
from .app import App
from .stats import AICallEvent, TenantUsageDaily
from .logs import PlatformLog
+from .alert import AlertRule, AlertRecord, NotificationChannel
+from .pricing import ModelPricing, TenantBilling
__all__ = [
"Tenant",
@@ -15,5 +17,10 @@ __all__ = [
"App",
"AICallEvent",
"TenantUsageDaily",
- "PlatformLog"
+ "PlatformLog",
+ "AlertRule",
+ "AlertRecord",
+ "NotificationChannel",
+ "ModelPricing",
+ "TenantBilling"
]
diff --git a/backend/app/models/alert.py b/backend/app/models/alert.py
new file mode 100644
index 0000000..48e44de
--- /dev/null
+++ b/backend/app/models/alert.py
@@ -0,0 +1,108 @@
+"""告警相关模型"""
+from datetime import datetime
+from sqlalchemy import Column, Integer, BigInteger, String, Text, Enum, SmallInteger, JSON, TIMESTAMP
+from ..database import Base
+
+
+class AlertRule(Base):
+ """告警规则表"""
+ __tablename__ = "platform_alert_rules"
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ name = Column(String(100), nullable=False) # 规则名称
+ description = Column(Text) # 规则描述
+
+ # 规则类型
+ rule_type = Column(Enum(
+ 'error_rate', # 错误率告警
+ 'call_count', # 调用次数告警
+ 'token_usage', # Token使用量告警
+ 'cost_threshold', # 费用阈值告警
+ 'latency', # 延迟告警
+ 'custom' # 自定义告警
+ ), nullable=False)
+
+ # 作用范围
+ scope_type = Column(Enum('global', 'tenant', 'app'), default='global') # 作用范围类型
+ scope_value = Column(String(100)) # 作用范围值,如租户ID或应用代码
+
+ # 告警条件
+ condition = Column(JSON, nullable=False) # 告警条件配置
+ # 示例: {"metric": "error_count", "operator": ">", "threshold": 10, "window": "5m"}
+
+ # 通知配置
+ notification_channels = Column(JSON) # 通知渠道列表
+ # 示例: [{"type": "wechat_bot", "webhook": "https://..."}, {"type": "email", "to": ["a@b.com"]}]
+
+ # 告警限制
+ cooldown_minutes = Column(Integer, default=30) # 冷却时间(分钟),避免重复告警
+ max_alerts_per_day = Column(Integer, default=10) # 每天最大告警次数
+
+ # 状态
+ status = Column(SmallInteger, default=1) # 0-禁用 1-启用
+ priority = Column(Enum('low', 'medium', 'high', 'critical'), default='medium') # 优先级
+
+ created_at = Column(TIMESTAMP, default=datetime.now)
+ updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now)
+
+
+class AlertRecord(Base):
+ """告警记录表"""
+ __tablename__ = "platform_alert_records"
+
+ id = Column(BigInteger, primary_key=True, autoincrement=True)
+ rule_id = Column(Integer, nullable=False, index=True) # 关联的规则ID
+ rule_name = Column(String(100)) # 规则名称(冗余,便于查询)
+
+ # 告警信息
+ alert_type = Column(String(50), nullable=False) # 告警类型
+ severity = Column(Enum('info', 'warning', 'error', 'critical'), default='warning') # 严重程度
+ title = Column(String(200), nullable=False) # 告警标题
+ message = Column(Text) # 告警详情
+
+ # 上下文
+ tenant_id = Column(String(50), index=True) # 相关租户
+ app_code = Column(String(50)) # 相关应用
+ metric_value = Column(String(100)) # 触发告警的指标值
+ threshold_value = Column(String(100)) # 阈值
+
+ # 通知状态
+ notification_status = Column(Enum('pending', 'sent', 'failed', 'skipped'), default='pending')
+ notification_result = Column(JSON) # 通知结果
+ notified_at = Column(TIMESTAMP) # 通知时间
+
+ # 处理状态
+ status = Column(Enum('active', 'acknowledged', 'resolved', 'ignored'), default='active')
+ acknowledged_by = Column(String(100)) # 确认人
+ acknowledged_at = Column(TIMESTAMP) # 确认时间
+ resolved_at = Column(TIMESTAMP) # 解决时间
+
+ created_at = Column(TIMESTAMP, default=datetime.now, index=True)
+
+
+class NotificationChannel(Base):
+ """通知渠道配置表"""
+ __tablename__ = "platform_notification_channels"
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ name = Column(String(100), nullable=False) # 渠道名称
+
+ channel_type = Column(Enum(
+ 'wechat_bot', # 企微机器人
+ 'email', # 邮件
+ 'sms', # 短信
+ 'webhook', # Webhook
+ 'dingtalk' # 钉钉
+ ), nullable=False)
+
+ # 渠道配置
+ config = Column(JSON, nullable=False)
+ # wechat_bot: {"webhook": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx"}
+ # email: {"smtp_host": "...", "smtp_port": 465, "username": "...", "password_encrypted": "..."}
+ # webhook: {"url": "https://...", "method": "POST", "headers": {...}}
+
+ # 状态
+ status = Column(SmallInteger, default=1) # 0-禁用 1-启用
+
+ created_at = Column(TIMESTAMP, default=datetime.now)
+ updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now)
diff --git a/backend/app/models/pricing.py b/backend/app/models/pricing.py
new file mode 100644
index 0000000..683eee6
--- /dev/null
+++ b/backend/app/models/pricing.py
@@ -0,0 +1,70 @@
+"""费用计算相关模型"""
+from datetime import datetime
+from decimal import Decimal
+from sqlalchemy import Column, Integer, String, Text, DECIMAL, SmallInteger, JSON, TIMESTAMP
+from ..database import Base
+
+
+class ModelPricing(Base):
+ """模型价格配置表"""
+ __tablename__ = "platform_model_pricing"
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+
+ # 模型标识
+ model_name = Column(String(100), nullable=False, unique=True) # 模型名称,如 gpt-4, claude-3-opus
+ provider = Column(String(50)) # 提供商,如 openai, anthropic, 4sapi
+ display_name = Column(String(100)) # 显示名称
+
+ # 价格配置(单位:元/1K tokens)
+ input_price_per_1k = Column(DECIMAL(10, 6), default=0) # 输入价格
+ output_price_per_1k = Column(DECIMAL(10, 6), default=0) # 输出价格
+
+ # 或固定价格(每次调用)
+ fixed_price_per_call = Column(DECIMAL(10, 6), default=0)
+
+ # 计费方式
+ pricing_type = Column(String(20), default='token') # token / call / hybrid
+
+ # 备注
+ description = Column(Text)
+
+ status = Column(SmallInteger, default=1) # 0-禁用 1-启用
+ created_at = Column(TIMESTAMP, default=datetime.now)
+ updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now)
+
+
+class TenantBilling(Base):
+ """租户账单表(月度汇总)"""
+ __tablename__ = "platform_tenant_billing"
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+
+ tenant_id = Column(String(50), nullable=False, index=True)
+ billing_month = Column(String(7), nullable=False) # 格式: YYYY-MM
+
+ # 使用量统计
+ total_calls = Column(Integer, default=0) # 总调用次数
+ total_input_tokens = Column(Integer, default=0) # 总输入token
+ total_output_tokens = Column(Integer, default=0) # 总输出token
+
+ # 费用统计
+ total_cost = Column(DECIMAL(12, 4), default=0) # 总费用
+
+ # 按模型分类的费用明细
+ cost_by_model = Column(JSON) # {"gpt-4": 10.5, "claude-3": 5.2}
+
+ # 按应用分类的费用明细
+ cost_by_app = Column(JSON) # {"tools": 8.0, "interview": 7.7}
+
+ # 状态
+ status = Column(String(20), default='pending') # pending / confirmed / paid
+
+ created_at = Column(TIMESTAMP, default=datetime.now)
+ updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now)
+
+ class Config:
+ # 联合唯一索引
+ __table_args__ = (
+ {'mysql_charset': 'utf8mb4'}
+ )
diff --git a/backend/app/routers/alerts.py b/backend/app/routers/alerts.py
new file mode 100644
index 0000000..8a3b91e
--- /dev/null
+++ b/backend/app/routers/alerts.py
@@ -0,0 +1,430 @@
+"""告警管理路由"""
+from typing import Optional, List
+from datetime import datetime, timedelta
+from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks
+from pydantic import BaseModel
+from sqlalchemy.orm import Session
+from sqlalchemy import desc, func
+
+from ..database import get_db
+from ..models.alert import AlertRule, AlertRecord, NotificationChannel
+from ..services.alert import AlertService
+from .auth import get_current_user, require_operator
+from ..models.user import User
+
+router = APIRouter(prefix="/alerts", tags=["告警管理"])
+
+
+# ============= Schemas =============
+
+class AlertRuleCreate(BaseModel):
+ name: str
+ description: Optional[str] = None
+ rule_type: str
+ scope_type: str = "global"
+ scope_value: Optional[str] = None
+ condition: dict
+ notification_channels: Optional[List[dict]] = None
+ cooldown_minutes: int = 30
+ max_alerts_per_day: int = 10
+ priority: str = "medium"
+
+
+class AlertRuleUpdate(BaseModel):
+ name: Optional[str] = None
+ description: Optional[str] = None
+ condition: Optional[dict] = None
+ notification_channels: Optional[List[dict]] = None
+ cooldown_minutes: Optional[int] = None
+ max_alerts_per_day: Optional[int] = None
+ priority: Optional[str] = None
+ status: Optional[int] = None
+
+
+class NotificationChannelCreate(BaseModel):
+ name: str
+ channel_type: str
+ config: dict
+
+
+class NotificationChannelUpdate(BaseModel):
+ name: Optional[str] = None
+ config: Optional[dict] = None
+ status: Optional[int] = None
+
+
+# ============= Alert Rules API =============
+
+@router.get("/rules")
+async def list_alert_rules(
+ page: int = Query(1, ge=1),
+ size: int = Query(20, ge=1, le=100),
+ rule_type: Optional[str] = None,
+ status: Optional[int] = None,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取告警规则列表"""
+ query = db.query(AlertRule)
+
+ if rule_type:
+ query = query.filter(AlertRule.rule_type == rule_type)
+ if status is not None:
+ query = query.filter(AlertRule.status == status)
+
+ total = query.count()
+ rules = query.order_by(desc(AlertRule.created_at)).offset((page - 1) * size).limit(size).all()
+
+ return {
+ "total": total,
+ "page": page,
+ "size": size,
+ "items": [format_rule(r) for r in rules]
+ }
+
+
+@router.get("/rules/{rule_id}")
+async def get_alert_rule(
+ rule_id: int,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取告警规则详情"""
+ rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first()
+ if not rule:
+ raise HTTPException(status_code=404, detail="告警规则不存在")
+ return format_rule(rule)
+
+
+@router.post("/rules")
+async def create_alert_rule(
+ data: AlertRuleCreate,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """创建告警规则"""
+ rule = AlertRule(
+ name=data.name,
+ description=data.description,
+ rule_type=data.rule_type,
+ scope_type=data.scope_type,
+ scope_value=data.scope_value,
+ condition=data.condition,
+ notification_channels=data.notification_channels,
+ cooldown_minutes=data.cooldown_minutes,
+ max_alerts_per_day=data.max_alerts_per_day,
+ priority=data.priority,
+ status=1
+ )
+ db.add(rule)
+ db.commit()
+ db.refresh(rule)
+
+ return {"success": True, "id": rule.id}
+
+
+@router.put("/rules/{rule_id}")
+async def update_alert_rule(
+ rule_id: int,
+ data: AlertRuleUpdate,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """更新告警规则"""
+ rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first()
+ if not rule:
+ raise HTTPException(status_code=404, detail="告警规则不存在")
+
+ update_data = data.model_dump(exclude_unset=True)
+ for key, value in update_data.items():
+ setattr(rule, key, value)
+
+ db.commit()
+ return {"success": True}
+
+
+@router.delete("/rules/{rule_id}")
+async def delete_alert_rule(
+ rule_id: int,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """删除告警规则"""
+ rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first()
+ if not rule:
+ raise HTTPException(status_code=404, detail="告警规则不存在")
+
+ db.delete(rule)
+ db.commit()
+ return {"success": True}
+
+
+# ============= Alert Records API =============
+
+@router.get("/records")
+async def list_alert_records(
+ page: int = Query(1, ge=1),
+ size: int = Query(20, ge=1, le=100),
+ status: Optional[str] = None,
+ severity: Optional[str] = None,
+ alert_type: Optional[str] = None,
+ tenant_id: Optional[str] = None,
+ start_date: Optional[str] = None,
+ end_date: Optional[str] = None,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取告警记录列表"""
+ query = db.query(AlertRecord)
+
+ if status:
+ query = query.filter(AlertRecord.status == status)
+ if severity:
+ query = query.filter(AlertRecord.severity == severity)
+ if alert_type:
+ query = query.filter(AlertRecord.alert_type == alert_type)
+ if tenant_id:
+ query = query.filter(AlertRecord.tenant_id == tenant_id)
+ if start_date:
+ query = query.filter(AlertRecord.created_at >= start_date)
+ if end_date:
+ query = query.filter(AlertRecord.created_at <= end_date + " 23:59:59")
+
+ total = query.count()
+ records = query.order_by(desc(AlertRecord.created_at)).offset((page - 1) * size).limit(size).all()
+
+ return {
+ "total": total,
+ "page": page,
+ "size": size,
+ "items": [format_record(r) for r in records]
+ }
+
+
+@router.get("/records/summary")
+async def get_alert_summary(
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取告警摘要统计"""
+ today = datetime.now().date()
+ week_start = today - timedelta(days=7)
+
+ # 今日告警数
+ today_count = db.query(func.count(AlertRecord.id)).filter(
+ func.date(AlertRecord.created_at) == today
+ ).scalar()
+
+ # 本周告警数
+ week_count = db.query(func.count(AlertRecord.id)).filter(
+ func.date(AlertRecord.created_at) >= week_start
+ ).scalar()
+
+ # 活跃告警数
+ active_count = db.query(func.count(AlertRecord.id)).filter(
+ AlertRecord.status == 'active'
+ ).scalar()
+
+ # 按严重程度统计
+ severity_stats = db.query(
+ AlertRecord.severity,
+ func.count(AlertRecord.id)
+ ).filter(
+ func.date(AlertRecord.created_at) >= week_start
+ ).group_by(AlertRecord.severity).all()
+
+ return {
+ "today_count": today_count,
+ "week_count": week_count,
+ "active_count": active_count,
+ "by_severity": {s: c for s, c in severity_stats}
+ }
+
+
+@router.get("/records/{record_id}")
+async def get_alert_record(
+ record_id: int,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取告警记录详情"""
+ record = db.query(AlertRecord).filter(AlertRecord.id == record_id).first()
+ if not record:
+ raise HTTPException(status_code=404, detail="告警记录不存在")
+ return format_record(record)
+
+
+@router.post("/records/{record_id}/acknowledge")
+async def acknowledge_alert(
+ record_id: int,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """确认告警"""
+ service = AlertService(db)
+ record = service.acknowledge_alert(record_id, user.username)
+ if not record:
+ raise HTTPException(status_code=404, detail="告警记录不存在")
+ return {"success": True}
+
+
+@router.post("/records/{record_id}/resolve")
+async def resolve_alert(
+ record_id: int,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """解决告警"""
+ service = AlertService(db)
+ record = service.resolve_alert(record_id)
+ if not record:
+ raise HTTPException(status_code=404, detail="告警记录不存在")
+ return {"success": True}
+
+
+# ============= Check Alerts API =============
+
+@router.post("/check")
+async def trigger_alert_check(
+ background_tasks: BackgroundTasks,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """手动触发告警检查"""
+ service = AlertService(db)
+ alerts = await service.check_all_rules()
+
+ # 异步发送通知
+ for alert in alerts:
+ rule = db.query(AlertRule).filter(AlertRule.id == alert.rule_id).first()
+ if rule:
+ background_tasks.add_task(service.send_notification, alert, rule)
+
+ return {
+ "success": True,
+ "triggered_count": len(alerts),
+ "alerts": [format_record(a) for a in alerts]
+ }
+
+
+# ============= Notification Channels API =============
+
+@router.get("/channels")
+async def list_notification_channels(
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取通知渠道列表"""
+ channels = db.query(NotificationChannel).order_by(desc(NotificationChannel.created_at)).all()
+ return [format_channel(c) for c in channels]
+
+
+@router.post("/channels")
+async def create_notification_channel(
+ data: NotificationChannelCreate,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """创建通知渠道"""
+ channel = NotificationChannel(
+ name=data.name,
+ channel_type=data.channel_type,
+ config=data.config,
+ status=1
+ )
+ db.add(channel)
+ db.commit()
+ db.refresh(channel)
+
+ return {"success": True, "id": channel.id}
+
+
+@router.put("/channels/{channel_id}")
+async def update_notification_channel(
+ channel_id: int,
+ data: NotificationChannelUpdate,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """更新通知渠道"""
+ channel = db.query(NotificationChannel).filter(NotificationChannel.id == channel_id).first()
+ if not channel:
+ raise HTTPException(status_code=404, detail="通知渠道不存在")
+
+ update_data = data.model_dump(exclude_unset=True)
+ for key, value in update_data.items():
+ setattr(channel, key, value)
+
+ db.commit()
+ return {"success": True}
+
+
+@router.delete("/channels/{channel_id}")
+async def delete_notification_channel(
+ channel_id: int,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """删除通知渠道"""
+ channel = db.query(NotificationChannel).filter(NotificationChannel.id == channel_id).first()
+ if not channel:
+ raise HTTPException(status_code=404, detail="通知渠道不存在")
+
+ db.delete(channel)
+ db.commit()
+ return {"success": True}
+
+
+# ============= Helper Functions =============
+
+def format_rule(rule: AlertRule) -> dict:
+ return {
+ "id": rule.id,
+ "name": rule.name,
+ "description": rule.description,
+ "rule_type": rule.rule_type,
+ "scope_type": rule.scope_type,
+ "scope_value": rule.scope_value,
+ "condition": rule.condition,
+ "notification_channels": rule.notification_channels,
+ "cooldown_minutes": rule.cooldown_minutes,
+ "max_alerts_per_day": rule.max_alerts_per_day,
+ "priority": rule.priority,
+ "status": rule.status,
+ "created_at": rule.created_at,
+ "updated_at": rule.updated_at
+ }
+
+
+def format_record(record: AlertRecord) -> dict:
+ return {
+ "id": record.id,
+ "rule_id": record.rule_id,
+ "rule_name": record.rule_name,
+ "alert_type": record.alert_type,
+ "severity": record.severity,
+ "title": record.title,
+ "message": record.message,
+ "tenant_id": record.tenant_id,
+ "app_code": record.app_code,
+ "metric_value": record.metric_value,
+ "threshold_value": record.threshold_value,
+ "notification_status": record.notification_status,
+ "status": record.status,
+ "acknowledged_by": record.acknowledged_by,
+ "acknowledged_at": record.acknowledged_at,
+ "resolved_at": record.resolved_at,
+ "created_at": record.created_at
+ }
+
+
+def format_channel(channel: NotificationChannel) -> dict:
+ return {
+ "id": channel.id,
+ "name": channel.name,
+ "channel_type": channel.channel_type,
+ "config": channel.config,
+ "status": channel.status,
+ "created_at": channel.created_at,
+ "updated_at": channel.updated_at
+ }
diff --git a/backend/app/routers/cost.py b/backend/app/routers/cost.py
new file mode 100644
index 0000000..239f9e4
--- /dev/null
+++ b/backend/app/routers/cost.py
@@ -0,0 +1,333 @@
+"""费用管理路由"""
+from typing import Optional, List
+from decimal import Decimal
+from fastapi import APIRouter, Depends, HTTPException, Query
+from pydantic import BaseModel
+from sqlalchemy.orm import Session
+from sqlalchemy import desc
+
+from ..database import get_db
+from ..models.pricing import ModelPricing, TenantBilling
+from ..services.cost import CostCalculator
+from .auth import get_current_user, require_operator
+from ..models.user import User
+
+router = APIRouter(prefix="/cost", tags=["费用管理"])
+
+
+# ============= Schemas =============
+
+class ModelPricingCreate(BaseModel):
+ model_name: str
+ provider: Optional[str] = None
+ display_name: Optional[str] = None
+ input_price_per_1k: float = 0
+ output_price_per_1k: float = 0
+ fixed_price_per_call: float = 0
+ pricing_type: str = "token"
+ description: Optional[str] = None
+
+
+class ModelPricingUpdate(BaseModel):
+ provider: Optional[str] = None
+ display_name: Optional[str] = None
+ input_price_per_1k: Optional[float] = None
+ output_price_per_1k: Optional[float] = None
+ fixed_price_per_call: Optional[float] = None
+ pricing_type: Optional[str] = None
+ description: Optional[str] = None
+ status: Optional[int] = None
+
+
+class CostCalculateRequest(BaseModel):
+ model_name: str
+ input_tokens: int = 0
+ output_tokens: int = 0
+
+
+# ============= Model Pricing API =============
+
+@router.get("/pricing")
+async def list_model_pricing(
+ page: int = Query(1, ge=1),
+ size: int = Query(50, ge=1, le=100),
+ provider: Optional[str] = None,
+ status: Optional[int] = None,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取模型价格配置列表"""
+ query = db.query(ModelPricing)
+
+ if provider:
+ query = query.filter(ModelPricing.provider == provider)
+ if status is not None:
+ query = query.filter(ModelPricing.status == status)
+
+ total = query.count()
+ items = query.order_by(ModelPricing.model_name).offset((page - 1) * size).limit(size).all()
+
+ return {
+ "total": total,
+ "page": page,
+ "size": size,
+ "items": [format_pricing(p) for p in items]
+ }
+
+
+@router.get("/pricing/{pricing_id}")
+async def get_model_pricing(
+ pricing_id: int,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取模型价格详情"""
+ pricing = db.query(ModelPricing).filter(ModelPricing.id == pricing_id).first()
+ if not pricing:
+ raise HTTPException(status_code=404, detail="模型价格配置不存在")
+ return format_pricing(pricing)
+
+
+@router.post("/pricing")
+async def create_model_pricing(
+ data: ModelPricingCreate,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """创建模型价格配置"""
+ # 检查是否已存在
+ existing = db.query(ModelPricing).filter(ModelPricing.model_name == data.model_name).first()
+ if existing:
+ raise HTTPException(status_code=400, detail="该模型价格配置已存在")
+
+ pricing = ModelPricing(
+ model_name=data.model_name,
+ provider=data.provider,
+ display_name=data.display_name,
+ input_price_per_1k=Decimal(str(data.input_price_per_1k)),
+ output_price_per_1k=Decimal(str(data.output_price_per_1k)),
+ fixed_price_per_call=Decimal(str(data.fixed_price_per_call)),
+ pricing_type=data.pricing_type,
+ description=data.description,
+ status=1
+ )
+ db.add(pricing)
+ db.commit()
+ db.refresh(pricing)
+
+ return {"success": True, "id": pricing.id}
+
+
+@router.put("/pricing/{pricing_id}")
+async def update_model_pricing(
+ pricing_id: int,
+ data: ModelPricingUpdate,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """更新模型价格配置"""
+ pricing = db.query(ModelPricing).filter(ModelPricing.id == pricing_id).first()
+ if not pricing:
+ raise HTTPException(status_code=404, detail="模型价格配置不存在")
+
+ update_data = data.model_dump(exclude_unset=True)
+
+ # 转换价格字段
+ for field in ['input_price_per_1k', 'output_price_per_1k', 'fixed_price_per_call']:
+ if field in update_data and update_data[field] is not None:
+ update_data[field] = Decimal(str(update_data[field]))
+
+ for key, value in update_data.items():
+ setattr(pricing, key, value)
+
+ db.commit()
+ return {"success": True}
+
+
+@router.delete("/pricing/{pricing_id}")
+async def delete_model_pricing(
+ pricing_id: int,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """删除模型价格配置"""
+ pricing = db.query(ModelPricing).filter(ModelPricing.id == pricing_id).first()
+ if not pricing:
+ raise HTTPException(status_code=404, detail="模型价格配置不存在")
+
+ db.delete(pricing)
+ db.commit()
+ return {"success": True}
+
+
+# ============= Cost Calculation API =============
+
+@router.post("/calculate")
+async def calculate_cost(
+ request: CostCalculateRequest,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """计算调用费用"""
+ calculator = CostCalculator(db)
+ cost = calculator.calculate_cost(
+ model_name=request.model_name,
+ input_tokens=request.input_tokens,
+ output_tokens=request.output_tokens
+ )
+
+ return {
+ "model": request.model_name,
+ "input_tokens": request.input_tokens,
+ "output_tokens": request.output_tokens,
+ "cost": float(cost)
+ }
+
+
+@router.get("/summary")
+async def get_cost_summary(
+ tenant_id: Optional[str] = None,
+ start_date: Optional[str] = None,
+ end_date: Optional[str] = None,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取费用汇总"""
+ calculator = CostCalculator(db)
+ return calculator.get_cost_summary(
+ tenant_id=tenant_id,
+ start_date=start_date,
+ end_date=end_date
+ )
+
+
+@router.get("/by-tenant")
+async def get_cost_by_tenant(
+ start_date: Optional[str] = None,
+ end_date: Optional[str] = None,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """按租户统计费用"""
+ calculator = CostCalculator(db)
+ return calculator.get_cost_by_tenant(
+ start_date=start_date,
+ end_date=end_date
+ )
+
+
+@router.get("/by-model")
+async def get_cost_by_model(
+ tenant_id: Optional[str] = None,
+ start_date: Optional[str] = None,
+ end_date: Optional[str] = None,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """按模型统计费用"""
+ calculator = CostCalculator(db)
+ return calculator.get_cost_by_model(
+ tenant_id=tenant_id,
+ start_date=start_date,
+ end_date=end_date
+ )
+
+
+# ============= Billing API =============
+
+@router.get("/billing")
+async def list_billing(
+ page: int = Query(1, ge=1),
+ size: int = Query(20, ge=1, le=100),
+ tenant_id: Optional[str] = None,
+ billing_month: Optional[str] = None,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取账单列表"""
+ query = db.query(TenantBilling)
+
+ if tenant_id:
+ query = query.filter(TenantBilling.tenant_id == tenant_id)
+ if billing_month:
+ query = query.filter(TenantBilling.billing_month == billing_month)
+
+ total = query.count()
+ items = query.order_by(desc(TenantBilling.billing_month)).offset((page - 1) * size).limit(size).all()
+
+ return {
+ "total": total,
+ "page": page,
+ "size": size,
+ "items": [format_billing(b) for b in items]
+ }
+
+
+@router.post("/billing/generate")
+async def generate_billing(
+ tenant_id: str = Query(...),
+ billing_month: str = Query(..., description="格式: YYYY-MM"),
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """生成月度账单"""
+ calculator = CostCalculator(db)
+ billing = calculator.generate_monthly_billing(tenant_id, billing_month)
+
+ return {
+ "success": True,
+ "billing": format_billing(billing)
+ }
+
+
+@router.post("/recalculate")
+async def recalculate_costs(
+ start_date: Optional[str] = None,
+ end_date: Optional[str] = None,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """重新计算事件费用"""
+ calculator = CostCalculator(db)
+ updated = calculator.update_event_costs(start_date, end_date)
+
+ return {
+ "success": True,
+ "updated_count": updated
+ }
+
+
+# ============= Helper Functions =============
+
+def format_pricing(pricing: ModelPricing) -> dict:
+ return {
+ "id": pricing.id,
+ "model_name": pricing.model_name,
+ "provider": pricing.provider,
+ "display_name": pricing.display_name,
+ "input_price_per_1k": float(pricing.input_price_per_1k or 0),
+ "output_price_per_1k": float(pricing.output_price_per_1k or 0),
+ "fixed_price_per_call": float(pricing.fixed_price_per_call or 0),
+ "pricing_type": pricing.pricing_type,
+ "description": pricing.description,
+ "status": pricing.status,
+ "created_at": pricing.created_at,
+ "updated_at": pricing.updated_at
+ }
+
+
+def format_billing(billing: TenantBilling) -> dict:
+ return {
+ "id": billing.id,
+ "tenant_id": billing.tenant_id,
+ "billing_month": billing.billing_month,
+ "total_calls": billing.total_calls,
+ "total_input_tokens": billing.total_input_tokens,
+ "total_output_tokens": billing.total_output_tokens,
+ "total_cost": float(billing.total_cost or 0),
+ "cost_by_model": billing.cost_by_model,
+ "cost_by_app": billing.cost_by_app,
+ "status": billing.status,
+ "created_at": billing.created_at,
+ "updated_at": billing.updated_at
+ }
diff --git a/backend/app/routers/logs.py b/backend/app/routers/logs.py
index d64458c..d2da06a 100644
--- a/backend/app/routers/logs.py
+++ b/backend/app/routers/logs.py
@@ -1,6 +1,10 @@
"""日志路由"""
+import csv
+import io
from typing import Optional
+from datetime import datetime
from fastapi import APIRouter, Depends, Header, HTTPException, Query
+from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from sqlalchemy import desc
@@ -13,6 +17,14 @@ from ..services.auth import decode_token
router = APIRouter(prefix="/logs", tags=["logs"])
settings = get_settings()
+# 尝试导入openpyxl
+try:
+ from openpyxl import Workbook
+ from openpyxl.styles import Font, Alignment, PatternFill
+ OPENPYXL_AVAILABLE = True
+except ImportError:
+ OPENPYXL_AVAILABLE = False
+
def get_current_user_optional(authorization: Optional[str] = Header(None)):
"""可选的用户认证"""
@@ -113,3 +125,154 @@ async def query_logs(
for item in items
]
}
+
+
+@router.get("/export")
+async def export_logs(
+ format: str = Query("csv", description="导出格式: csv 或 excel"),
+ log_type: Optional[str] = None,
+ level: Optional[str] = None,
+ app_code: Optional[str] = None,
+ tenant_id: Optional[str] = None,
+ start_date: Optional[str] = None,
+ end_date: Optional[str] = None,
+ limit: int = Query(10000, ge=1, le=100000, description="最大导出记录数"),
+ db: Session = Depends(get_db),
+ user = Depends(get_current_user_optional)
+):
+ """导出日志
+
+ 支持CSV和Excel格式,最多导出10万条记录
+ """
+ query = db.query(PlatformLog)
+
+ if log_type:
+ query = query.filter(PlatformLog.log_type == log_type)
+ if level:
+ query = query.filter(PlatformLog.level == level)
+ if app_code:
+ query = query.filter(PlatformLog.app_code == app_code)
+ if tenant_id:
+ query = query.filter(PlatformLog.tenant_id == tenant_id)
+ if start_date:
+ query = query.filter(PlatformLog.log_time >= start_date)
+ if end_date:
+ query = query.filter(PlatformLog.log_time <= end_date + " 23:59:59")
+
+ items = query.order_by(desc(PlatformLog.log_time)).limit(limit).all()
+
+ if format.lower() == "excel":
+ return export_excel(items)
+ else:
+ return export_csv(items)
+
+
+def export_csv(logs: list) -> StreamingResponse:
+ """导出为CSV格式"""
+ output = io.StringIO()
+ writer = csv.writer(output)
+
+ # 写入表头
+ headers = [
+ "ID", "类型", "级别", "应用", "租户", "Trace ID",
+ "消息", "路径", "方法", "状态码", "耗时(ms)",
+ "IP地址", "时间"
+ ]
+ writer.writerow(headers)
+
+ # 写入数据
+ for log in logs:
+ writer.writerow([
+ log.id,
+ log.log_type,
+ log.level,
+ log.app_code or "",
+ log.tenant_id or "",
+ log.trace_id or "",
+ log.message or "",
+ log.path or "",
+ log.method or "",
+ log.status_code or "",
+ log.duration_ms or "",
+ log.ip_address or "",
+ str(log.log_time) if log.log_time else ""
+ ])
+
+ output.seek(0)
+
+ # 生成文件名
+ filename = f"logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
+
+ return StreamingResponse(
+ iter([output.getvalue()]),
+ media_type="text/csv",
+ headers={
+ "Content-Disposition": f'attachment; filename="{filename}"',
+ "Content-Type": "text/csv; charset=utf-8-sig"
+ }
+ )
+
+
+def export_excel(logs: list) -> StreamingResponse:
+ """导出为Excel格式"""
+ if not OPENPYXL_AVAILABLE:
+ raise HTTPException(status_code=400, detail="Excel导出功能不可用,请安装openpyxl")
+
+ wb = Workbook()
+ ws = wb.active
+ ws.title = "日志导出"
+
+ # 表头样式
+ header_font = Font(bold=True, color="FFFFFF")
+ header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
+ header_alignment = Alignment(horizontal="center", vertical="center")
+
+ # 写入表头
+ headers = [
+ "ID", "类型", "级别", "应用", "租户", "Trace ID",
+ "消息", "路径", "方法", "状态码", "耗时(ms)",
+ "IP地址", "时间"
+ ]
+
+ for col, header in enumerate(headers, 1):
+ cell = ws.cell(row=1, column=col, value=header)
+ cell.font = header_font
+ cell.fill = header_fill
+ cell.alignment = header_alignment
+
+ # 写入数据
+ for row, log in enumerate(logs, 2):
+ ws.cell(row=row, column=1, value=log.id)
+ ws.cell(row=row, column=2, value=log.log_type)
+ ws.cell(row=row, column=3, value=log.level)
+ ws.cell(row=row, column=4, value=log.app_code or "")
+ ws.cell(row=row, column=5, value=log.tenant_id or "")
+ ws.cell(row=row, column=6, value=log.trace_id or "")
+ ws.cell(row=row, column=7, value=log.message or "")
+ ws.cell(row=row, column=8, value=log.path or "")
+ ws.cell(row=row, column=9, value=log.method or "")
+ ws.cell(row=row, column=10, value=log.status_code or "")
+ ws.cell(row=row, column=11, value=log.duration_ms or "")
+ ws.cell(row=row, column=12, value=log.ip_address or "")
+ ws.cell(row=row, column=13, value=str(log.log_time) if log.log_time else "")
+
+ # 调整列宽
+ column_widths = [8, 10, 10, 12, 12, 36, 50, 30, 8, 10, 10, 15, 20]
+ for col, width in enumerate(column_widths, 1):
+ ws.column_dimensions[chr(64 + col)].width = width
+
+ # 保存到内存
+ output = io.BytesIO()
+ wb.save(output)
+ output.seek(0)
+
+ # 生成文件名
+ filename = f"logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx"
+
+ return StreamingResponse(
+ iter([output.getvalue()]),
+ media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
+ headers={
+ "Content-Disposition": f'attachment; filename="{filename}"'
+ }
+ )
diff --git a/backend/app/routers/quota.py b/backend/app/routers/quota.py
new file mode 100644
index 0000000..3be7838
--- /dev/null
+++ b/backend/app/routers/quota.py
@@ -0,0 +1,264 @@
+"""配额管理路由"""
+from typing import Optional, Dict, Any
+from datetime import date
+from fastapi import APIRouter, Depends, HTTPException, Query
+from pydantic import BaseModel
+from sqlalchemy.orm import Session
+from sqlalchemy import desc
+
+from ..database import get_db
+from ..models.tenant import Subscription
+from ..services.quota import QuotaService
+from .auth import get_current_user, require_operator
+from ..models.user import User
+
+router = APIRouter(prefix="/quota", tags=["配额管理"])
+
+
+# ============= Schemas =============
+
+class QuotaConfigUpdate(BaseModel):
+ daily_calls: int = 0
+ daily_tokens: int = 0
+ monthly_calls: int = 0
+ monthly_tokens: int = 0
+ monthly_cost: float = 0
+ concurrent_calls: int = 0
+
+
+class SubscriptionCreate(BaseModel):
+ tenant_id: str
+ app_code: str
+ start_date: Optional[str] = None
+ end_date: Optional[str] = None
+ quota: QuotaConfigUpdate
+
+
+class SubscriptionUpdate(BaseModel):
+ start_date: Optional[str] = None
+ end_date: Optional[str] = None
+ quota: Optional[QuotaConfigUpdate] = None
+ status: Optional[str] = None
+
+
+# ============= Quota Check API =============
+
+@router.get("/check")
+async def check_quota(
+ tenant_id: str = Query(..., alias="tid"),
+ app_code: str = Query(..., alias="aid"),
+ estimated_tokens: int = Query(0),
+ db: Session = Depends(get_db)
+):
+ """检查配额是否足够
+
+ 用于调用前检查,返回是否允许继续调用
+ """
+ service = QuotaService(db)
+ result = service.check_quota(tenant_id, app_code, estimated_tokens)
+
+ return {
+ "allowed": result.allowed,
+ "reason": result.reason,
+ "quota_type": result.quota_type,
+ "limit": result.limit,
+ "used": result.used,
+ "remaining": result.remaining
+ }
+
+
+@router.get("/summary")
+async def get_quota_summary(
+ tenant_id: str = Query(...),
+ app_code: str = Query(...),
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取配额使用汇总"""
+ service = QuotaService(db)
+ return service.get_quota_summary(tenant_id, app_code)
+
+
+@router.get("/usage")
+async def get_quota_usage(
+ tenant_id: str = Query(...),
+ app_code: str = Query(...),
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取配额使用情况"""
+ service = QuotaService(db)
+ usage = service.get_usage(tenant_id, app_code)
+
+ return {
+ "daily_calls": usage.daily_calls,
+ "daily_tokens": usage.daily_tokens,
+ "monthly_calls": usage.monthly_calls,
+ "monthly_tokens": usage.monthly_tokens,
+ "monthly_cost": round(usage.monthly_cost, 2)
+ }
+
+
+# ============= Subscription API =============
+
+@router.get("/subscriptions")
+async def list_subscriptions(
+ page: int = Query(1, ge=1),
+ size: int = Query(20, ge=1, le=100),
+ tenant_id: Optional[str] = None,
+ app_code: Optional[str] = None,
+ status: Optional[str] = None,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取订阅列表"""
+ query = db.query(Subscription)
+
+ if tenant_id:
+ query = query.filter(Subscription.tenant_id == tenant_id)
+ if app_code:
+ query = query.filter(Subscription.app_code == app_code)
+ if status:
+ query = query.filter(Subscription.status == status)
+
+ total = query.count()
+ items = query.order_by(desc(Subscription.created_at)).offset((page - 1) * size).limit(size).all()
+
+ return {
+ "total": total,
+ "page": page,
+ "size": size,
+ "items": [format_subscription(s) for s in items]
+ }
+
+
+@router.get("/subscriptions/{subscription_id}")
+async def get_subscription(
+ subscription_id: int,
+ user: User = Depends(get_current_user),
+ db: Session = Depends(get_db)
+):
+ """获取订阅详情"""
+ subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first()
+ if not subscription:
+ raise HTTPException(status_code=404, detail="订阅不存在")
+ return format_subscription(subscription)
+
+
+@router.post("/subscriptions")
+async def create_subscription(
+ data: SubscriptionCreate,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """创建订阅"""
+ # 检查是否已存在
+ existing = db.query(Subscription).filter(
+ Subscription.tenant_id == data.tenant_id,
+ Subscription.app_code == data.app_code,
+ Subscription.status == 'active'
+ ).first()
+
+ if existing:
+ raise HTTPException(status_code=400, detail="该租户应用已有活跃订阅")
+
+ subscription = Subscription(
+ tenant_id=data.tenant_id,
+ app_code=data.app_code,
+ start_date=data.start_date or date.today(),
+ end_date=data.end_date,
+ quota=data.quota.model_dump() if data.quota else {},
+ status='active'
+ )
+ db.add(subscription)
+ db.commit()
+ db.refresh(subscription)
+
+ return {"success": True, "id": subscription.id}
+
+
+@router.put("/subscriptions/{subscription_id}")
+async def update_subscription(
+ subscription_id: int,
+ data: SubscriptionUpdate,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """更新订阅"""
+ subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first()
+ if not subscription:
+ raise HTTPException(status_code=404, detail="订阅不存在")
+
+ if data.start_date:
+ subscription.start_date = data.start_date
+ if data.end_date:
+ subscription.end_date = data.end_date
+ if data.quota:
+ subscription.quota = data.quota.model_dump()
+ if data.status:
+ subscription.status = data.status
+
+ db.commit()
+
+ # 清除缓存
+ service = QuotaService(db)
+ cache_key = f"quota:config:{subscription.tenant_id}:{subscription.app_code}"
+ service._cache.delete(cache_key)
+
+ return {"success": True}
+
+
+@router.delete("/subscriptions/{subscription_id}")
+async def delete_subscription(
+ subscription_id: int,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """删除订阅"""
+ subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first()
+ if not subscription:
+ raise HTTPException(status_code=404, detail="订阅不存在")
+
+ db.delete(subscription)
+ db.commit()
+
+ return {"success": True}
+
+
+@router.put("/subscriptions/{subscription_id}/quota")
+async def update_quota(
+ subscription_id: int,
+ data: QuotaConfigUpdate,
+ user: User = Depends(require_operator),
+ db: Session = Depends(get_db)
+):
+ """更新订阅配额"""
+ subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first()
+ if not subscription:
+ raise HTTPException(status_code=404, detail="订阅不存在")
+
+ subscription.quota = data.model_dump()
+ db.commit()
+
+ # 清除缓存
+ service = QuotaService(db)
+ cache_key = f"quota:config:{subscription.tenant_id}:{subscription.app_code}"
+ service._cache.delete(cache_key)
+
+ return {"success": True}
+
+
+# ============= Helper Functions =============
+
+def format_subscription(subscription: Subscription) -> dict:
+ return {
+ "id": subscription.id,
+ "tenant_id": subscription.tenant_id,
+ "app_code": subscription.app_code,
+ "start_date": str(subscription.start_date) if subscription.start_date else None,
+ "end_date": str(subscription.end_date) if subscription.end_date else None,
+ "quota": subscription.quota or {},
+ "status": subscription.status,
+ "created_at": subscription.created_at,
+ "updated_at": subscription.updated_at
+ }
diff --git a/backend/app/routers/wechat.py b/backend/app/routers/wechat.py
new file mode 100644
index 0000000..e98490b
--- /dev/null
+++ b/backend/app/routers/wechat.py
@@ -0,0 +1,264 @@
+"""企业微信JS-SDK路由"""
+from typing import Optional
+from fastapi import APIRouter, Depends, HTTPException, Query
+from pydantic import BaseModel
+from sqlalchemy.orm import Session
+
+from ..database import get_db
+from ..models.tenant_app import TenantApp
+from ..models.tenant_wechat_app import TenantWechatApp
+from ..services.wechat import WechatService, get_wechat_service_by_id
+
+router = APIRouter(prefix="/wechat", tags=["企业微信"])
+
+
+class JssdkSignatureRequest(BaseModel):
+ """JS-SDK签名请求"""
+ url: str # 当前页面URL(不含#及其后面部分)
+
+
+class JssdkSignatureResponse(BaseModel):
+ """JS-SDK签名响应"""
+ appId: str
+ agentId: str
+ timestamp: int
+ nonceStr: str
+ signature: str
+
+
+class OAuth2UrlRequest(BaseModel):
+ """OAuth2授权URL请求"""
+ redirect_uri: str
+ scope: str = "snsapi_base"
+ state: str = ""
+
+
+class UserInfoRequest(BaseModel):
+ """用户信息请求"""
+ code: str
+
+
+@router.post("/jssdk/signature")
+async def get_jssdk_signature(
+ request: JssdkSignatureRequest,
+ tenant_id: str = Query(..., alias="tid"),
+ app_code: str = Query(..., alias="aid"),
+ db: Session = Depends(get_db)
+):
+ """获取JS-SDK签名
+
+ 用于前端初始化企业微信JS-SDK
+
+ Args:
+ request: 包含当前页面URL
+ tenant_id: 租户ID
+ app_code: 应用代码
+
+ Returns:
+ JS-SDK签名信息
+ """
+ # 查找租户应用配置
+ tenant_app = db.query(TenantApp).filter(
+ TenantApp.tenant_id == tenant_id,
+ TenantApp.app_code == app_code,
+ TenantApp.status == 1
+ ).first()
+
+ if not tenant_app:
+ raise HTTPException(status_code=404, detail="租户应用配置不存在")
+
+ if not tenant_app.wechat_app_id:
+ raise HTTPException(status_code=400, detail="该应用未配置企业微信")
+
+ # 获取企微服务
+ wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
+ if not wechat_service:
+ raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
+
+ # 生成签名
+ signature_data = await wechat_service.get_jssdk_signature(request.url)
+ if not signature_data:
+ raise HTTPException(status_code=500, detail="获取JS-SDK签名失败")
+
+ return signature_data
+
+
+@router.get("/jssdk/signature")
+async def get_jssdk_signature_get(
+ url: str = Query(..., description="当前页面URL"),
+ tenant_id: str = Query(..., alias="tid"),
+ app_code: str = Query(..., alias="aid"),
+ db: Session = Depends(get_db)
+):
+ """获取JS-SDK签名(GET方式)
+
+ 方便前端JSONP调用
+ """
+ # 查找租户应用配置
+ tenant_app = db.query(TenantApp).filter(
+ TenantApp.tenant_id == tenant_id,
+ TenantApp.app_code == app_code,
+ TenantApp.status == 1
+ ).first()
+
+ if not tenant_app:
+ raise HTTPException(status_code=404, detail="租户应用配置不存在")
+
+ if not tenant_app.wechat_app_id:
+ raise HTTPException(status_code=400, detail="该应用未配置企业微信")
+
+ # 获取企微服务
+ wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
+ if not wechat_service:
+ raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
+
+ # 生成签名
+ signature_data = await wechat_service.get_jssdk_signature(url)
+ if not signature_data:
+ raise HTTPException(status_code=500, detail="获取JS-SDK签名失败")
+
+ return signature_data
+
+
+@router.post("/oauth2/url")
+async def get_oauth2_url(
+ request: OAuth2UrlRequest,
+ tenant_id: str = Query(..., alias="tid"),
+ app_code: str = Query(..., alias="aid"),
+ db: Session = Depends(get_db)
+):
+ """获取OAuth2授权URL
+
+ 用于企业微信内网页获取用户身份
+ """
+ # 查找租户应用配置
+ tenant_app = db.query(TenantApp).filter(
+ TenantApp.tenant_id == tenant_id,
+ TenantApp.app_code == app_code,
+ TenantApp.status == 1
+ ).first()
+
+ if not tenant_app:
+ raise HTTPException(status_code=404, detail="租户应用配置不存在")
+
+ if not tenant_app.wechat_app_id:
+ raise HTTPException(status_code=400, detail="该应用未配置企业微信")
+
+ # 获取企微服务
+ wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
+ if not wechat_service:
+ raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
+
+ # 生成OAuth2 URL
+ oauth_url = wechat_service.get_oauth2_url(
+ redirect_uri=request.redirect_uri,
+ scope=request.scope,
+ state=request.state
+ )
+
+ return {"url": oauth_url}
+
+
+@router.post("/oauth2/userinfo")
+async def get_user_info(
+ request: UserInfoRequest,
+ tenant_id: str = Query(..., alias="tid"),
+ app_code: str = Query(..., alias="aid"),
+ db: Session = Depends(get_db)
+):
+ """通过OAuth2 code获取用户信息
+
+ 在OAuth2回调后,用code换取用户信息
+ """
+ # 查找租户应用配置
+ tenant_app = db.query(TenantApp).filter(
+ TenantApp.tenant_id == tenant_id,
+ TenantApp.app_code == app_code,
+ TenantApp.status == 1
+ ).first()
+
+ if not tenant_app:
+ raise HTTPException(status_code=404, detail="租户应用配置不存在")
+
+ if not tenant_app.wechat_app_id:
+ raise HTTPException(status_code=400, detail="该应用未配置企业微信")
+
+ # 获取企微服务
+ wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
+ if not wechat_service:
+ raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
+
+ # 获取用户信息
+ user_info = await wechat_service.get_user_info_by_code(request.code)
+ if not user_info:
+ raise HTTPException(status_code=400, detail="获取用户信息失败,code可能已过期")
+
+ return user_info
+
+
+@router.get("/oauth2/userinfo")
+async def get_user_info_get(
+ code: str = Query(..., description="OAuth2回调的code"),
+ tenant_id: str = Query(..., alias="tid"),
+ app_code: str = Query(..., alias="aid"),
+ db: Session = Depends(get_db)
+):
+ """通过OAuth2 code获取用户信息(GET方式)"""
+ # 查找租户应用配置
+ tenant_app = db.query(TenantApp).filter(
+ TenantApp.tenant_id == tenant_id,
+ TenantApp.app_code == app_code,
+ TenantApp.status == 1
+ ).first()
+
+ if not tenant_app:
+ raise HTTPException(status_code=404, detail="租户应用配置不存在")
+
+ if not tenant_app.wechat_app_id:
+ raise HTTPException(status_code=400, detail="该应用未配置企业微信")
+
+ # 获取企微服务
+ wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
+ if not wechat_service:
+ raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
+
+ # 获取用户信息
+ user_info = await wechat_service.get_user_info_by_code(code)
+ if not user_info:
+ raise HTTPException(status_code=400, detail="获取用户信息失败,code可能已过期")
+
+ return user_info
+
+
+@router.get("/user/{user_id}")
+async def get_user_detail(
+ user_id: str,
+ tenant_id: str = Query(..., alias="tid"),
+ app_code: str = Query(..., alias="aid"),
+ db: Session = Depends(get_db)
+):
+ """获取企业微信成员详细信息"""
+ # 查找租户应用配置
+ tenant_app = db.query(TenantApp).filter(
+ TenantApp.tenant_id == tenant_id,
+ TenantApp.app_code == app_code,
+ TenantApp.status == 1
+ ).first()
+
+ if not tenant_app:
+ raise HTTPException(status_code=404, detail="租户应用配置不存在")
+
+ if not tenant_app.wechat_app_id:
+ raise HTTPException(status_code=400, detail="该应用未配置企业微信")
+
+ # 获取企微服务
+ wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db)
+ if not wechat_service:
+ raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用")
+
+ # 获取用户详情
+ user_detail = await wechat_service.get_user_detail(user_id)
+ if not user_detail:
+ raise HTTPException(status_code=404, detail="用户不存在")
+
+ return user_detail
diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py
index 60261bc..615aa2b 100644
--- a/backend/app/services/__init__.py
+++ b/backend/app/services/__init__.py
@@ -1,4 +1,22 @@
"""业务服务"""
from .crypto import encrypt_value, decrypt_value
+from .cache import CacheService, get_cache, get_redis_client
+from .wechat import WechatService, get_wechat_service_by_id
+from .alert import AlertService
+from .cost import CostCalculator, calculate_cost
+from .quota import QuotaService, check_quota_middleware
-__all__ = ["encrypt_value", "decrypt_value"]
+__all__ = [
+ "encrypt_value",
+ "decrypt_value",
+ "CacheService",
+ "get_cache",
+ "get_redis_client",
+ "WechatService",
+ "get_wechat_service_by_id",
+ "AlertService",
+ "CostCalculator",
+ "calculate_cost",
+ "QuotaService",
+ "check_quota_middleware"
+]
diff --git a/backend/app/services/alert.py b/backend/app/services/alert.py
new file mode 100644
index 0000000..b6b79ed
--- /dev/null
+++ b/backend/app/services/alert.py
@@ -0,0 +1,455 @@
+"""告警服务"""
+import logging
+from datetime import datetime, timedelta
+from typing import Optional, List, Dict, Any
+
+import httpx
+from sqlalchemy.orm import Session
+from sqlalchemy import func
+
+from ..models.alert import AlertRule, AlertRecord, NotificationChannel
+from ..models.stats import AICallEvent
+from ..models.logs import PlatformLog
+from .cache import get_cache
+
+logger = logging.getLogger(__name__)
+
+
+class AlertService:
+ """告警服务
+
+ 提供告警规则检测、告警记录管理、通知发送等功能
+ """
+
+ def __init__(self, db: Session):
+ self.db = db
+ self._cache = get_cache()
+
+ async def check_all_rules(self) -> List[AlertRecord]:
+ """检查所有启用的告警规则
+
+ Returns:
+ 触发的告警记录列表
+ """
+ rules = self.db.query(AlertRule).filter(AlertRule.status == 1).all()
+ triggered_alerts = []
+
+ for rule in rules:
+ try:
+ alert = await self.check_rule(rule)
+ if alert:
+ triggered_alerts.append(alert)
+ except Exception as e:
+ logger.error(f"Failed to check rule {rule.id}: {e}")
+
+ return triggered_alerts
+
+ async def check_rule(self, rule: AlertRule) -> Optional[AlertRecord]:
+ """检查单个告警规则
+
+ Args:
+ rule: 告警规则
+
+ Returns:
+ 触发的告警记录或None
+ """
+ # 检查冷却期
+ if self._is_in_cooldown(rule):
+ logger.debug(f"Rule {rule.id} is in cooldown")
+ return None
+
+ # 检查每日告警次数限制
+ if self._exceeds_daily_limit(rule):
+ logger.debug(f"Rule {rule.id} exceeds daily limit")
+ return None
+
+ # 根据规则类型检查
+ metric_value = None
+ threshold_value = None
+ triggered = False
+
+ condition = rule.condition or {}
+
+ if rule.rule_type == 'error_rate':
+ triggered, metric_value, threshold_value = self._check_error_rate(rule, condition)
+ elif rule.rule_type == 'call_count':
+ triggered, metric_value, threshold_value = self._check_call_count(rule, condition)
+ elif rule.rule_type == 'token_usage':
+ triggered, metric_value, threshold_value = self._check_token_usage(rule, condition)
+ elif rule.rule_type == 'cost_threshold':
+ triggered, metric_value, threshold_value = self._check_cost_threshold(rule, condition)
+ elif rule.rule_type == 'latency':
+ triggered, metric_value, threshold_value = self._check_latency(rule, condition)
+
+ if triggered:
+ alert = self._create_alert_record(rule, metric_value, threshold_value)
+ return alert
+
+ return None
+
+ def _is_in_cooldown(self, rule: AlertRule) -> bool:
+ """检查规则是否在冷却期"""
+ cache_key = f"alert:cooldown:{rule.id}"
+ return self._cache.exists(cache_key)
+
+ def _set_cooldown(self, rule: AlertRule):
+ """设置规则冷却期"""
+ cache_key = f"alert:cooldown:{rule.id}"
+ self._cache.set(cache_key, "1", ttl=rule.cooldown_minutes * 60)
+
+ def _exceeds_daily_limit(self, rule: AlertRule) -> bool:
+ """检查是否超过每日告警次数限制"""
+ today = datetime.now().date()
+ count = self.db.query(func.count(AlertRecord.id)).filter(
+ AlertRecord.rule_id == rule.id,
+ func.date(AlertRecord.created_at) == today
+ ).scalar()
+ return count >= rule.max_alerts_per_day
+
+ def _check_error_rate(self, rule: AlertRule, condition: dict) -> tuple:
+ """检查错误率"""
+ window_minutes = self._parse_window(condition.get('window', '5m'))
+ threshold = condition.get('threshold', 10) # 错误次数阈值
+ operator = condition.get('operator', '>')
+
+ since = datetime.now() - timedelta(minutes=window_minutes)
+
+ query = self.db.query(func.count(AICallEvent.id)).filter(
+ AICallEvent.created_at >= since,
+ AICallEvent.status == 'error'
+ )
+
+ # 应用作用范围
+ if rule.scope_type == 'tenant' and rule.scope_value:
+ query = query.filter(AICallEvent.tenant_id == rule.scope_value)
+ elif rule.scope_type == 'app' and rule.scope_value:
+ query = query.filter(AICallEvent.app_code == rule.scope_value)
+
+ error_count = query.scalar() or 0
+ triggered = self._compare(error_count, threshold, operator)
+
+ return triggered, str(error_count), str(threshold)
+
+ def _check_call_count(self, rule: AlertRule, condition: dict) -> tuple:
+ """检查调用次数"""
+ window_minutes = self._parse_window(condition.get('window', '1h'))
+ threshold = condition.get('threshold', 1000)
+ operator = condition.get('operator', '>')
+
+ since = datetime.now() - timedelta(minutes=window_minutes)
+
+ query = self.db.query(func.count(AICallEvent.id)).filter(
+ AICallEvent.created_at >= since
+ )
+
+ if rule.scope_type == 'tenant' and rule.scope_value:
+ query = query.filter(AICallEvent.tenant_id == rule.scope_value)
+ elif rule.scope_type == 'app' and rule.scope_value:
+ query = query.filter(AICallEvent.app_code == rule.scope_value)
+
+ call_count = query.scalar() or 0
+ triggered = self._compare(call_count, threshold, operator)
+
+ return triggered, str(call_count), str(threshold)
+
+ def _check_token_usage(self, rule: AlertRule, condition: dict) -> tuple:
+ """检查Token使用量"""
+ window_minutes = self._parse_window(condition.get('window', '1d'))
+ threshold = condition.get('threshold', 100000)
+ operator = condition.get('operator', '>')
+
+ since = datetime.now() - timedelta(minutes=window_minutes)
+
+ query = self.db.query(
+ func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0)
+ ).filter(
+ AICallEvent.created_at >= since
+ )
+
+ if rule.scope_type == 'tenant' and rule.scope_value:
+ query = query.filter(AICallEvent.tenant_id == rule.scope_value)
+ elif rule.scope_type == 'app' and rule.scope_value:
+ query = query.filter(AICallEvent.app_code == rule.scope_value)
+
+ token_usage = query.scalar() or 0
+ triggered = self._compare(token_usage, threshold, operator)
+
+ return triggered, str(token_usage), str(threshold)
+
+ def _check_cost_threshold(self, rule: AlertRule, condition: dict) -> tuple:
+ """检查费用阈值"""
+ window_minutes = self._parse_window(condition.get('window', '1d'))
+ threshold = condition.get('threshold', 100) # 费用阈值(元)
+ operator = condition.get('operator', '>')
+
+ since = datetime.now() - timedelta(minutes=window_minutes)
+
+ query = self.db.query(
+ func.coalesce(func.sum(AICallEvent.cost), 0)
+ ).filter(
+ AICallEvent.created_at >= since
+ )
+
+ if rule.scope_type == 'tenant' and rule.scope_value:
+ query = query.filter(AICallEvent.tenant_id == rule.scope_value)
+ elif rule.scope_type == 'app' and rule.scope_value:
+ query = query.filter(AICallEvent.app_code == rule.scope_value)
+
+ total_cost = float(query.scalar() or 0)
+ triggered = self._compare(total_cost, threshold, operator)
+
+ return triggered, f"¥{total_cost:.2f}", f"¥{threshold:.2f}"
+
+ def _check_latency(self, rule: AlertRule, condition: dict) -> tuple:
+ """检查延迟"""
+ window_minutes = self._parse_window(condition.get('window', '5m'))
+ threshold = condition.get('threshold', 5000) # 延迟阈值(ms)
+ operator = condition.get('operator', '>')
+ percentile = condition.get('percentile', 'avg') # avg, p95, p99, max
+
+ since = datetime.now() - timedelta(minutes=window_minutes)
+
+ query = self.db.query(AICallEvent.latency_ms).filter(
+ AICallEvent.created_at >= since,
+ AICallEvent.latency_ms.isnot(None)
+ )
+
+ if rule.scope_type == 'tenant' and rule.scope_value:
+ query = query.filter(AICallEvent.tenant_id == rule.scope_value)
+ elif rule.scope_type == 'app' and rule.scope_value:
+ query = query.filter(AICallEvent.app_code == rule.scope_value)
+
+ latencies = [r.latency_ms for r in query.all()]
+
+ if not latencies:
+ return False, "0", str(threshold)
+
+ if percentile == 'avg':
+ metric = sum(latencies) / len(latencies)
+ elif percentile == 'max':
+ metric = max(latencies)
+ elif percentile == 'p95':
+ latencies.sort()
+ idx = int(len(latencies) * 0.95)
+ metric = latencies[idx] if idx < len(latencies) else latencies[-1]
+ elif percentile == 'p99':
+ latencies.sort()
+ idx = int(len(latencies) * 0.99)
+ metric = latencies[idx] if idx < len(latencies) else latencies[-1]
+ else:
+ metric = sum(latencies) / len(latencies)
+
+ triggered = self._compare(metric, threshold, operator)
+
+ return triggered, f"{metric:.0f}ms", f"{threshold}ms"
+
+ def _parse_window(self, window: str) -> int:
+ """解析时间窗口字符串为分钟数"""
+ if window.endswith('m'):
+ return int(window[:-1])
+ elif window.endswith('h'):
+ return int(window[:-1]) * 60
+ elif window.endswith('d'):
+ return int(window[:-1]) * 60 * 24
+ else:
+ return int(window)
+
+ def _compare(self, value: float, threshold: float, operator: str) -> bool:
+ """比较值与阈值"""
+ if operator == '>':
+ return value > threshold
+ elif operator == '>=':
+ return value >= threshold
+ elif operator == '<':
+ return value < threshold
+ elif operator == '<=':
+ return value <= threshold
+ elif operator == '==':
+ return value == threshold
+ elif operator == '!=':
+ return value != threshold
+ return False
+
+ def _create_alert_record(
+ self,
+ rule: AlertRule,
+ metric_value: str,
+ threshold_value: str
+ ) -> AlertRecord:
+ """创建告警记录"""
+ title = f"[{rule.priority.upper()}] {rule.name}"
+ message = f"规则 '{rule.name}' 触发告警\n当前值: {metric_value}\n阈值: {threshold_value}"
+
+ if rule.scope_type == 'tenant':
+ message += f"\n租户: {rule.scope_value}"
+ elif rule.scope_type == 'app':
+ message += f"\n应用: {rule.scope_value}"
+
+ alert = AlertRecord(
+ rule_id=rule.id,
+ rule_name=rule.name,
+ alert_type=rule.rule_type,
+ severity=self._priority_to_severity(rule.priority),
+ title=title,
+ message=message,
+ tenant_id=rule.scope_value if rule.scope_type == 'tenant' else None,
+ app_code=rule.scope_value if rule.scope_type == 'app' else None,
+ metric_value=metric_value,
+ threshold_value=threshold_value,
+ notification_status='pending'
+ )
+
+ self.db.add(alert)
+ self.db.commit()
+ self.db.refresh(alert)
+
+ # 设置冷却期
+ self._set_cooldown(rule)
+
+ logger.info(f"Alert triggered: {title}")
+
+ return alert
+
+ def _priority_to_severity(self, priority: str) -> str:
+ """将优先级转换为严重程度"""
+ mapping = {
+ 'low': 'info',
+ 'medium': 'warning',
+ 'high': 'error',
+ 'critical': 'critical'
+ }
+ return mapping.get(priority, 'warning')
+
+ async def send_notification(self, alert: AlertRecord, rule: AlertRule) -> bool:
+ """发送告警通知
+
+ Args:
+ alert: 告警记录
+ rule: 告警规则
+
+ Returns:
+ 是否发送成功
+ """
+ if not rule.notification_channels:
+ alert.notification_status = 'skipped'
+ self.db.commit()
+ return True
+
+ results = []
+ success = True
+
+ for channel_config in rule.notification_channels:
+ try:
+ result = await self._send_to_channel(channel_config, alert)
+ results.append(result)
+ if not result.get('success'):
+ success = False
+ except Exception as e:
+ logger.error(f"Failed to send notification: {e}")
+ results.append({'success': False, 'error': str(e)})
+ success = False
+
+ alert.notification_status = 'sent' if success else 'failed'
+ alert.notification_result = results
+ alert.notified_at = datetime.now()
+ self.db.commit()
+
+ return success
+
+ async def _send_to_channel(self, channel_config: dict, alert: AlertRecord) -> dict:
+ """发送到指定渠道"""
+ channel_type = channel_config.get('type')
+
+ if channel_type == 'wechat_bot':
+ return await self._send_wechat_bot(channel_config, alert)
+ elif channel_type == 'webhook':
+ return await self._send_webhook(channel_config, alert)
+ else:
+ return {'success': False, 'error': f'Unsupported channel type: {channel_type}'}
+
+ async def _send_wechat_bot(self, config: dict, alert: AlertRecord) -> dict:
+ """发送到企微机器人"""
+ webhook = config.get('webhook')
+ if not webhook:
+ return {'success': False, 'error': 'Missing webhook URL'}
+
+ # 构建消息
+ content = f"**{alert.title}**\n\n{alert.message}\n\n时间: {alert.created_at}"
+
+ payload = {
+ "msgtype": "markdown",
+ "markdown": {
+ "content": content
+ }
+ }
+
+ try:
+ async with httpx.AsyncClient(timeout=10) as client:
+ response = await client.post(webhook, json=payload)
+ result = response.json()
+
+ if result.get('errcode', 0) == 0:
+ return {'success': True, 'channel': 'wechat_bot'}
+ else:
+ return {'success': False, 'error': result.get('errmsg')}
+ except Exception as e:
+ return {'success': False, 'error': str(e)}
+
+ async def _send_webhook(self, config: dict, alert: AlertRecord) -> dict:
+ """发送到Webhook"""
+ url = config.get('url')
+ if not url:
+ return {'success': False, 'error': 'Missing webhook URL'}
+
+ payload = {
+ "alert_id": alert.id,
+ "title": alert.title,
+ "message": alert.message,
+ "severity": alert.severity,
+ "alert_type": alert.alert_type,
+ "metric_value": alert.metric_value,
+ "threshold_value": alert.threshold_value,
+ "created_at": alert.created_at.isoformat()
+ }
+
+ headers = config.get('headers', {})
+ method = config.get('method', 'POST')
+
+ try:
+ async with httpx.AsyncClient(timeout=10) as client:
+ if method.upper() == 'POST':
+ response = await client.post(url, json=payload, headers=headers)
+ else:
+ response = await client.get(url, params=payload, headers=headers)
+
+ if response.status_code < 400:
+ return {'success': True, 'channel': 'webhook', 'status': response.status_code}
+ else:
+ return {'success': False, 'error': f'HTTP {response.status_code}'}
+ except Exception as e:
+ return {'success': False, 'error': str(e)}
+
+ def acknowledge_alert(self, alert_id: int, acknowledged_by: str) -> Optional[AlertRecord]:
+ """确认告警"""
+ alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first()
+ if not alert:
+ return None
+
+ alert.status = 'acknowledged'
+ alert.acknowledged_by = acknowledged_by
+ alert.acknowledged_at = datetime.now()
+ self.db.commit()
+
+ return alert
+
+ def resolve_alert(self, alert_id: int) -> Optional[AlertRecord]:
+ """解决告警"""
+ alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first()
+ if not alert:
+ return None
+
+ alert.status = 'resolved'
+ alert.resolved_at = datetime.now()
+ self.db.commit()
+
+ return alert
diff --git a/backend/app/services/cache.py b/backend/app/services/cache.py
new file mode 100644
index 0000000..8fd6baa
--- /dev/null
+++ b/backend/app/services/cache.py
@@ -0,0 +1,309 @@
+"""Redis缓存服务"""
+import json
+import logging
+from typing import Optional, Any, Union
+from functools import lru_cache
+
+try:
+ import redis
+ from redis import Redis
+ REDIS_AVAILABLE = True
+except ImportError:
+ REDIS_AVAILABLE = False
+ Redis = None
+
+from ..config import get_settings
+
+logger = logging.getLogger(__name__)
+
+# 全局Redis连接池
+_redis_pool: Optional[Any] = None
+_redis_client: Optional[Any] = None
+
+
+def get_redis_client() -> Optional[Any]:
+ """获取Redis客户端单例"""
+ global _redis_pool, _redis_client
+
+ if not REDIS_AVAILABLE:
+ logger.warning("Redis module not installed, cache disabled")
+ return None
+
+ if _redis_client is not None:
+ return _redis_client
+
+ settings = get_settings()
+
+ try:
+ _redis_pool = redis.ConnectionPool.from_url(
+ settings.REDIS_URL,
+ max_connections=20,
+ decode_responses=True
+ )
+ _redis_client = Redis(connection_pool=_redis_pool)
+ # 测试连接
+ _redis_client.ping()
+ logger.info(f"Redis connected: {settings.REDIS_URL}")
+ return _redis_client
+ except Exception as e:
+ logger.warning(f"Redis connection failed: {e}, cache disabled")
+ _redis_client = None
+ return None
+
+
+class CacheService:
+ """缓存服务
+
+ 提供统一的缓存接口,支持Redis和内存回退
+
+ 使用示例:
+ cache = CacheService()
+
+ # 设置缓存
+ cache.set("user:123", {"name": "test"}, ttl=3600)
+
+ # 获取缓存
+ user = cache.get("user:123")
+
+ # 删除缓存
+ cache.delete("user:123")
+ """
+
+ def __init__(self, prefix: Optional[str] = None):
+ """初始化缓存服务
+
+ Args:
+ prefix: 键前缀,默认使用配置中的REDIS_PREFIX
+ """
+ settings = get_settings()
+ self.prefix = prefix or settings.REDIS_PREFIX
+ self._client = get_redis_client()
+
+ # 内存回退缓存(当Redis不可用时使用)
+ self._memory_cache: dict = {}
+
+ @property
+ def is_available(self) -> bool:
+ """Redis是否可用"""
+ return self._client is not None
+
+ def _make_key(self, key: str) -> str:
+ """生成完整的缓存键"""
+ return f"{self.prefix}{key}"
+
+ def get(self, key: str, default: Any = None) -> Any:
+ """获取缓存值
+
+ Args:
+ key: 缓存键
+ default: 默认值
+
+ Returns:
+ 缓存值或默认值
+ """
+ full_key = self._make_key(key)
+
+ if self._client:
+ try:
+ value = self._client.get(full_key)
+ if value is None:
+ return default
+ # 尝试解析JSON
+ try:
+ return json.loads(value)
+ except (json.JSONDecodeError, TypeError):
+ return value
+ except Exception as e:
+ logger.error(f"Cache get error: {e}")
+ return default
+ else:
+ # 内存回退
+ return self._memory_cache.get(full_key, default)
+
+ def set(
+ self,
+ key: str,
+ value: Any,
+ ttl: Optional[int] = None,
+ nx: bool = False
+ ) -> bool:
+ """设置缓存值
+
+ Args:
+ key: 缓存键
+ value: 缓存值
+ ttl: 过期时间(秒)
+ nx: 只在键不存在时设置
+
+ Returns:
+ 是否设置成功
+ """
+ full_key = self._make_key(key)
+
+ # 序列化值
+ if isinstance(value, (dict, list)):
+ serialized = json.dumps(value, ensure_ascii=False)
+ else:
+ serialized = str(value) if value is not None else ""
+
+ if self._client:
+ try:
+ if nx:
+ result = self._client.set(full_key, serialized, ex=ttl, nx=True)
+ else:
+ result = self._client.set(full_key, serialized, ex=ttl)
+ return bool(result)
+ except Exception as e:
+ logger.error(f"Cache set error: {e}")
+ return False
+ else:
+ # 内存回退(不支持TTL和NX)
+ if nx and full_key in self._memory_cache:
+ return False
+ self._memory_cache[full_key] = value
+ return True
+
+ def delete(self, key: str) -> bool:
+ """删除缓存
+
+ Args:
+ key: 缓存键
+
+ Returns:
+ 是否删除成功
+ """
+ full_key = self._make_key(key)
+
+ if self._client:
+ try:
+ return bool(self._client.delete(full_key))
+ except Exception as e:
+ logger.error(f"Cache delete error: {e}")
+ return False
+ else:
+ return self._memory_cache.pop(full_key, None) is not None
+
+ def exists(self, key: str) -> bool:
+ """检查键是否存在
+
+ Args:
+ key: 缓存键
+
+ Returns:
+ 是否存在
+ """
+ full_key = self._make_key(key)
+
+ if self._client:
+ try:
+ return bool(self._client.exists(full_key))
+ except Exception as e:
+ logger.error(f"Cache exists error: {e}")
+ return False
+ else:
+ return full_key in self._memory_cache
+
+ def ttl(self, key: str) -> int:
+ """获取键的剩余过期时间
+
+ Args:
+ key: 缓存键
+
+ Returns:
+ 剩余秒数,-1表示永不过期,-2表示键不存在
+ """
+ full_key = self._make_key(key)
+
+ if self._client:
+ try:
+ return self._client.ttl(full_key)
+ except Exception as e:
+ logger.error(f"Cache ttl error: {e}")
+ return -2
+ else:
+ return -1 if full_key in self._memory_cache else -2
+
+ def incr(self, key: str, amount: int = 1) -> int:
+ """递增计数器
+
+ Args:
+ key: 缓存键
+ amount: 递增量
+
+ Returns:
+ 递增后的值
+ """
+ full_key = self._make_key(key)
+
+ if self._client:
+ try:
+ return self._client.incrby(full_key, amount)
+ except Exception as e:
+ logger.error(f"Cache incr error: {e}")
+ return 0
+ else:
+ current = self._memory_cache.get(full_key, 0)
+ new_value = int(current) + amount
+ self._memory_cache[full_key] = new_value
+ return new_value
+
+ def expire(self, key: str, ttl: int) -> bool:
+ """设置键的过期时间
+
+ Args:
+ key: 缓存键
+ ttl: 过期时间(秒)
+
+ Returns:
+ 是否设置成功
+ """
+ full_key = self._make_key(key)
+
+ if self._client:
+ try:
+ return bool(self._client.expire(full_key, ttl))
+ except Exception as e:
+ logger.error(f"Cache expire error: {e}")
+ return False
+ else:
+ return full_key in self._memory_cache
+
+ def clear_prefix(self, prefix: str) -> int:
+ """删除指定前缀的所有键
+
+ Args:
+ prefix: 键前缀
+
+ Returns:
+ 删除的键数量
+ """
+ full_prefix = self._make_key(prefix)
+
+ if self._client:
+ try:
+ keys = self._client.keys(f"{full_prefix}*")
+ if keys:
+ return self._client.delete(*keys)
+ return 0
+ except Exception as e:
+ logger.error(f"Cache clear_prefix error: {e}")
+ return 0
+ else:
+ count = 0
+ keys_to_delete = [k for k in self._memory_cache if k.startswith(full_prefix)]
+ for k in keys_to_delete:
+ del self._memory_cache[k]
+ count += 1
+ return count
+
+
+# 全局缓存实例
+_cache_instance: Optional[CacheService] = None
+
+
+def get_cache() -> CacheService:
+ """获取全局缓存实例"""
+ global _cache_instance
+ if _cache_instance is None:
+ _cache_instance = CacheService()
+ return _cache_instance
diff --git a/backend/app/services/cost.py b/backend/app/services/cost.py
new file mode 100644
index 0000000..0f39fef
--- /dev/null
+++ b/backend/app/services/cost.py
@@ -0,0 +1,420 @@
+"""费用计算服务"""
+import logging
+from datetime import datetime
+from decimal import Decimal
+from typing import Optional, Dict, List
+from functools import lru_cache
+
+from sqlalchemy.orm import Session
+from sqlalchemy import func
+
+from ..models.pricing import ModelPricing, TenantBilling
+from ..models.stats import AICallEvent
+from .cache import get_cache
+
+logger = logging.getLogger(__name__)
+
+
+class CostCalculator:
+ """费用计算器
+
+ 使用示例:
+ calculator = CostCalculator(db)
+
+ # 计算单次调用费用
+ cost = calculator.calculate_cost("gpt-4", input_tokens=100, output_tokens=200)
+
+ # 生成月度账单
+ billing = calculator.generate_monthly_billing("qiqi", "2026-01")
+ """
+
+ # 默认模型价格(当数据库中无配置时使用)
+ DEFAULT_PRICING = {
+ # OpenAI
+ "gpt-4": {"input": 0.21, "output": 0.42}, # 元/1K tokens
+ "gpt-4-turbo": {"input": 0.07, "output": 0.21},
+ "gpt-4o": {"input": 0.035, "output": 0.105},
+ "gpt-4o-mini": {"input": 0.00105, "output": 0.0042},
+ "gpt-3.5-turbo": {"input": 0.0035, "output": 0.014},
+
+ # Anthropic
+ "claude-3-opus": {"input": 0.105, "output": 0.525},
+ "claude-3-sonnet": {"input": 0.021, "output": 0.105},
+ "claude-3-haiku": {"input": 0.00175, "output": 0.00875},
+ "claude-3.5-sonnet": {"input": 0.021, "output": 0.105},
+
+ # 国内模型
+ "qwen-max": {"input": 0.02, "output": 0.06},
+ "qwen-plus": {"input": 0.004, "output": 0.012},
+ "qwen-turbo": {"input": 0.002, "output": 0.006},
+ "glm-4": {"input": 0.01, "output": 0.01},
+ "glm-4-flash": {"input": 0.0001, "output": 0.0001},
+ "deepseek-chat": {"input": 0.001, "output": 0.002},
+ "deepseek-coder": {"input": 0.001, "output": 0.002},
+
+ # 默认
+ "default": {"input": 0.01, "output": 0.03}
+ }
+
+ def __init__(self, db: Session):
+ self.db = db
+ self._cache = get_cache()
+ self._pricing_cache: Dict[str, ModelPricing] = {}
+
+ def get_model_pricing(self, model_name: str) -> Optional[ModelPricing]:
+ """获取模型价格配置
+
+ Args:
+ model_name: 模型名称
+
+ Returns:
+ ModelPricing实例或None
+ """
+ # 尝试从缓存获取
+ cache_key = f"pricing:{model_name}"
+ cached = self._cache.get(cache_key)
+ if cached:
+ return self._dict_to_pricing(cached)
+
+ # 从数据库查询
+ pricing = self.db.query(ModelPricing).filter(
+ ModelPricing.model_name == model_name,
+ ModelPricing.status == 1
+ ).first()
+
+ if pricing:
+ # 缓存1小时
+ self._cache.set(cache_key, self._pricing_to_dict(pricing), ttl=3600)
+ return pricing
+
+ return None
+
+ def _pricing_to_dict(self, pricing: ModelPricing) -> dict:
+ return {
+ "model_name": pricing.model_name,
+ "input_price_per_1k": str(pricing.input_price_per_1k),
+ "output_price_per_1k": str(pricing.output_price_per_1k),
+ "fixed_price_per_call": str(pricing.fixed_price_per_call),
+ "pricing_type": pricing.pricing_type
+ }
+
+ def _dict_to_pricing(self, d: dict) -> ModelPricing:
+ pricing = ModelPricing()
+ pricing.model_name = d.get("model_name")
+ pricing.input_price_per_1k = Decimal(d.get("input_price_per_1k", "0"))
+ pricing.output_price_per_1k = Decimal(d.get("output_price_per_1k", "0"))
+ pricing.fixed_price_per_call = Decimal(d.get("fixed_price_per_call", "0"))
+ pricing.pricing_type = d.get("pricing_type", "token")
+ return pricing
+
+ def calculate_cost(
+ self,
+ model_name: str,
+ input_tokens: int = 0,
+ output_tokens: int = 0,
+ call_count: int = 1
+ ) -> Decimal:
+ """计算调用费用
+
+ Args:
+ model_name: 模型名称
+ input_tokens: 输入token数
+ output_tokens: 输出token数
+ call_count: 调用次数
+
+ Returns:
+ 费用(元)
+ """
+ # 尝试获取数据库配置
+ pricing = self.get_model_pricing(model_name)
+
+ if pricing:
+ if pricing.pricing_type == 'call':
+ return pricing.fixed_price_per_call * call_count
+ elif pricing.pricing_type == 'hybrid':
+ token_cost = (
+ pricing.input_price_per_1k * Decimal(input_tokens) / 1000 +
+ pricing.output_price_per_1k * Decimal(output_tokens) / 1000
+ )
+ call_cost = pricing.fixed_price_per_call * call_count
+ return token_cost + call_cost
+ else: # token
+ return (
+ pricing.input_price_per_1k * Decimal(input_tokens) / 1000 +
+ pricing.output_price_per_1k * Decimal(output_tokens) / 1000
+ )
+
+ # 使用默认价格
+ default_prices = self.DEFAULT_PRICING.get(model_name) or self.DEFAULT_PRICING.get("default")
+ input_price = Decimal(str(default_prices["input"]))
+ output_price = Decimal(str(default_prices["output"]))
+
+ return (
+ input_price * Decimal(input_tokens) / 1000 +
+ output_price * Decimal(output_tokens) / 1000
+ )
+
+ def calculate_event_cost(self, event: AICallEvent) -> Decimal:
+ """计算单个事件的费用
+
+ Args:
+ event: AI调用事件
+
+ Returns:
+ 费用(元)
+ """
+ return self.calculate_cost(
+ model_name=event.model or "default",
+ input_tokens=event.input_tokens or 0,
+ output_tokens=event.output_tokens or 0
+ )
+
+ def update_event_costs(self, start_date: str = None, end_date: str = None) -> int:
+ """批量更新事件费用
+
+ 对于cost为0或NULL的事件,重新计算费用
+
+ Args:
+ start_date: 开始日期,格式 YYYY-MM-DD
+ end_date: 结束日期,格式 YYYY-MM-DD
+
+ Returns:
+ 更新的记录数
+ """
+ query = self.db.query(AICallEvent).filter(
+ (AICallEvent.cost == None) | (AICallEvent.cost == 0)
+ )
+
+ if start_date:
+ query = query.filter(AICallEvent.created_at >= start_date)
+ if end_date:
+ query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
+
+ events = query.all()
+ updated = 0
+
+ for event in events:
+ try:
+ cost = self.calculate_event_cost(event)
+ event.cost = cost
+ updated += 1
+ except Exception as e:
+ logger.error(f"Failed to calculate cost for event {event.id}: {e}")
+
+ self.db.commit()
+ logger.info(f"Updated {updated} event costs")
+
+ return updated
+
+ def generate_monthly_billing(
+ self,
+ tenant_id: str,
+ billing_month: str
+ ) -> TenantBilling:
+ """生成月度账单
+
+ Args:
+ tenant_id: 租户ID
+ billing_month: 账单月份,格式 YYYY-MM
+
+ Returns:
+ TenantBilling实例
+ """
+ # 检查是否已存在
+ existing = self.db.query(TenantBilling).filter(
+ TenantBilling.tenant_id == tenant_id,
+ TenantBilling.billing_month == billing_month
+ ).first()
+
+ if existing:
+ billing = existing
+ else:
+ billing = TenantBilling(
+ tenant_id=tenant_id,
+ billing_month=billing_month
+ )
+ self.db.add(billing)
+
+ # 计算统计数据
+ start_date = f"{billing_month}-01"
+ year, month = billing_month.split("-")
+ if int(month) == 12:
+ end_date = f"{int(year)+1}-01-01"
+ else:
+ end_date = f"{year}-{int(month)+1:02d}-01"
+
+ # 聚合查询
+ stats = self.db.query(
+ func.count(AICallEvent.id).label('total_calls'),
+ func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'),
+ func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'),
+ func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost')
+ ).filter(
+ AICallEvent.tenant_id == tenant_id,
+ AICallEvent.created_at >= start_date,
+ AICallEvent.created_at < end_date
+ ).first()
+
+ billing.total_calls = stats.total_calls or 0
+ billing.total_input_tokens = int(stats.total_input or 0)
+ billing.total_output_tokens = int(stats.total_output or 0)
+ billing.total_cost = stats.total_cost or Decimal("0")
+
+ # 按模型统计
+ model_stats = self.db.query(
+ AICallEvent.model,
+ func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
+ ).filter(
+ AICallEvent.tenant_id == tenant_id,
+ AICallEvent.created_at >= start_date,
+ AICallEvent.created_at < end_date
+ ).group_by(AICallEvent.model).all()
+
+ billing.cost_by_model = {
+ m.model or "unknown": float(m.cost) for m in model_stats
+ }
+
+ # 按应用统计
+ app_stats = self.db.query(
+ AICallEvent.app_code,
+ func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
+ ).filter(
+ AICallEvent.tenant_id == tenant_id,
+ AICallEvent.created_at >= start_date,
+ AICallEvent.created_at < end_date
+ ).group_by(AICallEvent.app_code).all()
+
+ billing.cost_by_app = {
+ a.app_code or "unknown": float(a.cost) for a in app_stats
+ }
+
+ self.db.commit()
+ self.db.refresh(billing)
+
+ return billing
+
+ def get_cost_summary(
+ self,
+ tenant_id: str = None,
+ start_date: str = None,
+ end_date: str = None
+ ) -> Dict:
+ """获取费用汇总
+
+ Args:
+ tenant_id: 租户ID(可选)
+ start_date: 开始日期
+ end_date: 结束日期
+
+ Returns:
+ 费用汇总字典
+ """
+ query = self.db.query(
+ func.count(AICallEvent.id).label('total_calls'),
+ func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'),
+ func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'),
+ func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost')
+ )
+
+ if tenant_id:
+ query = query.filter(AICallEvent.tenant_id == tenant_id)
+ if start_date:
+ query = query.filter(AICallEvent.created_at >= start_date)
+ if end_date:
+ query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
+
+ stats = query.first()
+
+ return {
+ "total_calls": stats.total_calls or 0,
+ "total_input_tokens": int(stats.total_input or 0),
+ "total_output_tokens": int(stats.total_output or 0),
+ "total_cost": float(stats.total_cost or 0)
+ }
+
+ def get_cost_by_tenant(
+ self,
+ start_date: str = None,
+ end_date: str = None
+ ) -> List[Dict]:
+ """按租户统计费用
+
+ Returns:
+ 租户费用列表
+ """
+ query = self.db.query(
+ AICallEvent.tenant_id,
+ func.count(AICallEvent.id).label('calls'),
+ func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
+ )
+
+ if start_date:
+ query = query.filter(AICallEvent.created_at >= start_date)
+ if end_date:
+ query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
+
+ results = query.group_by(AICallEvent.tenant_id).order_by(
+ func.sum(AICallEvent.cost).desc()
+ ).all()
+
+ return [
+ {
+ "tenant_id": r.tenant_id,
+ "calls": r.calls,
+ "cost": float(r.cost)
+ }
+ for r in results
+ ]
+
+ def get_cost_by_model(
+ self,
+ tenant_id: str = None,
+ start_date: str = None,
+ end_date: str = None
+ ) -> List[Dict]:
+ """按模型统计费用
+
+ Returns:
+ 模型费用列表
+ """
+ query = self.db.query(
+ AICallEvent.model,
+ func.count(AICallEvent.id).label('calls'),
+ func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('input_tokens'),
+ func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('output_tokens'),
+ func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
+ )
+
+ if tenant_id:
+ query = query.filter(AICallEvent.tenant_id == tenant_id)
+ if start_date:
+ query = query.filter(AICallEvent.created_at >= start_date)
+ if end_date:
+ query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59")
+
+ results = query.group_by(AICallEvent.model).order_by(
+ func.sum(AICallEvent.cost).desc()
+ ).all()
+
+ return [
+ {
+ "model": r.model or "unknown",
+ "calls": r.calls,
+ "input_tokens": int(r.input_tokens),
+ "output_tokens": int(r.output_tokens),
+ "cost": float(r.cost)
+ }
+ for r in results
+ ]
+
+
+# 便捷函数
+def calculate_cost(
+ db: Session,
+ model_name: str,
+ input_tokens: int = 0,
+ output_tokens: int = 0
+) -> Decimal:
+ """快速计算费用"""
+ calculator = CostCalculator(db)
+ return calculator.calculate_cost(model_name, input_tokens, output_tokens)
diff --git a/backend/app/services/quota.py b/backend/app/services/quota.py
new file mode 100644
index 0000000..d4ad5e2
--- /dev/null
+++ b/backend/app/services/quota.py
@@ -0,0 +1,346 @@
+"""配额管理服务"""
+import logging
+from datetime import datetime, date, timedelta
+from typing import Optional, Dict, Any, Tuple
+from dataclasses import dataclass
+
+from sqlalchemy.orm import Session
+from sqlalchemy import func
+
+from ..models.tenant import Tenant, Subscription
+from ..models.stats import AICallEvent
+from .cache import get_cache
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class QuotaConfig:
+ """配额配置"""
+ daily_calls: int = 0 # 每日调用限制,0表示无限制
+ daily_tokens: int = 0 # 每日Token限制
+ monthly_calls: int = 0 # 每月调用限制
+ monthly_tokens: int = 0 # 每月Token限制
+ monthly_cost: float = 0 # 每月费用限制(元)
+ concurrent_calls: int = 0 # 并发调用限制
+
+
+@dataclass
+class QuotaUsage:
+ """配额使用情况"""
+ daily_calls: int = 0
+ daily_tokens: int = 0
+ monthly_calls: int = 0
+ monthly_tokens: int = 0
+ monthly_cost: float = 0
+
+
+@dataclass
+class QuotaCheckResult:
+ """配额检查结果"""
+ allowed: bool
+ reason: Optional[str] = None
+ quota_type: Optional[str] = None
+ limit: int = 0
+ used: int = 0
+ remaining: int = 0
+
+
+class QuotaService:
+ """配额管理服务
+
+ 使用示例:
+ quota_service = QuotaService(db)
+
+ # 检查配额
+ result = quota_service.check_quota("qiqi", "tools")
+ if not result.allowed:
+ raise HTTPException(status_code=429, detail=result.reason)
+
+ # 获取使用情况
+ usage = quota_service.get_usage("qiqi", "tools")
+ """
+
+ # 默认配额(当无订阅配置时使用)
+ DEFAULT_QUOTA = QuotaConfig(
+ daily_calls=1000,
+ daily_tokens=100000,
+ monthly_calls=30000,
+ monthly_tokens=3000000,
+ monthly_cost=100
+ )
+
+ def __init__(self, db: Session):
+ self.db = db
+ self._cache = get_cache()
+
+ def get_subscription(self, tenant_id: str, app_code: str) -> Optional[Subscription]:
+ """获取租户订阅配置"""
+ return self.db.query(Subscription).filter(
+ Subscription.tenant_id == tenant_id,
+ Subscription.app_code == app_code,
+ Subscription.status == 'active'
+ ).first()
+
+ def get_quota_config(self, tenant_id: str, app_code: str) -> QuotaConfig:
+ """获取配额配置
+
+ Args:
+ tenant_id: 租户ID
+ app_code: 应用代码
+
+ Returns:
+ QuotaConfig实例
+ """
+ # 尝试从缓存获取
+ cache_key = f"quota:config:{tenant_id}:{app_code}"
+ cached = self._cache.get(cache_key)
+ if cached:
+ return QuotaConfig(**cached)
+
+ # 从订阅表获取
+ subscription = self.get_subscription(tenant_id, app_code)
+
+ if subscription and subscription.quota:
+ quota = subscription.quota
+ config = QuotaConfig(
+ daily_calls=quota.get('daily_calls', 0),
+ daily_tokens=quota.get('daily_tokens', 0),
+ monthly_calls=quota.get('monthly_calls', 0),
+ monthly_tokens=quota.get('monthly_tokens', 0),
+ monthly_cost=quota.get('monthly_cost', 0),
+ concurrent_calls=quota.get('concurrent_calls', 0)
+ )
+ else:
+ config = self.DEFAULT_QUOTA
+
+ # 缓存5分钟
+ self._cache.set(cache_key, config.__dict__, ttl=300)
+
+ return config
+
+ def get_usage(self, tenant_id: str, app_code: str) -> QuotaUsage:
+ """获取配额使用情况
+
+ Args:
+ tenant_id: 租户ID
+ app_code: 应用代码
+
+ Returns:
+ QuotaUsage实例
+ """
+ today = date.today()
+ month_start = today.replace(day=1)
+
+ # 今日使用量
+ daily_stats = self.db.query(
+ func.count(AICallEvent.id).label('calls'),
+ func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens')
+ ).filter(
+ AICallEvent.tenant_id == tenant_id,
+ AICallEvent.app_code == app_code,
+ func.date(AICallEvent.created_at) == today
+ ).first()
+
+ # 本月使用量
+ monthly_stats = self.db.query(
+ func.count(AICallEvent.id).label('calls'),
+ func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens'),
+ func.coalesce(func.sum(AICallEvent.cost), 0).label('cost')
+ ).filter(
+ AICallEvent.tenant_id == tenant_id,
+ AICallEvent.app_code == app_code,
+ func.date(AICallEvent.created_at) >= month_start
+ ).first()
+
+ return QuotaUsage(
+ daily_calls=daily_stats.calls or 0,
+ daily_tokens=int(daily_stats.tokens or 0),
+ monthly_calls=monthly_stats.calls or 0,
+ monthly_tokens=int(monthly_stats.tokens or 0),
+ monthly_cost=float(monthly_stats.cost or 0)
+ )
+
+ def check_quota(
+ self,
+ tenant_id: str,
+ app_code: str,
+ estimated_tokens: int = 0
+ ) -> QuotaCheckResult:
+ """检查配额是否足够
+
+ Args:
+ tenant_id: 租户ID
+ app_code: 应用代码
+ estimated_tokens: 预估Token消耗
+
+ Returns:
+ QuotaCheckResult实例
+ """
+ config = self.get_quota_config(tenant_id, app_code)
+ usage = self.get_usage(tenant_id, app_code)
+
+ # 检查日调用次数
+ if config.daily_calls > 0:
+ if usage.daily_calls >= config.daily_calls:
+ return QuotaCheckResult(
+ allowed=False,
+ reason=f"已达到每日调用限制 ({config.daily_calls} 次)",
+ quota_type="daily_calls",
+ limit=config.daily_calls,
+ used=usage.daily_calls,
+ remaining=0
+ )
+
+ # 检查日Token限制
+ if config.daily_tokens > 0:
+ if usage.daily_tokens + estimated_tokens > config.daily_tokens:
+ return QuotaCheckResult(
+ allowed=False,
+ reason=f"已达到每日Token限制 ({config.daily_tokens:,})",
+ quota_type="daily_tokens",
+ limit=config.daily_tokens,
+ used=usage.daily_tokens,
+ remaining=max(0, config.daily_tokens - usage.daily_tokens)
+ )
+
+ # 检查月调用次数
+ if config.monthly_calls > 0:
+ if usage.monthly_calls >= config.monthly_calls:
+ return QuotaCheckResult(
+ allowed=False,
+ reason=f"已达到每月调用限制 ({config.monthly_calls} 次)",
+ quota_type="monthly_calls",
+ limit=config.monthly_calls,
+ used=usage.monthly_calls,
+ remaining=0
+ )
+
+ # 检查月Token限制
+ if config.monthly_tokens > 0:
+ if usage.monthly_tokens + estimated_tokens > config.monthly_tokens:
+ return QuotaCheckResult(
+ allowed=False,
+ reason=f"已达到每月Token限制 ({config.monthly_tokens:,})",
+ quota_type="monthly_tokens",
+ limit=config.monthly_tokens,
+ used=usage.monthly_tokens,
+ remaining=max(0, config.monthly_tokens - usage.monthly_tokens)
+ )
+
+ # 检查月费用限制
+ if config.monthly_cost > 0:
+ if usage.monthly_cost >= config.monthly_cost:
+ return QuotaCheckResult(
+ allowed=False,
+ reason=f"已达到每月费用限制 (¥{config.monthly_cost:.2f})",
+ quota_type="monthly_cost",
+ limit=int(config.monthly_cost * 100), # 转为分
+ used=int(usage.monthly_cost * 100),
+ remaining=max(0, int((config.monthly_cost - usage.monthly_cost) * 100))
+ )
+
+ # 所有检查通过
+ return QuotaCheckResult(
+ allowed=True,
+ quota_type="daily_calls",
+ limit=config.daily_calls,
+ used=usage.daily_calls,
+ remaining=max(0, config.daily_calls - usage.daily_calls) if config.daily_calls > 0 else -1
+ )
+
+ def get_quota_summary(self, tenant_id: str, app_code: str) -> Dict[str, Any]:
+ """获取配额汇总信息
+
+ Returns:
+ 包含配额配置和使用情况的字典
+ """
+ config = self.get_quota_config(tenant_id, app_code)
+ usage = self.get_usage(tenant_id, app_code)
+
+ def calc_percentage(used: int, limit: int) -> float:
+ if limit <= 0:
+ return 0
+ return min(100, round(used / limit * 100, 1))
+
+ return {
+ "config": {
+ "daily_calls": config.daily_calls,
+ "daily_tokens": config.daily_tokens,
+ "monthly_calls": config.monthly_calls,
+ "monthly_tokens": config.monthly_tokens,
+ "monthly_cost": config.monthly_cost
+ },
+ "usage": {
+ "daily_calls": usage.daily_calls,
+ "daily_tokens": usage.daily_tokens,
+ "monthly_calls": usage.monthly_calls,
+ "monthly_tokens": usage.monthly_tokens,
+ "monthly_cost": round(usage.monthly_cost, 2)
+ },
+ "percentage": {
+ "daily_calls": calc_percentage(usage.daily_calls, config.daily_calls),
+ "daily_tokens": calc_percentage(usage.daily_tokens, config.daily_tokens),
+ "monthly_calls": calc_percentage(usage.monthly_calls, config.monthly_calls),
+ "monthly_tokens": calc_percentage(usage.monthly_tokens, config.monthly_tokens),
+ "monthly_cost": calc_percentage(int(usage.monthly_cost * 100), int(config.monthly_cost * 100))
+ }
+ }
+
+ def update_quota(
+ self,
+ tenant_id: str,
+ app_code: str,
+ quota_config: Dict[str, Any]
+ ) -> Subscription:
+ """更新配额配置
+
+ Args:
+ tenant_id: 租户ID
+ app_code: 应用代码
+ quota_config: 配额配置字典
+
+ Returns:
+ 更新后的Subscription实例
+ """
+ subscription = self.get_subscription(tenant_id, app_code)
+
+ if not subscription:
+ # 创建新订阅
+ subscription = Subscription(
+ tenant_id=tenant_id,
+ app_code=app_code,
+ start_date=date.today(),
+ quota=quota_config,
+ status='active'
+ )
+ self.db.add(subscription)
+ else:
+ # 更新现有订阅
+ subscription.quota = quota_config
+
+ self.db.commit()
+ self.db.refresh(subscription)
+
+ # 清除缓存
+ cache_key = f"quota:config:{tenant_id}:{app_code}"
+ self._cache.delete(cache_key)
+
+ return subscription
+
+
+def check_quota_middleware(
+ db: Session,
+ tenant_id: str,
+ app_code: str,
+ estimated_tokens: int = 0
+) -> QuotaCheckResult:
+ """配额检查中间件函数
+
+ 可在路由中使用:
+ result = check_quota_middleware(db, "qiqi", "tools")
+ if not result.allowed:
+ raise HTTPException(status_code=429, detail=result.reason)
+ """
+ service = QuotaService(db)
+ return service.check_quota(tenant_id, app_code, estimated_tokens)
diff --git a/backend/app/services/wechat.py b/backend/app/services/wechat.py
new file mode 100644
index 0000000..1b59d78
--- /dev/null
+++ b/backend/app/services/wechat.py
@@ -0,0 +1,371 @@
+"""企业微信服务"""
+import hashlib
+import time
+import logging
+from typing import Optional, Dict, Any
+from dataclasses import dataclass
+
+import httpx
+
+from ..config import get_settings
+from .cache import get_cache
+from .crypto import decrypt_config
+
+logger = logging.getLogger(__name__)
+settings = get_settings()
+
+
+@dataclass
+class WechatConfig:
+ """企业微信应用配置"""
+ corp_id: str
+ agent_id: str
+ secret: str
+
+
+class WechatService:
+ """企业微信服务
+
+ 提供access_token获取、JS-SDK签名、OAuth2等功能
+
+ 使用示例:
+ wechat = WechatService(corp_id="wwxxxx", agent_id="1000001", secret="xxx")
+
+ # 获取access_token
+ token = await wechat.get_access_token()
+
+ # 获取JS-SDK签名
+ signature = await wechat.get_jssdk_signature("https://example.com/page")
+ """
+
+ # 企业微信API基础URL
+ BASE_URL = "https://qyapi.weixin.qq.com"
+
+ def __init__(self, corp_id: str, agent_id: str, secret: str):
+ """初始化企业微信服务
+
+ Args:
+ corp_id: 企业ID
+ agent_id: 应用AgentId
+ secret: 应用Secret(明文)
+ """
+ self.corp_id = corp_id
+ self.agent_id = agent_id
+ self.secret = secret
+ self._cache = get_cache()
+
+ @classmethod
+ def from_wechat_app(cls, wechat_app) -> "WechatService":
+ """从TenantWechatApp模型创建服务实例
+
+ Args:
+ wechat_app: TenantWechatApp数据库模型
+
+ Returns:
+ WechatService实例
+ """
+ secret = ""
+ if wechat_app.secret_encrypted:
+ try:
+ secret = decrypt_config(wechat_app.secret_encrypted)
+ except Exception as e:
+ logger.error(f"Failed to decrypt secret: {e}")
+
+ return cls(
+ corp_id=wechat_app.corp_id,
+ agent_id=wechat_app.agent_id,
+ secret=secret
+ )
+
+ def _cache_key(self, key_type: str) -> str:
+ """生成缓存键"""
+ return f"wechat:{self.corp_id}:{self.agent_id}:{key_type}"
+
+ async def get_access_token(self, force_refresh: bool = False) -> Optional[str]:
+ """获取access_token
+
+ 企业微信access_token有效期7200秒,需要缓存
+
+ Args:
+ force_refresh: 是否强制刷新
+
+ Returns:
+ access_token或None
+ """
+ cache_key = self._cache_key("access_token")
+
+ # 尝试从缓存获取
+ if not force_refresh:
+ cached = self._cache.get(cache_key)
+ if cached:
+ logger.debug(f"Access token from cache: {cached[:20]}...")
+ return cached
+
+ # 从企业微信API获取
+ url = f"{self.BASE_URL}/cgi-bin/gettoken"
+ params = {
+ "corpid": self.corp_id,
+ "corpsecret": self.secret
+ }
+
+ try:
+ async with httpx.AsyncClient(timeout=10) as client:
+ response = await client.get(url, params=params)
+ result = response.json()
+
+ if result.get("errcode", 0) != 0:
+ logger.error(f"Get access_token failed: {result}")
+ return None
+
+ access_token = result.get("access_token")
+ expires_in = result.get("expires_in", 7200)
+
+ # 缓存,提前200秒过期以确保安全
+ self._cache.set(
+ cache_key,
+ access_token,
+ ttl=min(expires_in - 200, settings.WECHAT_ACCESS_TOKEN_EXPIRE)
+ )
+
+ logger.info(f"Got new access_token for {self.corp_id}")
+ return access_token
+ except Exception as e:
+ logger.error(f"Get access_token error: {e}")
+ return None
+
+ async def get_jsapi_ticket(self, force_refresh: bool = False) -> Optional[str]:
+ """获取jsapi_ticket
+
+ 用于生成JS-SDK签名
+
+ Args:
+ force_refresh: 是否强制刷新
+
+ Returns:
+ jsapi_ticket或None
+ """
+ cache_key = self._cache_key("jsapi_ticket")
+
+ # 尝试从缓存获取
+ if not force_refresh:
+ cached = self._cache.get(cache_key)
+ if cached:
+ logger.debug(f"JSAPI ticket from cache: {cached[:20]}...")
+ return cached
+
+ # 先获取access_token
+ access_token = await self.get_access_token()
+ if not access_token:
+ return None
+
+ # 获取jsapi_ticket
+ url = f"{self.BASE_URL}/cgi-bin/get_jsapi_ticket"
+ params = {"access_token": access_token}
+
+ try:
+ async with httpx.AsyncClient(timeout=10) as client:
+ response = await client.get(url, params=params)
+ result = response.json()
+
+ if result.get("errcode", 0) != 0:
+ logger.error(f"Get jsapi_ticket failed: {result}")
+ return None
+
+ ticket = result.get("ticket")
+ expires_in = result.get("expires_in", 7200)
+
+ # 缓存
+ self._cache.set(
+ cache_key,
+ ticket,
+ ttl=min(expires_in - 200, settings.WECHAT_JSAPI_TICKET_EXPIRE)
+ )
+
+ logger.info(f"Got new jsapi_ticket for {self.corp_id}")
+ return ticket
+ except Exception as e:
+ logger.error(f"Get jsapi_ticket error: {e}")
+ return None
+
+ async def get_jssdk_signature(
+ self,
+ url: str,
+ noncestr: Optional[str] = None,
+ timestamp: Optional[int] = None
+ ) -> Optional[Dict[str, Any]]:
+ """生成JS-SDK签名
+
+ Args:
+ url: 当前页面URL(不含#及其后面部分)
+ noncestr: 随机字符串,可选
+ timestamp: 时间戳,可选
+
+ Returns:
+ 签名信息字典,包含signature, noncestr, timestamp, appId等
+ """
+ ticket = await self.get_jsapi_ticket()
+ if not ticket:
+ return None
+
+ # 生成随机字符串和时间戳
+ if noncestr is None:
+ import secrets
+ noncestr = secrets.token_hex(8)
+ if timestamp is None:
+ timestamp = int(time.time())
+
+ # 构建签名字符串
+ sign_str = f"jsapi_ticket={ticket}&noncestr={noncestr}×tamp={timestamp}&url={url}"
+
+ # SHA1签名
+ signature = hashlib.sha1(sign_str.encode()).hexdigest()
+
+ return {
+ "appId": self.corp_id,
+ "agentId": self.agent_id,
+ "timestamp": timestamp,
+ "nonceStr": noncestr,
+ "signature": signature,
+ "url": url
+ }
+
+ def get_oauth2_url(
+ self,
+ redirect_uri: str,
+ scope: str = "snsapi_base",
+ state: str = ""
+ ) -> str:
+ """生成OAuth2授权URL
+
+ Args:
+ redirect_uri: 授权后重定向的URL
+ scope: 应用授权作用域
+ - snsapi_base: 静默授权,只能获取成员基础信息
+ - snsapi_privateinfo: 手动授权,可获取成员详细信息
+ state: 重定向后会带上state参数
+
+ Returns:
+ OAuth2授权URL
+ """
+ import urllib.parse
+
+ encoded_uri = urllib.parse.quote(redirect_uri, safe='')
+
+ url = (
+ f"https://open.weixin.qq.com/connect/oauth2/authorize"
+ f"?appid={self.corp_id}"
+ f"&redirect_uri={encoded_uri}"
+ f"&response_type=code"
+ f"&scope={scope}"
+ f"&state={state}"
+ f"&agentid={self.agent_id}"
+ f"#wechat_redirect"
+ )
+
+ return url
+
+ async def get_user_info_by_code(self, code: str) -> Optional[Dict[str, Any]]:
+ """通过OAuth2 code获取用户信息
+
+ Args:
+ code: OAuth2回调返回的code
+
+ Returns:
+ 用户信息字典,包含UserId, DeviceId等
+ """
+ access_token = await self.get_access_token()
+ if not access_token:
+ return None
+
+ url = f"{self.BASE_URL}/cgi-bin/auth/getuserinfo"
+ params = {
+ "access_token": access_token,
+ "code": code
+ }
+
+ try:
+ async with httpx.AsyncClient(timeout=10) as client:
+ response = await client.get(url, params=params)
+ result = response.json()
+
+ if result.get("errcode", 0) != 0:
+ logger.error(f"Get user info by code failed: {result}")
+ return None
+
+ return {
+ "user_id": result.get("userid") or result.get("UserId"),
+ "device_id": result.get("deviceid") or result.get("DeviceId"),
+ "open_id": result.get("openid") or result.get("OpenId"),
+ "external_userid": result.get("external_userid"),
+ }
+ except Exception as e:
+ logger.error(f"Get user info by code error: {e}")
+ return None
+
+ async def get_user_detail(self, user_id: str) -> Optional[Dict[str, Any]]:
+ """获取成员详细信息
+
+ Args:
+ user_id: 成员UserID
+
+ Returns:
+ 成员详细信息
+ """
+ access_token = await self.get_access_token()
+ if not access_token:
+ return None
+
+ url = f"{self.BASE_URL}/cgi-bin/user/get"
+ params = {
+ "access_token": access_token,
+ "userid": user_id
+ }
+
+ try:
+ async with httpx.AsyncClient(timeout=10) as client:
+ response = await client.get(url, params=params)
+ result = response.json()
+
+ if result.get("errcode", 0) != 0:
+ logger.error(f"Get user detail failed: {result}")
+ return None
+
+ return {
+ "userid": result.get("userid"),
+ "name": result.get("name"),
+ "department": result.get("department"),
+ "position": result.get("position"),
+ "mobile": result.get("mobile"),
+ "email": result.get("email"),
+ "avatar": result.get("avatar"),
+ "status": result.get("status"),
+ }
+ except Exception as e:
+ logger.error(f"Get user detail error: {e}")
+ return None
+
+
+async def get_wechat_service_by_id(
+ wechat_app_id: int,
+ db_session
+) -> Optional[WechatService]:
+ """根据企微应用ID获取服务实例
+
+ Args:
+ wechat_app_id: platform_tenant_wechat_apps表的ID
+ db_session: 数据库session
+
+ Returns:
+ WechatService实例或None
+ """
+ from ..models.tenant_wechat_app import TenantWechatApp
+
+ wechat_app = db_session.query(TenantWechatApp).filter(
+ TenantWechatApp.id == wechat_app_id,
+ TenantWechatApp.status == 1
+ ).first()
+
+ if not wechat_app:
+ return None
+
+ return WechatService.from_wechat_app(wechat_app)
diff --git a/backend/env.template b/backend/env.template
new file mode 100644
index 0000000..5df0c88
--- /dev/null
+++ b/backend/env.template
@@ -0,0 +1,26 @@
+# 000-platform 环境配置模板
+# 复制此文件为 .env 并填写实际值
+
+# ==================== 应用配置 ====================
+APP_NAME=platform
+APP_VERSION=1.0.0
+DEBUG=false
+
+# ==================== 数据库配置 ====================
+DB_HOST=localhost
+DB_PORT=3306
+DB_USER=root
+DB_PASSWORD=your-password
+DB_NAME=new_qiqi
+
+# ==================== JWT 配置 ====================
+JWT_SECRET_KEY=your-secret-key-change-in-production
+JWT_ALGORITHM=HS256
+JWT_EXPIRE_MINUTES=1440
+
+# ==================== 安全配置 ====================
+# 用于加密敏感数据(如企微 Secret)
+ENCRYPTION_KEY=your-encryption-key-32-bytes
+
+# ==================== 可选:Redis 缓存 ====================
+# REDIS_URL=redis://localhost:6379/0
diff --git a/backend/requirements.txt b/backend/requirements.txt
index c68d32b..ac816e7 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -10,3 +10,5 @@ passlib[bcrypt]>=1.7.4
bcrypt>=4.0.0
python-multipart>=0.0.6
httpx>=0.26.0
+redis>=5.0.0
+openpyxl>=3.1.0
diff --git a/deploy/nginx/frontend.conf b/deploy/nginx/frontend.conf
index fcc26d1..5a12504 100644
--- a/deploy/nginx/frontend.conf
+++ b/deploy/nginx/frontend.conf
@@ -5,18 +5,19 @@ server {
root /usr/share/nginx/html;
index index.html;
+ # Docker 内部 DNS 解析器
+ resolver 127.0.0.11 valid=30s;
+
# Vue Router history mode
location / {
try_files $uri $uri/ /index.html;
}
# API 代理到后端
- # 使用宿主机网关地址(Docker 默认网桥)
- # 如果 172.18.0.1 不可用,可能需要调整为实际的 Docker 网桥地址
- set $backend_host 172.18.0.1;
-
+ # 使用 Docker 容器名称,通过 Docker DNS 解析
location /api/ {
- proxy_pass http://$backend_host:8001/api/;
+ set $backend platform-backend-test:8000;
+ proxy_pass http://$backend/api/;
proxy_http_version 1.1;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
diff --git a/frontend/src/api/index.js b/frontend/src/api/index.js
index 1b14bad..075fbef 100644
--- a/frontend/src/api/index.js
+++ b/frontend/src/api/index.js
@@ -7,6 +7,53 @@ const api = axios.create({
timeout: 30000
})
+/**
+ * 解析 API 错误响应
+ */
+function parseApiError(error) {
+ const result = {
+ code: 'UNKNOWN_ERROR',
+ message: '发生了未知错误',
+ traceId: '',
+ status: 500
+ }
+
+ if (!error.response) {
+ result.code = 'NETWORK_ERROR'
+ result.message = '网络连接失败,请检查网络后重试'
+ return result
+ }
+
+ const { status, data, headers } = error.response
+ result.status = status
+ result.traceId = headers['x-trace-id'] || headers['X-Trace-ID'] || ''
+
+ if (data && data.error) {
+ result.code = data.error.code || result.code
+ result.message = data.error.message || result.message
+ result.traceId = data.error.trace_id || result.traceId
+ } else if (data && data.detail) {
+ result.message = typeof data.detail === 'string' ? data.detail : JSON.stringify(data.detail)
+ }
+
+ return result
+}
+
+/**
+ * 跳转到错误页面
+ */
+function navigateToErrorPage(errorInfo) {
+ router.push({
+ name: 'Error',
+ query: {
+ code: errorInfo.code,
+ message: errorInfo.message,
+ trace_id: errorInfo.traceId,
+ status: String(errorInfo.status)
+ }
+ })
+}
+
// 请求拦截器
api.interceptors.request.use(
config => {
@@ -19,10 +66,15 @@ api.interceptors.request.use(
error => Promise.reject(error)
)
-// 响应拦截器
+// 响应拦截器(集成 TraceID 追踪)
api.interceptors.response.use(
response => response,
error => {
+ const errorInfo = parseApiError(error)
+ const traceLog = errorInfo.traceId ? ` (trace: ${errorInfo.traceId})` : ''
+
+ console.error(`[API Error] ${errorInfo.code}: ${errorInfo.message}${traceLog}`)
+
if (error.response?.status === 401) {
localStorage.removeItem('token')
localStorage.removeItem('user')
@@ -30,9 +82,14 @@ api.interceptors.response.use(
ElMessage.error('登录已过期,请重新登录')
} else if (error.response?.status === 403) {
ElMessage.error('没有权限执行此操作')
+ } else if (['INTERNAL_ERROR', 'SERVICE_UNAVAILABLE', 'GATEWAY_ERROR'].includes(errorInfo.code)) {
+ // 严重错误跳转到错误页面
+ navigateToErrorPage(errorInfo)
} else {
- ElMessage.error(error.response?.data?.detail || error.message || '请求失败')
+ // 普通错误显示消息
+ ElMessage.error(errorInfo.message)
}
+
return Promise.reject(error)
}
)
diff --git a/frontend/src/router/index.js b/frontend/src/router/index.js
index fad7e4a..bd68d40 100644
--- a/frontend/src/router/index.js
+++ b/frontend/src/router/index.js
@@ -8,6 +8,12 @@ const routes = [
component: () => import('@/views/login/index.vue'),
meta: { title: '登录', public: true }
},
+ {
+ path: '/error',
+ name: 'Error',
+ component: () => import('@/views/error/index.vue'),
+ meta: { title: '出错了', public: true }
+ },
{
path: '/',
component: () => import('@/components/Layout.vue'),
diff --git a/frontend/src/views/dashboard/index.vue b/frontend/src/views/dashboard/index.vue
index 6e49d46..fdf44e3 100644
--- a/frontend/src/views/dashboard/index.vue
+++ b/frontend/src/views/dashboard/index.vue
@@ -7,11 +7,15 @@ const stats = ref({
totalTenants: 0,
activeTenants: 0,
todayCalls: 0,
- todayTokens: 0
+ todayTokens: 0,
+ weekCalls: 0,
+ weekTokens: 0
})
const recentLogs = ref([])
+const trendData = ref([])
const chartRef = ref(null)
+const chartLoading = ref(false)
let chartInstance = null
async function fetchStats() {
@@ -25,6 +29,8 @@ async function fetchStats() {
if (statsRes.data) {
stats.value.todayCalls = statsRes.data.today_calls || 0
stats.value.todayTokens = statsRes.data.today_tokens || 0
+ stats.value.weekCalls = statsRes.data.week_calls || 0
+ stats.value.weekTokens = statsRes.data.week_tokens || 0
}
} catch (e) {
console.error('获取统计失败:', e)
@@ -40,10 +46,38 @@ async function fetchRecentLogs() {
}
}
+async function fetchTrendData() {
+ chartLoading.value = true
+ try {
+ const res = await api.get('/api/stats/trend', { params: { days: 7 } })
+ trendData.value = res.data.trend || []
+ updateChart()
+ } catch (e) {
+ console.error('获取趋势数据失败:', e)
+ // 如果API失败,使用空数据
+ trendData.value = []
+ updateChart()
+ } finally {
+ chartLoading.value = false
+ }
+}
+
function initChart() {
if (!chartRef.value) return
-
chartInstance = echarts.init(chartRef.value)
+}
+
+function updateChart() {
+ if (!chartInstance) return
+
+ // 从API数据提取日期和调用次数
+ const dates = trendData.value.map(item => {
+ // 格式化日期为 MM-DD
+ const date = new Date(item.date)
+ return `${(date.getMonth() + 1).toString().padStart(2, '0')}-${date.getDate().toString().padStart(2, '0')}`
+ })
+ const calls = trendData.value.map(item => item.calls || 0)
+ const tokens = trendData.value.map(item => item.tokens || 0)
const option = {
title: {
@@ -51,22 +85,44 @@ function initChart() {
textStyle: { fontSize: 14, fontWeight: 500 }
},
tooltip: {
- trigger: 'axis'
+ trigger: 'axis',
+ formatter: function(params) {
+ let result = params[0].axisValue + '
'
+ params.forEach(param => {
+ result += `${param.marker} ${param.seriesName}: ${param.value.toLocaleString()}
`
+ })
+ return result
+ }
+ },
+ legend: {
+ data: ['调用次数', 'Token 消耗'],
+ top: 0,
+ right: 0
},
grid: {
left: '3%',
right: '4%',
bottom: '3%',
+ top: 50,
containLabel: true
},
xAxis: {
type: 'category',
boundaryGap: false,
- data: ['周一', '周二', '周三', '周四', '周五', '周六', '周日']
- },
- yAxis: {
- type: 'value'
+ data: dates.length > 0 ? dates : ['暂无数据']
},
+ yAxis: [
+ {
+ type: 'value',
+ name: '调用次数',
+ position: 'left'
+ },
+ {
+ type: 'value',
+ name: 'Token',
+ position: 'right'
+ }
+ ],
series: [
{
name: '调用次数',
@@ -80,7 +136,16 @@ function initChart() {
},
lineStyle: { color: '#409eff' },
itemStyle: { color: '#409eff' },
- data: [120, 132, 101, 134, 90, 230, 210]
+ data: calls.length > 0 ? calls : [0]
+ },
+ {
+ name: 'Token 消耗',
+ type: 'line',
+ yAxisIndex: 1,
+ smooth: true,
+ lineStyle: { color: '#67c23a' },
+ itemStyle: { color: '#67c23a' },
+ data: tokens.length > 0 ? tokens : [0]
}
]
}
@@ -96,6 +161,7 @@ onMounted(() => {
fetchStats()
fetchRecentLogs()
initChart()
+ fetchTrendData()
window.addEventListener('resize', handleResize)
})
@@ -113,22 +179,22 @@ onUnmounted(() => {
{{ traceId }}
+ 如需技术支持,请提供此追踪码
+