328 lines
11 KiB
Python
328 lines
11 KiB
Python
"""智能定价路由
|
|
|
|
智能定价建议相关的 API 接口
|
|
"""
|
|
|
|
from typing import Optional, List
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from fastapi.responses import StreamingResponse
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.database import get_db
|
|
from app.models import PricingPlan, Project
|
|
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
|
|
from app.schemas.pricing import (
|
|
StrategyType,
|
|
PricingPlanCreate,
|
|
PricingPlanUpdate,
|
|
PricingPlanResponse,
|
|
PricingPlanListResponse,
|
|
PricingPlanQuery,
|
|
GeneratePricingRequest,
|
|
GeneratePricingResponse,
|
|
SimulateStrategyRequest,
|
|
SimulateStrategyResponse,
|
|
)
|
|
from app.services.pricing_service import PricingService
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# 定价方案 CRUD
|
|
|
|
# 定价方案允许的排序字段白名单
|
|
PRICING_PLAN_SORT_FIELDS = {"created_at", "updated_at", "plan_name", "base_cost", "target_margin", "suggested_price"}
|
|
|
|
|
|
@router.get("/pricing-plans", response_model=ResponseModel[PaginatedData[PricingPlanListResponse]])
|
|
async def list_pricing_plans(
|
|
page: int = Query(1, ge=1),
|
|
page_size: int = Query(20, ge=1, le=100),
|
|
project_id: Optional[int] = None,
|
|
strategy_type: Optional[StrategyType] = None,
|
|
is_active: Optional[bool] = None,
|
|
sort_by: str = "created_at",
|
|
sort_order: str = "desc",
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""获取定价方案列表"""
|
|
query = select(PricingPlan).options(
|
|
selectinload(PricingPlan.project),
|
|
selectinload(PricingPlan.creator),
|
|
)
|
|
|
|
if project_id:
|
|
query = query.where(PricingPlan.project_id == project_id)
|
|
if strategy_type:
|
|
query = query.where(PricingPlan.strategy_type == strategy_type.value)
|
|
if is_active is not None:
|
|
query = query.where(PricingPlan.is_active == is_active)
|
|
|
|
# 计算总数
|
|
count_query = select(func.count()).select_from(query.subquery())
|
|
total_result = await db.execute(count_query)
|
|
total = total_result.scalar() or 0
|
|
|
|
# 排序 - 使用白名单验证防止注入
|
|
if sort_by not in PRICING_PLAN_SORT_FIELDS:
|
|
sort_by = "created_at"
|
|
sort_column = getattr(PricingPlan, sort_by, PricingPlan.created_at)
|
|
if sort_order == "desc":
|
|
query = query.order_by(sort_column.desc())
|
|
else:
|
|
query = query.order_by(sort_column.asc())
|
|
|
|
# 分页
|
|
query = query.offset((page - 1) * page_size).limit(page_size)
|
|
|
|
result = await db.execute(query)
|
|
plans = result.scalars().all()
|
|
|
|
items = []
|
|
for plan in plans:
|
|
items.append(PricingPlanListResponse(
|
|
id=plan.id,
|
|
project_id=plan.project_id,
|
|
project_name=plan.project.project_name if plan.project else None,
|
|
plan_name=plan.plan_name,
|
|
strategy_type=plan.strategy_type,
|
|
base_cost=float(plan.base_cost),
|
|
target_margin=float(plan.target_margin),
|
|
suggested_price=float(plan.suggested_price),
|
|
final_price=float(plan.final_price) if plan.final_price else None,
|
|
is_active=plan.is_active,
|
|
created_at=plan.created_at,
|
|
created_by_name=plan.creator.username if plan.creator else None,
|
|
))
|
|
|
|
return ResponseModel(data=PaginatedData(
|
|
items=items,
|
|
total=total,
|
|
page=page,
|
|
page_size=page_size,
|
|
total_pages=(total + page_size - 1) // page_size,
|
|
))
|
|
|
|
|
|
@router.post("/pricing-plans", response_model=ResponseModel[PricingPlanResponse])
|
|
async def create_pricing_plan(
|
|
data: PricingPlanCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""创建定价方案"""
|
|
service = PricingService(db)
|
|
|
|
try:
|
|
plan = await service.create_pricing_plan(
|
|
project_id=data.project_id,
|
|
plan_name=data.plan_name,
|
|
strategy_type=data.strategy_type,
|
|
target_margin=data.target_margin,
|
|
)
|
|
await db.commit()
|
|
|
|
# 重新加载关系
|
|
await db.refresh(plan, ["project", "creator"])
|
|
|
|
return ResponseModel(
|
|
message="创建成功",
|
|
data=PricingPlanResponse(
|
|
id=plan.id,
|
|
project_id=plan.project_id,
|
|
project_name=plan.project.project_name if plan.project else None,
|
|
plan_name=plan.plan_name,
|
|
strategy_type=plan.strategy_type,
|
|
base_cost=float(plan.base_cost),
|
|
target_margin=float(plan.target_margin),
|
|
suggested_price=float(plan.suggested_price),
|
|
final_price=float(plan.final_price) if plan.final_price else None,
|
|
ai_advice=plan.ai_advice,
|
|
is_active=plan.is_active,
|
|
created_at=plan.created_at,
|
|
updated_at=plan.updated_at,
|
|
)
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.get("/pricing-plans/{plan_id}", response_model=ResponseModel[PricingPlanResponse])
|
|
async def get_pricing_plan(
|
|
plan_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""获取定价方案详情"""
|
|
result = await db.execute(
|
|
select(PricingPlan).options(
|
|
selectinload(PricingPlan.project),
|
|
selectinload(PricingPlan.creator),
|
|
).where(PricingPlan.id == plan_id)
|
|
)
|
|
plan = result.scalar_one_or_none()
|
|
|
|
if not plan:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail={"code": ErrorCode.NOT_FOUND, "message": "定价方案不存在"}
|
|
)
|
|
|
|
return ResponseModel(data=PricingPlanResponse(
|
|
id=plan.id,
|
|
project_id=plan.project_id,
|
|
project_name=plan.project.project_name if plan.project else None,
|
|
plan_name=plan.plan_name,
|
|
strategy_type=plan.strategy_type,
|
|
base_cost=float(plan.base_cost),
|
|
target_margin=float(plan.target_margin),
|
|
suggested_price=float(plan.suggested_price),
|
|
final_price=float(plan.final_price) if plan.final_price else None,
|
|
ai_advice=plan.ai_advice,
|
|
is_active=plan.is_active,
|
|
created_at=plan.created_at,
|
|
updated_at=plan.updated_at,
|
|
created_by_name=plan.creator.username if plan.creator else None,
|
|
))
|
|
|
|
|
|
@router.put("/pricing-plans/{plan_id}", response_model=ResponseModel[PricingPlanResponse])
|
|
async def update_pricing_plan(
|
|
plan_id: int,
|
|
data: PricingPlanUpdate,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""更新定价方案"""
|
|
service = PricingService(db)
|
|
|
|
try:
|
|
plan = await service.update_pricing_plan(
|
|
plan_id=plan_id,
|
|
plan_name=data.plan_name,
|
|
strategy_type=data.strategy_type.value if data.strategy_type else None,
|
|
target_margin=data.target_margin,
|
|
final_price=data.final_price,
|
|
is_active=data.is_active,
|
|
)
|
|
await db.commit()
|
|
|
|
await db.refresh(plan, ["project", "creator"])
|
|
|
|
return ResponseModel(
|
|
message="更新成功",
|
|
data=PricingPlanResponse(
|
|
id=plan.id,
|
|
project_id=plan.project_id,
|
|
project_name=plan.project.project_name if plan.project else None,
|
|
plan_name=plan.plan_name,
|
|
strategy_type=plan.strategy_type,
|
|
base_cost=float(plan.base_cost),
|
|
target_margin=float(plan.target_margin),
|
|
suggested_price=float(plan.suggested_price),
|
|
final_price=float(plan.final_price) if plan.final_price else None,
|
|
ai_advice=plan.ai_advice,
|
|
is_active=plan.is_active,
|
|
created_at=plan.created_at,
|
|
updated_at=plan.updated_at,
|
|
)
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.delete("/pricing-plans/{plan_id}", response_model=ResponseModel)
|
|
async def delete_pricing_plan(
|
|
plan_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""删除定价方案"""
|
|
result = await db.execute(
|
|
select(PricingPlan).where(PricingPlan.id == plan_id)
|
|
)
|
|
plan = result.scalar_one_or_none()
|
|
|
|
if not plan:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail={"code": ErrorCode.NOT_FOUND, "message": "定价方案不存在"}
|
|
)
|
|
|
|
await db.delete(plan)
|
|
await db.commit()
|
|
|
|
return ResponseModel(message="删除成功")
|
|
|
|
|
|
# AI 定价建议
|
|
|
|
@router.post("/projects/{project_id}/generate-pricing", response_model=ResponseModel[GeneratePricingResponse])
|
|
async def generate_pricing(
|
|
project_id: int,
|
|
request: GeneratePricingRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""AI 生成定价建议
|
|
|
|
支持流式和非流式两种模式
|
|
"""
|
|
service = PricingService(db)
|
|
|
|
# 检查项目是否存在
|
|
result = await db.execute(
|
|
select(Project).where(Project.id == project_id)
|
|
)
|
|
if not result.scalar_one_or_none():
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
|
)
|
|
|
|
if request.stream:
|
|
# 流式返回
|
|
return StreamingResponse(
|
|
service.generate_pricing_advice_stream(
|
|
project_id=project_id,
|
|
target_margin=request.target_margin,
|
|
),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
}
|
|
)
|
|
else:
|
|
# 非流式返回
|
|
try:
|
|
response = await service.generate_pricing_advice(
|
|
project_id=project_id,
|
|
target_margin=request.target_margin,
|
|
strategies=request.strategies,
|
|
)
|
|
return ResponseModel(data=response)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={"code": ErrorCode.AI_SERVICE_ERROR, "message": str(e)}
|
|
)
|
|
|
|
|
|
@router.post("/projects/{project_id}/simulate-strategy", response_model=ResponseModel[SimulateStrategyResponse])
|
|
async def simulate_strategy(
|
|
project_id: int,
|
|
request: SimulateStrategyRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""模拟定价策略"""
|
|
service = PricingService(db)
|
|
|
|
try:
|
|
response = await service.simulate_strategies(
|
|
project_id=project_id,
|
|
strategies=request.strategies,
|
|
target_margin=request.target_margin,
|
|
)
|
|
return ResponseModel(data=response)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|