Initial commit: 智能项目定价模型

This commit is contained in:
kuzma
2026-01-31 21:33:06 +08:00
commit ef0824303f
174 changed files with 31705 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""服务层单元测试"""

View File

@@ -0,0 +1,415 @@
"""成本计算服务单元测试
测试 CostService 的核心业务逻辑
"""
import pytest
from decimal import Decimal
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.cost_service import CostService
from app.schemas.project_cost import AllocationMethod, CostItemType
from app.models import (
Material, Equipment, StaffLevel, Project, FixedCost,
ProjectCostItem, ProjectLaborCost, ProjectCostSummary
)
class TestCostService:
"""成本服务测试类"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_material_info(
self,
db_session: AsyncSession,
sample_material: Material
):
"""测试获取耗材信息"""
service = CostService(db_session)
# 获取存在的耗材
material = await service.get_material_info(sample_material.id)
assert material is not None
assert material.material_name == "冷凝胶"
assert material.unit_price == Decimal("2.00")
# 获取不存在的耗材
material = await service.get_material_info(99999)
assert material is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_equipment_info(
self,
db_session: AsyncSession,
sample_equipment: Equipment
):
"""测试获取设备信息"""
service = CostService(db_session)
# 获取存在的设备
equipment = await service.get_equipment_info(sample_equipment.id)
assert equipment is not None
assert equipment.equipment_name == "光子仪"
assert equipment.depreciation_per_use == Decimal("47.50")
# 获取不存在的设备
equipment = await service.get_equipment_info(99999)
assert equipment is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_staff_level_info(
self,
db_session: AsyncSession,
sample_staff_level: StaffLevel
):
"""测试获取人员级别信息"""
service = CostService(db_session)
# 获取存在的级别
level = await service.get_staff_level_info(sample_staff_level.id)
assert level is not None
assert level.level_name == "中级美容师"
assert level.hourly_rate == Decimal("50.00")
# 获取不存在的级别
level = await service.get_staff_level_info(99999)
assert level is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_material_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project
):
"""测试耗材成本计算"""
service = CostService(db_session)
total, breakdown = await service.calculate_material_cost(
sample_project_with_costs.id
)
assert total == Decimal("40.00") # 20 * 2.00
assert len(breakdown) == 1
assert breakdown[0]["name"] == "冷凝胶"
assert breakdown[0]["quantity"] == 20.0
assert breakdown[0]["total"] == 40.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_equipment_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project
):
"""测试设备折旧成本计算"""
service = CostService(db_session)
total, breakdown = await service.calculate_equipment_cost(
sample_project_with_costs.id
)
assert total == Decimal("47.50") # 1 * 47.50
assert len(breakdown) == 1
assert breakdown[0]["name"] == "光子仪"
assert breakdown[0]["depreciation_per_use"] == 47.5
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_labor_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project
):
"""测试人工成本计算"""
service = CostService(db_session)
total, breakdown = await service.calculate_labor_cost(
sample_project_with_costs.id
)
assert total == Decimal("50.00") # 60分钟 / 60 * 50
assert len(breakdown) == 1
assert breakdown[0]["name"] == "中级美容师"
assert breakdown[0]["duration_minutes"] == 60
assert breakdown[0]["hourly_rate"] == 50.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_fixed_cost_allocation_by_count(
self,
db_session: AsyncSession,
sample_project: Project,
sample_fixed_cost: FixedCost
):
"""测试固定成本按项目数量分摊"""
service = CostService(db_session)
allocation, detail = await service.calculate_fixed_cost_allocation(
sample_project.id,
method=AllocationMethod.COUNT
)
# 只有一个项目,分摊全部固定成本
assert allocation == Decimal("30000.00")
assert detail["method"] == "count"
assert detail["project_count"] == 1
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_fixed_cost_allocation_by_duration(
self,
db_session: AsyncSession,
sample_project: Project,
sample_fixed_cost: FixedCost
):
"""测试固定成本按时长分摊"""
service = CostService(db_session)
allocation, detail = await service.calculate_fixed_cost_allocation(
sample_project.id,
method=AllocationMethod.DURATION
)
# 只有一个项目,占比 100%
assert allocation == Decimal("30000.00")
assert detail["method"] == "duration"
assert detail["project_duration"] == 60
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_project_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project,
sample_fixed_cost: FixedCost
):
"""测试项目总成本计算"""
service = CostService(db_session)
result = await service.calculate_project_cost(
sample_project_with_costs.id,
allocation_method=AllocationMethod.COUNT
)
assert result.project_id == sample_project_with_costs.id
assert result.project_name == "光子嫩肤"
# 验证成本构成
breakdown = result.cost_breakdown
assert breakdown["material_cost"]["subtotal"] == 40.0
assert breakdown["equipment_cost"]["subtotal"] == 47.5
assert breakdown["labor_cost"]["subtotal"] == 50.0
# 总成本 = 耗材40 + 设备47.5 + 人工50 + 固定30000
expected_total = 40 + 47.5 + 50 + 30000
assert result.total_cost == expected_total
assert result.min_price_suggestion == expected_total
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_project_cost_not_found(
self,
db_session: AsyncSession
):
"""测试项目不存在时的错误处理"""
service = CostService(db_session)
with pytest.raises(ValueError, match="项目不存在"):
await service.calculate_project_cost(99999)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_cost_item_material(
self,
db_session: AsyncSession,
sample_project: Project,
sample_material: Material
):
"""测试添加耗材成本明细"""
service = CostService(db_session)
cost_item = await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.MATERIAL,
item_id=sample_material.id,
quantity=10,
remark="测试备注"
)
assert cost_item.project_id == sample_project.id
assert cost_item.item_type == "material"
assert cost_item.quantity == Decimal("10")
assert cost_item.unit_cost == Decimal("2.00")
assert cost_item.total_cost == Decimal("20.00")
assert cost_item.remark == "测试备注"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_cost_item_equipment(
self,
db_session: AsyncSession,
sample_project: Project,
sample_equipment: Equipment
):
"""测试添加设备折旧成本明细"""
service = CostService(db_session)
cost_item = await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.EQUIPMENT,
item_id=sample_equipment.id,
quantity=1,
)
assert cost_item.item_type == "equipment"
assert cost_item.unit_cost == Decimal("47.50")
assert cost_item.total_cost == Decimal("47.50")
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_cost_item_not_found(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试添加不存在的耗材/设备时的错误处理"""
service = CostService(db_session)
with pytest.raises(ValueError, match="耗材不存在"):
await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.MATERIAL,
item_id=99999,
quantity=1,
)
with pytest.raises(ValueError, match="设备不存在"):
await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.EQUIPMENT,
item_id=99999,
quantity=1,
)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_labor_cost(
self,
db_session: AsyncSession,
sample_project: Project,
sample_staff_level: StaffLevel
):
"""测试添加人工成本"""
service = CostService(db_session)
labor_cost = await service.add_labor_cost(
project_id=sample_project.id,
staff_level_id=sample_staff_level.id,
duration_minutes=30,
remark="测试人工"
)
assert labor_cost.project_id == sample_project.id
assert labor_cost.duration_minutes == 30
assert labor_cost.hourly_rate == Decimal("50.00")
# 30分钟 / 60 * 50 = 25
assert labor_cost.labor_cost == Decimal("25.00")
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_labor_cost_not_found(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试添加不存在的人员级别时的错误处理"""
service = CostService(db_session)
with pytest.raises(ValueError, match="人员级别不存在"):
await service.add_labor_cost(
project_id=sample_project.id,
staff_level_id=99999,
duration_minutes=30,
)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_update_cost_item(
self,
db_session: AsyncSession,
sample_project: Project,
sample_material: Material
):
"""测试更新成本明细"""
service = CostService(db_session)
# 先添加
cost_item = await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.MATERIAL,
item_id=sample_material.id,
quantity=10,
)
# 更新数量
updated = await service.update_cost_item(
cost_item=cost_item,
quantity=20,
remark="更新后备注"
)
assert updated.quantity == Decimal("20")
assert updated.total_cost == Decimal("40.00") # 20 * 2
assert updated.remark == "更新后备注"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_update_labor_cost(
self,
db_session: AsyncSession,
sample_project: Project,
sample_staff_level: StaffLevel
):
"""测试更新人工成本"""
service = CostService(db_session)
# 先添加
labor = await service.add_labor_cost(
project_id=sample_project.id,
staff_level_id=sample_staff_level.id,
duration_minutes=30,
)
# 更新时长
updated = await service.update_labor_cost(
labor_item=labor,
duration_minutes=60,
)
assert updated.duration_minutes == 60
assert updated.labor_cost == Decimal("50.00") # 60/60 * 50
@pytest.mark.unit
@pytest.mark.asyncio
async def test_empty_project_cost(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试没有成本明细的项目计算"""
service = CostService(db_session)
# 计算空项目成本(无固定成本)
total_material, _ = await service.calculate_material_cost(sample_project.id)
total_equipment, _ = await service.calculate_equipment_cost(sample_project.id)
total_labor, _ = await service.calculate_labor_cost(sample_project.id)
assert total_material == Decimal("0")
assert total_equipment == Decimal("0")
assert total_labor == Decimal("0")

View File

@@ -0,0 +1,305 @@
"""市场分析服务单元测试
测试 MarketService 的核心业务逻辑
"""
import pytest
from decimal import Decimal
from datetime import date
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.market_service import MarketService
from app.models import (
Project, Competitor, CompetitorPrice, BenchmarkPrice, Category
)
class TestMarketService:
"""市场分析服务测试类"""
@pytest.mark.unit
def test_calculate_price_statistics_empty(self):
"""测试空价格列表的统计"""
service = MarketService(None) # 不需要 db
stats = service.calculate_price_statistics([])
assert stats.min_price == 0
assert stats.max_price == 0
assert stats.avg_price == 0
assert stats.median_price == 0
assert stats.std_deviation is None or stats.std_deviation == 0
@pytest.mark.unit
def test_calculate_price_statistics_single(self):
"""测试单个价格的统计"""
service = MarketService(None)
stats = service.calculate_price_statistics([500.0])
assert stats.min_price == 500.0
assert stats.max_price == 500.0
assert stats.avg_price == 500.0
assert stats.median_price == 500.0
@pytest.mark.unit
def test_calculate_price_statistics_multiple(self):
"""测试多个价格的统计"""
service = MarketService(None)
prices = [300.0, 400.0, 500.0, 600.0, 700.0]
stats = service.calculate_price_statistics(prices)
assert stats.min_price == 300.0
assert stats.max_price == 700.0
assert stats.avg_price == 500.0 # (300+400+500+600+700)/5
assert stats.median_price == 500.0 # 中位数
assert stats.std_deviation is not None
assert stats.std_deviation > 0
@pytest.mark.unit
def test_calculate_price_distribution(self):
"""测试价格分布计算"""
service = MarketService(None)
# 价格范围 300-900分为三个区间
# 低: 300-500, 中: 500-700, 高: 700-900
prices = [350.0, 450.0, 550.0, 650.0, 750.0, 850.0]
distribution = service.calculate_price_distribution(
prices=prices,
min_price=300.0,
max_price=900.0
)
# 验证分布
assert distribution.low.count == 2 # 350, 450
assert distribution.medium.count == 2 # 550, 650
assert distribution.high.count == 2 # 750, 850
# 验证百分比
assert distribution.low.percentage == pytest.approx(33.3, rel=0.1)
assert distribution.medium.percentage == pytest.approx(33.3, rel=0.1)
assert distribution.high.percentage == pytest.approx(33.3, rel=0.1)
@pytest.mark.unit
def test_calculate_price_distribution_empty(self):
"""测试空价格列表的分布"""
service = MarketService(None)
distribution = service.calculate_price_distribution(
prices=[],
min_price=0,
max_price=0
)
assert distribution.low.count == 0
assert distribution.medium.count == 0
assert distribution.high.count == 0
@pytest.mark.unit
def test_calculate_suggested_range(self):
"""测试建议定价区间计算"""
service = MarketService(None)
suggested = service.calculate_suggested_range(
avg_price=500.0,
min_price=300.0,
max_price=700.0,
benchmark_avg=None
)
# 以均价为中心 ±20%
assert suggested.min == pytest.approx(400.0, rel=0.01) # 500 * 0.8
assert suggested.max == pytest.approx(600.0, rel=0.01) # 500 * 1.2
assert suggested.recommended == 500.0
@pytest.mark.unit
def test_calculate_suggested_range_with_benchmark(self):
"""测试带标杆参考的建议定价区间"""
service = MarketService(None)
suggested = service.calculate_suggested_range(
avg_price=500.0,
min_price=300.0,
max_price=700.0,
benchmark_avg=600.0
)
# 推荐价格 = 市场均价 * 0.6 + 标杆均价 * 0.4
expected_recommended = 500 * 0.6 + 600 * 0.4 # 540
assert suggested.recommended == pytest.approx(expected_recommended, rel=0.01)
@pytest.mark.unit
def test_calculate_suggested_range_zero_avg(self):
"""测试均价为0时的处理"""
service = MarketService(None)
suggested = service.calculate_suggested_range(
avg_price=0,
min_price=0,
max_price=0,
benchmark_avg=None
)
assert suggested.min == 0
assert suggested.max == 0
assert suggested.recommended == 0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_competitor_prices_for_project(
self,
db_session: AsyncSession,
sample_competitor_price: CompetitorPrice,
sample_project: Project
):
"""测试获取项目的竞品价格"""
service = MarketService(db_session)
prices = await service.get_competitor_prices_for_project(
sample_project.id
)
assert len(prices) == 1
assert float(prices[0].original_price) == 680.0
assert float(prices[0].promo_price) == 480.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_competitor_prices_filter_by_competitor(
self,
db_session: AsyncSession,
sample_competitor_price: CompetitorPrice,
sample_project: Project,
sample_competitor: Competitor
):
"""测试按竞品机构筛选价格"""
service = MarketService(db_session)
# 指定竞品ID
prices = await service.get_competitor_prices_for_project(
sample_project.id,
competitor_ids=[sample_competitor.id]
)
assert len(prices) == 1
# 指定不存在的竞品ID
prices = await service.get_competitor_prices_for_project(
sample_project.id,
competitor_ids=[99999]
)
assert len(prices) == 0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_analyze_market(
self,
db_session: AsyncSession,
sample_project: Project,
sample_competitor_price: CompetitorPrice
):
"""测试市场分析"""
service = MarketService(db_session)
result = await service.analyze_market(
project_id=sample_project.id,
include_benchmark=False
)
assert result.project_id == sample_project.id
assert result.project_name == "光子嫩肤"
assert result.competitor_count == 1
assert result.price_statistics.min_price == 680.0
assert result.price_statistics.max_price == 680.0
assert result.price_statistics.avg_price == 680.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_analyze_market_not_found(
self,
db_session: AsyncSession
):
"""测试项目不存在时的错误处理"""
service = MarketService(db_session)
with pytest.raises(ValueError, match="项目不存在"):
await service.analyze_market(project_id=99999)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_latest_analysis(
self,
db_session: AsyncSession,
sample_project: Project,
sample_competitor_price: CompetitorPrice
):
"""测试获取最新分析结果"""
service = MarketService(db_session)
# 先执行分析
await service.analyze_market(
project_id=sample_project.id,
include_benchmark=False
)
# 获取最新结果
latest = await service.get_latest_analysis(sample_project.id)
assert latest is not None
assert latest.project_id == sample_project.id
assert float(latest.market_avg_price) == 680.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_benchmark_prices_empty(
self,
db_session: AsyncSession
):
"""测试没有标杆价格时的处理"""
service = MarketService(db_session)
benchmarks = await service.get_benchmark_prices_for_category(None)
assert benchmarks == []
benchmarks = await service.get_benchmark_prices_for_category(99999)
assert benchmarks == []
class TestMarketServiceEdgeCases:
"""市场分析服务边界情况测试"""
@pytest.mark.unit
def test_price_distribution_same_min_max(self):
"""测试最小最大价相同时的分布"""
service = MarketService(None)
distribution = service.calculate_price_distribution(
prices=[500.0, 500.0],
min_price=500.0,
max_price=500.0
)
# 应返回 N/A
assert distribution.low.range == "N/A"
@pytest.mark.unit
def test_statistics_with_outliers(self):
"""测试包含极端值的统计"""
service = MarketService(None)
# 包含一个极端高价
prices = [300.0, 400.0, 500.0, 600.0, 5000.0]
stats = service.calculate_price_statistics(prices)
assert stats.min_price == 300.0
assert stats.max_price == 5000.0
# 均值会被拉高
assert stats.avg_price == 1360.0 # (300+400+500+600+5000)/5
# 中位数不受极端值影响
assert stats.median_price == 500.0
# 标准差会很大
assert stats.std_deviation > 1000

View File

@@ -0,0 +1,369 @@
"""智能定价服务单元测试
测试 PricingService 的核心业务逻辑
"""
import pytest
from decimal import Decimal
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.pricing_service import PricingService
from app.schemas.pricing import (
StrategyType, MarketReference, StrategySuggestion, PricingSuggestions
)
from app.models import Project, ProjectCostSummary, PricingPlan
class TestPricingService:
"""智能定价服务测试类"""
@pytest.mark.unit
def test_calculate_strategy_price_traffic(self):
"""测试引流款定价策略"""
service = PricingService(None)
suggestion = service.calculate_strategy_price(
base_cost=100.0,
strategy=StrategyType.TRAFFIC,
)
# 引流款利润率 10%-20%,使用中间值 15%
# 价格 = 100 / (1 - 0.15) ≈ 117.65
assert suggestion.strategy == "引流款"
assert suggestion.suggested_price > 100 # 大于成本
assert suggestion.suggested_price < 130 # 利润率适中
assert suggestion.margin > 0
assert "引流" in suggestion.description
@pytest.mark.unit
def test_calculate_strategy_price_profit(self):
"""测试利润款定价策略"""
service = PricingService(None)
suggestion = service.calculate_strategy_price(
base_cost=100.0,
strategy=StrategyType.PROFIT,
target_margin=50, # 50% 目标毛利率
)
# 价格 = 100 / (1 - 0.5) = 200
assert suggestion.strategy == "利润款"
assert suggestion.suggested_price >= 200
assert suggestion.margin >= 45 # 接近目标
assert "日常" in suggestion.description
@pytest.mark.unit
def test_calculate_strategy_price_premium(self):
"""测试高端款定价策略"""
service = PricingService(None)
suggestion = service.calculate_strategy_price(
base_cost=100.0,
strategy=StrategyType.PREMIUM,
)
# 高端款利润率 60%-80%,使用中间值 70%
# 价格 = 100 / (1 - 0.7) ≈ 333
assert suggestion.strategy == "高端款"
assert suggestion.suggested_price > 300
assert suggestion.margin > 60
assert "高端" in suggestion.description
@pytest.mark.unit
def test_calculate_strategy_price_with_market_reference(self):
"""测试带市场参考的定价"""
service = PricingService(None)
market_ref = MarketReference(min=80.0, max=150.0, avg=100.0)
# 引流款应该参考市场最低价
suggestion = service.calculate_strategy_price(
base_cost=50.0,
strategy=StrategyType.TRAFFIC,
market_ref=market_ref,
)
# 应该取市场最低价的 90% 和成本定价的较低者
assert suggestion.suggested_price <= 100 # 不会太高
assert suggestion.suggested_price >= 50 * 1.05 # 不低于成本
@pytest.mark.unit
def test_calculate_strategy_price_ensures_profit(self):
"""测试确保价格不低于成本"""
service = PricingService(None)
market_ref = MarketReference(min=30.0, max=50.0, avg=40.0)
# 即使市场价很低,也不能低于成本
suggestion = service.calculate_strategy_price(
base_cost=100.0, # 成本高于市场价
strategy=StrategyType.TRAFFIC,
market_ref=market_ref,
)
# 价格至少是成本的 1.05 倍
assert suggestion.suggested_price >= 100 * 1.05
@pytest.mark.unit
def test_calculate_all_strategies(self):
"""测试计算所有策略"""
service = PricingService(None)
suggestions = service.calculate_all_strategies(
base_cost=100.0,
target_margin=50.0,
)
assert suggestions.traffic is not None
assert suggestions.profit is not None
assert suggestions.premium is not None
# 价格应该递增:引流款 < 利润款 < 高端款
assert suggestions.traffic.suggested_price < suggestions.profit.suggested_price
assert suggestions.profit.suggested_price < suggestions.premium.suggested_price
@pytest.mark.unit
def test_calculate_all_strategies_selected(self):
"""测试只计算选定的策略"""
service = PricingService(None)
suggestions = service.calculate_all_strategies(
base_cost=100.0,
target_margin=50.0,
strategies=[StrategyType.TRAFFIC, StrategyType.PROFIT],
)
assert suggestions.traffic is not None
assert suggestions.profit is not None
assert suggestions.premium is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_project_with_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project
):
"""测试获取项目及成本"""
service = PricingService(db_session)
project, cost_summary = await service.get_project_with_cost(
sample_project_with_costs.id
)
assert project.id == sample_project_with_costs.id
assert project.project_name == "光子嫩肤"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_project_with_cost_not_found(
self,
db_session: AsyncSession
):
"""测试项目不存在时的错误处理"""
service = PricingService(db_session)
with pytest.raises(ValueError, match="项目不存在"):
await service.get_project_with_cost(99999)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_create_pricing_plan(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试创建定价方案"""
# 先添加成本汇总
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
material_cost=Decimal("40.00"),
equipment_cost=Decimal("50.00"),
labor_cost=Decimal("60.00"),
fixed_cost_allocation=Decimal("30.00"),
total_cost=Decimal("180.00"),
calculated_at=datetime.now(),
)
db_session.add(cost_summary)
await db_session.commit()
service = PricingService(db_session)
plan = await service.create_pricing_plan(
project_id=sample_project.id,
plan_name="测试定价方案",
strategy_type=StrategyType.PROFIT,
target_margin=50.0,
)
assert plan.project_id == sample_project.id
assert plan.plan_name == "测试定价方案"
assert plan.strategy_type == "profit"
assert float(plan.target_margin) == 50.0
assert float(plan.base_cost) == 180.0
assert plan.suggested_price > plan.base_cost
@pytest.mark.unit
@pytest.mark.asyncio
async def test_update_pricing_plan(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan
):
"""测试更新定价方案"""
service = PricingService(db_session)
updated = await service.update_pricing_plan(
plan_id=sample_pricing_plan.id,
final_price=599.00,
plan_name="更新后方案名",
)
assert float(updated.final_price) == 599.00
assert updated.plan_name == "更新后方案名"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_update_pricing_plan_not_found(
self,
db_session: AsyncSession
):
"""测试更新不存在的方案"""
service = PricingService(db_session)
with pytest.raises(ValueError, match="定价方案不存在"):
await service.update_pricing_plan(
plan_id=99999,
final_price=599.00,
)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_simulate_strategies(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试策略模拟"""
# 添加成本汇总
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
total_cost=Decimal("200.00"),
material_cost=Decimal("100.00"),
equipment_cost=Decimal("50.00"),
labor_cost=Decimal("50.00"),
fixed_cost_allocation=Decimal("0.00"),
calculated_at=datetime.now(),
)
db_session.add(cost_summary)
await db_session.commit()
service = PricingService(db_session)
response = await service.simulate_strategies(
project_id=sample_project.id,
strategies=[StrategyType.TRAFFIC, StrategyType.PROFIT, StrategyType.PREMIUM],
target_margin=50.0,
)
assert response.project_id == sample_project.id
assert response.base_cost == 200.0
assert len(response.results) == 3
# 验证结果排序
prices = [r.suggested_price for r in response.results]
assert prices == sorted(prices) # 应该是升序
@pytest.mark.unit
def test_format_cost_data(self):
"""测试成本数据格式化"""
service = PricingService(None)
# 测试空数据
result = service._format_cost_data(None)
assert "暂无成本数据" in result
@pytest.mark.unit
def test_format_market_data(self):
"""测试市场数据格式化"""
service = PricingService(None)
# 测试空数据
result = service._format_market_data(None)
assert "暂无市场行情数据" in result
# 测试有数据
market_ref = MarketReference(min=100.0, max=500.0, avg=300.0)
result = service._format_market_data(market_ref)
assert "100.00" in result
assert "500.00" in result
assert "300.00" in result
@pytest.mark.unit
def test_extract_recommendations(self):
"""测试提取 AI 建议列表"""
service = PricingService(None)
content = """
根据分析,建议如下:
- 建议一:常规定价 580 元
- 建议二:新客首单 388 元
* 建议三VIP 会员 520 元
1. 定期促销活动
2. 会员体系建设
"""
recommendations = service._extract_recommendations(content)
assert len(recommendations) == 5
assert "常规定价" in recommendations[0]
class TestPricingServiceWithAI:
"""需要 AI 服务的定价测试"""
@pytest.mark.unit
@pytest.mark.asyncio
@patch('app.services.pricing_service.AIServiceWrapper')
async def test_generate_pricing_advice_ai_failure(
self,
mock_ai_wrapper,
db_session: AsyncSession,
sample_project: Project
):
"""测试 AI 调用失败时的降级处理"""
# 添加成本汇总
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
total_cost=Decimal("200.00"),
material_cost=Decimal("100.00"),
equipment_cost=Decimal("50.00"),
labor_cost=Decimal("50.00"),
fixed_cost_allocation=Decimal("0.00"),
calculated_at=datetime.now(),
)
db_session.add(cost_summary)
await db_session.commit()
# 模拟 AI 调用失败
mock_instance = MagicMock()
mock_instance.chat = AsyncMock(side_effect=Exception("AI 服务不可用"))
mock_ai_wrapper.return_value = mock_instance
service = PricingService(db_session)
# 即使 AI 失败,基本定价计算应该仍然返回
response = await service.generate_pricing_advice(
project_id=sample_project.id,
target_margin=50.0,
)
# 验证基本定价仍然可用
assert response.project_id == sample_project.id
assert response.cost_base == 200.0
assert response.pricing_suggestions is not None
# AI 建议可能为空
assert response.ai_advice is None or response.ai_usage is None

View File

@@ -0,0 +1,211 @@
"""利润模拟服务单元测试
测试 ProfitService 的核心业务逻辑
"""
import pytest
from decimal import Decimal
from unittest.mock import AsyncMock, patch, MagicMock
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.profit_service import ProfitService
from app.schemas.profit import PeriodType
from app.models import PricingPlan, FixedCost
class TestProfitService:
"""利润模拟服务测试类"""
@pytest.mark.unit
def test_calculate_profit_basic(self):
"""测试基础利润计算"""
service = ProfitService(None)
revenue, cost, profit, margin = service.calculate_profit(
price=100.0,
cost_per_unit=60.0,
volume=100
)
assert revenue == 10000.0
assert cost == 6000.0
assert profit == 4000.0
assert margin == 40.0
@pytest.mark.unit
def test_calculate_profit_zero_revenue(self):
"""测试零收入时的处理"""
service = ProfitService(None)
revenue, cost, profit, margin = service.calculate_profit(
price=100.0,
cost_per_unit=60.0,
volume=0
)
assert revenue == 0
assert cost == 0
assert profit == 0
assert margin == 0
@pytest.mark.unit
def test_calculate_profit_negative(self):
"""测试亏损情况"""
service = ProfitService(None)
revenue, cost, profit, margin = service.calculate_profit(
price=50.0,
cost_per_unit=60.0,
volume=100
)
assert revenue == 5000.0
assert cost == 6000.0
assert profit == -1000.0
assert margin == -20.0
@pytest.mark.unit
def test_calculate_breakeven_basic(self):
"""测试基础盈亏平衡计算"""
service = ProfitService(None)
breakeven = service.calculate_breakeven(
price=100.0,
variable_cost=60.0,
fixed_cost=0
)
assert breakeven == 1
@pytest.mark.unit
def test_calculate_breakeven_with_fixed_cost(self):
"""测试有固定成本的盈亏平衡"""
service = ProfitService(None)
breakeven = service.calculate_breakeven(
price=100.0,
variable_cost=60.0,
fixed_cost=4000.0
)
assert breakeven == 101
@pytest.mark.unit
def test_calculate_breakeven_no_margin(self):
"""测试边际贡献为负时的处理"""
service = ProfitService(None)
breakeven = service.calculate_breakeven(
price=50.0,
variable_cost=60.0,
fixed_cost=1000.0
)
assert breakeven == 999999
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_pricing_plan(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan
):
"""测试获取定价方案"""
service = ProfitService(db_session)
plan = await service.get_pricing_plan(sample_pricing_plan.id)
assert plan.id == sample_pricing_plan.id
assert plan.plan_name == "2026年Q1定价"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_pricing_plan_not_found(
self,
db_session: AsyncSession
):
"""测试获取不存在的方案"""
service = ProfitService(db_session)
with pytest.raises(ValueError, match="定价方案不存在"):
await service.get_pricing_plan(99999)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_monthly_fixed_cost(
self,
db_session: AsyncSession,
sample_fixed_cost: FixedCost
):
"""测试获取月度固定成本"""
service = ProfitService(db_session)
total = await service.get_monthly_fixed_cost()
assert total == Decimal("30000.00")
@pytest.mark.unit
@pytest.mark.asyncio
async def test_simulate_profit(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan
):
"""测试利润模拟"""
service = ProfitService(db_session)
response = await service.simulate_profit(
pricing_plan_id=sample_pricing_plan.id,
price=580.0,
estimated_volume=100,
period_type=PeriodType.MONTHLY,
)
assert response.pricing_plan_id == sample_pricing_plan.id
assert response.input.price == 580.0
assert response.input.estimated_volume == 100
@pytest.mark.unit
@pytest.mark.asyncio
async def test_sensitivity_analysis(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan
):
"""测试敏感性分析"""
service = ProfitService(db_session)
sim_response = await service.simulate_profit(
pricing_plan_id=sample_pricing_plan.id,
price=580.0,
estimated_volume=100,
period_type=PeriodType.MONTHLY,
)
response = await service.sensitivity_analysis(
simulation_id=sim_response.simulation_id,
price_change_rates=[-20, -10, 0, 10, 20]
)
assert response.simulation_id == sim_response.simulation_id
assert len(response.sensitivity_results) == 5
@pytest.mark.unit
@pytest.mark.asyncio
async def test_breakeven_analysis(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan,
sample_fixed_cost: FixedCost
):
"""测试盈亏平衡分析"""
service = ProfitService(db_session)
response = await service.breakeven_analysis(
pricing_plan_id=sample_pricing_plan.id
)
assert response.pricing_plan_id == sample_pricing_plan.id
assert response.price > 0
assert response.breakeven_volume > 0