306 lines
9.7 KiB
Python
306 lines
9.7 KiB
Python
"""市场分析服务单元测试
|
||
|
||
测试 MarketService 的核心业务逻辑
|
||
"""
|
||
|
||
import pytest
|
||
from decimal import Decimal
|
||
from datetime import date
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.services.market_service import MarketService
|
||
from app.models import (
|
||
Project, Competitor, CompetitorPrice, BenchmarkPrice, Category
|
||
)
|
||
|
||
|
||
class TestMarketService:
|
||
"""市场分析服务测试类"""
|
||
|
||
@pytest.mark.unit
|
||
def test_calculate_price_statistics_empty(self):
|
||
"""测试空价格列表的统计"""
|
||
service = MarketService(None) # 不需要 db
|
||
|
||
stats = service.calculate_price_statistics([])
|
||
|
||
assert stats.min_price == 0
|
||
assert stats.max_price == 0
|
||
assert stats.avg_price == 0
|
||
assert stats.median_price == 0
|
||
assert stats.std_deviation is None or stats.std_deviation == 0
|
||
|
||
@pytest.mark.unit
|
||
def test_calculate_price_statistics_single(self):
|
||
"""测试单个价格的统计"""
|
||
service = MarketService(None)
|
||
|
||
stats = service.calculate_price_statistics([500.0])
|
||
|
||
assert stats.min_price == 500.0
|
||
assert stats.max_price == 500.0
|
||
assert stats.avg_price == 500.0
|
||
assert stats.median_price == 500.0
|
||
|
||
@pytest.mark.unit
|
||
def test_calculate_price_statistics_multiple(self):
|
||
"""测试多个价格的统计"""
|
||
service = MarketService(None)
|
||
|
||
prices = [300.0, 400.0, 500.0, 600.0, 700.0]
|
||
stats = service.calculate_price_statistics(prices)
|
||
|
||
assert stats.min_price == 300.0
|
||
assert stats.max_price == 700.0
|
||
assert stats.avg_price == 500.0 # (300+400+500+600+700)/5
|
||
assert stats.median_price == 500.0 # 中位数
|
||
assert stats.std_deviation is not None
|
||
assert stats.std_deviation > 0
|
||
|
||
@pytest.mark.unit
|
||
def test_calculate_price_distribution(self):
|
||
"""测试价格分布计算"""
|
||
service = MarketService(None)
|
||
|
||
# 价格范围 300-900,分为三个区间
|
||
# 低: 300-500, 中: 500-700, 高: 700-900
|
||
prices = [350.0, 450.0, 550.0, 650.0, 750.0, 850.0]
|
||
|
||
distribution = service.calculate_price_distribution(
|
||
prices=prices,
|
||
min_price=300.0,
|
||
max_price=900.0
|
||
)
|
||
|
||
# 验证分布
|
||
assert distribution.low.count == 2 # 350, 450
|
||
assert distribution.medium.count == 2 # 550, 650
|
||
assert distribution.high.count == 2 # 750, 850
|
||
|
||
# 验证百分比
|
||
assert distribution.low.percentage == pytest.approx(33.3, rel=0.1)
|
||
assert distribution.medium.percentage == pytest.approx(33.3, rel=0.1)
|
||
assert distribution.high.percentage == pytest.approx(33.3, rel=0.1)
|
||
|
||
@pytest.mark.unit
|
||
def test_calculate_price_distribution_empty(self):
|
||
"""测试空价格列表的分布"""
|
||
service = MarketService(None)
|
||
|
||
distribution = service.calculate_price_distribution(
|
||
prices=[],
|
||
min_price=0,
|
||
max_price=0
|
||
)
|
||
|
||
assert distribution.low.count == 0
|
||
assert distribution.medium.count == 0
|
||
assert distribution.high.count == 0
|
||
|
||
@pytest.mark.unit
|
||
def test_calculate_suggested_range(self):
|
||
"""测试建议定价区间计算"""
|
||
service = MarketService(None)
|
||
|
||
suggested = service.calculate_suggested_range(
|
||
avg_price=500.0,
|
||
min_price=300.0,
|
||
max_price=700.0,
|
||
benchmark_avg=None
|
||
)
|
||
|
||
# 以均价为中心 ±20%
|
||
assert suggested.min == pytest.approx(400.0, rel=0.01) # 500 * 0.8
|
||
assert suggested.max == pytest.approx(600.0, rel=0.01) # 500 * 1.2
|
||
assert suggested.recommended == 500.0
|
||
|
||
@pytest.mark.unit
|
||
def test_calculate_suggested_range_with_benchmark(self):
|
||
"""测试带标杆参考的建议定价区间"""
|
||
service = MarketService(None)
|
||
|
||
suggested = service.calculate_suggested_range(
|
||
avg_price=500.0,
|
||
min_price=300.0,
|
||
max_price=700.0,
|
||
benchmark_avg=600.0
|
||
)
|
||
|
||
# 推荐价格 = 市场均价 * 0.6 + 标杆均价 * 0.4
|
||
expected_recommended = 500 * 0.6 + 600 * 0.4 # 540
|
||
assert suggested.recommended == pytest.approx(expected_recommended, rel=0.01)
|
||
|
||
@pytest.mark.unit
|
||
def test_calculate_suggested_range_zero_avg(self):
|
||
"""测试均价为0时的处理"""
|
||
service = MarketService(None)
|
||
|
||
suggested = service.calculate_suggested_range(
|
||
avg_price=0,
|
||
min_price=0,
|
||
max_price=0,
|
||
benchmark_avg=None
|
||
)
|
||
|
||
assert suggested.min == 0
|
||
assert suggested.max == 0
|
||
assert suggested.recommended == 0
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_get_competitor_prices_for_project(
|
||
self,
|
||
db_session: AsyncSession,
|
||
sample_competitor_price: CompetitorPrice,
|
||
sample_project: Project
|
||
):
|
||
"""测试获取项目的竞品价格"""
|
||
service = MarketService(db_session)
|
||
|
||
prices = await service.get_competitor_prices_for_project(
|
||
sample_project.id
|
||
)
|
||
|
||
assert len(prices) == 1
|
||
assert float(prices[0].original_price) == 680.0
|
||
assert float(prices[0].promo_price) == 480.0
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_get_competitor_prices_filter_by_competitor(
|
||
self,
|
||
db_session: AsyncSession,
|
||
sample_competitor_price: CompetitorPrice,
|
||
sample_project: Project,
|
||
sample_competitor: Competitor
|
||
):
|
||
"""测试按竞品机构筛选价格"""
|
||
service = MarketService(db_session)
|
||
|
||
# 指定竞品ID
|
||
prices = await service.get_competitor_prices_for_project(
|
||
sample_project.id,
|
||
competitor_ids=[sample_competitor.id]
|
||
)
|
||
|
||
assert len(prices) == 1
|
||
|
||
# 指定不存在的竞品ID
|
||
prices = await service.get_competitor_prices_for_project(
|
||
sample_project.id,
|
||
competitor_ids=[99999]
|
||
)
|
||
|
||
assert len(prices) == 0
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_analyze_market(
|
||
self,
|
||
db_session: AsyncSession,
|
||
sample_project: Project,
|
||
sample_competitor_price: CompetitorPrice
|
||
):
|
||
"""测试市场分析"""
|
||
service = MarketService(db_session)
|
||
|
||
result = await service.analyze_market(
|
||
project_id=sample_project.id,
|
||
include_benchmark=False
|
||
)
|
||
|
||
assert result.project_id == sample_project.id
|
||
assert result.project_name == "光子嫩肤"
|
||
assert result.competitor_count == 1
|
||
assert result.price_statistics.min_price == 680.0
|
||
assert result.price_statistics.max_price == 680.0
|
||
assert result.price_statistics.avg_price == 680.0
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_analyze_market_not_found(
|
||
self,
|
||
db_session: AsyncSession
|
||
):
|
||
"""测试项目不存在时的错误处理"""
|
||
service = MarketService(db_session)
|
||
|
||
with pytest.raises(ValueError, match="项目不存在"):
|
||
await service.analyze_market(project_id=99999)
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_get_latest_analysis(
|
||
self,
|
||
db_session: AsyncSession,
|
||
sample_project: Project,
|
||
sample_competitor_price: CompetitorPrice
|
||
):
|
||
"""测试获取最新分析结果"""
|
||
service = MarketService(db_session)
|
||
|
||
# 先执行分析
|
||
await service.analyze_market(
|
||
project_id=sample_project.id,
|
||
include_benchmark=False
|
||
)
|
||
|
||
# 获取最新结果
|
||
latest = await service.get_latest_analysis(sample_project.id)
|
||
|
||
assert latest is not None
|
||
assert latest.project_id == sample_project.id
|
||
assert float(latest.market_avg_price) == 680.0
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.asyncio
|
||
async def test_get_benchmark_prices_empty(
|
||
self,
|
||
db_session: AsyncSession
|
||
):
|
||
"""测试没有标杆价格时的处理"""
|
||
service = MarketService(db_session)
|
||
|
||
benchmarks = await service.get_benchmark_prices_for_category(None)
|
||
assert benchmarks == []
|
||
|
||
benchmarks = await service.get_benchmark_prices_for_category(99999)
|
||
assert benchmarks == []
|
||
|
||
|
||
class TestMarketServiceEdgeCases:
|
||
"""市场分析服务边界情况测试"""
|
||
|
||
@pytest.mark.unit
|
||
def test_price_distribution_same_min_max(self):
|
||
"""测试最小最大价相同时的分布"""
|
||
service = MarketService(None)
|
||
|
||
distribution = service.calculate_price_distribution(
|
||
prices=[500.0, 500.0],
|
||
min_price=500.0,
|
||
max_price=500.0
|
||
)
|
||
|
||
# 应返回 N/A
|
||
assert distribution.low.range == "N/A"
|
||
|
||
@pytest.mark.unit
|
||
def test_statistics_with_outliers(self):
|
||
"""测试包含极端值的统计"""
|
||
service = MarketService(None)
|
||
|
||
# 包含一个极端高价
|
||
prices = [300.0, 400.0, 500.0, 600.0, 5000.0]
|
||
stats = service.calculate_price_statistics(prices)
|
||
|
||
assert stats.min_price == 300.0
|
||
assert stats.max_price == 5000.0
|
||
# 均值会被拉高
|
||
assert stats.avg_price == 1360.0 # (300+400+500+600+5000)/5
|
||
# 中位数不受极端值影响
|
||
assert stats.median_price == 500.0
|
||
# 标准差会很大
|
||
assert stats.std_deviation > 1000
|