Files
smart-project-pricing/后端服务/tests/test_services/test_market_service.py
2026-01-31 21:33:06 +08:00

306 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""市场分析服务单元测试
测试 MarketService 的核心业务逻辑
"""
import pytest
from decimal import Decimal
from datetime import date
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.market_service import MarketService
from app.models import (
Project, Competitor, CompetitorPrice, BenchmarkPrice, Category
)
class TestMarketService:
"""市场分析服务测试类"""
@pytest.mark.unit
def test_calculate_price_statistics_empty(self):
"""测试空价格列表的统计"""
service = MarketService(None) # 不需要 db
stats = service.calculate_price_statistics([])
assert stats.min_price == 0
assert stats.max_price == 0
assert stats.avg_price == 0
assert stats.median_price == 0
assert stats.std_deviation is None or stats.std_deviation == 0
@pytest.mark.unit
def test_calculate_price_statistics_single(self):
"""测试单个价格的统计"""
service = MarketService(None)
stats = service.calculate_price_statistics([500.0])
assert stats.min_price == 500.0
assert stats.max_price == 500.0
assert stats.avg_price == 500.0
assert stats.median_price == 500.0
@pytest.mark.unit
def test_calculate_price_statistics_multiple(self):
"""测试多个价格的统计"""
service = MarketService(None)
prices = [300.0, 400.0, 500.0, 600.0, 700.0]
stats = service.calculate_price_statistics(prices)
assert stats.min_price == 300.0
assert stats.max_price == 700.0
assert stats.avg_price == 500.0 # (300+400+500+600+700)/5
assert stats.median_price == 500.0 # 中位数
assert stats.std_deviation is not None
assert stats.std_deviation > 0
@pytest.mark.unit
def test_calculate_price_distribution(self):
"""测试价格分布计算"""
service = MarketService(None)
# 价格范围 300-900分为三个区间
# 低: 300-500, 中: 500-700, 高: 700-900
prices = [350.0, 450.0, 550.0, 650.0, 750.0, 850.0]
distribution = service.calculate_price_distribution(
prices=prices,
min_price=300.0,
max_price=900.0
)
# 验证分布
assert distribution.low.count == 2 # 350, 450
assert distribution.medium.count == 2 # 550, 650
assert distribution.high.count == 2 # 750, 850
# 验证百分比
assert distribution.low.percentage == pytest.approx(33.3, rel=0.1)
assert distribution.medium.percentage == pytest.approx(33.3, rel=0.1)
assert distribution.high.percentage == pytest.approx(33.3, rel=0.1)
@pytest.mark.unit
def test_calculate_price_distribution_empty(self):
"""测试空价格列表的分布"""
service = MarketService(None)
distribution = service.calculate_price_distribution(
prices=[],
min_price=0,
max_price=0
)
assert distribution.low.count == 0
assert distribution.medium.count == 0
assert distribution.high.count == 0
@pytest.mark.unit
def test_calculate_suggested_range(self):
"""测试建议定价区间计算"""
service = MarketService(None)
suggested = service.calculate_suggested_range(
avg_price=500.0,
min_price=300.0,
max_price=700.0,
benchmark_avg=None
)
# 以均价为中心 ±20%
assert suggested.min == pytest.approx(400.0, rel=0.01) # 500 * 0.8
assert suggested.max == pytest.approx(600.0, rel=0.01) # 500 * 1.2
assert suggested.recommended == 500.0
@pytest.mark.unit
def test_calculate_suggested_range_with_benchmark(self):
"""测试带标杆参考的建议定价区间"""
service = MarketService(None)
suggested = service.calculate_suggested_range(
avg_price=500.0,
min_price=300.0,
max_price=700.0,
benchmark_avg=600.0
)
# 推荐价格 = 市场均价 * 0.6 + 标杆均价 * 0.4
expected_recommended = 500 * 0.6 + 600 * 0.4 # 540
assert suggested.recommended == pytest.approx(expected_recommended, rel=0.01)
@pytest.mark.unit
def test_calculate_suggested_range_zero_avg(self):
"""测试均价为0时的处理"""
service = MarketService(None)
suggested = service.calculate_suggested_range(
avg_price=0,
min_price=0,
max_price=0,
benchmark_avg=None
)
assert suggested.min == 0
assert suggested.max == 0
assert suggested.recommended == 0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_competitor_prices_for_project(
self,
db_session: AsyncSession,
sample_competitor_price: CompetitorPrice,
sample_project: Project
):
"""测试获取项目的竞品价格"""
service = MarketService(db_session)
prices = await service.get_competitor_prices_for_project(
sample_project.id
)
assert len(prices) == 1
assert float(prices[0].original_price) == 680.0
assert float(prices[0].promo_price) == 480.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_competitor_prices_filter_by_competitor(
self,
db_session: AsyncSession,
sample_competitor_price: CompetitorPrice,
sample_project: Project,
sample_competitor: Competitor
):
"""测试按竞品机构筛选价格"""
service = MarketService(db_session)
# 指定竞品ID
prices = await service.get_competitor_prices_for_project(
sample_project.id,
competitor_ids=[sample_competitor.id]
)
assert len(prices) == 1
# 指定不存在的竞品ID
prices = await service.get_competitor_prices_for_project(
sample_project.id,
competitor_ids=[99999]
)
assert len(prices) == 0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_analyze_market(
self,
db_session: AsyncSession,
sample_project: Project,
sample_competitor_price: CompetitorPrice
):
"""测试市场分析"""
service = MarketService(db_session)
result = await service.analyze_market(
project_id=sample_project.id,
include_benchmark=False
)
assert result.project_id == sample_project.id
assert result.project_name == "光子嫩肤"
assert result.competitor_count == 1
assert result.price_statistics.min_price == 680.0
assert result.price_statistics.max_price == 680.0
assert result.price_statistics.avg_price == 680.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_analyze_market_not_found(
self,
db_session: AsyncSession
):
"""测试项目不存在时的错误处理"""
service = MarketService(db_session)
with pytest.raises(ValueError, match="项目不存在"):
await service.analyze_market(project_id=99999)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_latest_analysis(
self,
db_session: AsyncSession,
sample_project: Project,
sample_competitor_price: CompetitorPrice
):
"""测试获取最新分析结果"""
service = MarketService(db_session)
# 先执行分析
await service.analyze_market(
project_id=sample_project.id,
include_benchmark=False
)
# 获取最新结果
latest = await service.get_latest_analysis(sample_project.id)
assert latest is not None
assert latest.project_id == sample_project.id
assert float(latest.market_avg_price) == 680.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_benchmark_prices_empty(
self,
db_session: AsyncSession
):
"""测试没有标杆价格时的处理"""
service = MarketService(db_session)
benchmarks = await service.get_benchmark_prices_for_category(None)
assert benchmarks == []
benchmarks = await service.get_benchmark_prices_for_category(99999)
assert benchmarks == []
class TestMarketServiceEdgeCases:
"""市场分析服务边界情况测试"""
@pytest.mark.unit
def test_price_distribution_same_min_max(self):
"""测试最小最大价相同时的分布"""
service = MarketService(None)
distribution = service.calculate_price_distribution(
prices=[500.0, 500.0],
min_price=500.0,
max_price=500.0
)
# 应返回 N/A
assert distribution.low.range == "N/A"
@pytest.mark.unit
def test_statistics_with_outliers(self):
"""测试包含极端值的统计"""
service = MarketService(None)
# 包含一个极端高价
prices = [300.0, 400.0, 500.0, 600.0, 5000.0]
stats = service.calculate_price_statistics(prices)
assert stats.min_price == 300.0
assert stats.max_price == 5000.0
# 均值会被拉高
assert stats.avg_price == 1360.0 # (300+400+500+600+5000)/5
# 中位数不受极端值影响
assert stats.median_price == 500.0
# 标准差会很大
assert stats.std_deviation > 1000