350 lines
9.6 KiB
Python
350 lines
9.6 KiB
Python
"""pytest 测试配置
|
||
|
||
提供测试所需的 fixtures:
|
||
- 测试数据库会话
|
||
- 测试客户端
|
||
- 测试数据工厂
|
||
|
||
遵循瑞小美系统技术栈标准
|
||
"""
|
||
|
||
import os
|
||
from decimal import Decimal
|
||
from datetime import datetime
|
||
from typing import AsyncGenerator
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from httpx import AsyncClient, ASGITransport
|
||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
# 设置测试环境变量
|
||
os.environ["APP_ENV"] = "test"
|
||
os.environ["DEBUG"] = "true"
|
||
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
|
||
|
||
from app.main import app
|
||
from app.database import Base, get_db
|
||
from app.models import (
|
||
Category, Material, Equipment, StaffLevel, FixedCost,
|
||
Project, ProjectCostItem, ProjectLaborCost, ProjectCostSummary,
|
||
Competitor, CompetitorPrice, BenchmarkPrice, MarketAnalysisResult,
|
||
PricingPlan, ProfitSimulation, SensitivityAnalysis
|
||
)
|
||
|
||
|
||
# 测试数据库引擎(使用 SQLite 内存数据库)
|
||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||
|
||
test_engine = create_async_engine(
|
||
TEST_DATABASE_URL,
|
||
echo=False,
|
||
poolclass=StaticPool,
|
||
connect_args={"check_same_thread": False},
|
||
)
|
||
|
||
test_session_maker = async_sessionmaker(
|
||
test_engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autocommit=False,
|
||
autoflush=False,
|
||
)
|
||
|
||
|
||
@pytest_asyncio.fixture(scope="function")
|
||
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||
"""获取测试数据库会话
|
||
|
||
每个测试函数使用独立的数据库会话,测试后自动回滚
|
||
"""
|
||
# 创建表
|
||
async with test_engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
|
||
async with test_session_maker() as session:
|
||
try:
|
||
yield session
|
||
finally:
|
||
await session.rollback()
|
||
await session.close()
|
||
|
||
# 清理表
|
||
async with test_engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.drop_all)
|
||
|
||
|
||
@pytest_asyncio.fixture(scope="function")
|
||
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||
"""获取测试客户端
|
||
|
||
使用测试数据库会话替换应用的数据库依赖
|
||
"""
|
||
async def override_get_db():
|
||
yield db_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||
yield ac
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
|
||
# ============ 测试数据工厂 ============
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_category(db_session: AsyncSession) -> Category:
|
||
"""创建示例项目分类"""
|
||
category = Category(
|
||
category_name="光电类",
|
||
parent_id=None,
|
||
sort_order=1,
|
||
is_active=True,
|
||
)
|
||
db_session.add(category)
|
||
await db_session.commit()
|
||
await db_session.refresh(category)
|
||
return category
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_material(db_session: AsyncSession) -> Material:
|
||
"""创建示例耗材"""
|
||
material = Material(
|
||
material_code="MAT001",
|
||
material_name="冷凝胶",
|
||
unit="ml",
|
||
unit_price=Decimal("2.00"),
|
||
supplier="供应商A",
|
||
material_type="consumable",
|
||
is_active=True,
|
||
)
|
||
db_session.add(material)
|
||
await db_session.commit()
|
||
await db_session.refresh(material)
|
||
return material
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_equipment(db_session: AsyncSession) -> Equipment:
|
||
"""创建示例设备"""
|
||
equipment = Equipment(
|
||
equipment_code="EQP001",
|
||
equipment_name="光子仪",
|
||
original_value=Decimal("100000.00"),
|
||
residual_rate=Decimal("5.00"),
|
||
service_years=5,
|
||
estimated_uses=2000,
|
||
depreciation_per_use=Decimal("47.50"), # (100000 - 5000) / 2000
|
||
purchase_date=datetime(2025, 1, 1).date(),
|
||
is_active=True,
|
||
)
|
||
db_session.add(equipment)
|
||
await db_session.commit()
|
||
await db_session.refresh(equipment)
|
||
return equipment
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_staff_level(db_session: AsyncSession) -> StaffLevel:
|
||
"""创建示例人员级别"""
|
||
staff_level = StaffLevel(
|
||
level_code="L2",
|
||
level_name="中级美容师",
|
||
hourly_rate=Decimal("50.00"),
|
||
is_active=True,
|
||
)
|
||
db_session.add(staff_level)
|
||
await db_session.commit()
|
||
await db_session.refresh(staff_level)
|
||
return staff_level
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_fixed_cost(db_session: AsyncSession) -> FixedCost:
|
||
"""创建示例固定成本"""
|
||
fixed_cost = FixedCost(
|
||
cost_name="房租",
|
||
cost_type="rent",
|
||
monthly_amount=Decimal("30000.00"),
|
||
year_month=datetime.now().strftime("%Y-%m"),
|
||
allocation_method="count",
|
||
is_active=True,
|
||
)
|
||
db_session.add(fixed_cost)
|
||
await db_session.commit()
|
||
await db_session.refresh(fixed_cost)
|
||
return fixed_cost
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_project(
|
||
db_session: AsyncSession,
|
||
sample_category: Category
|
||
) -> Project:
|
||
"""创建示例服务项目"""
|
||
project = Project(
|
||
project_code="PRJ001",
|
||
project_name="光子嫩肤",
|
||
category_id=sample_category.id,
|
||
description="IPL光子嫩肤项目",
|
||
duration_minutes=60,
|
||
is_active=True,
|
||
)
|
||
db_session.add(project)
|
||
await db_session.commit()
|
||
await db_session.refresh(project)
|
||
return project
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_project_with_costs(
|
||
db_session: AsyncSession,
|
||
sample_project: Project,
|
||
sample_material: Material,
|
||
sample_equipment: Equipment,
|
||
sample_staff_level: StaffLevel,
|
||
sample_fixed_cost: FixedCost,
|
||
) -> Project:
|
||
"""创建带成本明细的示例项目"""
|
||
# 添加耗材成本
|
||
cost_item = ProjectCostItem(
|
||
project_id=sample_project.id,
|
||
item_type="material",
|
||
item_id=sample_material.id,
|
||
quantity=Decimal("20"),
|
||
unit_cost=Decimal("2.00"),
|
||
total_cost=Decimal("40.00"),
|
||
)
|
||
db_session.add(cost_item)
|
||
|
||
# 添加设备折旧成本
|
||
equip_cost = ProjectCostItem(
|
||
project_id=sample_project.id,
|
||
item_type="equipment",
|
||
item_id=sample_equipment.id,
|
||
quantity=Decimal("1"),
|
||
unit_cost=Decimal("47.50"),
|
||
total_cost=Decimal("47.50"),
|
||
)
|
||
db_session.add(equip_cost)
|
||
|
||
# 添加人工成本
|
||
labor_cost = ProjectLaborCost(
|
||
project_id=sample_project.id,
|
||
staff_level_id=sample_staff_level.id,
|
||
duration_minutes=60,
|
||
hourly_rate=Decimal("50.00"),
|
||
labor_cost=Decimal("50.00"), # 60分钟 / 60 * 50
|
||
)
|
||
db_session.add(labor_cost)
|
||
|
||
await db_session.commit()
|
||
await db_session.refresh(sample_project)
|
||
return sample_project
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_competitor(db_session: AsyncSession) -> Competitor:
|
||
"""创建示例竞品机构"""
|
||
competitor = Competitor(
|
||
competitor_name="美丽人生医美",
|
||
address="XX市XX路100号",
|
||
distance_km=Decimal("2.5"),
|
||
positioning="medium",
|
||
contact="13800138000",
|
||
is_key_competitor=True,
|
||
is_active=True,
|
||
)
|
||
db_session.add(competitor)
|
||
await db_session.commit()
|
||
await db_session.refresh(competitor)
|
||
return competitor
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_competitor_price(
|
||
db_session: AsyncSession,
|
||
sample_competitor: Competitor,
|
||
sample_project: Project,
|
||
) -> CompetitorPrice:
|
||
"""创建示例竞品价格"""
|
||
price = CompetitorPrice(
|
||
competitor_id=sample_competitor.id,
|
||
project_id=sample_project.id,
|
||
project_name="光子嫩肤",
|
||
original_price=Decimal("680.00"),
|
||
promo_price=Decimal("480.00"),
|
||
member_price=Decimal("580.00"),
|
||
price_source="meituan",
|
||
collected_at=datetime.now().date(),
|
||
)
|
||
db_session.add(price)
|
||
await db_session.commit()
|
||
await db_session.refresh(price)
|
||
return price
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_pricing_plan(
|
||
db_session: AsyncSession,
|
||
sample_project: Project,
|
||
) -> PricingPlan:
|
||
"""创建示例定价方案"""
|
||
plan = PricingPlan(
|
||
project_id=sample_project.id,
|
||
plan_name="2026年Q1定价",
|
||
strategy_type="profit",
|
||
base_cost=Decimal("280.50"),
|
||
target_margin=Decimal("50.00"),
|
||
suggested_price=Decimal("561.00"),
|
||
final_price=Decimal("580.00"),
|
||
is_active=True,
|
||
)
|
||
db_session.add(plan)
|
||
await db_session.commit()
|
||
await db_session.refresh(plan)
|
||
return plan
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def sample_cost_summary(
|
||
db_session: AsyncSession,
|
||
sample_project: Project,
|
||
) -> ProjectCostSummary:
|
||
"""创建示例成本汇总"""
|
||
cost_summary = ProjectCostSummary(
|
||
project_id=sample_project.id,
|
||
material_cost=Decimal("100.00"),
|
||
equipment_cost=Decimal("50.00"),
|
||
labor_cost=Decimal("50.00"),
|
||
fixed_cost_allocation=Decimal("30.00"),
|
||
total_cost=Decimal("230.00"),
|
||
calculated_at=datetime.now(), # 确保设置计算时间
|
||
)
|
||
db_session.add(cost_summary)
|
||
await db_session.commit()
|
||
await db_session.refresh(cost_summary)
|
||
return cost_summary
|
||
|
||
|
||
# ============ 辅助函数 ============
|
||
|
||
def assert_response_success(response, expected_code=0):
|
||
"""断言响应成功"""
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["code"] == expected_code
|
||
assert "data" in data
|
||
return data["data"]
|
||
|
||
|
||
def assert_response_error(response, expected_code):
|
||
"""断言响应错误"""
|
||
data = response.json()
|
||
assert data["code"] == expected_code
|
||
return data
|