"""pytest 测试配置 提供测试所需的 fixtures: - 测试数据库会话 - 测试客户端 - 测试数据工厂 遵循瑞小美系统技术栈标准 """ import os from decimal import Decimal from datetime import datetime from typing import AsyncGenerator import pytest import pytest_asyncio from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.pool import StaticPool # 设置测试环境变量 os.environ["APP_ENV"] = "test" os.environ["DEBUG"] = "true" os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:" from app.main import app from app.database import Base, get_db from app.models import ( Category, Material, Equipment, StaffLevel, FixedCost, Project, ProjectCostItem, ProjectLaborCost, ProjectCostSummary, Competitor, CompetitorPrice, BenchmarkPrice, MarketAnalysisResult, PricingPlan, ProfitSimulation, SensitivityAnalysis ) # 测试数据库引擎(使用 SQLite 内存数据库) TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" test_engine = create_async_engine( TEST_DATABASE_URL, echo=False, poolclass=StaticPool, connect_args={"check_same_thread": False}, ) test_session_maker = async_sessionmaker( test_engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) @pytest_asyncio.fixture(scope="function") async def db_session() -> AsyncGenerator[AsyncSession, None]: """获取测试数据库会话 每个测试函数使用独立的数据库会话,测试后自动回滚 """ # 创建表 async with test_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async with test_session_maker() as session: try: yield session finally: await session.rollback() await session.close() # 清理表 async with test_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest_asyncio.fixture(scope="function") async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]: """获取测试客户端 使用测试数据库会话替换应用的数据库依赖 """ async def override_get_db(): yield db_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() # ============ 测试数据工厂 ============ @pytest_asyncio.fixture async def sample_category(db_session: AsyncSession) -> Category: """创建示例项目分类""" category = Category( category_name="光电类", parent_id=None, sort_order=1, is_active=True, ) db_session.add(category) await db_session.commit() await db_session.refresh(category) return category @pytest_asyncio.fixture async def sample_material(db_session: AsyncSession) -> Material: """创建示例耗材""" material = Material( material_code="MAT001", material_name="冷凝胶", unit="ml", unit_price=Decimal("2.00"), supplier="供应商A", material_type="consumable", is_active=True, ) db_session.add(material) await db_session.commit() await db_session.refresh(material) return material @pytest_asyncio.fixture async def sample_equipment(db_session: AsyncSession) -> Equipment: """创建示例设备""" equipment = Equipment( equipment_code="EQP001", equipment_name="光子仪", original_value=Decimal("100000.00"), residual_rate=Decimal("5.00"), service_years=5, estimated_uses=2000, depreciation_per_use=Decimal("47.50"), # (100000 - 5000) / 2000 purchase_date=datetime(2025, 1, 1).date(), is_active=True, ) db_session.add(equipment) await db_session.commit() await db_session.refresh(equipment) return equipment @pytest_asyncio.fixture async def sample_staff_level(db_session: AsyncSession) -> StaffLevel: """创建示例人员级别""" staff_level = StaffLevel( level_code="L2", level_name="中级美容师", hourly_rate=Decimal("50.00"), is_active=True, ) db_session.add(staff_level) await db_session.commit() await db_session.refresh(staff_level) return staff_level @pytest_asyncio.fixture async def sample_fixed_cost(db_session: AsyncSession) -> FixedCost: """创建示例固定成本""" fixed_cost = FixedCost( cost_name="房租", cost_type="rent", monthly_amount=Decimal("30000.00"), year_month=datetime.now().strftime("%Y-%m"), allocation_method="count", is_active=True, ) db_session.add(fixed_cost) await db_session.commit() await db_session.refresh(fixed_cost) return fixed_cost @pytest_asyncio.fixture async def sample_project( db_session: AsyncSession, sample_category: Category ) -> Project: """创建示例服务项目""" project = Project( project_code="PRJ001", project_name="光子嫩肤", category_id=sample_category.id, description="IPL光子嫩肤项目", duration_minutes=60, is_active=True, ) db_session.add(project) await db_session.commit() await db_session.refresh(project) return project @pytest_asyncio.fixture async def sample_project_with_costs( db_session: AsyncSession, sample_project: Project, sample_material: Material, sample_equipment: Equipment, sample_staff_level: StaffLevel, sample_fixed_cost: FixedCost, ) -> Project: """创建带成本明细的示例项目""" # 添加耗材成本 cost_item = ProjectCostItem( project_id=sample_project.id, item_type="material", item_id=sample_material.id, quantity=Decimal("20"), unit_cost=Decimal("2.00"), total_cost=Decimal("40.00"), ) db_session.add(cost_item) # 添加设备折旧成本 equip_cost = ProjectCostItem( project_id=sample_project.id, item_type="equipment", item_id=sample_equipment.id, quantity=Decimal("1"), unit_cost=Decimal("47.50"), total_cost=Decimal("47.50"), ) db_session.add(equip_cost) # 添加人工成本 labor_cost = ProjectLaborCost( project_id=sample_project.id, staff_level_id=sample_staff_level.id, duration_minutes=60, hourly_rate=Decimal("50.00"), labor_cost=Decimal("50.00"), # 60分钟 / 60 * 50 ) db_session.add(labor_cost) await db_session.commit() await db_session.refresh(sample_project) return sample_project @pytest_asyncio.fixture async def sample_competitor(db_session: AsyncSession) -> Competitor: """创建示例竞品机构""" competitor = Competitor( competitor_name="美丽人生医美", address="XX市XX路100号", distance_km=Decimal("2.5"), positioning="medium", contact="13800138000", is_key_competitor=True, is_active=True, ) db_session.add(competitor) await db_session.commit() await db_session.refresh(competitor) return competitor @pytest_asyncio.fixture async def sample_competitor_price( db_session: AsyncSession, sample_competitor: Competitor, sample_project: Project, ) -> CompetitorPrice: """创建示例竞品价格""" price = CompetitorPrice( competitor_id=sample_competitor.id, project_id=sample_project.id, project_name="光子嫩肤", original_price=Decimal("680.00"), promo_price=Decimal("480.00"), member_price=Decimal("580.00"), price_source="meituan", collected_at=datetime.now().date(), ) db_session.add(price) await db_session.commit() await db_session.refresh(price) return price @pytest_asyncio.fixture async def sample_pricing_plan( db_session: AsyncSession, sample_project: Project, ) -> PricingPlan: """创建示例定价方案""" plan = PricingPlan( project_id=sample_project.id, plan_name="2026年Q1定价", strategy_type="profit", base_cost=Decimal("280.50"), target_margin=Decimal("50.00"), suggested_price=Decimal("561.00"), final_price=Decimal("580.00"), is_active=True, ) db_session.add(plan) await db_session.commit() await db_session.refresh(plan) return plan @pytest_asyncio.fixture async def sample_cost_summary( db_session: AsyncSession, sample_project: Project, ) -> ProjectCostSummary: """创建示例成本汇总""" cost_summary = ProjectCostSummary( project_id=sample_project.id, material_cost=Decimal("100.00"), equipment_cost=Decimal("50.00"), labor_cost=Decimal("50.00"), fixed_cost_allocation=Decimal("30.00"), total_cost=Decimal("230.00"), calculated_at=datetime.now(), # 确保设置计算时间 ) db_session.add(cost_summary) await db_session.commit() await db_session.refresh(cost_summary) return cost_summary # ============ 辅助函数 ============ def assert_response_success(response, expected_code=0): """断言响应成功""" assert response.status_code == 200 data = response.json() assert data["code"] == expected_code assert "data" in data return data["data"] def assert_response_error(response, expected_code): """断言响应错误""" data = response.json() assert data["code"] == expected_code return data