"""费用计算服务""" import logging from datetime import datetime from decimal import Decimal from typing import Optional, Dict, List from functools import lru_cache from sqlalchemy.orm import Session from sqlalchemy import func from ..models.pricing import ModelPricing, TenantBilling from ..models.stats import AICallEvent from .cache import get_cache logger = logging.getLogger(__name__) class CostCalculator: """费用计算器 使用示例: calculator = CostCalculator(db) # 计算单次调用费用 cost = calculator.calculate_cost("gpt-4", input_tokens=100, output_tokens=200) # 生成月度账单 billing = calculator.generate_monthly_billing("qiqi", "2026-01") """ # 默认模型价格(当数据库中无配置时使用) DEFAULT_PRICING = { # OpenAI "gpt-4": {"input": 0.21, "output": 0.42}, # 元/1K tokens "gpt-4-turbo": {"input": 0.07, "output": 0.21}, "gpt-4o": {"input": 0.035, "output": 0.105}, "gpt-4o-mini": {"input": 0.00105, "output": 0.0042}, "gpt-3.5-turbo": {"input": 0.0035, "output": 0.014}, # Anthropic "claude-3-opus": {"input": 0.105, "output": 0.525}, "claude-3-sonnet": {"input": 0.021, "output": 0.105}, "claude-3-haiku": {"input": 0.00175, "output": 0.00875}, "claude-3.5-sonnet": {"input": 0.021, "output": 0.105}, # 国内模型 "qwen-max": {"input": 0.02, "output": 0.06}, "qwen-plus": {"input": 0.004, "output": 0.012}, "qwen-turbo": {"input": 0.002, "output": 0.006}, "glm-4": {"input": 0.01, "output": 0.01}, "glm-4-flash": {"input": 0.0001, "output": 0.0001}, "deepseek-chat": {"input": 0.001, "output": 0.002}, "deepseek-coder": {"input": 0.001, "output": 0.002}, # 默认 "default": {"input": 0.01, "output": 0.03} } def __init__(self, db: Session): self.db = db self._cache = get_cache() self._pricing_cache: Dict[str, ModelPricing] = {} def get_model_pricing(self, model_name: str) -> Optional[ModelPricing]: """获取模型价格配置 Args: model_name: 模型名称 Returns: ModelPricing实例或None """ # 尝试从缓存获取 cache_key = f"pricing:{model_name}" cached = self._cache.get(cache_key) if cached: return self._dict_to_pricing(cached) # 从数据库查询 pricing = self.db.query(ModelPricing).filter( ModelPricing.model_name == model_name, ModelPricing.status == 1 ).first() if pricing: # 缓存1小时 self._cache.set(cache_key, self._pricing_to_dict(pricing), ttl=3600) return pricing return None def _pricing_to_dict(self, pricing: ModelPricing) -> dict: return { "model_name": pricing.model_name, "input_price_per_1k": str(pricing.input_price_per_1k), "output_price_per_1k": str(pricing.output_price_per_1k), "fixed_price_per_call": str(pricing.fixed_price_per_call), "pricing_type": pricing.pricing_type } def _dict_to_pricing(self, d: dict) -> ModelPricing: pricing = ModelPricing() pricing.model_name = d.get("model_name") pricing.input_price_per_1k = Decimal(d.get("input_price_per_1k", "0")) pricing.output_price_per_1k = Decimal(d.get("output_price_per_1k", "0")) pricing.fixed_price_per_call = Decimal(d.get("fixed_price_per_call", "0")) pricing.pricing_type = d.get("pricing_type", "token") return pricing def calculate_cost( self, model_name: str, input_tokens: int = 0, output_tokens: int = 0, call_count: int = 1 ) -> Decimal: """计算调用费用 Args: model_name: 模型名称 input_tokens: 输入token数 output_tokens: 输出token数 call_count: 调用次数 Returns: 费用(元) """ # 尝试获取数据库配置 pricing = self.get_model_pricing(model_name) if pricing: if pricing.pricing_type == 'call': return pricing.fixed_price_per_call * call_count elif pricing.pricing_type == 'hybrid': token_cost = ( pricing.input_price_per_1k * Decimal(input_tokens) / 1000 + pricing.output_price_per_1k * Decimal(output_tokens) / 1000 ) call_cost = pricing.fixed_price_per_call * call_count return token_cost + call_cost else: # token return ( pricing.input_price_per_1k * Decimal(input_tokens) / 1000 + pricing.output_price_per_1k * Decimal(output_tokens) / 1000 ) # 使用默认价格 default_prices = self.DEFAULT_PRICING.get(model_name) or self.DEFAULT_PRICING.get("default") input_price = Decimal(str(default_prices["input"])) output_price = Decimal(str(default_prices["output"])) return ( input_price * Decimal(input_tokens) / 1000 + output_price * Decimal(output_tokens) / 1000 ) def calculate_event_cost(self, event: AICallEvent) -> Decimal: """计算单个事件的费用 Args: event: AI调用事件 Returns: 费用(元) """ return self.calculate_cost( model_name=event.model or "default", input_tokens=event.input_tokens or 0, output_tokens=event.output_tokens or 0 ) def update_event_costs(self, start_date: str = None, end_date: str = None) -> int: """批量更新事件费用 对于cost为0或NULL的事件,重新计算费用 Args: start_date: 开始日期,格式 YYYY-MM-DD end_date: 结束日期,格式 YYYY-MM-DD Returns: 更新的记录数 """ query = self.db.query(AICallEvent).filter( (AICallEvent.cost == None) | (AICallEvent.cost == 0) ) if start_date: query = query.filter(AICallEvent.created_at >= start_date) if end_date: query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59") events = query.all() updated = 0 for event in events: try: cost = self.calculate_event_cost(event) event.cost = cost updated += 1 except Exception as e: logger.error(f"Failed to calculate cost for event {event.id}: {e}") self.db.commit() logger.info(f"Updated {updated} event costs") return updated def generate_monthly_billing( self, tenant_id: str, billing_month: str ) -> TenantBilling: """生成月度账单 Args: tenant_id: 租户ID billing_month: 账单月份,格式 YYYY-MM Returns: TenantBilling实例 """ # 检查是否已存在 existing = self.db.query(TenantBilling).filter( TenantBilling.tenant_id == tenant_id, TenantBilling.billing_month == billing_month ).first() if existing: billing = existing else: billing = TenantBilling( tenant_id=tenant_id, billing_month=billing_month ) self.db.add(billing) # 计算统计数据 start_date = f"{billing_month}-01" year, month = billing_month.split("-") if int(month) == 12: end_date = f"{int(year)+1}-01-01" else: end_date = f"{year}-{int(month)+1:02d}-01" # 聚合查询 stats = self.db.query( func.count(AICallEvent.id).label('total_calls'), func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'), func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'), func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost') ).filter( AICallEvent.tenant_id == tenant_id, AICallEvent.created_at >= start_date, AICallEvent.created_at < end_date ).first() billing.total_calls = stats.total_calls or 0 billing.total_input_tokens = int(stats.total_input or 0) billing.total_output_tokens = int(stats.total_output or 0) billing.total_cost = stats.total_cost or Decimal("0") # 按模型统计 model_stats = self.db.query( AICallEvent.model, func.coalesce(func.sum(AICallEvent.cost), 0).label('cost') ).filter( AICallEvent.tenant_id == tenant_id, AICallEvent.created_at >= start_date, AICallEvent.created_at < end_date ).group_by(AICallEvent.model).all() billing.cost_by_model = { m.model or "unknown": float(m.cost) for m in model_stats } # 按应用统计 app_stats = self.db.query( AICallEvent.app_code, func.coalesce(func.sum(AICallEvent.cost), 0).label('cost') ).filter( AICallEvent.tenant_id == tenant_id, AICallEvent.created_at >= start_date, AICallEvent.created_at < end_date ).group_by(AICallEvent.app_code).all() billing.cost_by_app = { a.app_code or "unknown": float(a.cost) for a in app_stats } self.db.commit() self.db.refresh(billing) return billing def get_cost_summary( self, tenant_id: str = None, start_date: str = None, end_date: str = None ) -> Dict: """获取费用汇总 Args: tenant_id: 租户ID(可选) start_date: 开始日期 end_date: 结束日期 Returns: 费用汇总字典 """ query = self.db.query( func.count(AICallEvent.id).label('total_calls'), func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('total_input'), func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('total_output'), func.coalesce(func.sum(AICallEvent.cost), 0).label('total_cost') ) if tenant_id: query = query.filter(AICallEvent.tenant_id == tenant_id) if start_date: query = query.filter(AICallEvent.created_at >= start_date) if end_date: query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59") stats = query.first() return { "total_calls": stats.total_calls or 0, "total_input_tokens": int(stats.total_input or 0), "total_output_tokens": int(stats.total_output or 0), "total_cost": float(stats.total_cost or 0) } def get_cost_by_tenant( self, start_date: str = None, end_date: str = None ) -> List[Dict]: """按租户统计费用 Returns: 租户费用列表 """ query = self.db.query( AICallEvent.tenant_id, func.count(AICallEvent.id).label('calls'), func.coalesce(func.sum(AICallEvent.cost), 0).label('cost') ) if start_date: query = query.filter(AICallEvent.created_at >= start_date) if end_date: query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59") results = query.group_by(AICallEvent.tenant_id).order_by( func.sum(AICallEvent.cost).desc() ).all() return [ { "tenant_id": r.tenant_id, "calls": r.calls, "cost": float(r.cost) } for r in results ] def get_cost_by_model( self, tenant_id: str = None, start_date: str = None, end_date: str = None ) -> List[Dict]: """按模型统计费用 Returns: 模型费用列表 """ query = self.db.query( AICallEvent.model, func.count(AICallEvent.id).label('calls'), func.coalesce(func.sum(AICallEvent.input_tokens), 0).label('input_tokens'), func.coalesce(func.sum(AICallEvent.output_tokens), 0).label('output_tokens'), func.coalesce(func.sum(AICallEvent.cost), 0).label('cost') ) if tenant_id: query = query.filter(AICallEvent.tenant_id == tenant_id) if start_date: query = query.filter(AICallEvent.created_at >= start_date) if end_date: query = query.filter(AICallEvent.created_at <= end_date + " 23:59:59") results = query.group_by(AICallEvent.model).order_by( func.sum(AICallEvent.cost).desc() ).all() return [ { "model": r.model or "unknown", "calls": r.calls, "input_tokens": int(r.input_tokens), "output_tokens": int(r.output_tokens), "cost": float(r.cost) } for r in results ] # 便捷函数 def calculate_cost( db: Session, model_name: str, input_tokens: int = 0, output_tokens: int = 0 ) -> Decimal: """快速计算费用""" calculator = CostCalculator(db) return calculator.calculate_cost(model_name, input_tokens, output_tokens)