Files
smart-project-pricing/后端服务/tests/conftest.py
2026-01-31 21:33:06 +08:00

350 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""pytest 测试配置
提供测试所需的 fixtures
- 测试数据库会话
- 测试客户端
- 测试数据工厂
遵循瑞小美系统技术栈标准
"""
import os
from decimal import Decimal
from datetime import datetime
from typing import AsyncGenerator
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.pool import StaticPool
# 设置测试环境变量
os.environ["APP_ENV"] = "test"
os.environ["DEBUG"] = "true"
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
from app.main import app
from app.database import Base, get_db
from app.models import (
Category, Material, Equipment, StaffLevel, FixedCost,
Project, ProjectCostItem, ProjectLaborCost, ProjectCostSummary,
Competitor, CompetitorPrice, BenchmarkPrice, MarketAnalysisResult,
PricingPlan, ProfitSimulation, SensitivityAnalysis
)
# 测试数据库引擎(使用 SQLite 内存数据库)
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
test_engine = create_async_engine(
TEST_DATABASE_URL,
echo=False,
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
test_session_maker = async_sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
@pytest_asyncio.fixture(scope="function")
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""获取测试数据库会话
每个测试函数使用独立的数据库会话,测试后自动回滚
"""
# 创建表
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with test_session_maker() as session:
try:
yield session
finally:
await session.rollback()
await session.close()
# 清理表
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture(scope="function")
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
"""获取测试客户端
使用测试数据库会话替换应用的数据库依赖
"""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
# ============ 测试数据工厂 ============
@pytest_asyncio.fixture
async def sample_category(db_session: AsyncSession) -> Category:
"""创建示例项目分类"""
category = Category(
category_name="光电类",
parent_id=None,
sort_order=1,
is_active=True,
)
db_session.add(category)
await db_session.commit()
await db_session.refresh(category)
return category
@pytest_asyncio.fixture
async def sample_material(db_session: AsyncSession) -> Material:
"""创建示例耗材"""
material = Material(
material_code="MAT001",
material_name="冷凝胶",
unit="ml",
unit_price=Decimal("2.00"),
supplier="供应商A",
material_type="consumable",
is_active=True,
)
db_session.add(material)
await db_session.commit()
await db_session.refresh(material)
return material
@pytest_asyncio.fixture
async def sample_equipment(db_session: AsyncSession) -> Equipment:
"""创建示例设备"""
equipment = Equipment(
equipment_code="EQP001",
equipment_name="光子仪",
original_value=Decimal("100000.00"),
residual_rate=Decimal("5.00"),
service_years=5,
estimated_uses=2000,
depreciation_per_use=Decimal("47.50"), # (100000 - 5000) / 2000
purchase_date=datetime(2025, 1, 1).date(),
is_active=True,
)
db_session.add(equipment)
await db_session.commit()
await db_session.refresh(equipment)
return equipment
@pytest_asyncio.fixture
async def sample_staff_level(db_session: AsyncSession) -> StaffLevel:
"""创建示例人员级别"""
staff_level = StaffLevel(
level_code="L2",
level_name="中级美容师",
hourly_rate=Decimal("50.00"),
is_active=True,
)
db_session.add(staff_level)
await db_session.commit()
await db_session.refresh(staff_level)
return staff_level
@pytest_asyncio.fixture
async def sample_fixed_cost(db_session: AsyncSession) -> FixedCost:
"""创建示例固定成本"""
fixed_cost = FixedCost(
cost_name="房租",
cost_type="rent",
monthly_amount=Decimal("30000.00"),
year_month=datetime.now().strftime("%Y-%m"),
allocation_method="count",
is_active=True,
)
db_session.add(fixed_cost)
await db_session.commit()
await db_session.refresh(fixed_cost)
return fixed_cost
@pytest_asyncio.fixture
async def sample_project(
db_session: AsyncSession,
sample_category: Category
) -> Project:
"""创建示例服务项目"""
project = Project(
project_code="PRJ001",
project_name="光子嫩肤",
category_id=sample_category.id,
description="IPL光子嫩肤项目",
duration_minutes=60,
is_active=True,
)
db_session.add(project)
await db_session.commit()
await db_session.refresh(project)
return project
@pytest_asyncio.fixture
async def sample_project_with_costs(
db_session: AsyncSession,
sample_project: Project,
sample_material: Material,
sample_equipment: Equipment,
sample_staff_level: StaffLevel,
sample_fixed_cost: FixedCost,
) -> Project:
"""创建带成本明细的示例项目"""
# 添加耗材成本
cost_item = ProjectCostItem(
project_id=sample_project.id,
item_type="material",
item_id=sample_material.id,
quantity=Decimal("20"),
unit_cost=Decimal("2.00"),
total_cost=Decimal("40.00"),
)
db_session.add(cost_item)
# 添加设备折旧成本
equip_cost = ProjectCostItem(
project_id=sample_project.id,
item_type="equipment",
item_id=sample_equipment.id,
quantity=Decimal("1"),
unit_cost=Decimal("47.50"),
total_cost=Decimal("47.50"),
)
db_session.add(equip_cost)
# 添加人工成本
labor_cost = ProjectLaborCost(
project_id=sample_project.id,
staff_level_id=sample_staff_level.id,
duration_minutes=60,
hourly_rate=Decimal("50.00"),
labor_cost=Decimal("50.00"), # 60分钟 / 60 * 50
)
db_session.add(labor_cost)
await db_session.commit()
await db_session.refresh(sample_project)
return sample_project
@pytest_asyncio.fixture
async def sample_competitor(db_session: AsyncSession) -> Competitor:
"""创建示例竞品机构"""
competitor = Competitor(
competitor_name="美丽人生医美",
address="XX市XX路100号",
distance_km=Decimal("2.5"),
positioning="medium",
contact="13800138000",
is_key_competitor=True,
is_active=True,
)
db_session.add(competitor)
await db_session.commit()
await db_session.refresh(competitor)
return competitor
@pytest_asyncio.fixture
async def sample_competitor_price(
db_session: AsyncSession,
sample_competitor: Competitor,
sample_project: Project,
) -> CompetitorPrice:
"""创建示例竞品价格"""
price = CompetitorPrice(
competitor_id=sample_competitor.id,
project_id=sample_project.id,
project_name="光子嫩肤",
original_price=Decimal("680.00"),
promo_price=Decimal("480.00"),
member_price=Decimal("580.00"),
price_source="meituan",
collected_at=datetime.now().date(),
)
db_session.add(price)
await db_session.commit()
await db_session.refresh(price)
return price
@pytest_asyncio.fixture
async def sample_pricing_plan(
db_session: AsyncSession,
sample_project: Project,
) -> PricingPlan:
"""创建示例定价方案"""
plan = PricingPlan(
project_id=sample_project.id,
plan_name="2026年Q1定价",
strategy_type="profit",
base_cost=Decimal("280.50"),
target_margin=Decimal("50.00"),
suggested_price=Decimal("561.00"),
final_price=Decimal("580.00"),
is_active=True,
)
db_session.add(plan)
await db_session.commit()
await db_session.refresh(plan)
return plan
@pytest_asyncio.fixture
async def sample_cost_summary(
db_session: AsyncSession,
sample_project: Project,
) -> ProjectCostSummary:
"""创建示例成本汇总"""
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
material_cost=Decimal("100.00"),
equipment_cost=Decimal("50.00"),
labor_cost=Decimal("50.00"),
fixed_cost_allocation=Decimal("30.00"),
total_cost=Decimal("230.00"),
calculated_at=datetime.now(), # 确保设置计算时间
)
db_session.add(cost_summary)
await db_session.commit()
await db_session.refresh(cost_summary)
return cost_summary
# ============ 辅助函数 ============
def assert_response_success(response, expected_code=0):
"""断言响应成功"""
assert response.status_code == 200
data = response.json()
assert data["code"] == expected_code
assert "data" in data
return data["data"]
def assert_response_error(response, expected_code):
"""断言响应错误"""
data = response.json()
assert data["code"] == expected_code
return data