"""智能定价路由 智能定价建议相关的 API 接口 """ from typing import Optional, List from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.database import get_db from app.models import PricingPlan, Project from app.schemas.common import ResponseModel, PaginatedData, ErrorCode from app.schemas.pricing import ( StrategyType, PricingPlanCreate, PricingPlanUpdate, PricingPlanResponse, PricingPlanListResponse, PricingPlanQuery, GeneratePricingRequest, GeneratePricingResponse, SimulateStrategyRequest, SimulateStrategyResponse, ) from app.services.pricing_service import PricingService router = APIRouter() # 定价方案 CRUD # 定价方案允许的排序字段白名单 PRICING_PLAN_SORT_FIELDS = {"created_at", "updated_at", "plan_name", "base_cost", "target_margin", "suggested_price"} @router.get("/pricing-plans", response_model=ResponseModel[PaginatedData[PricingPlanListResponse]]) async def list_pricing_plans( page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), project_id: Optional[int] = None, strategy_type: Optional[StrategyType] = None, is_active: Optional[bool] = None, sort_by: str = "created_at", sort_order: str = "desc", db: AsyncSession = Depends(get_db), ): """获取定价方案列表""" query = select(PricingPlan).options( selectinload(PricingPlan.project), selectinload(PricingPlan.creator), ) if project_id: query = query.where(PricingPlan.project_id == project_id) if strategy_type: query = query.where(PricingPlan.strategy_type == strategy_type.value) if is_active is not None: query = query.where(PricingPlan.is_active == is_active) # 计算总数 count_query = select(func.count()).select_from(query.subquery()) total_result = await db.execute(count_query) total = total_result.scalar() or 0 # 排序 - 使用白名单验证防止注入 if sort_by not in PRICING_PLAN_SORT_FIELDS: sort_by = "created_at" sort_column = getattr(PricingPlan, sort_by, PricingPlan.created_at) if sort_order == "desc": query = query.order_by(sort_column.desc()) else: query = query.order_by(sort_column.asc()) # 分页 query = query.offset((page - 1) * page_size).limit(page_size) result = await db.execute(query) plans = result.scalars().all() items = [] for plan in plans: items.append(PricingPlanListResponse( id=plan.id, project_id=plan.project_id, project_name=plan.project.project_name if plan.project else None, plan_name=plan.plan_name, strategy_type=plan.strategy_type, base_cost=float(plan.base_cost), target_margin=float(plan.target_margin), suggested_price=float(plan.suggested_price), final_price=float(plan.final_price) if plan.final_price else None, is_active=plan.is_active, created_at=plan.created_at, created_by_name=plan.creator.username if plan.creator else None, )) return ResponseModel(data=PaginatedData( items=items, total=total, page=page, page_size=page_size, total_pages=(total + page_size - 1) // page_size, )) @router.post("/pricing-plans", response_model=ResponseModel[PricingPlanResponse]) async def create_pricing_plan( data: PricingPlanCreate, db: AsyncSession = Depends(get_db), ): """创建定价方案""" service = PricingService(db) try: plan = await service.create_pricing_plan( project_id=data.project_id, plan_name=data.plan_name, strategy_type=data.strategy_type, target_margin=data.target_margin, ) await db.commit() # 重新加载关系 await db.refresh(plan, ["project", "creator"]) return ResponseModel( message="创建成功", data=PricingPlanResponse( id=plan.id, project_id=plan.project_id, project_name=plan.project.project_name if plan.project else None, plan_name=plan.plan_name, strategy_type=plan.strategy_type, base_cost=float(plan.base_cost), target_margin=float(plan.target_margin), suggested_price=float(plan.suggested_price), final_price=float(plan.final_price) if plan.final_price else None, ai_advice=plan.ai_advice, is_active=plan.is_active, created_at=plan.created_at, updated_at=plan.updated_at, ) ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @router.get("/pricing-plans/{plan_id}", response_model=ResponseModel[PricingPlanResponse]) async def get_pricing_plan( plan_id: int, db: AsyncSession = Depends(get_db), ): """获取定价方案详情""" result = await db.execute( select(PricingPlan).options( selectinload(PricingPlan.project), selectinload(PricingPlan.creator), ).where(PricingPlan.id == plan_id) ) plan = result.scalar_one_or_none() if not plan: raise HTTPException( status_code=404, detail={"code": ErrorCode.NOT_FOUND, "message": "定价方案不存在"} ) return ResponseModel(data=PricingPlanResponse( id=plan.id, project_id=plan.project_id, project_name=plan.project.project_name if plan.project else None, plan_name=plan.plan_name, strategy_type=plan.strategy_type, base_cost=float(plan.base_cost), target_margin=float(plan.target_margin), suggested_price=float(plan.suggested_price), final_price=float(plan.final_price) if plan.final_price else None, ai_advice=plan.ai_advice, is_active=plan.is_active, created_at=plan.created_at, updated_at=plan.updated_at, created_by_name=plan.creator.username if plan.creator else None, )) @router.put("/pricing-plans/{plan_id}", response_model=ResponseModel[PricingPlanResponse]) async def update_pricing_plan( plan_id: int, data: PricingPlanUpdate, db: AsyncSession = Depends(get_db), ): """更新定价方案""" service = PricingService(db) try: plan = await service.update_pricing_plan( plan_id=plan_id, plan_name=data.plan_name, strategy_type=data.strategy_type.value if data.strategy_type else None, target_margin=data.target_margin, final_price=data.final_price, is_active=data.is_active, ) await db.commit() await db.refresh(plan, ["project", "creator"]) return ResponseModel( message="更新成功", data=PricingPlanResponse( id=plan.id, project_id=plan.project_id, project_name=plan.project.project_name if plan.project else None, plan_name=plan.plan_name, strategy_type=plan.strategy_type, base_cost=float(plan.base_cost), target_margin=float(plan.target_margin), suggested_price=float(plan.suggested_price), final_price=float(plan.final_price) if plan.final_price else None, ai_advice=plan.ai_advice, is_active=plan.is_active, created_at=plan.created_at, updated_at=plan.updated_at, ) ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @router.delete("/pricing-plans/{plan_id}", response_model=ResponseModel) async def delete_pricing_plan( plan_id: int, db: AsyncSession = Depends(get_db), ): """删除定价方案""" result = await db.execute( select(PricingPlan).where(PricingPlan.id == plan_id) ) plan = result.scalar_one_or_none() if not plan: raise HTTPException( status_code=404, detail={"code": ErrorCode.NOT_FOUND, "message": "定价方案不存在"} ) await db.delete(plan) await db.commit() return ResponseModel(message="删除成功") # AI 定价建议 @router.post("/projects/{project_id}/generate-pricing", response_model=ResponseModel[GeneratePricingResponse]) async def generate_pricing( project_id: int, request: GeneratePricingRequest, db: AsyncSession = Depends(get_db), ): """AI 生成定价建议 支持流式和非流式两种模式 """ service = PricingService(db) # 检查项目是否存在 result = await db.execute( select(Project).where(Project.id == project_id) ) if not result.scalar_one_or_none(): raise HTTPException( status_code=404, detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"} ) if request.stream: # 流式返回 return StreamingResponse( service.generate_pricing_advice_stream( project_id=project_id, target_margin=request.target_margin, ), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", } ) else: # 非流式返回 try: response = await service.generate_pricing_advice( project_id=project_id, target_margin=request.target_margin, strategies=request.strategies, ) return ResponseModel(data=response) except Exception as e: raise HTTPException( status_code=500, detail={"code": ErrorCode.AI_SERVICE_ERROR, "message": str(e)} ) @router.post("/projects/{project_id}/simulate-strategy", response_model=ResponseModel[SimulateStrategyResponse]) async def simulate_strategy( project_id: int, request: SimulateStrategyRequest, db: AsyncSession = Depends(get_db), ): """模拟定价策略""" service = PricingService(db) try: response = await service.simulate_strategies( project_id=project_id, strategies=request.strategies, target_margin=request.target_margin, ) return ResponseModel(data=response) except ValueError as e: raise HTTPException(status_code=400, detail=str(e))