diff --git a/.drone.yml b/.drone.yml index 9082340..3ec04f3 100644 --- a/.drone.yml +++ b/.drone.yml @@ -46,10 +46,11 @@ steps: CONFIG_ENCRYPT_KEY: from_secret: config_encrypt_key commands: + - docker network create platform-network 2>/dev/null || true - docker stop platform-backend-test platform-frontend-test || true - docker rm platform-backend-test platform-frontend-test || true - - docker run -d --name platform-backend-test -p 8001:8000 --restart unless-stopped -e DATABASE_URL=$DATABASE_URL -e API_KEY=$API_KEY -e JWT_SECRET=$JWT_SECRET -e CONFIG_ENCRYPT_KEY=$CONFIG_ENCRYPT_KEY platform-backend:latest - - docker run -d --name platform-frontend-test -p 3003:80 --restart unless-stopped platform-frontend:latest + - docker run -d --name platform-backend-test --network platform-network -p 8001:8000 --restart unless-stopped -e DATABASE_URL=$DATABASE_URL -e API_KEY=$API_KEY -e JWT_SECRET=$JWT_SECRET -e CONFIG_ENCRYPT_KEY=$CONFIG_ENCRYPT_KEY platform-backend:latest + - docker run -d --name platform-frontend-test --network platform-network -p 3003:80 --restart unless-stopped platform-frontend:latest when: branch: - develop diff --git a/backend/app/config.py b/backend/app/config.py index e0a91b4..6822eb8 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -2,6 +2,7 @@ import os from functools import lru_cache from pydantic_settings import BaseSettings +from typing import Optional class Settings(BaseSettings): @@ -14,6 +15,10 @@ class Settings(BaseSettings): # 数据库 DATABASE_URL: str = "mysql+pymysql://scrm_reader:ScrmReader2024Pass@47.107.71.55:3306/new_qiqi" + # Redis + REDIS_URL: str = "redis://localhost:6379/0" + REDIS_PREFIX: str = "platform:" + # API Key(内部服务调用) API_KEY: str = "platform_api_key_2026" @@ -29,6 +34,10 @@ class Settings(BaseSettings): # 配置加密密钥 CONFIG_ENCRYPT_KEY: str = "platform_config_key_32bytes!!" + # 企业微信配置 + WECHAT_ACCESS_TOKEN_EXPIRE: int = 7000 # access_token缓存时间(秒),企微有效期7200秒 + WECHAT_JSAPI_TICKET_EXPIRE: int = 7000 # jsapi_ticket缓存时间(秒) + class Config: env_file = ".env" extra = "ignore" diff --git a/backend/app/main.py b/backend/app/main.py index e024ba1..f7b094c 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,4 +1,5 @@ """平台服务入口""" +import logging from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -9,6 +10,15 @@ from .routers.tenants import router as tenants_router from .routers.tenant_apps import router as tenant_apps_router from .routers.tenant_wechat_apps import router as tenant_wechat_apps_router from .routers.apps import router as apps_router +from .routers.wechat import router as wechat_router +from .routers.alerts import router as alerts_router +from .routers.cost import router as cost_router +from .routers.quota import router as quota_router +from .middleware import TraceMiddleware, setup_exception_handlers, RequestLoggerMiddleware +from .middleware.trace import setup_logging + +# 配置日志(包含 TraceID) +setup_logging(level=logging.INFO, include_trace=True) settings = get_settings() @@ -18,6 +28,20 @@ app = FastAPI( description="平台基础设施服务 - 统计/日志/配置管理" ) +# 配置统一异常处理 +setup_exception_handlers(app) + +# 中间件按添加的反序执行,所以: +# 1. CORS 最后添加,最先执行 +# 2. TraceMiddleware 在 RequestLoggerMiddleware 之后添加,这样先执行 +# 3. RequestLoggerMiddleware 最先添加,最后执行(此时 trace_id 已设置) + +# 请求日志中间件(自动记录到数据库) +app.add_middleware(RequestLoggerMiddleware, app_code="000-platform") + +# TraceID 追踪中间件 +app.add_middleware(TraceMiddleware, log_requests=True) + # CORS app.add_middleware( CORSMiddleware, @@ -25,6 +49,7 @@ app.add_middleware( allow_credentials=True, allow_methods=["*"], allow_headers=["*"], + expose_headers=["X-Trace-ID", "X-Response-Time"] ) # 注册路由 @@ -37,6 +62,10 @@ app.include_router(apps_router, prefix="/api") app.include_router(stats_router, prefix="/api") app.include_router(logs_router, prefix="/api") app.include_router(config_router, prefix="/api") +app.include_router(wechat_router, prefix="/api") +app.include_router(alerts_router, prefix="/api") +app.include_router(cost_router, prefix="/api") +app.include_router(quota_router, prefix="/api") @app.get("/") diff --git a/backend/app/middleware/__init__.py b/backend/app/middleware/__init__.py new file mode 100644 index 0000000..2a04008 --- /dev/null +++ b/backend/app/middleware/__init__.py @@ -0,0 +1,19 @@ +""" +中间件模块 + +提供: +- TraceID 追踪 +- 统一异常处理 +- 请求日志记录 +""" +from .trace import TraceMiddleware, get_trace_id, set_trace_id +from .exception_handler import setup_exception_handlers +from .request_logger import RequestLoggerMiddleware + +__all__ = [ + "TraceMiddleware", + "get_trace_id", + "set_trace_id", + "setup_exception_handlers", + "RequestLoggerMiddleware" +] diff --git a/backend/app/middleware/exception_handler.py b/backend/app/middleware/exception_handler.py new file mode 100644 index 0000000..946efb5 --- /dev/null +++ b/backend/app/middleware/exception_handler.py @@ -0,0 +1,128 @@ +""" +统一异常处理 + +捕获所有异常,返回统一格式的错误响应,包含 TraceID。 +""" +import logging +import traceback +from typing import Union +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException + +from .trace import get_trace_id + +logger = logging.getLogger(__name__) + + +class ErrorCode: + """错误码常量""" + BAD_REQUEST = "BAD_REQUEST" + UNAUTHORIZED = "UNAUTHORIZED" + FORBIDDEN = "FORBIDDEN" + NOT_FOUND = "NOT_FOUND" + VALIDATION_ERROR = "VALIDATION_ERROR" + RATE_LIMITED = "RATE_LIMITED" + INTERNAL_ERROR = "INTERNAL_ERROR" + SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE" + GATEWAY_ERROR = "GATEWAY_ERROR" + + +STATUS_TO_ERROR_CODE = { + 400: ErrorCode.BAD_REQUEST, + 401: ErrorCode.UNAUTHORIZED, + 403: ErrorCode.FORBIDDEN, + 404: ErrorCode.NOT_FOUND, + 422: ErrorCode.VALIDATION_ERROR, + 429: ErrorCode.RATE_LIMITED, + 500: ErrorCode.INTERNAL_ERROR, + 502: ErrorCode.GATEWAY_ERROR, + 503: ErrorCode.SERVICE_UNAVAILABLE, +} + + +def create_error_response( + status_code: int, + code: str, + message: str, + trace_id: str = None, + details: dict = None +) -> JSONResponse: + """创建统一格式的错误响应""" + if trace_id is None: + trace_id = get_trace_id() + + error_body = { + "code": code, + "message": message, + "trace_id": trace_id + } + + if details: + error_body["details"] = details + + return JSONResponse( + status_code=status_code, + content={"success": False, "error": error_body}, + headers={"X-Trace-ID": trace_id} + ) + + +async def http_exception_handler(request: Request, exc: Union[HTTPException, StarletteHTTPException]): + """处理 HTTP 异常""" + trace_id = get_trace_id() + status_code = exc.status_code + error_code = STATUS_TO_ERROR_CODE.get(status_code, ErrorCode.INTERNAL_ERROR) + message = exc.detail if isinstance(exc.detail, str) else str(exc.detail) + + logger.warning(f"[{trace_id}] HTTP {status_code}: {message}") + + return create_error_response( + status_code=status_code, + code=error_code, + message=message, + trace_id=trace_id + ) + + +async def validation_exception_handler(request: Request, exc: RequestValidationError): + """处理请求验证错误""" + trace_id = get_trace_id() + errors = exc.errors() + error_messages = [f"{'.'.join(str(l) for l in e['loc'])}: {e['msg']}" for e in errors] + + logger.warning(f"[{trace_id}] 验证错误: {error_messages}") + + return create_error_response( + status_code=422, + code=ErrorCode.VALIDATION_ERROR, + message="请求参数验证失败", + trace_id=trace_id, + details={"validation_errors": error_messages} + ) + + +async def generic_exception_handler(request: Request, exc: Exception): + """处理所有未捕获的异常""" + trace_id = get_trace_id() + + logger.error(f"[{trace_id}] 未捕获异常: {type(exc).__name__}: {exc}") + logger.error(f"[{trace_id}] 堆栈:\n{traceback.format_exc()}") + + return create_error_response( + status_code=500, + code=ErrorCode.INTERNAL_ERROR, + message="服务器内部错误,请稍后重试", + trace_id=trace_id + ) + + +def setup_exception_handlers(app: FastAPI): + """配置 FastAPI 应用的异常处理器""" + app.add_exception_handler(HTTPException, http_exception_handler) + app.add_exception_handler(StarletteHTTPException, http_exception_handler) + app.add_exception_handler(RequestValidationError, validation_exception_handler) + app.add_exception_handler(Exception, generic_exception_handler) + + logger.info("异常处理器已配置") diff --git a/backend/app/middleware/request_logger.py b/backend/app/middleware/request_logger.py new file mode 100644 index 0000000..e50bdc7 --- /dev/null +++ b/backend/app/middleware/request_logger.py @@ -0,0 +1,190 @@ +""" +请求日志中间件 + +自动将所有请求记录到数据库 platform_logs 表 +""" +import time +import logging +from typing import Optional, Set +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +from .trace import get_trace_id +from ..database import SessionLocal +from ..models.logs import PlatformLog + +logger = logging.getLogger(__name__) + + +class RequestLoggerMiddleware(BaseHTTPMiddleware): + """请求日志中间件 + + 自动记录所有请求到数据库,便于后续查询和分析 + + 使用示例: + app.add_middleware(RequestLoggerMiddleware, app_code="000-platform") + """ + + # 默认排除的路径(不记录这些请求) + DEFAULT_EXCLUDE_PATHS: Set[str] = { + "/", + "/docs", + "/redoc", + "/openapi.json", + "/api/health", + "/api/health/", + "/favicon.ico", + } + + def __init__( + self, + app, + app_code: str = "platform", + exclude_paths: Optional[Set[str]] = None, + log_request_body: bool = False, + log_response_body: bool = False, + max_body_length: int = 1000 + ): + """初始化中间件 + + Args: + app: FastAPI应用 + app_code: 应用代码,记录到日志中 + exclude_paths: 排除的路径集合,这些路径不记录日志 + log_request_body: 是否记录请求体 + log_response_body: 是否记录响应体 + max_body_length: 记录体的最大长度 + """ + super().__init__(app) + self.app_code = app_code + self.exclude_paths = exclude_paths or self.DEFAULT_EXCLUDE_PATHS + self.log_request_body = log_request_body + self.log_response_body = log_response_body + self.max_body_length = max_body_length + + async def dispatch(self, request: Request, call_next) -> Response: + path = request.url.path + + # 检查是否排除 + if self._should_exclude(path): + return await call_next(request) + + trace_id = get_trace_id() + method = request.method + start_time = time.time() + + # 获取客户端IP + client_ip = self._get_client_ip(request) + + # 获取租户ID(从查询参数) + tenant_id = request.query_params.get("tid") or request.query_params.get("tenant_id") + + # 执行请求 + response = None + error_message = None + status_code = 500 + + try: + response = await call_next(request) + status_code = response.status_code + except Exception as e: + error_message = str(e) + raise + finally: + duration_ms = int((time.time() - start_time) * 1000) + + # 异步写入数据库(不阻塞响应) + try: + self._save_log( + trace_id=trace_id, + method=method, + path=path, + status_code=status_code, + duration_ms=duration_ms, + ip_address=client_ip, + tenant_id=tenant_id, + error_message=error_message + ) + except Exception as e: + logger.error(f"Failed to save request log: {e}") + + return response + + def _should_exclude(self, path: str) -> bool: + """检查路径是否应排除""" + # 精确匹配 + if path in self.exclude_paths: + return True + + # 前缀匹配(静态文件等) + exclude_prefixes = ["/static/", "/assets/", "/_next/"] + for prefix in exclude_prefixes: + if path.startswith(prefix): + return True + + return False + + def _get_client_ip(self, request: Request) -> str: + """获取客户端真实IP""" + # 优先从代理头获取 + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + + # 直连IP + if request.client: + return request.client.host + + return "unknown" + + def _save_log( + self, + trace_id: str, + method: str, + path: str, + status_code: int, + duration_ms: int, + ip_address: str, + tenant_id: Optional[str] = None, + error_message: Optional[str] = None + ): + """保存日志到数据库""" + from datetime import datetime + + # 使用独立的数据库会话 + db = SessionLocal() + try: + # 转换 tenant_id 为整数(如果是数字字符串) + tenant_id_int = None + if tenant_id: + try: + tenant_id_int = int(tenant_id) + except (ValueError, TypeError): + tenant_id_int = None + + log_entry = PlatformLog( + log_type="request", + level="error" if status_code >= 500 else ("warn" if status_code >= 400 else "info"), + app_code=self.app_code, + tenant_id=tenant_id_int, + trace_id=trace_id, + message=f"{method} {path}" + (f" - {error_message}" if error_message else ""), + path=path, + method=method, + status_code=status_code, + duration_ms=duration_ms, + log_time=datetime.now(), # 必须设置 log_time + context={"ip": ip_address} # ip_address 放到 context 中 + ) + db.add(log_entry) + db.commit() + except Exception as e: + logger.error(f"Database error saving log: {e}") + db.rollback() + finally: + db.close() diff --git a/backend/app/middleware/trace.py b/backend/app/middleware/trace.py new file mode 100644 index 0000000..4bd02fa --- /dev/null +++ b/backend/app/middleware/trace.py @@ -0,0 +1,114 @@ +""" +TraceID 追踪中间件 + +为每个请求生成唯一的 TraceID,用于日志追踪和问题排查。 + +功能: +- 自动生成 TraceID(或从请求头获取) +- 注入到响应头 X-Trace-ID +- 提供上下文变量供日志使用 +- 支持请求耗时统计 +""" +import time +import uuid +import logging +from contextvars import ContextVar +from typing import Optional +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +logger = logging.getLogger(__name__) + +# 上下文变量存储当前请求的 TraceID +_trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None) + +# 请求头名称 +TRACE_ID_HEADER = "X-Trace-ID" +REQUEST_ID_HEADER = "X-Request-ID" + + +def get_trace_id() -> str: + """获取当前请求的 TraceID""" + trace_id = _trace_id_var.get() + return trace_id if trace_id else "no-trace" + + +def set_trace_id(trace_id: str) -> None: + """设置当前请求的 TraceID""" + _trace_id_var.set(trace_id) + + +def generate_trace_id() -> str: + """生成新的 TraceID,格式: 时间戳-随机8位""" + timestamp = int(time.time()) + random_part = uuid.uuid4().hex[:8] + return f"{timestamp}-{random_part}" + + +class TraceMiddleware(BaseHTTPMiddleware): + """TraceID 追踪中间件""" + + def __init__(self, app, log_requests: bool = True): + super().__init__(app) + self.log_requests = log_requests + + async def dispatch(self, request: Request, call_next) -> Response: + # 从请求头获取 TraceID,或生成新的 + trace_id = ( + request.headers.get(TRACE_ID_HEADER) or + request.headers.get(REQUEST_ID_HEADER) or + generate_trace_id() + ) + + set_trace_id(trace_id) + + start_time = time.time() + method = request.method + path = request.url.path + + if self.log_requests: + logger.info(f"[{trace_id}] --> {method} {path}") + + try: + response = await call_next(request) + duration_ms = int((time.time() - start_time) * 1000) + + response.headers[TRACE_ID_HEADER] = trace_id + response.headers["X-Response-Time"] = f"{duration_ms}ms" + + if self.log_requests: + logger.info(f"[{trace_id}] <-- {response.status_code} ({duration_ms}ms)") + + return response + + except Exception as e: + duration_ms = int((time.time() - start_time) * 1000) + logger.error(f"[{trace_id}] !!! 请求异常: {e} ({duration_ms}ms)") + raise + + +class TraceLogFilter(logging.Filter): + """日志过滤器:自动添加 TraceID""" + + def filter(self, record): + record.trace_id = get_trace_id() + return True + + +def setup_logging(level: int = logging.INFO, include_trace: bool = True): + """配置日志格式""" + if include_trace: + format_str = "%(asctime)s [%(trace_id)s] %(levelname)s %(name)s: %(message)s" + else: + format_str = "%(asctime)s %(levelname)s %(name)s: %(message)s" + + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter(format_str, datefmt="%Y-%m-%d %H:%M:%S")) + + if include_trace: + handler.addFilter(TraceLogFilter()) + + root_logger = logging.getLogger() + root_logger.setLevel(level) + root_logger.handlers = [handler] diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index f610327..63c7c17 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -5,6 +5,8 @@ from .tenant_wechat_app import TenantWechatApp from .app import App from .stats import AICallEvent, TenantUsageDaily from .logs import PlatformLog +from .alert import AlertRule, AlertRecord, NotificationChannel +from .pricing import ModelPricing, TenantBilling __all__ = [ "Tenant", @@ -15,5 +17,10 @@ __all__ = [ "App", "AICallEvent", "TenantUsageDaily", - "PlatformLog" + "PlatformLog", + "AlertRule", + "AlertRecord", + "NotificationChannel", + "ModelPricing", + "TenantBilling" ] diff --git a/backend/app/models/alert.py b/backend/app/models/alert.py new file mode 100644 index 0000000..48e44de --- /dev/null +++ b/backend/app/models/alert.py @@ -0,0 +1,108 @@ +"""告警相关模型""" +from datetime import datetime +from sqlalchemy import Column, Integer, BigInteger, String, Text, Enum, SmallInteger, JSON, TIMESTAMP +from ..database import Base + + +class AlertRule(Base): + """告警规则表""" + __tablename__ = "platform_alert_rules" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False) # 规则名称 + description = Column(Text) # 规则描述 + + # 规则类型 + rule_type = Column(Enum( + 'error_rate', # 错误率告警 + 'call_count', # 调用次数告警 + 'token_usage', # Token使用量告警 + 'cost_threshold', # 费用阈值告警 + 'latency', # 延迟告警 + 'custom' # 自定义告警 + ), nullable=False) + + # 作用范围 + scope_type = Column(Enum('global', 'tenant', 'app'), default='global') # 作用范围类型 + scope_value = Column(String(100)) # 作用范围值,如租户ID或应用代码 + + # 告警条件 + condition = Column(JSON, nullable=False) # 告警条件配置 + # 示例: {"metric": "error_count", "operator": ">", "threshold": 10, "window": "5m"} + + # 通知配置 + notification_channels = Column(JSON) # 通知渠道列表 + # 示例: [{"type": "wechat_bot", "webhook": "https://..."}, {"type": "email", "to": ["a@b.com"]}] + + # 告警限制 + cooldown_minutes = Column(Integer, default=30) # 冷却时间(分钟),避免重复告警 + max_alerts_per_day = Column(Integer, default=10) # 每天最大告警次数 + + # 状态 + status = Column(SmallInteger, default=1) # 0-禁用 1-启用 + priority = Column(Enum('low', 'medium', 'high', 'critical'), default='medium') # 优先级 + + created_at = Column(TIMESTAMP, default=datetime.now) + updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now) + + +class AlertRecord(Base): + """告警记录表""" + __tablename__ = "platform_alert_records" + + id = Column(BigInteger, primary_key=True, autoincrement=True) + rule_id = Column(Integer, nullable=False, index=True) # 关联的规则ID + rule_name = Column(String(100)) # 规则名称(冗余,便于查询) + + # 告警信息 + alert_type = Column(String(50), nullable=False) # 告警类型 + severity = Column(Enum('info', 'warning', 'error', 'critical'), default='warning') # 严重程度 + title = Column(String(200), nullable=False) # 告警标题 + message = Column(Text) # 告警详情 + + # 上下文 + tenant_id = Column(String(50), index=True) # 相关租户 + app_code = Column(String(50)) # 相关应用 + metric_value = Column(String(100)) # 触发告警的指标值 + threshold_value = Column(String(100)) # 阈值 + + # 通知状态 + notification_status = Column(Enum('pending', 'sent', 'failed', 'skipped'), default='pending') + notification_result = Column(JSON) # 通知结果 + notified_at = Column(TIMESTAMP) # 通知时间 + + # 处理状态 + status = Column(Enum('active', 'acknowledged', 'resolved', 'ignored'), default='active') + acknowledged_by = Column(String(100)) # 确认人 + acknowledged_at = Column(TIMESTAMP) # 确认时间 + resolved_at = Column(TIMESTAMP) # 解决时间 + + created_at = Column(TIMESTAMP, default=datetime.now, index=True) + + +class NotificationChannel(Base): + """通知渠道配置表""" + __tablename__ = "platform_notification_channels" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False) # 渠道名称 + + channel_type = Column(Enum( + 'wechat_bot', # 企微机器人 + 'email', # 邮件 + 'sms', # 短信 + 'webhook', # Webhook + 'dingtalk' # 钉钉 + ), nullable=False) + + # 渠道配置 + config = Column(JSON, nullable=False) + # wechat_bot: {"webhook": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx"} + # email: {"smtp_host": "...", "smtp_port": 465, "username": "...", "password_encrypted": "..."} + # webhook: {"url": "https://...", "method": "POST", "headers": {...}} + + # 状态 + status = Column(SmallInteger, default=1) # 0-禁用 1-启用 + + created_at = Column(TIMESTAMP, default=datetime.now) + updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now) diff --git a/backend/app/models/pricing.py b/backend/app/models/pricing.py new file mode 100644 index 0000000..683eee6 --- /dev/null +++ b/backend/app/models/pricing.py @@ -0,0 +1,70 @@ +"""费用计算相关模型""" +from datetime import datetime +from decimal import Decimal +from sqlalchemy import Column, Integer, String, Text, DECIMAL, SmallInteger, JSON, TIMESTAMP +from ..database import Base + + +class ModelPricing(Base): + """模型价格配置表""" + __tablename__ = "platform_model_pricing" + + id = Column(Integer, primary_key=True, autoincrement=True) + + # 模型标识 + model_name = Column(String(100), nullable=False, unique=True) # 模型名称,如 gpt-4, claude-3-opus + provider = Column(String(50)) # 提供商,如 openai, anthropic, 4sapi + display_name = Column(String(100)) # 显示名称 + + # 价格配置(单位:元/1K tokens) + input_price_per_1k = Column(DECIMAL(10, 6), default=0) # 输入价格 + output_price_per_1k = Column(DECIMAL(10, 6), default=0) # 输出价格 + + # 或固定价格(每次调用) + fixed_price_per_call = Column(DECIMAL(10, 6), default=0) + + # 计费方式 + pricing_type = Column(String(20), default='token') # token / call / hybrid + + # 备注 + description = Column(Text) + + status = Column(SmallInteger, default=1) # 0-禁用 1-启用 + created_at = Column(TIMESTAMP, default=datetime.now) + updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now) + + +class TenantBilling(Base): + """租户账单表(月度汇总)""" + __tablename__ = "platform_tenant_billing" + + id = Column(Integer, primary_key=True, autoincrement=True) + + tenant_id = Column(String(50), nullable=False, index=True) + billing_month = Column(String(7), nullable=False) # 格式: YYYY-MM + + # 使用量统计 + total_calls = Column(Integer, default=0) # 总调用次数 + total_input_tokens = Column(Integer, default=0) # 总输入token + total_output_tokens = Column(Integer, default=0) # 总输出token + + # 费用统计 + total_cost = Column(DECIMAL(12, 4), default=0) # 总费用 + + # 按模型分类的费用明细 + cost_by_model = Column(JSON) # {"gpt-4": 10.5, "claude-3": 5.2} + + # 按应用分类的费用明细 + cost_by_app = Column(JSON) # {"tools": 8.0, "interview": 7.7} + + # 状态 + status = Column(String(20), default='pending') # pending / confirmed / paid + + created_at = Column(TIMESTAMP, default=datetime.now) + updated_at = Column(TIMESTAMP, default=datetime.now, onupdate=datetime.now) + + class Config: + # 联合唯一索引 + __table_args__ = ( + {'mysql_charset': 'utf8mb4'} + ) diff --git a/backend/app/routers/alerts.py b/backend/app/routers/alerts.py new file mode 100644 index 0000000..8a3b91e --- /dev/null +++ b/backend/app/routers/alerts.py @@ -0,0 +1,430 @@ +"""告警管理路由""" +from typing import Optional, List +from datetime import datetime, timedelta +from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks +from pydantic import BaseModel +from sqlalchemy.orm import Session +from sqlalchemy import desc, func + +from ..database import get_db +from ..models.alert import AlertRule, AlertRecord, NotificationChannel +from ..services.alert import AlertService +from .auth import get_current_user, require_operator +from ..models.user import User + +router = APIRouter(prefix="/alerts", tags=["告警管理"]) + + +# ============= Schemas ============= + +class AlertRuleCreate(BaseModel): + name: str + description: Optional[str] = None + rule_type: str + scope_type: str = "global" + scope_value: Optional[str] = None + condition: dict + notification_channels: Optional[List[dict]] = None + cooldown_minutes: int = 30 + max_alerts_per_day: int = 10 + priority: str = "medium" + + +class AlertRuleUpdate(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + condition: Optional[dict] = None + notification_channels: Optional[List[dict]] = None + cooldown_minutes: Optional[int] = None + max_alerts_per_day: Optional[int] = None + priority: Optional[str] = None + status: Optional[int] = None + + +class NotificationChannelCreate(BaseModel): + name: str + channel_type: str + config: dict + + +class NotificationChannelUpdate(BaseModel): + name: Optional[str] = None + config: Optional[dict] = None + status: Optional[int] = None + + +# ============= Alert Rules API ============= + +@router.get("/rules") +async def list_alert_rules( + page: int = Query(1, ge=1), + size: int = Query(20, ge=1, le=100), + rule_type: Optional[str] = None, + status: Optional[int] = None, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取告警规则列表""" + query = db.query(AlertRule) + + if rule_type: + query = query.filter(AlertRule.rule_type == rule_type) + if status is not None: + query = query.filter(AlertRule.status == status) + + total = query.count() + rules = query.order_by(desc(AlertRule.created_at)).offset((page - 1) * size).limit(size).all() + + return { + "total": total, + "page": page, + "size": size, + "items": [format_rule(r) for r in rules] + } + + +@router.get("/rules/{rule_id}") +async def get_alert_rule( + rule_id: int, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取告警规则详情""" + rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first() + if not rule: + raise HTTPException(status_code=404, detail="告警规则不存在") + return format_rule(rule) + + +@router.post("/rules") +async def create_alert_rule( + data: AlertRuleCreate, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """创建告警规则""" + rule = AlertRule( + name=data.name, + description=data.description, + rule_type=data.rule_type, + scope_type=data.scope_type, + scope_value=data.scope_value, + condition=data.condition, + notification_channels=data.notification_channels, + cooldown_minutes=data.cooldown_minutes, + max_alerts_per_day=data.max_alerts_per_day, + priority=data.priority, + status=1 + ) + db.add(rule) + db.commit() + db.refresh(rule) + + return {"success": True, "id": rule.id} + + +@router.put("/rules/{rule_id}") +async def update_alert_rule( + rule_id: int, + data: AlertRuleUpdate, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """更新告警规则""" + rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first() + if not rule: + raise HTTPException(status_code=404, detail="告警规则不存在") + + update_data = data.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(rule, key, value) + + db.commit() + return {"success": True} + + +@router.delete("/rules/{rule_id}") +async def delete_alert_rule( + rule_id: int, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """删除告警规则""" + rule = db.query(AlertRule).filter(AlertRule.id == rule_id).first() + if not rule: + raise HTTPException(status_code=404, detail="告警规则不存在") + + db.delete(rule) + db.commit() + return {"success": True} + + +# ============= Alert Records API ============= + +@router.get("/records") +async def list_alert_records( + page: int = Query(1, ge=1), + size: int = Query(20, ge=1, le=100), + status: Optional[str] = None, + severity: Optional[str] = None, + alert_type: Optional[str] = None, + tenant_id: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取告警记录列表""" + query = db.query(AlertRecord) + + if status: + query = query.filter(AlertRecord.status == status) + if severity: + query = query.filter(AlertRecord.severity == severity) + if alert_type: + query = query.filter(AlertRecord.alert_type == alert_type) + if tenant_id: + query = query.filter(AlertRecord.tenant_id == tenant_id) + if start_date: + query = query.filter(AlertRecord.created_at >= start_date) + if end_date: + query = query.filter(AlertRecord.created_at <= end_date + " 23:59:59") + + total = query.count() + records = query.order_by(desc(AlertRecord.created_at)).offset((page - 1) * size).limit(size).all() + + return { + "total": total, + "page": page, + "size": size, + "items": [format_record(r) for r in records] + } + + +@router.get("/records/summary") +async def get_alert_summary( + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取告警摘要统计""" + today = datetime.now().date() + week_start = today - timedelta(days=7) + + # 今日告警数 + today_count = db.query(func.count(AlertRecord.id)).filter( + func.date(AlertRecord.created_at) == today + ).scalar() + + # 本周告警数 + week_count = db.query(func.count(AlertRecord.id)).filter( + func.date(AlertRecord.created_at) >= week_start + ).scalar() + + # 活跃告警数 + active_count = db.query(func.count(AlertRecord.id)).filter( + AlertRecord.status == 'active' + ).scalar() + + # 按严重程度统计 + severity_stats = db.query( + AlertRecord.severity, + func.count(AlertRecord.id) + ).filter( + func.date(AlertRecord.created_at) >= week_start + ).group_by(AlertRecord.severity).all() + + return { + "today_count": today_count, + "week_count": week_count, + "active_count": active_count, + "by_severity": {s: c for s, c in severity_stats} + } + + +@router.get("/records/{record_id}") +async def get_alert_record( + record_id: int, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取告警记录详情""" + record = db.query(AlertRecord).filter(AlertRecord.id == record_id).first() + if not record: + raise HTTPException(status_code=404, detail="告警记录不存在") + return format_record(record) + + +@router.post("/records/{record_id}/acknowledge") +async def acknowledge_alert( + record_id: int, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """确认告警""" + service = AlertService(db) + record = service.acknowledge_alert(record_id, user.username) + if not record: + raise HTTPException(status_code=404, detail="告警记录不存在") + return {"success": True} + + +@router.post("/records/{record_id}/resolve") +async def resolve_alert( + record_id: int, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """解决告警""" + service = AlertService(db) + record = service.resolve_alert(record_id) + if not record: + raise HTTPException(status_code=404, detail="告警记录不存在") + return {"success": True} + + +# ============= Check Alerts API ============= + +@router.post("/check") +async def trigger_alert_check( + background_tasks: BackgroundTasks, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """手动触发告警检查""" + service = AlertService(db) + alerts = await service.check_all_rules() + + # 异步发送通知 + for alert in alerts: + rule = db.query(AlertRule).filter(AlertRule.id == alert.rule_id).first() + if rule: + background_tasks.add_task(service.send_notification, alert, rule) + + return { + "success": True, + "triggered_count": len(alerts), + "alerts": [format_record(a) for a in alerts] + } + + +# ============= Notification Channels API ============= + +@router.get("/channels") +async def list_notification_channels( + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取通知渠道列表""" + channels = db.query(NotificationChannel).order_by(desc(NotificationChannel.created_at)).all() + return [format_channel(c) for c in channels] + + +@router.post("/channels") +async def create_notification_channel( + data: NotificationChannelCreate, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """创建通知渠道""" + channel = NotificationChannel( + name=data.name, + channel_type=data.channel_type, + config=data.config, + status=1 + ) + db.add(channel) + db.commit() + db.refresh(channel) + + return {"success": True, "id": channel.id} + + +@router.put("/channels/{channel_id}") +async def update_notification_channel( + channel_id: int, + data: NotificationChannelUpdate, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """更新通知渠道""" + channel = db.query(NotificationChannel).filter(NotificationChannel.id == channel_id).first() + if not channel: + raise HTTPException(status_code=404, detail="通知渠道不存在") + + update_data = data.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(channel, key, value) + + db.commit() + return {"success": True} + + +@router.delete("/channels/{channel_id}") +async def delete_notification_channel( + channel_id: int, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """删除通知渠道""" + channel = db.query(NotificationChannel).filter(NotificationChannel.id == channel_id).first() + if not channel: + raise HTTPException(status_code=404, detail="通知渠道不存在") + + db.delete(channel) + db.commit() + return {"success": True} + + +# ============= Helper Functions ============= + +def format_rule(rule: AlertRule) -> dict: + return { + "id": rule.id, + "name": rule.name, + "description": rule.description, + "rule_type": rule.rule_type, + "scope_type": rule.scope_type, + "scope_value": rule.scope_value, + "condition": rule.condition, + "notification_channels": rule.notification_channels, + "cooldown_minutes": rule.cooldown_minutes, + "max_alerts_per_day": rule.max_alerts_per_day, + "priority": rule.priority, + "status": rule.status, + "created_at": rule.created_at, + "updated_at": rule.updated_at + } + + +def format_record(record: AlertRecord) -> dict: + return { + "id": record.id, + "rule_id": record.rule_id, + "rule_name": record.rule_name, + "alert_type": record.alert_type, + "severity": record.severity, + "title": record.title, + "message": record.message, + "tenant_id": record.tenant_id, + "app_code": record.app_code, + "metric_value": record.metric_value, + "threshold_value": record.threshold_value, + "notification_status": record.notification_status, + "status": record.status, + "acknowledged_by": record.acknowledged_by, + "acknowledged_at": record.acknowledged_at, + "resolved_at": record.resolved_at, + "created_at": record.created_at + } + + +def format_channel(channel: NotificationChannel) -> dict: + return { + "id": channel.id, + "name": channel.name, + "channel_type": channel.channel_type, + "config": channel.config, + "status": channel.status, + "created_at": channel.created_at, + "updated_at": channel.updated_at + } diff --git a/backend/app/routers/cost.py b/backend/app/routers/cost.py new file mode 100644 index 0000000..239f9e4 --- /dev/null +++ b/backend/app/routers/cost.py @@ -0,0 +1,333 @@ +"""费用管理路由""" +from typing import Optional, List +from decimal import Decimal +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy.orm import Session +from sqlalchemy import desc + +from ..database import get_db +from ..models.pricing import ModelPricing, TenantBilling +from ..services.cost import CostCalculator +from .auth import get_current_user, require_operator +from ..models.user import User + +router = APIRouter(prefix="/cost", tags=["费用管理"]) + + +# ============= Schemas ============= + +class ModelPricingCreate(BaseModel): + model_name: str + provider: Optional[str] = None + display_name: Optional[str] = None + input_price_per_1k: float = 0 + output_price_per_1k: float = 0 + fixed_price_per_call: float = 0 + pricing_type: str = "token" + description: Optional[str] = None + + +class ModelPricingUpdate(BaseModel): + provider: Optional[str] = None + display_name: Optional[str] = None + input_price_per_1k: Optional[float] = None + output_price_per_1k: Optional[float] = None + fixed_price_per_call: Optional[float] = None + pricing_type: Optional[str] = None + description: Optional[str] = None + status: Optional[int] = None + + +class CostCalculateRequest(BaseModel): + model_name: str + input_tokens: int = 0 + output_tokens: int = 0 + + +# ============= Model Pricing API ============= + +@router.get("/pricing") +async def list_model_pricing( + page: int = Query(1, ge=1), + size: int = Query(50, ge=1, le=100), + provider: Optional[str] = None, + status: Optional[int] = None, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取模型价格配置列表""" + query = db.query(ModelPricing) + + if provider: + query = query.filter(ModelPricing.provider == provider) + if status is not None: + query = query.filter(ModelPricing.status == status) + + total = query.count() + items = query.order_by(ModelPricing.model_name).offset((page - 1) * size).limit(size).all() + + return { + "total": total, + "page": page, + "size": size, + "items": [format_pricing(p) for p in items] + } + + +@router.get("/pricing/{pricing_id}") +async def get_model_pricing( + pricing_id: int, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取模型价格详情""" + pricing = db.query(ModelPricing).filter(ModelPricing.id == pricing_id).first() + if not pricing: + raise HTTPException(status_code=404, detail="模型价格配置不存在") + return format_pricing(pricing) + + +@router.post("/pricing") +async def create_model_pricing( + data: ModelPricingCreate, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """创建模型价格配置""" + # 检查是否已存在 + existing = db.query(ModelPricing).filter(ModelPricing.model_name == data.model_name).first() + if existing: + raise HTTPException(status_code=400, detail="该模型价格配置已存在") + + pricing = ModelPricing( + model_name=data.model_name, + provider=data.provider, + display_name=data.display_name, + input_price_per_1k=Decimal(str(data.input_price_per_1k)), + output_price_per_1k=Decimal(str(data.output_price_per_1k)), + fixed_price_per_call=Decimal(str(data.fixed_price_per_call)), + pricing_type=data.pricing_type, + description=data.description, + status=1 + ) + db.add(pricing) + db.commit() + db.refresh(pricing) + + return {"success": True, "id": pricing.id} + + +@router.put("/pricing/{pricing_id}") +async def update_model_pricing( + pricing_id: int, + data: ModelPricingUpdate, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """更新模型价格配置""" + pricing = db.query(ModelPricing).filter(ModelPricing.id == pricing_id).first() + if not pricing: + raise HTTPException(status_code=404, detail="模型价格配置不存在") + + update_data = data.model_dump(exclude_unset=True) + + # 转换价格字段 + for field in ['input_price_per_1k', 'output_price_per_1k', 'fixed_price_per_call']: + if field in update_data and update_data[field] is not None: + update_data[field] = Decimal(str(update_data[field])) + + for key, value in update_data.items(): + setattr(pricing, key, value) + + db.commit() + return {"success": True} + + +@router.delete("/pricing/{pricing_id}") +async def delete_model_pricing( + pricing_id: int, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """删除模型价格配置""" + pricing = db.query(ModelPricing).filter(ModelPricing.id == pricing_id).first() + if not pricing: + raise HTTPException(status_code=404, detail="模型价格配置不存在") + + db.delete(pricing) + db.commit() + return {"success": True} + + +# ============= Cost Calculation API ============= + +@router.post("/calculate") +async def calculate_cost( + request: CostCalculateRequest, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """计算调用费用""" + calculator = CostCalculator(db) + cost = calculator.calculate_cost( + model_name=request.model_name, + input_tokens=request.input_tokens, + output_tokens=request.output_tokens + ) + + return { + "model": request.model_name, + "input_tokens": request.input_tokens, + "output_tokens": request.output_tokens, + "cost": float(cost) + } + + +@router.get("/summary") +async def get_cost_summary( + tenant_id: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取费用汇总""" + calculator = CostCalculator(db) + return calculator.get_cost_summary( + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date + ) + + +@router.get("/by-tenant") +async def get_cost_by_tenant( + start_date: Optional[str] = None, + end_date: Optional[str] = None, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """按租户统计费用""" + calculator = CostCalculator(db) + return calculator.get_cost_by_tenant( + start_date=start_date, + end_date=end_date + ) + + +@router.get("/by-model") +async def get_cost_by_model( + tenant_id: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """按模型统计费用""" + calculator = CostCalculator(db) + return calculator.get_cost_by_model( + tenant_id=tenant_id, + start_date=start_date, + end_date=end_date + ) + + +# ============= Billing API ============= + +@router.get("/billing") +async def list_billing( + page: int = Query(1, ge=1), + size: int = Query(20, ge=1, le=100), + tenant_id: Optional[str] = None, + billing_month: Optional[str] = None, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取账单列表""" + query = db.query(TenantBilling) + + if tenant_id: + query = query.filter(TenantBilling.tenant_id == tenant_id) + if billing_month: + query = query.filter(TenantBilling.billing_month == billing_month) + + total = query.count() + items = query.order_by(desc(TenantBilling.billing_month)).offset((page - 1) * size).limit(size).all() + + return { + "total": total, + "page": page, + "size": size, + "items": [format_billing(b) for b in items] + } + + +@router.post("/billing/generate") +async def generate_billing( + tenant_id: str = Query(...), + billing_month: str = Query(..., description="格式: YYYY-MM"), + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """生成月度账单""" + calculator = CostCalculator(db) + billing = calculator.generate_monthly_billing(tenant_id, billing_month) + + return { + "success": True, + "billing": format_billing(billing) + } + + +@router.post("/recalculate") +async def recalculate_costs( + start_date: Optional[str] = None, + end_date: Optional[str] = None, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """重新计算事件费用""" + calculator = CostCalculator(db) + updated = calculator.update_event_costs(start_date, end_date) + + return { + "success": True, + "updated_count": updated + } + + +# ============= Helper Functions ============= + +def format_pricing(pricing: ModelPricing) -> dict: + return { + "id": pricing.id, + "model_name": pricing.model_name, + "provider": pricing.provider, + "display_name": pricing.display_name, + "input_price_per_1k": float(pricing.input_price_per_1k or 0), + "output_price_per_1k": float(pricing.output_price_per_1k or 0), + "fixed_price_per_call": float(pricing.fixed_price_per_call or 0), + "pricing_type": pricing.pricing_type, + "description": pricing.description, + "status": pricing.status, + "created_at": pricing.created_at, + "updated_at": pricing.updated_at + } + + +def format_billing(billing: TenantBilling) -> dict: + return { + "id": billing.id, + "tenant_id": billing.tenant_id, + "billing_month": billing.billing_month, + "total_calls": billing.total_calls, + "total_input_tokens": billing.total_input_tokens, + "total_output_tokens": billing.total_output_tokens, + "total_cost": float(billing.total_cost or 0), + "cost_by_model": billing.cost_by_model, + "cost_by_app": billing.cost_by_app, + "status": billing.status, + "created_at": billing.created_at, + "updated_at": billing.updated_at + } diff --git a/backend/app/routers/logs.py b/backend/app/routers/logs.py index d64458c..d2da06a 100644 --- a/backend/app/routers/logs.py +++ b/backend/app/routers/logs.py @@ -1,6 +1,10 @@ """日志路由""" +import csv +import io from typing import Optional +from datetime import datetime from fastapi import APIRouter, Depends, Header, HTTPException, Query +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from sqlalchemy import desc @@ -13,6 +17,14 @@ from ..services.auth import decode_token router = APIRouter(prefix="/logs", tags=["logs"]) settings = get_settings() +# 尝试导入openpyxl +try: + from openpyxl import Workbook + from openpyxl.styles import Font, Alignment, PatternFill + OPENPYXL_AVAILABLE = True +except ImportError: + OPENPYXL_AVAILABLE = False + def get_current_user_optional(authorization: Optional[str] = Header(None)): """可选的用户认证""" @@ -113,3 +125,154 @@ async def query_logs( for item in items ] } + + +@router.get("/export") +async def export_logs( + format: str = Query("csv", description="导出格式: csv 或 excel"), + log_type: Optional[str] = None, + level: Optional[str] = None, + app_code: Optional[str] = None, + tenant_id: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + limit: int = Query(10000, ge=1, le=100000, description="最大导出记录数"), + db: Session = Depends(get_db), + user = Depends(get_current_user_optional) +): + """导出日志 + + 支持CSV和Excel格式,最多导出10万条记录 + """ + query = db.query(PlatformLog) + + if log_type: + query = query.filter(PlatformLog.log_type == log_type) + if level: + query = query.filter(PlatformLog.level == level) + if app_code: + query = query.filter(PlatformLog.app_code == app_code) + if tenant_id: + query = query.filter(PlatformLog.tenant_id == tenant_id) + if start_date: + query = query.filter(PlatformLog.log_time >= start_date) + if end_date: + query = query.filter(PlatformLog.log_time <= end_date + " 23:59:59") + + items = query.order_by(desc(PlatformLog.log_time)).limit(limit).all() + + if format.lower() == "excel": + return export_excel(items) + else: + return export_csv(items) + + +def export_csv(logs: list) -> StreamingResponse: + """导出为CSV格式""" + output = io.StringIO() + writer = csv.writer(output) + + # 写入表头 + headers = [ + "ID", "类型", "级别", "应用", "租户", "Trace ID", + "消息", "路径", "方法", "状态码", "耗时(ms)", + "IP地址", "时间" + ] + writer.writerow(headers) + + # 写入数据 + for log in logs: + writer.writerow([ + log.id, + log.log_type, + log.level, + log.app_code or "", + log.tenant_id or "", + log.trace_id or "", + log.message or "", + log.path or "", + log.method or "", + log.status_code or "", + log.duration_ms or "", + log.ip_address or "", + str(log.log_time) if log.log_time else "" + ]) + + output.seek(0) + + # 生成文件名 + filename = f"logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" + + return StreamingResponse( + iter([output.getvalue()]), + media_type="text/csv", + headers={ + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Type": "text/csv; charset=utf-8-sig" + } + ) + + +def export_excel(logs: list) -> StreamingResponse: + """导出为Excel格式""" + if not OPENPYXL_AVAILABLE: + raise HTTPException(status_code=400, detail="Excel导出功能不可用,请安装openpyxl") + + wb = Workbook() + ws = wb.active + ws.title = "日志导出" + + # 表头样式 + header_font = Font(bold=True, color="FFFFFF") + header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid") + header_alignment = Alignment(horizontal="center", vertical="center") + + # 写入表头 + headers = [ + "ID", "类型", "级别", "应用", "租户", "Trace ID", + "消息", "路径", "方法", "状态码", "耗时(ms)", + "IP地址", "时间" + ] + + for col, header in enumerate(headers, 1): + cell = ws.cell(row=1, column=col, value=header) + cell.font = header_font + cell.fill = header_fill + cell.alignment = header_alignment + + # 写入数据 + for row, log in enumerate(logs, 2): + ws.cell(row=row, column=1, value=log.id) + ws.cell(row=row, column=2, value=log.log_type) + ws.cell(row=row, column=3, value=log.level) + ws.cell(row=row, column=4, value=log.app_code or "") + ws.cell(row=row, column=5, value=log.tenant_id or "") + ws.cell(row=row, column=6, value=log.trace_id or "") + ws.cell(row=row, column=7, value=log.message or "") + ws.cell(row=row, column=8, value=log.path or "") + ws.cell(row=row, column=9, value=log.method or "") + ws.cell(row=row, column=10, value=log.status_code or "") + ws.cell(row=row, column=11, value=log.duration_ms or "") + ws.cell(row=row, column=12, value=log.ip_address or "") + ws.cell(row=row, column=13, value=str(log.log_time) if log.log_time else "") + + # 调整列宽 + column_widths = [8, 10, 10, 12, 12, 36, 50, 30, 8, 10, 10, 15, 20] + for col, width in enumerate(column_widths, 1): + ws.column_dimensions[chr(64 + col)].width = width + + # 保存到内存 + output = io.BytesIO() + wb.save(output) + output.seek(0) + + # 生成文件名 + filename = f"logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.xlsx" + + return StreamingResponse( + iter([output.getvalue()]), + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={ + "Content-Disposition": f'attachment; filename="{filename}"' + } + ) diff --git a/backend/app/routers/quota.py b/backend/app/routers/quota.py new file mode 100644 index 0000000..3be7838 --- /dev/null +++ b/backend/app/routers/quota.py @@ -0,0 +1,264 @@ +"""配额管理路由""" +from typing import Optional, Dict, Any +from datetime import date +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy.orm import Session +from sqlalchemy import desc + +from ..database import get_db +from ..models.tenant import Subscription +from ..services.quota import QuotaService +from .auth import get_current_user, require_operator +from ..models.user import User + +router = APIRouter(prefix="/quota", tags=["配额管理"]) + + +# ============= Schemas ============= + +class QuotaConfigUpdate(BaseModel): + daily_calls: int = 0 + daily_tokens: int = 0 + monthly_calls: int = 0 + monthly_tokens: int = 0 + monthly_cost: float = 0 + concurrent_calls: int = 0 + + +class SubscriptionCreate(BaseModel): + tenant_id: str + app_code: str + start_date: Optional[str] = None + end_date: Optional[str] = None + quota: QuotaConfigUpdate + + +class SubscriptionUpdate(BaseModel): + start_date: Optional[str] = None + end_date: Optional[str] = None + quota: Optional[QuotaConfigUpdate] = None + status: Optional[str] = None + + +# ============= Quota Check API ============= + +@router.get("/check") +async def check_quota( + tenant_id: str = Query(..., alias="tid"), + app_code: str = Query(..., alias="aid"), + estimated_tokens: int = Query(0), + db: Session = Depends(get_db) +): + """检查配额是否足够 + + 用于调用前检查,返回是否允许继续调用 + """ + service = QuotaService(db) + result = service.check_quota(tenant_id, app_code, estimated_tokens) + + return { + "allowed": result.allowed, + "reason": result.reason, + "quota_type": result.quota_type, + "limit": result.limit, + "used": result.used, + "remaining": result.remaining + } + + +@router.get("/summary") +async def get_quota_summary( + tenant_id: str = Query(...), + app_code: str = Query(...), + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取配额使用汇总""" + service = QuotaService(db) + return service.get_quota_summary(tenant_id, app_code) + + +@router.get("/usage") +async def get_quota_usage( + tenant_id: str = Query(...), + app_code: str = Query(...), + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取配额使用情况""" + service = QuotaService(db) + usage = service.get_usage(tenant_id, app_code) + + return { + "daily_calls": usage.daily_calls, + "daily_tokens": usage.daily_tokens, + "monthly_calls": usage.monthly_calls, + "monthly_tokens": usage.monthly_tokens, + "monthly_cost": round(usage.monthly_cost, 2) + } + + +# ============= Subscription API ============= + +@router.get("/subscriptions") +async def list_subscriptions( + page: int = Query(1, ge=1), + size: int = Query(20, ge=1, le=100), + tenant_id: Optional[str] = None, + app_code: Optional[str] = None, + status: Optional[str] = None, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取订阅列表""" + query = db.query(Subscription) + + if tenant_id: + query = query.filter(Subscription.tenant_id == tenant_id) + if app_code: + query = query.filter(Subscription.app_code == app_code) + if status: + query = query.filter(Subscription.status == status) + + total = query.count() + items = query.order_by(desc(Subscription.created_at)).offset((page - 1) * size).limit(size).all() + + return { + "total": total, + "page": page, + "size": size, + "items": [format_subscription(s) for s in items] + } + + +@router.get("/subscriptions/{subscription_id}") +async def get_subscription( + subscription_id: int, + user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取订阅详情""" + subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first() + if not subscription: + raise HTTPException(status_code=404, detail="订阅不存在") + return format_subscription(subscription) + + +@router.post("/subscriptions") +async def create_subscription( + data: SubscriptionCreate, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """创建订阅""" + # 检查是否已存在 + existing = db.query(Subscription).filter( + Subscription.tenant_id == data.tenant_id, + Subscription.app_code == data.app_code, + Subscription.status == 'active' + ).first() + + if existing: + raise HTTPException(status_code=400, detail="该租户应用已有活跃订阅") + + subscription = Subscription( + tenant_id=data.tenant_id, + app_code=data.app_code, + start_date=data.start_date or date.today(), + end_date=data.end_date, + quota=data.quota.model_dump() if data.quota else {}, + status='active' + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + + return {"success": True, "id": subscription.id} + + +@router.put("/subscriptions/{subscription_id}") +async def update_subscription( + subscription_id: int, + data: SubscriptionUpdate, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """更新订阅""" + subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first() + if not subscription: + raise HTTPException(status_code=404, detail="订阅不存在") + + if data.start_date: + subscription.start_date = data.start_date + if data.end_date: + subscription.end_date = data.end_date + if data.quota: + subscription.quota = data.quota.model_dump() + if data.status: + subscription.status = data.status + + db.commit() + + # 清除缓存 + service = QuotaService(db) + cache_key = f"quota:config:{subscription.tenant_id}:{subscription.app_code}" + service._cache.delete(cache_key) + + return {"success": True} + + +@router.delete("/subscriptions/{subscription_id}") +async def delete_subscription( + subscription_id: int, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """删除订阅""" + subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first() + if not subscription: + raise HTTPException(status_code=404, detail="订阅不存在") + + db.delete(subscription) + db.commit() + + return {"success": True} + + +@router.put("/subscriptions/{subscription_id}/quota") +async def update_quota( + subscription_id: int, + data: QuotaConfigUpdate, + user: User = Depends(require_operator), + db: Session = Depends(get_db) +): + """更新订阅配额""" + subscription = db.query(Subscription).filter(Subscription.id == subscription_id).first() + if not subscription: + raise HTTPException(status_code=404, detail="订阅不存在") + + subscription.quota = data.model_dump() + db.commit() + + # 清除缓存 + service = QuotaService(db) + cache_key = f"quota:config:{subscription.tenant_id}:{subscription.app_code}" + service._cache.delete(cache_key) + + return {"success": True} + + +# ============= Helper Functions ============= + +def format_subscription(subscription: Subscription) -> dict: + return { + "id": subscription.id, + "tenant_id": subscription.tenant_id, + "app_code": subscription.app_code, + "start_date": str(subscription.start_date) if subscription.start_date else None, + "end_date": str(subscription.end_date) if subscription.end_date else None, + "quota": subscription.quota or {}, + "status": subscription.status, + "created_at": subscription.created_at, + "updated_at": subscription.updated_at + } diff --git a/backend/app/routers/wechat.py b/backend/app/routers/wechat.py new file mode 100644 index 0000000..e98490b --- /dev/null +++ b/backend/app/routers/wechat.py @@ -0,0 +1,264 @@ +"""企业微信JS-SDK路由""" +from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Query +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from ..database import get_db +from ..models.tenant_app import TenantApp +from ..models.tenant_wechat_app import TenantWechatApp +from ..services.wechat import WechatService, get_wechat_service_by_id + +router = APIRouter(prefix="/wechat", tags=["企业微信"]) + + +class JssdkSignatureRequest(BaseModel): + """JS-SDK签名请求""" + url: str # 当前页面URL(不含#及其后面部分) + + +class JssdkSignatureResponse(BaseModel): + """JS-SDK签名响应""" + appId: str + agentId: str + timestamp: int + nonceStr: str + signature: str + + +class OAuth2UrlRequest(BaseModel): + """OAuth2授权URL请求""" + redirect_uri: str + scope: str = "snsapi_base" + state: str = "" + + +class UserInfoRequest(BaseModel): + """用户信息请求""" + code: str + + +@router.post("/jssdk/signature") +async def get_jssdk_signature( + request: JssdkSignatureRequest, + tenant_id: str = Query(..., alias="tid"), + app_code: str = Query(..., alias="aid"), + db: Session = Depends(get_db) +): + """获取JS-SDK签名 + + 用于前端初始化企业微信JS-SDK + + Args: + request: 包含当前页面URL + tenant_id: 租户ID + app_code: 应用代码 + + Returns: + JS-SDK签名信息 + """ + # 查找租户应用配置 + tenant_app = db.query(TenantApp).filter( + TenantApp.tenant_id == tenant_id, + TenantApp.app_code == app_code, + TenantApp.status == 1 + ).first() + + if not tenant_app: + raise HTTPException(status_code=404, detail="租户应用配置不存在") + + if not tenant_app.wechat_app_id: + raise HTTPException(status_code=400, detail="该应用未配置企业微信") + + # 获取企微服务 + wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db) + if not wechat_service: + raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用") + + # 生成签名 + signature_data = await wechat_service.get_jssdk_signature(request.url) + if not signature_data: + raise HTTPException(status_code=500, detail="获取JS-SDK签名失败") + + return signature_data + + +@router.get("/jssdk/signature") +async def get_jssdk_signature_get( + url: str = Query(..., description="当前页面URL"), + tenant_id: str = Query(..., alias="tid"), + app_code: str = Query(..., alias="aid"), + db: Session = Depends(get_db) +): + """获取JS-SDK签名(GET方式) + + 方便前端JSONP调用 + """ + # 查找租户应用配置 + tenant_app = db.query(TenantApp).filter( + TenantApp.tenant_id == tenant_id, + TenantApp.app_code == app_code, + TenantApp.status == 1 + ).first() + + if not tenant_app: + raise HTTPException(status_code=404, detail="租户应用配置不存在") + + if not tenant_app.wechat_app_id: + raise HTTPException(status_code=400, detail="该应用未配置企业微信") + + # 获取企微服务 + wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db) + if not wechat_service: + raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用") + + # 生成签名 + signature_data = await wechat_service.get_jssdk_signature(url) + if not signature_data: + raise HTTPException(status_code=500, detail="获取JS-SDK签名失败") + + return signature_data + + +@router.post("/oauth2/url") +async def get_oauth2_url( + request: OAuth2UrlRequest, + tenant_id: str = Query(..., alias="tid"), + app_code: str = Query(..., alias="aid"), + db: Session = Depends(get_db) +): + """获取OAuth2授权URL + + 用于企业微信内网页获取用户身份 + """ + # 查找租户应用配置 + tenant_app = db.query(TenantApp).filter( + TenantApp.tenant_id == tenant_id, + TenantApp.app_code == app_code, + TenantApp.status == 1 + ).first() + + if not tenant_app: + raise HTTPException(status_code=404, detail="租户应用配置不存在") + + if not tenant_app.wechat_app_id: + raise HTTPException(status_code=400, detail="该应用未配置企业微信") + + # 获取企微服务 + wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db) + if not wechat_service: + raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用") + + # 生成OAuth2 URL + oauth_url = wechat_service.get_oauth2_url( + redirect_uri=request.redirect_uri, + scope=request.scope, + state=request.state + ) + + return {"url": oauth_url} + + +@router.post("/oauth2/userinfo") +async def get_user_info( + request: UserInfoRequest, + tenant_id: str = Query(..., alias="tid"), + app_code: str = Query(..., alias="aid"), + db: Session = Depends(get_db) +): + """通过OAuth2 code获取用户信息 + + 在OAuth2回调后,用code换取用户信息 + """ + # 查找租户应用配置 + tenant_app = db.query(TenantApp).filter( + TenantApp.tenant_id == tenant_id, + TenantApp.app_code == app_code, + TenantApp.status == 1 + ).first() + + if not tenant_app: + raise HTTPException(status_code=404, detail="租户应用配置不存在") + + if not tenant_app.wechat_app_id: + raise HTTPException(status_code=400, detail="该应用未配置企业微信") + + # 获取企微服务 + wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db) + if not wechat_service: + raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用") + + # 获取用户信息 + user_info = await wechat_service.get_user_info_by_code(request.code) + if not user_info: + raise HTTPException(status_code=400, detail="获取用户信息失败,code可能已过期") + + return user_info + + +@router.get("/oauth2/userinfo") +async def get_user_info_get( + code: str = Query(..., description="OAuth2回调的code"), + tenant_id: str = Query(..., alias="tid"), + app_code: str = Query(..., alias="aid"), + db: Session = Depends(get_db) +): + """通过OAuth2 code获取用户信息(GET方式)""" + # 查找租户应用配置 + tenant_app = db.query(TenantApp).filter( + TenantApp.tenant_id == tenant_id, + TenantApp.app_code == app_code, + TenantApp.status == 1 + ).first() + + if not tenant_app: + raise HTTPException(status_code=404, detail="租户应用配置不存在") + + if not tenant_app.wechat_app_id: + raise HTTPException(status_code=400, detail="该应用未配置企业微信") + + # 获取企微服务 + wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db) + if not wechat_service: + raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用") + + # 获取用户信息 + user_info = await wechat_service.get_user_info_by_code(code) + if not user_info: + raise HTTPException(status_code=400, detail="获取用户信息失败,code可能已过期") + + return user_info + + +@router.get("/user/{user_id}") +async def get_user_detail( + user_id: str, + tenant_id: str = Query(..., alias="tid"), + app_code: str = Query(..., alias="aid"), + db: Session = Depends(get_db) +): + """获取企业微信成员详细信息""" + # 查找租户应用配置 + tenant_app = db.query(TenantApp).filter( + TenantApp.tenant_id == tenant_id, + TenantApp.app_code == app_code, + TenantApp.status == 1 + ).first() + + if not tenant_app: + raise HTTPException(status_code=404, detail="租户应用配置不存在") + + if not tenant_app.wechat_app_id: + raise HTTPException(status_code=400, detail="该应用未配置企业微信") + + # 获取企微服务 + wechat_service = await get_wechat_service_by_id(tenant_app.wechat_app_id, db) + if not wechat_service: + raise HTTPException(status_code=404, detail="企业微信应用不存在或已禁用") + + # 获取用户详情 + user_detail = await wechat_service.get_user_detail(user_id) + if not user_detail: + raise HTTPException(status_code=404, detail="用户不存在") + + return user_detail diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index 60261bc..615aa2b 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -1,4 +1,22 @@ """业务服务""" from .crypto import encrypt_value, decrypt_value +from .cache import CacheService, get_cache, get_redis_client +from .wechat import WechatService, get_wechat_service_by_id +from .alert import AlertService +from .cost import CostCalculator, calculate_cost +from .quota import QuotaService, check_quota_middleware -__all__ = ["encrypt_value", "decrypt_value"] +__all__ = [ + "encrypt_value", + "decrypt_value", + "CacheService", + "get_cache", + "get_redis_client", + "WechatService", + "get_wechat_service_by_id", + "AlertService", + "CostCalculator", + "calculate_cost", + "QuotaService", + "check_quota_middleware" +] diff --git a/backend/app/services/alert.py b/backend/app/services/alert.py new file mode 100644 index 0000000..b6b79ed --- /dev/null +++ b/backend/app/services/alert.py @@ -0,0 +1,455 @@ +"""告警服务""" +import logging +from datetime import datetime, timedelta +from typing import Optional, List, Dict, Any + +import httpx +from sqlalchemy.orm import Session +from sqlalchemy import func + +from ..models.alert import AlertRule, AlertRecord, NotificationChannel +from ..models.stats import AICallEvent +from ..models.logs import PlatformLog +from .cache import get_cache + +logger = logging.getLogger(__name__) + + +class AlertService: + """告警服务 + + 提供告警规则检测、告警记录管理、通知发送等功能 + """ + + def __init__(self, db: Session): + self.db = db + self._cache = get_cache() + + async def check_all_rules(self) -> List[AlertRecord]: + """检查所有启用的告警规则 + + Returns: + 触发的告警记录列表 + """ + rules = self.db.query(AlertRule).filter(AlertRule.status == 1).all() + triggered_alerts = [] + + for rule in rules: + try: + alert = await self.check_rule(rule) + if alert: + triggered_alerts.append(alert) + except Exception as e: + logger.error(f"Failed to check rule {rule.id}: {e}") + + return triggered_alerts + + async def check_rule(self, rule: AlertRule) -> Optional[AlertRecord]: + """检查单个告警规则 + + Args: + rule: 告警规则 + + Returns: + 触发的告警记录或None + """ + # 检查冷却期 + if self._is_in_cooldown(rule): + logger.debug(f"Rule {rule.id} is in cooldown") + return None + + # 检查每日告警次数限制 + if self._exceeds_daily_limit(rule): + logger.debug(f"Rule {rule.id} exceeds daily limit") + return None + + # 根据规则类型检查 + metric_value = None + threshold_value = None + triggered = False + + condition = rule.condition or {} + + if rule.rule_type == 'error_rate': + triggered, metric_value, threshold_value = self._check_error_rate(rule, condition) + elif rule.rule_type == 'call_count': + triggered, metric_value, threshold_value = self._check_call_count(rule, condition) + elif rule.rule_type == 'token_usage': + triggered, metric_value, threshold_value = self._check_token_usage(rule, condition) + elif rule.rule_type == 'cost_threshold': + triggered, metric_value, threshold_value = self._check_cost_threshold(rule, condition) + elif rule.rule_type == 'latency': + triggered, metric_value, threshold_value = self._check_latency(rule, condition) + + if triggered: + alert = self._create_alert_record(rule, metric_value, threshold_value) + return alert + + return None + + def _is_in_cooldown(self, rule: AlertRule) -> bool: + """检查规则是否在冷却期""" + cache_key = f"alert:cooldown:{rule.id}" + return self._cache.exists(cache_key) + + def _set_cooldown(self, rule: AlertRule): + """设置规则冷却期""" + cache_key = f"alert:cooldown:{rule.id}" + self._cache.set(cache_key, "1", ttl=rule.cooldown_minutes * 60) + + def _exceeds_daily_limit(self, rule: AlertRule) -> bool: + """检查是否超过每日告警次数限制""" + today = datetime.now().date() + count = self.db.query(func.count(AlertRecord.id)).filter( + AlertRecord.rule_id == rule.id, + func.date(AlertRecord.created_at) == today + ).scalar() + return count >= rule.max_alerts_per_day + + def _check_error_rate(self, rule: AlertRule, condition: dict) -> tuple: + """检查错误率""" + window_minutes = self._parse_window(condition.get('window', '5m')) + threshold = condition.get('threshold', 10) # 错误次数阈值 + operator = condition.get('operator', '>') + + since = datetime.now() - timedelta(minutes=window_minutes) + + query = self.db.query(func.count(AICallEvent.id)).filter( + AICallEvent.created_at >= since, + AICallEvent.status == 'error' + ) + + # 应用作用范围 + if rule.scope_type == 'tenant' and rule.scope_value: + query = query.filter(AICallEvent.tenant_id == rule.scope_value) + elif rule.scope_type == 'app' and rule.scope_value: + query = query.filter(AICallEvent.app_code == rule.scope_value) + + error_count = query.scalar() or 0 + triggered = self._compare(error_count, threshold, operator) + + return triggered, str(error_count), str(threshold) + + def _check_call_count(self, rule: AlertRule, condition: dict) -> tuple: + """检查调用次数""" + window_minutes = self._parse_window(condition.get('window', '1h')) + threshold = condition.get('threshold', 1000) + operator = condition.get('operator', '>') + + since = datetime.now() - timedelta(minutes=window_minutes) + + query = self.db.query(func.count(AICallEvent.id)).filter( + AICallEvent.created_at >= since + ) + + if rule.scope_type == 'tenant' and rule.scope_value: + query = query.filter(AICallEvent.tenant_id == rule.scope_value) + elif rule.scope_type == 'app' and rule.scope_value: + query = query.filter(AICallEvent.app_code == rule.scope_value) + + call_count = query.scalar() or 0 + triggered = self._compare(call_count, threshold, operator) + + return triggered, str(call_count), str(threshold) + + def _check_token_usage(self, rule: AlertRule, condition: dict) -> tuple: + """检查Token使用量""" + window_minutes = self._parse_window(condition.get('window', '1d')) + threshold = condition.get('threshold', 100000) + operator = condition.get('operator', '>') + + since = datetime.now() - timedelta(minutes=window_minutes) + + query = self.db.query( + func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0) + ).filter( + AICallEvent.created_at >= since + ) + + if rule.scope_type == 'tenant' and rule.scope_value: + query = query.filter(AICallEvent.tenant_id == rule.scope_value) + elif rule.scope_type == 'app' and rule.scope_value: + query = query.filter(AICallEvent.app_code == rule.scope_value) + + token_usage = query.scalar() or 0 + triggered = self._compare(token_usage, threshold, operator) + + return triggered, str(token_usage), str(threshold) + + def _check_cost_threshold(self, rule: AlertRule, condition: dict) -> tuple: + """检查费用阈值""" + window_minutes = self._parse_window(condition.get('window', '1d')) + threshold = condition.get('threshold', 100) # 费用阈值(元) + operator = condition.get('operator', '>') + + since = datetime.now() - timedelta(minutes=window_minutes) + + query = self.db.query( + func.coalesce(func.sum(AICallEvent.cost), 0) + ).filter( + AICallEvent.created_at >= since + ) + + if rule.scope_type == 'tenant' and rule.scope_value: + query = query.filter(AICallEvent.tenant_id == rule.scope_value) + elif rule.scope_type == 'app' and rule.scope_value: + query = query.filter(AICallEvent.app_code == rule.scope_value) + + total_cost = float(query.scalar() or 0) + triggered = self._compare(total_cost, threshold, operator) + + return triggered, f"¥{total_cost:.2f}", f"¥{threshold:.2f}" + + def _check_latency(self, rule: AlertRule, condition: dict) -> tuple: + """检查延迟""" + window_minutes = self._parse_window(condition.get('window', '5m')) + threshold = condition.get('threshold', 5000) # 延迟阈值(ms) + operator = condition.get('operator', '>') + percentile = condition.get('percentile', 'avg') # avg, p95, p99, max + + since = datetime.now() - timedelta(minutes=window_minutes) + + query = self.db.query(AICallEvent.latency_ms).filter( + AICallEvent.created_at >= since, + AICallEvent.latency_ms.isnot(None) + ) + + if rule.scope_type == 'tenant' and rule.scope_value: + query = query.filter(AICallEvent.tenant_id == rule.scope_value) + elif rule.scope_type == 'app' and rule.scope_value: + query = query.filter(AICallEvent.app_code == rule.scope_value) + + latencies = [r.latency_ms for r in query.all()] + + if not latencies: + return False, "0", str(threshold) + + if percentile == 'avg': + metric = sum(latencies) / len(latencies) + elif percentile == 'max': + metric = max(latencies) + elif percentile == 'p95': + latencies.sort() + idx = int(len(latencies) * 0.95) + metric = latencies[idx] if idx < len(latencies) else latencies[-1] + elif percentile == 'p99': + latencies.sort() + idx = int(len(latencies) * 0.99) + metric = latencies[idx] if idx < len(latencies) else latencies[-1] + else: + metric = sum(latencies) / len(latencies) + + triggered = self._compare(metric, threshold, operator) + + return triggered, f"{metric:.0f}ms", f"{threshold}ms" + + def _parse_window(self, window: str) -> int: + """解析时间窗口字符串为分钟数""" + if window.endswith('m'): + return int(window[:-1]) + elif window.endswith('h'): + return int(window[:-1]) * 60 + elif window.endswith('d'): + return int(window[:-1]) * 60 * 24 + else: + return int(window) + + def _compare(self, value: float, threshold: float, operator: str) -> bool: + """比较值与阈值""" + if operator == '>': + return value > threshold + elif operator == '>=': + return value >= threshold + elif operator == '<': + return value < threshold + elif operator == '<=': + return value <= threshold + elif operator == '==': + return value == threshold + elif operator == '!=': + return value != threshold + return False + + def _create_alert_record( + self, + rule: AlertRule, + metric_value: str, + threshold_value: str + ) -> AlertRecord: + """创建告警记录""" + title = f"[{rule.priority.upper()}] {rule.name}" + message = f"规则 '{rule.name}' 触发告警\n当前值: {metric_value}\n阈值: {threshold_value}" + + if rule.scope_type == 'tenant': + message += f"\n租户: {rule.scope_value}" + elif rule.scope_type == 'app': + message += f"\n应用: {rule.scope_value}" + + alert = AlertRecord( + rule_id=rule.id, + rule_name=rule.name, + alert_type=rule.rule_type, + severity=self._priority_to_severity(rule.priority), + title=title, + message=message, + tenant_id=rule.scope_value if rule.scope_type == 'tenant' else None, + app_code=rule.scope_value if rule.scope_type == 'app' else None, + metric_value=metric_value, + threshold_value=threshold_value, + notification_status='pending' + ) + + self.db.add(alert) + self.db.commit() + self.db.refresh(alert) + + # 设置冷却期 + self._set_cooldown(rule) + + logger.info(f"Alert triggered: {title}") + + return alert + + def _priority_to_severity(self, priority: str) -> str: + """将优先级转换为严重程度""" + mapping = { + 'low': 'info', + 'medium': 'warning', + 'high': 'error', + 'critical': 'critical' + } + return mapping.get(priority, 'warning') + + async def send_notification(self, alert: AlertRecord, rule: AlertRule) -> bool: + """发送告警通知 + + Args: + alert: 告警记录 + rule: 告警规则 + + Returns: + 是否发送成功 + """ + if not rule.notification_channels: + alert.notification_status = 'skipped' + self.db.commit() + return True + + results = [] + success = True + + for channel_config in rule.notification_channels: + try: + result = await self._send_to_channel(channel_config, alert) + results.append(result) + if not result.get('success'): + success = False + except Exception as e: + logger.error(f"Failed to send notification: {e}") + results.append({'success': False, 'error': str(e)}) + success = False + + alert.notification_status = 'sent' if success else 'failed' + alert.notification_result = results + alert.notified_at = datetime.now() + self.db.commit() + + return success + + async def _send_to_channel(self, channel_config: dict, alert: AlertRecord) -> dict: + """发送到指定渠道""" + channel_type = channel_config.get('type') + + if channel_type == 'wechat_bot': + return await self._send_wechat_bot(channel_config, alert) + elif channel_type == 'webhook': + return await self._send_webhook(channel_config, alert) + else: + return {'success': False, 'error': f'Unsupported channel type: {channel_type}'} + + async def _send_wechat_bot(self, config: dict, alert: AlertRecord) -> dict: + """发送到企微机器人""" + webhook = config.get('webhook') + if not webhook: + return {'success': False, 'error': 'Missing webhook URL'} + + # 构建消息 + content = f"**{alert.title}**\n\n{alert.message}\n\n时间: {alert.created_at}" + + payload = { + "msgtype": "markdown", + "markdown": { + "content": content + } + } + + try: + async with httpx.AsyncClient(timeout=10) as client: + response = await client.post(webhook, json=payload) + result = response.json() + + if result.get('errcode', 0) == 0: + return {'success': True, 'channel': 'wechat_bot'} + else: + return {'success': False, 'error': result.get('errmsg')} + except Exception as e: + return {'success': False, 'error': str(e)} + + async def _send_webhook(self, config: dict, alert: AlertRecord) -> dict: + """发送到Webhook""" + url = config.get('url') + if not url: + return {'success': False, 'error': 'Missing webhook URL'} + + payload = { + "alert_id": alert.id, + "title": alert.title, + "message": alert.message, + "severity": alert.severity, + "alert_type": alert.alert_type, + "metric_value": alert.metric_value, + "threshold_value": alert.threshold_value, + "created_at": alert.created_at.isoformat() + } + + headers = config.get('headers', {}) + method = config.get('method', 'POST') + + try: + async with httpx.AsyncClient(timeout=10) as client: + if method.upper() == 'POST': + response = await client.post(url, json=payload, headers=headers) + else: + response = await client.get(url, params=payload, headers=headers) + + if response.status_code < 400: + return {'success': True, 'channel': 'webhook', 'status': response.status_code} + else: + return {'success': False, 'error': f'HTTP {response.status_code}'} + except Exception as e: + return {'success': False, 'error': str(e)} + + def acknowledge_alert(self, alert_id: int, acknowledged_by: str) -> Optional[AlertRecord]: + """确认告警""" + alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first() + if not alert: + return None + + alert.status = 'acknowledged' + alert.acknowledged_by = acknowledged_by + alert.acknowledged_at = datetime.now() + self.db.commit() + + return alert + + def resolve_alert(self, alert_id: int) -> Optional[AlertRecord]: + """解决告警""" + alert = self.db.query(AlertRecord).filter(AlertRecord.id == alert_id).first() + if not alert: + return None + + alert.status = 'resolved' + alert.resolved_at = datetime.now() + self.db.commit() + + return alert diff --git a/backend/app/services/cache.py b/backend/app/services/cache.py new file mode 100644 index 0000000..8fd6baa --- /dev/null +++ b/backend/app/services/cache.py @@ -0,0 +1,309 @@ +"""Redis缓存服务""" +import json +import logging +from typing import Optional, Any, Union +from functools import lru_cache + +try: + import redis + from redis import Redis + REDIS_AVAILABLE = True +except ImportError: + REDIS_AVAILABLE = False + Redis = None + +from ..config import get_settings + +logger = logging.getLogger(__name__) + +# 全局Redis连接池 +_redis_pool: Optional[Any] = None +_redis_client: Optional[Any] = None + + +def get_redis_client() -> Optional[Any]: + """获取Redis客户端单例""" + global _redis_pool, _redis_client + + if not REDIS_AVAILABLE: + logger.warning("Redis module not installed, cache disabled") + return None + + if _redis_client is not None: + return _redis_client + + settings = get_settings() + + try: + _redis_pool = redis.ConnectionPool.from_url( + settings.REDIS_URL, + max_connections=20, + decode_responses=True + ) + _redis_client = Redis(connection_pool=_redis_pool) + # 测试连接 + _redis_client.ping() + logger.info(f"Redis connected: {settings.REDIS_URL}") + return _redis_client + except Exception as e: + logger.warning(f"Redis connection failed: {e}, cache disabled") + _redis_client = None + return None + + +class CacheService: + """缓存服务 + + 提供统一的缓存接口,支持Redis和内存回退 + + 使用示例: + cache = CacheService() + + # 设置缓存 + cache.set("user:123", {"name": "test"}, ttl=3600) + + # 获取缓存 + user = cache.get("user:123") + + # 删除缓存 + cache.delete("user:123") + """ + + def __init__(self, prefix: Optional[str] = None): + """初始化缓存服务 + + Args: + prefix: 键前缀,默认使用配置中的REDIS_PREFIX + """ + settings = get_settings() + self.prefix = prefix or settings.REDIS_PREFIX + self._client = get_redis_client() + + # 内存回退缓存(当Redis不可用时使用) + self._memory_cache: dict = {} + + @property + def is_available(self) -> bool: + """Redis是否可用""" + return self._client is not None + + def _make_key(self, key: str) -> str: + """生成完整的缓存键""" + return f"{self.prefix}{key}" + + def get(self, key: str, default: Any = None) -> Any: + """获取缓存值 + + Args: + key: 缓存键 + default: 默认值 + + Returns: + 缓存值或默认值 + """ + full_key = self._make_key(key) + + if self._client: + try: + value = self._client.get(full_key) + if value is None: + return default + # 尝试解析JSON + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + except Exception as e: + logger.error(f"Cache get error: {e}") + return default + else: + # 内存回退 + return self._memory_cache.get(full_key, default) + + def set( + self, + key: str, + value: Any, + ttl: Optional[int] = None, + nx: bool = False + ) -> bool: + """设置缓存值 + + Args: + key: 缓存键 + value: 缓存值 + ttl: 过期时间(秒) + nx: 只在键不存在时设置 + + Returns: + 是否设置成功 + """ + full_key = self._make_key(key) + + # 序列化值 + if isinstance(value, (dict, list)): + serialized = json.dumps(value, ensure_ascii=False) + else: + serialized = str(value) if value is not None else "" + + if self._client: + try: + if nx: + result = self._client.set(full_key, serialized, ex=ttl, nx=True) + else: + result = self._client.set(full_key, serialized, ex=ttl) + return bool(result) + except Exception as e: + logger.error(f"Cache set error: {e}") + return False + else: + # 内存回退(不支持TTL和NX) + if nx and full_key in self._memory_cache: + return False + self._memory_cache[full_key] = value + return True + + def delete(self, key: str) -> bool: + """删除缓存 + + Args: + key: 缓存键 + + Returns: + 是否删除成功 + """ + full_key = self._make_key(key) + + if self._client: + try: + return bool(self._client.delete(full_key)) + except Exception as e: + logger.error(f"Cache delete error: {e}") + return False + else: + return self._memory_cache.pop(full_key, None) is not None + + def exists(self, key: str) -> bool: + """检查键是否存在 + + Args: + key: 缓存键 + + Returns: + 是否存在 + """ + full_key = self._make_key(key) + + if self._client: + try: + return bool(self._client.exists(full_key)) + except Exception as e: + logger.error(f"Cache exists error: {e}") + return False + else: + return full_key in self._memory_cache + + def ttl(self, key: str) -> int: + """获取键的剩余过期时间 + + Args: + key: 缓存键 + + Returns: + 剩余秒数,-1表示永不过期,-2表示键不存在 + """ + full_key = self._make_key(key) + + if self._client: + try: + return self._client.ttl(full_key) + except Exception as e: + logger.error(f"Cache ttl error: {e}") + return -2 + else: + return -1 if full_key in self._memory_cache else -2 + + def incr(self, key: str, amount: int = 1) -> int: + """递增计数器 + + Args: + key: 缓存键 + amount: 递增量 + + Returns: + 递增后的值 + """ + full_key = self._make_key(key) + + if self._client: + try: + return self._client.incrby(full_key, amount) + except Exception as e: + logger.error(f"Cache incr error: {e}") + return 0 + else: + current = self._memory_cache.get(full_key, 0) + new_value = int(current) + amount + self._memory_cache[full_key] = new_value + return new_value + + def expire(self, key: str, ttl: int) -> bool: + """设置键的过期时间 + + Args: + key: 缓存键 + ttl: 过期时间(秒) + + Returns: + 是否设置成功 + """ + full_key = self._make_key(key) + + if self._client: + try: + return bool(self._client.expire(full_key, ttl)) + except Exception as e: + logger.error(f"Cache expire error: {e}") + return False + else: + return full_key in self._memory_cache + + def clear_prefix(self, prefix: str) -> int: + """删除指定前缀的所有键 + + Args: + prefix: 键前缀 + + Returns: + 删除的键数量 + """ + full_prefix = self._make_key(prefix) + + if self._client: + try: + keys = self._client.keys(f"{full_prefix}*") + if keys: + return self._client.delete(*keys) + return 0 + except Exception as e: + logger.error(f"Cache clear_prefix error: {e}") + return 0 + else: + count = 0 + keys_to_delete = [k for k in self._memory_cache if k.startswith(full_prefix)] + for k in keys_to_delete: + del self._memory_cache[k] + count += 1 + return count + + +# 全局缓存实例 +_cache_instance: Optional[CacheService] = None + + +def get_cache() -> CacheService: + """获取全局缓存实例""" + global _cache_instance + if _cache_instance is None: + _cache_instance = CacheService() + return _cache_instance diff --git a/backend/app/services/cost.py b/backend/app/services/cost.py new file mode 100644 index 0000000..0f39fef --- /dev/null +++ b/backend/app/services/cost.py @@ -0,0 +1,420 @@ +"""费用计算服务""" +import logging +from datetime import datetime +from decimal import Decimal +from typing import Optional, Dict, List +from functools import lru_cache + +from sqlalchemy.orm import Session +from sqlalchemy import func + +from ..models.pricing import ModelPricing, TenantBilling +from ..models.stats import AICallEvent +from .cache import get_cache + +logger = logging.getLogger(__name__) + + +class CostCalculator: + """费用计算器 + + 使用示例: + calculator = CostCalculator(db) + + # 计算单次调用费用 + cost = calculator.calculate_cost("gpt-4", input_tokens=100, output_tokens=200) + + # 生成月度账单 + billing = calculator.generate_monthly_billing("qiqi", "2026-01") + """ + + # 默认模型价格(当数据库中无配置时使用) + DEFAULT_PRICING = { + # OpenAI + "gpt-4": {"input": 0.21, "output": 0.42}, # 元/1K tokens + "gpt-4-turbo": {"input": 0.07, "output": 0.21}, + "gpt-4o": {"input": 0.035, "output": 0.105}, + "gpt-4o-mini": {"input": 0.00105, "output": 0.0042}, + "gpt-3.5-turbo": {"input": 0.0035, "output": 0.014}, + + # Anthropic + "claude-3-opus": {"input": 0.105, "output": 0.525}, + "claude-3-sonnet": {"input": 0.021, "output": 0.105}, + "claude-3-haiku": {"input": 0.00175, "output": 0.00875}, + "claude-3.5-sonnet": {"input": 0.021, "output": 0.105}, + + # 国内模型 + "qwen-max": {"input": 0.02, "output": 0.06}, + "qwen-plus": {"input": 0.004, "output": 0.012}, + "qwen-turbo": {"input": 0.002, "output": 0.006}, + "glm-4": {"input": 0.01, "output": 0.01}, + "glm-4-flash": {"input": 0.0001, "output": 0.0001}, + "deepseek-chat": {"input": 0.001, "output": 0.002}, + "deepseek-coder": {"input": 0.001, "output": 0.002}, + + # 默认 + "default": {"input": 0.01, "output": 0.03} + } + + def __init__(self, db: Session): + self.db = db + self._cache = get_cache() + self._pricing_cache: Dict[str, ModelPricing] = {} + + def get_model_pricing(self, model_name: str) -> Optional[ModelPricing]: + """获取模型价格配置 + + Args: + model_name: 模型名称 + + Returns: + ModelPricing实例或None + """ + # 尝试从缓存获取 + cache_key = f"pricing:{model_name}" + cached = self._cache.get(cache_key) + if cached: + return self._dict_to_pricing(cached) + + # 从数据库查询 + pricing = self.db.query(ModelPricing).filter( + ModelPricing.model_name == model_name, + ModelPricing.status == 1 + ).first() + + if pricing: + # 缓存1小时 + self._cache.set(cache_key, self._pricing_to_dict(pricing), ttl=3600) + return pricing + + return None + + def _pricing_to_dict(self, pricing: ModelPricing) -> dict: + return { + "model_name": pricing.model_name, + "input_price_per_1k": str(pricing.input_price_per_1k), + "output_price_per_1k": str(pricing.output_price_per_1k), + "fixed_price_per_call": str(pricing.fixed_price_per_call), + "pricing_type": pricing.pricing_type + } + + def _dict_to_pricing(self, d: dict) -> ModelPricing: + pricing = ModelPricing() + pricing.model_name = d.get("model_name") + pricing.input_price_per_1k = Decimal(d.get("input_price_per_1k", "0")) + pricing.output_price_per_1k = Decimal(d.get("output_price_per_1k", "0")) + pricing.fixed_price_per_call = Decimal(d.get("fixed_price_per_call", "0")) + pricing.pricing_type = d.get("pricing_type", "token") + return pricing + + def calculate_cost( + self, + model_name: str, + input_tokens: int = 0, + output_tokens: int = 0, + call_count: int = 1 + ) -> Decimal: + """计算调用费用 + + Args: + model_name: 模型名称 + input_tokens: 输入token数 + output_tokens: 输出token数 + call_count: 调用次数 + + Returns: + 费用(元) + """ + # 尝试获取数据库配置 + pricing = self.get_model_pricing(model_name) + + if pricing: + if pricing.pricing_type == 'call': + return pricing.fixed_price_per_call * call_count + elif pricing.pricing_type == 'hybrid': + token_cost = ( + pricing.input_price_per_1k * Decimal(input_tokens) / 1000 + + pricing.output_price_per_1k * Decimal(output_tokens) / 1000 + ) + call_cost = pricing.fixed_price_per_call * call_count + return token_cost + call_cost + else: # token + return ( + pricing.input_price_per_1k * Decimal(input_tokens) / 1000 + + pricing.output_price_per_1k * Decimal(output_tokens) / 1000 + ) + + # 使用默认价格 + default_prices = self.DEFAULT_PRICING.get(model_name) or self.DEFAULT_PRICING.get("default") + input_price = Decimal(str(default_prices["input"])) + output_price = Decimal(str(default_prices["output"])) + + return ( + input_price * Decimal(input_tokens) / 1000 + + output_price * Decimal(output_tokens) / 1000 + ) + + def calculate_event_cost(self, event: AICallEvent) -> Decimal: + """计算单个事件的费用 + + Args: + event: AI调用事件 + + Returns: + 费用(元) + """ + return self.calculate_cost( + model_name=event.model or "default", + input_tokens=event.input_tokens or 0, + output_tokens=event.output_tokens or 0 + ) + + def update_event_costs(self, start_date: str = None, end_date: str = None) -> int: + """批量更新事件费用 + + 对于cost为0或NULL的事件,重新计算费用 + + Args: + start_date: 开始日期,格式 YYYY-MM-DD + end_date: 结束日期,格式 YYYY-MM-DD + + Returns: + 更新的记录数 + """ + query = self.db.query(AICallEvent).filter( + (AICallEvent.cost == None) | (AICallEvent.cost == 0) + ) + + if start_date: + query = query.filter(AICallEvent.created_at >= start_date) + if end_date: + query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59") + + events = query.all() + updated = 0 + + for event in events: + try: + cost = self.calculate_event_cost(event) + event.cost = cost + updated += 1 + except Exception as e: + logger.error(f"Failed to calculate cost for event {event.id}: {e}") + + self.db.commit() + logger.info(f"Updated {updated} event costs") + + return updated + + def generate_monthly_billing( + self, + tenant_id: str, + billing_month: str + ) -> TenantBilling: + """生成月度账单 + + Args: + tenant_id: 租户ID + billing_month: 账单月份,格式 YYYY-MM + + Returns: + TenantBilling实例 + """ + # 检查是否已存在 + existing = self.db.query(TenantBilling).filter( + TenantBilling.tenant_id == tenant_id, + TenantBilling.billing_month == billing_month + ).first() + + if existing: + billing = existing + else: + billing = TenantBilling( + tenant_id=tenant_id, + billing_month=billing_month + ) + self.db.add(billing) + + # 计算统计数据 + start_date = f"{billing_month}-01" + year, month = billing_month.split("-") + if int(month) == 12: + end_date = f"{int(year)+1}-01-01" + else: + end_date = f"{year}-{int(month)+1:02d}-01" + + # 聚合查询 + stats = self.db.query( + func.count(AICallEvent.id).label('total_calls'), + func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'), + func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'), + func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost') + ).filter( + AICallEvent.tenant_id == tenant_id, + AICallEvent.created_at >= start_date, + AICallEvent.created_at < end_date + ).first() + + billing.total_calls = stats.total_calls or 0 + billing.total_input_tokens = int(stats.total_input or 0) + billing.total_output_tokens = int(stats.total_output or 0) + billing.total_cost = stats.total_cost or Decimal("0") + + # 按模型统计 + model_stats = self.db.query( + AICallEvent.model, + func.coalesce(func.sum(AICallEvent.cost), 0).label('cost') + ).filter( + AICallEvent.tenant_id == tenant_id, + AICallEvent.created_at >= start_date, + AICallEvent.created_at < end_date + ).group_by(AICallEvent.model).all() + + billing.cost_by_model = { + m.model or "unknown": float(m.cost) for m in model_stats + } + + # 按应用统计 + app_stats = self.db.query( + AICallEvent.app_code, + func.coalesce(func.sum(AICallEvent.cost), 0).label('cost') + ).filter( + AICallEvent.tenant_id == tenant_id, + AICallEvent.created_at >= start_date, + AICallEvent.created_at < end_date + ).group_by(AICallEvent.app_code).all() + + billing.cost_by_app = { + a.app_code or "unknown": float(a.cost) for a in app_stats + } + + self.db.commit() + self.db.refresh(billing) + + return billing + + def get_cost_summary( + self, + tenant_id: str = None, + start_date: str = None, + end_date: str = None + ) -> Dict: + """获取费用汇总 + + Args: + tenant_id: 租户ID(可选) + start_date: 开始日期 + end_date: 结束日期 + + Returns: + 费用汇总字典 + """ + query = self.db.query( + func.count(AICallEvent.id).label('total_calls'), + func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'), + func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'), + func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost') + ) + + if tenant_id: + query = query.filter(AICallEvent.tenant_id == tenant_id) + if start_date: + query = query.filter(AICallEvent.created_at >= start_date) + if end_date: + query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59") + + stats = query.first() + + return { + "total_calls": stats.total_calls or 0, + "total_input_tokens": int(stats.total_input or 0), + "total_output_tokens": int(stats.total_output or 0), + "total_cost": float(stats.total_cost or 0) + } + + def get_cost_by_tenant( + self, + start_date: str = None, + end_date: str = None + ) -> List[Dict]: + """按租户统计费用 + + Returns: + 租户费用列表 + """ + query = self.db.query( + AICallEvent.tenant_id, + func.count(AICallEvent.id).label('calls'), + func.coalesce(func.sum(AICallEvent.cost), 0).label('cost') + ) + + if start_date: + query = query.filter(AICallEvent.created_at >= start_date) + if end_date: + query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59") + + results = query.group_by(AICallEvent.tenant_id).order_by( + func.sum(AICallEvent.cost).desc() + ).all() + + return [ + { + "tenant_id": r.tenant_id, + "calls": r.calls, + "cost": float(r.cost) + } + for r in results + ] + + def get_cost_by_model( + self, + tenant_id: str = None, + start_date: str = None, + end_date: str = None + ) -> List[Dict]: + """按模型统计费用 + + Returns: + 模型费用列表 + """ + query = self.db.query( + AICallEvent.model, + func.count(AICallEvent.id).label('calls'), + func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('input_tokens'), + func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('output_tokens'), + func.coalesce(func.sum(AICallEvent.cost), 0).label('cost') + ) + + if tenant_id: + query = query.filter(AICallEvent.tenant_id == tenant_id) + if start_date: + query = query.filter(AICallEvent.created_at >= start_date) + if end_date: + query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59") + + results = query.group_by(AICallEvent.model).order_by( + func.sum(AICallEvent.cost).desc() + ).all() + + return [ + { + "model": r.model or "unknown", + "calls": r.calls, + "input_tokens": int(r.input_tokens), + "output_tokens": int(r.output_tokens), + "cost": float(r.cost) + } + for r in results + ] + + +# 便捷函数 +def calculate_cost( + db: Session, + model_name: str, + input_tokens: int = 0, + output_tokens: int = 0 +) -> Decimal: + """快速计算费用""" + calculator = CostCalculator(db) + return calculator.calculate_cost(model_name, input_tokens, output_tokens) diff --git a/backend/app/services/quota.py b/backend/app/services/quota.py new file mode 100644 index 0000000..d4ad5e2 --- /dev/null +++ b/backend/app/services/quota.py @@ -0,0 +1,346 @@ +"""配额管理服务""" +import logging +from datetime import datetime, date, timedelta +from typing import Optional, Dict, Any, Tuple +from dataclasses import dataclass + +from sqlalchemy.orm import Session +from sqlalchemy import func + +from ..models.tenant import Tenant, Subscription +from ..models.stats import AICallEvent +from .cache import get_cache + +logger = logging.getLogger(__name__) + + +@dataclass +class QuotaConfig: + """配额配置""" + daily_calls: int = 0 # 每日调用限制,0表示无限制 + daily_tokens: int = 0 # 每日Token限制 + monthly_calls: int = 0 # 每月调用限制 + monthly_tokens: int = 0 # 每月Token限制 + monthly_cost: float = 0 # 每月费用限制(元) + concurrent_calls: int = 0 # 并发调用限制 + + +@dataclass +class QuotaUsage: + """配额使用情况""" + daily_calls: int = 0 + daily_tokens: int = 0 + monthly_calls: int = 0 + monthly_tokens: int = 0 + monthly_cost: float = 0 + + +@dataclass +class QuotaCheckResult: + """配额检查结果""" + allowed: bool + reason: Optional[str] = None + quota_type: Optional[str] = None + limit: int = 0 + used: int = 0 + remaining: int = 0 + + +class QuotaService: + """配额管理服务 + + 使用示例: + quota_service = QuotaService(db) + + # 检查配额 + result = quota_service.check_quota("qiqi", "tools") + if not result.allowed: + raise HTTPException(status_code=429, detail=result.reason) + + # 获取使用情况 + usage = quota_service.get_usage("qiqi", "tools") + """ + + # 默认配额(当无订阅配置时使用) + DEFAULT_QUOTA = QuotaConfig( + daily_calls=1000, + daily_tokens=100000, + monthly_calls=30000, + monthly_tokens=3000000, + monthly_cost=100 + ) + + def __init__(self, db: Session): + self.db = db + self._cache = get_cache() + + def get_subscription(self, tenant_id: str, app_code: str) -> Optional[Subscription]: + """获取租户订阅配置""" + return self.db.query(Subscription).filter( + Subscription.tenant_id == tenant_id, + Subscription.app_code == app_code, + Subscription.status == 'active' + ).first() + + def get_quota_config(self, tenant_id: str, app_code: str) -> QuotaConfig: + """获取配额配置 + + Args: + tenant_id: 租户ID + app_code: 应用代码 + + Returns: + QuotaConfig实例 + """ + # 尝试从缓存获取 + cache_key = f"quota:config:{tenant_id}:{app_code}" + cached = self._cache.get(cache_key) + if cached: + return QuotaConfig(**cached) + + # 从订阅表获取 + subscription = self.get_subscription(tenant_id, app_code) + + if subscription and subscription.quota: + quota = subscription.quota + config = QuotaConfig( + daily_calls=quota.get('daily_calls', 0), + daily_tokens=quota.get('daily_tokens', 0), + monthly_calls=quota.get('monthly_calls', 0), + monthly_tokens=quota.get('monthly_tokens', 0), + monthly_cost=quota.get('monthly_cost', 0), + concurrent_calls=quota.get('concurrent_calls', 0) + ) + else: + config = self.DEFAULT_QUOTA + + # 缓存5分钟 + self._cache.set(cache_key, config.__dict__, ttl=300) + + return config + + def get_usage(self, tenant_id: str, app_code: str) -> QuotaUsage: + """获取配额使用情况 + + Args: + tenant_id: 租户ID + app_code: 应用代码 + + Returns: + QuotaUsage实例 + """ + today = date.today() + month_start = today.replace(day=1) + + # 今日使用量 + daily_stats = self.db.query( + func.count(AICallEvent.id).label('calls'), + func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens') + ).filter( + AICallEvent.tenant_id == tenant_id, + AICallEvent.app_code == app_code, + func.date(AICallEvent.created_at) == today + ).first() + + # 本月使用量 + monthly_stats = self.db.query( + func.count(AICallEvent.id).label('calls'), + func.coalesce(func.sum(AICallEvent.input_tokens + AICallEvent.output_tokens), 0).label('tokens'), + func.coalesce(func.sum(AICallEvent.cost), 0).label('cost') + ).filter( + AICallEvent.tenant_id == tenant_id, + AICallEvent.app_code == app_code, + func.date(AICallEvent.created_at) >= month_start + ).first() + + return QuotaUsage( + daily_calls=daily_stats.calls or 0, + daily_tokens=int(daily_stats.tokens or 0), + monthly_calls=monthly_stats.calls or 0, + monthly_tokens=int(monthly_stats.tokens or 0), + monthly_cost=float(monthly_stats.cost or 0) + ) + + def check_quota( + self, + tenant_id: str, + app_code: str, + estimated_tokens: int = 0 + ) -> QuotaCheckResult: + """检查配额是否足够 + + Args: + tenant_id: 租户ID + app_code: 应用代码 + estimated_tokens: 预估Token消耗 + + Returns: + QuotaCheckResult实例 + """ + config = self.get_quota_config(tenant_id, app_code) + usage = self.get_usage(tenant_id, app_code) + + # 检查日调用次数 + if config.daily_calls > 0: + if usage.daily_calls >= config.daily_calls: + return QuotaCheckResult( + allowed=False, + reason=f"已达到每日调用限制 ({config.daily_calls} 次)", + quota_type="daily_calls", + limit=config.daily_calls, + used=usage.daily_calls, + remaining=0 + ) + + # 检查日Token限制 + if config.daily_tokens > 0: + if usage.daily_tokens + estimated_tokens > config.daily_tokens: + return QuotaCheckResult( + allowed=False, + reason=f"已达到每日Token限制 ({config.daily_tokens:,})", + quota_type="daily_tokens", + limit=config.daily_tokens, + used=usage.daily_tokens, + remaining=max(0, config.daily_tokens - usage.daily_tokens) + ) + + # 检查月调用次数 + if config.monthly_calls > 0: + if usage.monthly_calls >= config.monthly_calls: + return QuotaCheckResult( + allowed=False, + reason=f"已达到每月调用限制 ({config.monthly_calls} 次)", + quota_type="monthly_calls", + limit=config.monthly_calls, + used=usage.monthly_calls, + remaining=0 + ) + + # 检查月Token限制 + if config.monthly_tokens > 0: + if usage.monthly_tokens + estimated_tokens > config.monthly_tokens: + return QuotaCheckResult( + allowed=False, + reason=f"已达到每月Token限制 ({config.monthly_tokens:,})", + quota_type="monthly_tokens", + limit=config.monthly_tokens, + used=usage.monthly_tokens, + remaining=max(0, config.monthly_tokens - usage.monthly_tokens) + ) + + # 检查月费用限制 + if config.monthly_cost > 0: + if usage.monthly_cost >= config.monthly_cost: + return QuotaCheckResult( + allowed=False, + reason=f"已达到每月费用限制 (¥{config.monthly_cost:.2f})", + quota_type="monthly_cost", + limit=int(config.monthly_cost * 100), # 转为分 + used=int(usage.monthly_cost * 100), + remaining=max(0, int((config.monthly_cost - usage.monthly_cost) * 100)) + ) + + # 所有检查通过 + return QuotaCheckResult( + allowed=True, + quota_type="daily_calls", + limit=config.daily_calls, + used=usage.daily_calls, + remaining=max(0, config.daily_calls - usage.daily_calls) if config.daily_calls > 0 else -1 + ) + + def get_quota_summary(self, tenant_id: str, app_code: str) -> Dict[str, Any]: + """获取配额汇总信息 + + Returns: + 包含配额配置和使用情况的字典 + """ + config = self.get_quota_config(tenant_id, app_code) + usage = self.get_usage(tenant_id, app_code) + + def calc_percentage(used: int, limit: int) -> float: + if limit <= 0: + return 0 + return min(100, round(used / limit * 100, 1)) + + return { + "config": { + "daily_calls": config.daily_calls, + "daily_tokens": config.daily_tokens, + "monthly_calls": config.monthly_calls, + "monthly_tokens": config.monthly_tokens, + "monthly_cost": config.monthly_cost + }, + "usage": { + "daily_calls": usage.daily_calls, + "daily_tokens": usage.daily_tokens, + "monthly_calls": usage.monthly_calls, + "monthly_tokens": usage.monthly_tokens, + "monthly_cost": round(usage.monthly_cost, 2) + }, + "percentage": { + "daily_calls": calc_percentage(usage.daily_calls, config.daily_calls), + "daily_tokens": calc_percentage(usage.daily_tokens, config.daily_tokens), + "monthly_calls": calc_percentage(usage.monthly_calls, config.monthly_calls), + "monthly_tokens": calc_percentage(usage.monthly_tokens, config.monthly_tokens), + "monthly_cost": calc_percentage(int(usage.monthly_cost * 100), int(config.monthly_cost * 100)) + } + } + + def update_quota( + self, + tenant_id: str, + app_code: str, + quota_config: Dict[str, Any] + ) -> Subscription: + """更新配额配置 + + Args: + tenant_id: 租户ID + app_code: 应用代码 + quota_config: 配额配置字典 + + Returns: + 更新后的Subscription实例 + """ + subscription = self.get_subscription(tenant_id, app_code) + + if not subscription: + # 创建新订阅 + subscription = Subscription( + tenant_id=tenant_id, + app_code=app_code, + start_date=date.today(), + quota=quota_config, + status='active' + ) + self.db.add(subscription) + else: + # 更新现有订阅 + subscription.quota = quota_config + + self.db.commit() + self.db.refresh(subscription) + + # 清除缓存 + cache_key = f"quota:config:{tenant_id}:{app_code}" + self._cache.delete(cache_key) + + return subscription + + +def check_quota_middleware( + db: Session, + tenant_id: str, + app_code: str, + estimated_tokens: int = 0 +) -> QuotaCheckResult: + """配额检查中间件函数 + + 可在路由中使用: + result = check_quota_middleware(db, "qiqi", "tools") + if not result.allowed: + raise HTTPException(status_code=429, detail=result.reason) + """ + service = QuotaService(db) + return service.check_quota(tenant_id, app_code, estimated_tokens) diff --git a/backend/app/services/wechat.py b/backend/app/services/wechat.py new file mode 100644 index 0000000..1b59d78 --- /dev/null +++ b/backend/app/services/wechat.py @@ -0,0 +1,371 @@ +"""企业微信服务""" +import hashlib +import time +import logging +from typing import Optional, Dict, Any +from dataclasses import dataclass + +import httpx + +from ..config import get_settings +from .cache import get_cache +from .crypto import decrypt_config + +logger = logging.getLogger(__name__) +settings = get_settings() + + +@dataclass +class WechatConfig: + """企业微信应用配置""" + corp_id: str + agent_id: str + secret: str + + +class WechatService: + """企业微信服务 + + 提供access_token获取、JS-SDK签名、OAuth2等功能 + + 使用示例: + wechat = WechatService(corp_id="wwxxxx", agent_id="1000001", secret="xxx") + + # 获取access_token + token = await wechat.get_access_token() + + # 获取JS-SDK签名 + signature = await wechat.get_jssdk_signature("https://example.com/page") + """ + + # 企业微信API基础URL + BASE_URL = "https://qyapi.weixin.qq.com" + + def __init__(self, corp_id: str, agent_id: str, secret: str): + """初始化企业微信服务 + + Args: + corp_id: 企业ID + agent_id: 应用AgentId + secret: 应用Secret(明文) + """ + self.corp_id = corp_id + self.agent_id = agent_id + self.secret = secret + self._cache = get_cache() + + @classmethod + def from_wechat_app(cls, wechat_app) -> "WechatService": + """从TenantWechatApp模型创建服务实例 + + Args: + wechat_app: TenantWechatApp数据库模型 + + Returns: + WechatService实例 + """ + secret = "" + if wechat_app.secret_encrypted: + try: + secret = decrypt_config(wechat_app.secret_encrypted) + except Exception as e: + logger.error(f"Failed to decrypt secret: {e}") + + return cls( + corp_id=wechat_app.corp_id, + agent_id=wechat_app.agent_id, + secret=secret + ) + + def _cache_key(self, key_type: str) -> str: + """生成缓存键""" + return f"wechat:{self.corp_id}:{self.agent_id}:{key_type}" + + async def get_access_token(self, force_refresh: bool = False) -> Optional[str]: + """获取access_token + + 企业微信access_token有效期7200秒,需要缓存 + + Args: + force_refresh: 是否强制刷新 + + Returns: + access_token或None + """ + cache_key = self._cache_key("access_token") + + # 尝试从缓存获取 + if not force_refresh: + cached = self._cache.get(cache_key) + if cached: + logger.debug(f"Access token from cache: {cached[:20]}...") + return cached + + # 从企业微信API获取 + url = f"{self.BASE_URL}/cgi-bin/gettoken" + params = { + "corpid": self.corp_id, + "corpsecret": self.secret + } + + try: + async with httpx.AsyncClient(timeout=10) as client: + response = await client.get(url, params=params) + result = response.json() + + if result.get("errcode", 0) != 0: + logger.error(f"Get access_token failed: {result}") + return None + + access_token = result.get("access_token") + expires_in = result.get("expires_in", 7200) + + # 缓存,提前200秒过期以确保安全 + self._cache.set( + cache_key, + access_token, + ttl=min(expires_in - 200, settings.WECHAT_ACCESS_TOKEN_EXPIRE) + ) + + logger.info(f"Got new access_token for {self.corp_id}") + return access_token + except Exception as e: + logger.error(f"Get access_token error: {e}") + return None + + async def get_jsapi_ticket(self, force_refresh: bool = False) -> Optional[str]: + """获取jsapi_ticket + + 用于生成JS-SDK签名 + + Args: + force_refresh: 是否强制刷新 + + Returns: + jsapi_ticket或None + """ + cache_key = self._cache_key("jsapi_ticket") + + # 尝试从缓存获取 + if not force_refresh: + cached = self._cache.get(cache_key) + if cached: + logger.debug(f"JSAPI ticket from cache: {cached[:20]}...") + return cached + + # 先获取access_token + access_token = await self.get_access_token() + if not access_token: + return None + + # 获取jsapi_ticket + url = f"{self.BASE_URL}/cgi-bin/get_jsapi_ticket" + params = {"access_token": access_token} + + try: + async with httpx.AsyncClient(timeout=10) as client: + response = await client.get(url, params=params) + result = response.json() + + if result.get("errcode", 0) != 0: + logger.error(f"Get jsapi_ticket failed: {result}") + return None + + ticket = result.get("ticket") + expires_in = result.get("expires_in", 7200) + + # 缓存 + self._cache.set( + cache_key, + ticket, + ttl=min(expires_in - 200, settings.WECHAT_JSAPI_TICKET_EXPIRE) + ) + + logger.info(f"Got new jsapi_ticket for {self.corp_id}") + return ticket + except Exception as e: + logger.error(f"Get jsapi_ticket error: {e}") + return None + + async def get_jssdk_signature( + self, + url: str, + noncestr: Optional[str] = None, + timestamp: Optional[int] = None + ) -> Optional[Dict[str, Any]]: + """生成JS-SDK签名 + + Args: + url: 当前页面URL(不含#及其后面部分) + noncestr: 随机字符串,可选 + timestamp: 时间戳,可选 + + Returns: + 签名信息字典,包含signature, noncestr, timestamp, appId等 + """ + ticket = await self.get_jsapi_ticket() + if not ticket: + return None + + # 生成随机字符串和时间戳 + if noncestr is None: + import secrets + noncestr = secrets.token_hex(8) + if timestamp is None: + timestamp = int(time.time()) + + # 构建签名字符串 + sign_str = f"jsapi_ticket={ticket}&noncestr={noncestr}×tamp={timestamp}&url={url}" + + # SHA1签名 + signature = hashlib.sha1(sign_str.encode()).hexdigest() + + return { + "appId": self.corp_id, + "agentId": self.agent_id, + "timestamp": timestamp, + "nonceStr": noncestr, + "signature": signature, + "url": url + } + + def get_oauth2_url( + self, + redirect_uri: str, + scope: str = "snsapi_base", + state: str = "" + ) -> str: + """生成OAuth2授权URL + + Args: + redirect_uri: 授权后重定向的URL + scope: 应用授权作用域 + - snsapi_base: 静默授权,只能获取成员基础信息 + - snsapi_privateinfo: 手动授权,可获取成员详细信息 + state: 重定向后会带上state参数 + + Returns: + OAuth2授权URL + """ + import urllib.parse + + encoded_uri = urllib.parse.quote(redirect_uri, safe='') + + url = ( + f"https://open.weixin.qq.com/connect/oauth2/authorize" + f"?appid={self.corp_id}" + f"&redirect_uri={encoded_uri}" + f"&response_type=code" + f"&scope={scope}" + f"&state={state}" + f"&agentid={self.agent_id}" + f"#wechat_redirect" + ) + + return url + + async def get_user_info_by_code(self, code: str) -> Optional[Dict[str, Any]]: + """通过OAuth2 code获取用户信息 + + Args: + code: OAuth2回调返回的code + + Returns: + 用户信息字典,包含UserId, DeviceId等 + """ + access_token = await self.get_access_token() + if not access_token: + return None + + url = f"{self.BASE_URL}/cgi-bin/auth/getuserinfo" + params = { + "access_token": access_token, + "code": code + } + + try: + async with httpx.AsyncClient(timeout=10) as client: + response = await client.get(url, params=params) + result = response.json() + + if result.get("errcode", 0) != 0: + logger.error(f"Get user info by code failed: {result}") + return None + + return { + "user_id": result.get("userid") or result.get("UserId"), + "device_id": result.get("deviceid") or result.get("DeviceId"), + "open_id": result.get("openid") or result.get("OpenId"), + "external_userid": result.get("external_userid"), + } + except Exception as e: + logger.error(f"Get user info by code error: {e}") + return None + + async def get_user_detail(self, user_id: str) -> Optional[Dict[str, Any]]: + """获取成员详细信息 + + Args: + user_id: 成员UserID + + Returns: + 成员详细信息 + """ + access_token = await self.get_access_token() + if not access_token: + return None + + url = f"{self.BASE_URL}/cgi-bin/user/get" + params = { + "access_token": access_token, + "userid": user_id + } + + try: + async with httpx.AsyncClient(timeout=10) as client: + response = await client.get(url, params=params) + result = response.json() + + if result.get("errcode", 0) != 0: + logger.error(f"Get user detail failed: {result}") + return None + + return { + "userid": result.get("userid"), + "name": result.get("name"), + "department": result.get("department"), + "position": result.get("position"), + "mobile": result.get("mobile"), + "email": result.get("email"), + "avatar": result.get("avatar"), + "status": result.get("status"), + } + except Exception as e: + logger.error(f"Get user detail error: {e}") + return None + + +async def get_wechat_service_by_id( + wechat_app_id: int, + db_session +) -> Optional[WechatService]: + """根据企微应用ID获取服务实例 + + Args: + wechat_app_id: platform_tenant_wechat_apps表的ID + db_session: 数据库session + + Returns: + WechatService实例或None + """ + from ..models.tenant_wechat_app import TenantWechatApp + + wechat_app = db_session.query(TenantWechatApp).filter( + TenantWechatApp.id == wechat_app_id, + TenantWechatApp.status == 1 + ).first() + + if not wechat_app: + return None + + return WechatService.from_wechat_app(wechat_app) diff --git a/backend/env.template b/backend/env.template new file mode 100644 index 0000000..5df0c88 --- /dev/null +++ b/backend/env.template @@ -0,0 +1,26 @@ +# 000-platform 环境配置模板 +# 复制此文件为 .env 并填写实际值 + +# ==================== 应用配置 ==================== +APP_NAME=platform +APP_VERSION=1.0.0 +DEBUG=false + +# ==================== 数据库配置 ==================== +DB_HOST=localhost +DB_PORT=3306 +DB_USER=root +DB_PASSWORD=your-password +DB_NAME=new_qiqi + +# ==================== JWT 配置 ==================== +JWT_SECRET_KEY=your-secret-key-change-in-production +JWT_ALGORITHM=HS256 +JWT_EXPIRE_MINUTES=1440 + +# ==================== 安全配置 ==================== +# 用于加密敏感数据(如企微 Secret) +ENCRYPTION_KEY=your-encryption-key-32-bytes + +# ==================== 可选:Redis 缓存 ==================== +# REDIS_URL=redis://localhost:6379/0 diff --git a/backend/requirements.txt b/backend/requirements.txt index c68d32b..ac816e7 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -10,3 +10,5 @@ passlib[bcrypt]>=1.7.4 bcrypt>=4.0.0 python-multipart>=0.0.6 httpx>=0.26.0 +redis>=5.0.0 +openpyxl>=3.1.0 diff --git a/deploy/nginx/frontend.conf b/deploy/nginx/frontend.conf index fcc26d1..5a12504 100644 --- a/deploy/nginx/frontend.conf +++ b/deploy/nginx/frontend.conf @@ -5,18 +5,19 @@ server { root /usr/share/nginx/html; index index.html; + # Docker 内部 DNS 解析器 + resolver 127.0.0.11 valid=30s; + # Vue Router history mode location / { try_files $uri $uri/ /index.html; } # API 代理到后端 - # 使用宿主机网关地址(Docker 默认网桥) - # 如果 172.18.0.1 不可用,可能需要调整为实际的 Docker 网桥地址 - set $backend_host 172.18.0.1; - + # 使用 Docker 容器名称,通过 Docker DNS 解析 location /api/ { - proxy_pass http://$backend_host:8001/api/; + set $backend platform-backend-test:8000; + proxy_pass http://$backend/api/; proxy_http_version 1.1; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; diff --git a/frontend/src/api/index.js b/frontend/src/api/index.js index 1b14bad..075fbef 100644 --- a/frontend/src/api/index.js +++ b/frontend/src/api/index.js @@ -7,6 +7,53 @@ const api = axios.create({ timeout: 30000 }) +/** + * 解析 API 错误响应 + */ +function parseApiError(error) { + const result = { + code: 'UNKNOWN_ERROR', + message: '发生了未知错误', + traceId: '', + status: 500 + } + + if (!error.response) { + result.code = 'NETWORK_ERROR' + result.message = '网络连接失败,请检查网络后重试' + return result + } + + const { status, data, headers } = error.response + result.status = status + result.traceId = headers['x-trace-id'] || headers['X-Trace-ID'] || '' + + if (data && data.error) { + result.code = data.error.code || result.code + result.message = data.error.message || result.message + result.traceId = data.error.trace_id || result.traceId + } else if (data && data.detail) { + result.message = typeof data.detail === 'string' ? data.detail : JSON.stringify(data.detail) + } + + return result +} + +/** + * 跳转到错误页面 + */ +function navigateToErrorPage(errorInfo) { + router.push({ + name: 'Error', + query: { + code: errorInfo.code, + message: errorInfo.message, + trace_id: errorInfo.traceId, + status: String(errorInfo.status) + } + }) +} + // 请求拦截器 api.interceptors.request.use( config => { @@ -19,10 +66,15 @@ api.interceptors.request.use( error => Promise.reject(error) ) -// 响应拦截器 +// 响应拦截器(集成 TraceID 追踪) api.interceptors.response.use( response => response, error => { + const errorInfo = parseApiError(error) + const traceLog = errorInfo.traceId ? ` (trace: ${errorInfo.traceId})` : '' + + console.error(`[API Error] ${errorInfo.code}: ${errorInfo.message}${traceLog}`) + if (error.response?.status === 401) { localStorage.removeItem('token') localStorage.removeItem('user') @@ -30,9 +82,14 @@ api.interceptors.response.use( ElMessage.error('登录已过期,请重新登录') } else if (error.response?.status === 403) { ElMessage.error('没有权限执行此操作') + } else if (['INTERNAL_ERROR', 'SERVICE_UNAVAILABLE', 'GATEWAY_ERROR'].includes(errorInfo.code)) { + // 严重错误跳转到错误页面 + navigateToErrorPage(errorInfo) } else { - ElMessage.error(error.response?.data?.detail || error.message || '请求失败') + // 普通错误显示消息 + ElMessage.error(errorInfo.message) } + return Promise.reject(error) } ) diff --git a/frontend/src/router/index.js b/frontend/src/router/index.js index fad7e4a..bd68d40 100644 --- a/frontend/src/router/index.js +++ b/frontend/src/router/index.js @@ -8,6 +8,12 @@ const routes = [ component: () => import('@/views/login/index.vue'), meta: { title: '登录', public: true } }, + { + path: '/error', + name: 'Error', + component: () => import('@/views/error/index.vue'), + meta: { title: '出错了', public: true } + }, { path: '/', component: () => import('@/components/Layout.vue'), diff --git a/frontend/src/views/dashboard/index.vue b/frontend/src/views/dashboard/index.vue index 6e49d46..fdf44e3 100644 --- a/frontend/src/views/dashboard/index.vue +++ b/frontend/src/views/dashboard/index.vue @@ -7,11 +7,15 @@ const stats = ref({ totalTenants: 0, activeTenants: 0, todayCalls: 0, - todayTokens: 0 + todayTokens: 0, + weekCalls: 0, + weekTokens: 0 }) const recentLogs = ref([]) +const trendData = ref([]) const chartRef = ref(null) +const chartLoading = ref(false) let chartInstance = null async function fetchStats() { @@ -25,6 +29,8 @@ async function fetchStats() { if (statsRes.data) { stats.value.todayCalls = statsRes.data.today_calls || 0 stats.value.todayTokens = statsRes.data.today_tokens || 0 + stats.value.weekCalls = statsRes.data.week_calls || 0 + stats.value.weekTokens = statsRes.data.week_tokens || 0 } } catch (e) { console.error('获取统计失败:', e) @@ -40,10 +46,38 @@ async function fetchRecentLogs() { } } +async function fetchTrendData() { + chartLoading.value = true + try { + const res = await api.get('/api/stats/trend', { params: { days: 7 } }) + trendData.value = res.data.trend || [] + updateChart() + } catch (e) { + console.error('获取趋势数据失败:', e) + // 如果API失败,使用空数据 + trendData.value = [] + updateChart() + } finally { + chartLoading.value = false + } +} + function initChart() { if (!chartRef.value) return - chartInstance = echarts.init(chartRef.value) +} + +function updateChart() { + if (!chartInstance) return + + // 从API数据提取日期和调用次数 + const dates = trendData.value.map(item => { + // 格式化日期为 MM-DD + const date = new Date(item.date) + return `${(date.getMonth() + 1).toString().padStart(2, '0')}-${date.getDate().toString().padStart(2, '0')}` + }) + const calls = trendData.value.map(item => item.calls || 0) + const tokens = trendData.value.map(item => item.tokens || 0) const option = { title: { @@ -51,22 +85,44 @@ function initChart() { textStyle: { fontSize: 14, fontWeight: 500 } }, tooltip: { - trigger: 'axis' + trigger: 'axis', + formatter: function(params) { + let result = params[0].axisValue + '
' + params.forEach(param => { + result += `${param.marker} ${param.seriesName}: ${param.value.toLocaleString()}
` + }) + return result + } + }, + legend: { + data: ['调用次数', 'Token 消耗'], + top: 0, + right: 0 }, grid: { left: '3%', right: '4%', bottom: '3%', + top: 50, containLabel: true }, xAxis: { type: 'category', boundaryGap: false, - data: ['周一', '周二', '周三', '周四', '周五', '周六', '周日'] - }, - yAxis: { - type: 'value' + data: dates.length > 0 ? dates : ['暂无数据'] }, + yAxis: [ + { + type: 'value', + name: '调用次数', + position: 'left' + }, + { + type: 'value', + name: 'Token', + position: 'right' + } + ], series: [ { name: '调用次数', @@ -80,7 +136,16 @@ function initChart() { }, lineStyle: { color: '#409eff' }, itemStyle: { color: '#409eff' }, - data: [120, 132, 101, 134, 90, 230, 210] + data: calls.length > 0 ? calls : [0] + }, + { + name: 'Token 消耗', + type: 'line', + yAxisIndex: 1, + smooth: true, + lineStyle: { color: '#67c23a' }, + itemStyle: { color: '#67c23a' }, + data: tokens.length > 0 ? tokens : [0] } ] } @@ -96,6 +161,7 @@ onMounted(() => { fetchStats() fetchRecentLogs() initChart() + fetchTrendData() window.addEventListener('resize', handleResize) }) @@ -113,22 +179,22 @@ onUnmounted(() => {
租户总数
{{ stats.totalTenants }}
-
-
活跃租户
-
{{ stats.activeTenants || '-' }}
-
今日 AI 调用
-
{{ stats.todayCalls }}
+
{{ stats.todayCalls.toLocaleString() }}
今日 Token 消耗
{{ stats.todayTokens.toLocaleString() }}
+
+
本周 AI 调用
+
{{ stats.weekCalls.toLocaleString() }}
+
-
+
diff --git a/frontend/src/views/error/index.vue b/frontend/src/views/error/index.vue new file mode 100644 index 0000000..9db86d8 --- /dev/null +++ b/frontend/src/views/error/index.vue @@ -0,0 +1,179 @@ + + + + + diff --git a/sdk/stats_client.py b/sdk/stats_client.py index 1812650..a5695ea 100644 --- a/sdk/stats_client.py +++ b/sdk/stats_client.py @@ -1,12 +1,25 @@ """AI统计上报客户端""" import os +import json +import asyncio +import logging +import threading from datetime import datetime from decimal import Decimal from typing import Optional, List from dataclasses import dataclass, asdict +from pathlib import Path + +try: + import httpx + HTTPX_AVAILABLE = True +except ImportError: + HTTPX_AVAILABLE = False from .trace import get_trace_id, get_tenant_id, get_user_id +logger = logging.getLogger(__name__) + @dataclass class AICallEvent: @@ -32,6 +45,24 @@ class AICallEvent: self.trace_id = get_trace_id() if self.user_id is None: self.user_id = get_user_id() + + def to_dict(self) -> dict: + """转换为可序列化的字典""" + return { + "tenant_id": self.tenant_id, + "app_code": self.app_code, + "module_code": self.module_code, + "prompt_name": self.prompt_name, + "model": self.model, + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "cost": str(self.cost), + "latency_ms": self.latency_ms, + "status": self.status, + "user_id": self.user_id, + "trace_id": self.trace_id, + "event_time": self.event_time.isoformat() if self.event_time else None + } class StatsClient: @@ -51,23 +82,37 @@ class StatsClient: ) """ + # 失败事件持久化文件 + FAILED_EVENTS_FILE = ".platform_failed_events.json" + def __init__( self, tenant_id: int, app_code: str, platform_url: Optional[str] = None, api_key: Optional[str] = None, - local_only: bool = True + local_only: bool = False, + max_retries: int = 3, + retry_delay: float = 1.0, + timeout: float = 10.0 ): self.tenant_id = tenant_id self.app_code = app_code self.platform_url = platform_url or os.getenv("PLATFORM_URL", "") self.api_key = api_key or os.getenv("PLATFORM_API_KEY", "") - self.local_only = local_only or not self.platform_url + self.local_only = local_only or not self.platform_url or not HTTPX_AVAILABLE + self.max_retries = max_retries + self.retry_delay = retry_delay + self.timeout = timeout # 批量上报缓冲区 self._buffer: List[AICallEvent] = [] self._buffer_size = 10 # 达到此数量时自动上报 + self._lock = threading.Lock() + + # 在启动时尝试发送之前失败的事件 + if not self.local_only: + self._retry_failed_events() def report_ai_call( self, @@ -113,36 +158,172 @@ class StatsClient: user_id=user_id ) - self._buffer.append(event) + with self._lock: + self._buffer.append(event) + should_flush = flush or len(self._buffer) >= self._buffer_size - if flush or len(self._buffer) >= self._buffer_size: + if should_flush: self.flush() return event def flush(self): """发送缓冲区中的所有事件""" - if not self._buffer: - return - - events = self._buffer.copy() - self._buffer.clear() + with self._lock: + if not self._buffer: + return + events = self._buffer.copy() + self._buffer.clear() if self.local_only: # 本地模式:仅打印 for event in events: - print(f"[STATS] {event.app_code}/{event.module_code}: " - f"{event.prompt_name} - {event.input_tokens}+{event.output_tokens} tokens") + logger.info(f"[STATS] {event.app_code}/{event.module_code}: " + f"{event.prompt_name} - {event.input_tokens}+{event.output_tokens} tokens") else: # 远程上报 self._send_to_platform(events) def _send_to_platform(self, events: List[AICallEvent]): - """发送事件到平台(异步,后续实现)""" - # TODO: 使用httpx异步发送 - pass + """发送事件到平台""" + if not HTTPX_AVAILABLE: + logger.warning("httpx not installed, falling back to local mode") + return + + # 转换事件为可序列化格式 + payload = {"events": [e.to_dict() for e in events]} + + # 尝试在事件循环中运行 + try: + loop = asyncio.get_running_loop() + # 已在异步上下文中,创建任务 + asyncio.create_task(self._send_async(payload, events)) + except RuntimeError: + # 没有运行中的事件循环,使用同步方式 + self._send_sync(payload, events) + + def _send_sync(self, payload: dict, events: List[AICallEvent]): + """同步发送事件""" + url = f"{self.platform_url.rstrip('/')}/api/stats/report/batch" + headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"} + + for attempt in range(self.max_retries): + try: + with httpx.Client(timeout=self.timeout) as client: + response = client.post(url, json=payload, headers=headers) + + if response.status_code == 200: + result = response.json() + logger.debug(f"Stats reported successfully: {result.get('count', len(events))} events") + return + else: + logger.warning(f"Stats report failed with status {response.status_code}: {response.text}") + except httpx.TimeoutException: + logger.warning(f"Stats report timeout (attempt {attempt + 1}/{self.max_retries})") + except httpx.RequestError as e: + logger.warning(f"Stats report request error (attempt {attempt + 1}/{self.max_retries}): {e}") + except Exception as e: + logger.error(f"Stats report unexpected error: {e}") + break + + # 重试延迟 + if attempt < self.max_retries - 1: + import time + time.sleep(self.retry_delay * (attempt + 1)) + + # 所有重试都失败,持久化到文件 + self._persist_failed_events(events) + + async def _send_async(self, payload: dict, events: List[AICallEvent]): + """异步发送事件""" + url = f"{self.platform_url.rstrip('/')}/api/stats/report/batch" + headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"} + + for attempt in range(self.max_retries): + try: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post(url, json=payload, headers=headers) + + if response.status_code == 200: + result = response.json() + logger.debug(f"Stats reported successfully: {result.get('count', len(events))} events") + return + else: + logger.warning(f"Stats report failed with status {response.status_code}: {response.text}") + except httpx.TimeoutException: + logger.warning(f"Stats report timeout (attempt {attempt + 1}/{self.max_retries})") + except httpx.RequestError as e: + logger.warning(f"Stats report request error (attempt {attempt + 1}/{self.max_retries}): {e}") + except Exception as e: + logger.error(f"Stats report unexpected error: {e}") + break + + # 重试延迟 + if attempt < self.max_retries - 1: + await asyncio.sleep(self.retry_delay * (attempt + 1)) + + # 所有重试都失败,持久化到文件 + self._persist_failed_events(events) + + def _persist_failed_events(self, events: List[AICallEvent]): + """持久化失败的事件到文件""" + try: + failed_file = Path(self.FAILED_EVENTS_FILE) + existing = [] + + if failed_file.exists(): + try: + existing = json.loads(failed_file.read_text()) + except (json.JSONDecodeError, IOError): + existing = [] + + # 添加新的失败事件 + for event in events: + existing.append(event.to_dict()) + + # 限制最多保存1000条 + if len(existing) > 1000: + existing = existing[-1000:] + + failed_file.write_text(json.dumps(existing, ensure_ascii=False, indent=2)) + logger.info(f"Persisted {len(events)} failed events to {self.FAILED_EVENTS_FILE}") + except Exception as e: + logger.error(f"Failed to persist events: {e}") + + def _retry_failed_events(self): + """重试之前失败的事件""" + try: + failed_file = Path(self.FAILED_EVENTS_FILE) + if not failed_file.exists(): + return + + events_data = json.loads(failed_file.read_text()) + if not events_data: + return + + logger.info(f"Retrying {len(events_data)} previously failed events") + + # 尝试发送 + payload = {"events": events_data} + url = f"{self.platform_url.rstrip('/')}/api/stats/report/batch" + headers = {"X-API-Key": self.api_key, "Content-Type": "application/json"} + + try: + with httpx.Client(timeout=self.timeout) as client: + response = client.post(url, json=payload, headers=headers) + if response.status_code == 200: + # 成功后删除文件 + failed_file.unlink() + logger.info(f"Successfully sent {len(events_data)} previously failed events") + except Exception as e: + logger.warning(f"Failed to retry events: {e}") + except Exception as e: + logger.error(f"Error loading failed events: {e}") def __del__(self): """析构时发送剩余事件""" - if self._buffer: - self.flush() + try: + if self._buffer: + self.flush() + except Exception: + pass # 忽略析构时的错误