Initial commit: 智能项目定价模型
This commit is contained in:
68
后端服务/Dockerfile
Normal file
68
后端服务/Dockerfile
Normal file
@@ -0,0 +1,68 @@
|
||||
# 智能项目定价模型 - 后端 Dockerfile
|
||||
# 遵循瑞小美部署规范:使用具体版本号,配置阿里云镜像源
|
||||
|
||||
# 构建阶段
|
||||
FROM python:3.11.9-slim AS builder
|
||||
|
||||
# 配置阿里云 APT 源
|
||||
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
# 安装构建依赖
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
libffi-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 配置阿里云 pip 源
|
||||
RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ \
|
||||
&& pip config set global.trusted-host mirrors.aliyun.com
|
||||
|
||||
# 创建虚拟环境
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# 复制依赖文件并安装
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 运行阶段
|
||||
FROM python:3.11.9-slim AS runner
|
||||
|
||||
# 配置阿里云 APT 源
|
||||
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
# 安装运行时依赖
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 创建非 root 用户
|
||||
RUN groupadd -r appgroup && useradd -r -g appgroup appuser
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 从构建阶段复制虚拟环境
|
||||
COPY --from=builder /opt/venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# 复制应用代码
|
||||
COPY --chown=appuser:appgroup . .
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
TZ=Asia/Shanghai
|
||||
|
||||
# 切换到非 root 用户
|
||||
USER appuser
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8000
|
||||
|
||||
# 健康检查
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# 启动命令
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
29
后端服务/Dockerfile.dev
Normal file
29
后端服务/Dockerfile.dev
Normal file
@@ -0,0 +1,29 @@
|
||||
# 开发环境 Dockerfile
|
||||
|
||||
FROM python:3.11.9-slim
|
||||
|
||||
# 配置阿里云源
|
||||
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources
|
||||
|
||||
# 安装依赖
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 配置 pip
|
||||
RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 安装依赖
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
TZ=Asia/Shanghai
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||
90
后端服务/SECURITY_CHECKLIST.md
Normal file
90
后端服务/SECURITY_CHECKLIST.md
Normal file
@@ -0,0 +1,90 @@
|
||||
# 安全检查清单
|
||||
|
||||
> 智能项目定价模型 - M4 测试优化阶段
|
||||
> 遵循瑞小美系统技术栈标准与安全规范
|
||||
|
||||
---
|
||||
|
||||
## 1. API Key 管理
|
||||
|
||||
- [x] 未在代码中硬编码 API Key
|
||||
- [x] 通过 `shared_backend.AIService` 调用 AI 服务
|
||||
- [x] API Key 从门户系统统一获取
|
||||
- [x] `.env` 文件在 `.gitignore` 中排除
|
||||
|
||||
## 2. 输入验证
|
||||
|
||||
- [x] 使用 Pydantic 进行请求参数验证
|
||||
- [x] 添加 SQL 注入检测
|
||||
- [x] 添加 XSS 检测
|
||||
- [x] 限制请求数据嵌套深度
|
||||
|
||||
## 3. 身份认证
|
||||
|
||||
- [ ] 集成 OAuth 认证(待 M5 部署阶段完成)
|
||||
- [x] 预留认证中间件接口
|
||||
- [x] API 路由支持权限控制
|
||||
|
||||
## 4. 速率限制
|
||||
|
||||
- [x] 实现请求速率限制中间件
|
||||
- [x] AI 接口有特殊限制(10 次/分钟)
|
||||
- [x] 默认限制 100 次/分钟
|
||||
|
||||
## 5. 安全响应头
|
||||
|
||||
- [x] X-XSS-Protection
|
||||
- [x] X-Content-Type-Options
|
||||
- [x] X-Frame-Options
|
||||
- [x] Content-Security-Policy
|
||||
|
||||
## 6. 数据保护
|
||||
|
||||
- [x] 敏感数据使用 DECIMAL 类型存储
|
||||
- [x] 数据库连接使用连接池
|
||||
- [x] 支持敏感字段脱敏(待实现具体业务)
|
||||
|
||||
## 7. 日志审计
|
||||
|
||||
- [x] 实现审计日志记录器
|
||||
- [x] 记录敏感操作
|
||||
- [x] 日志格式为 JSON(便于分析)
|
||||
|
||||
## 8. 错误处理
|
||||
|
||||
- [x] 统一错误响应格式
|
||||
- [x] 生产环境不暴露内部错误详情
|
||||
- [x] 错误码规范化
|
||||
|
||||
## 9. 依赖安全
|
||||
|
||||
- [x] 使用固定版本的依赖
|
||||
- [ ] 定期检查依赖漏洞(建议使用 `pip-audit`)
|
||||
|
||||
## 10. 部署安全
|
||||
|
||||
- [x] Docker 容器使用非 root 用户(待验证)
|
||||
- [x] 只暴露必要端口(Nginx 80/443)
|
||||
- [x] 配置健康检查
|
||||
- [x] 配置资源限制
|
||||
|
||||
---
|
||||
|
||||
## 安全测试命令
|
||||
|
||||
```bash
|
||||
# 运行安全相关测试
|
||||
pytest tests/ -m "security" -v
|
||||
|
||||
# 检查依赖漏洞
|
||||
pip install pip-audit
|
||||
pip-audit
|
||||
|
||||
# 检查代码安全问题
|
||||
pip install bandit
|
||||
bandit -r app/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
*瑞小美技术团队 · 2026-01-20*
|
||||
3
后端服务/app/__init__.py
Normal file
3
后端服务/app/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""智能项目定价模型 - 后端服务"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
85
后端服务/app/config.py
Normal file
85
后端服务/app/config.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""配置管理模块
|
||||
|
||||
使用 Pydantic Settings 管理环境变量配置
|
||||
遵循瑞小美系统技术栈标准
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import model_validator
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置"""
|
||||
|
||||
# 应用配置
|
||||
APP_NAME: str = "智能项目定价模型"
|
||||
APP_VERSION: str = "1.0.0"
|
||||
APP_ENV: str = "development"
|
||||
DEBUG: bool = True
|
||||
SECRET_KEY: str = "" # 必须通过环境变量设置
|
||||
|
||||
# 数据库配置 - MySQL 8.0, utf8mb4
|
||||
# 开发环境使用默认值,生产环境必须通过环境变量设置
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///:memory:" # 安全的开发默认值
|
||||
|
||||
# 数据库连接池配置
|
||||
DB_POOL_SIZE: int = 5
|
||||
DB_MAX_OVERFLOW: int = 10
|
||||
DB_POOL_RECYCLE: int = 3600
|
||||
|
||||
# 门户系统配置 - AI Key 从门户获取
|
||||
PORTAL_CONFIG_API: str = "http://portal-backend:8000/api/ai/internal/config"
|
||||
|
||||
# AI 服务配置
|
||||
AI_MODULE_CODE: str = "pricing_model"
|
||||
|
||||
# 时区配置 - Asia/Shanghai
|
||||
TIMEZONE: str = "Asia/Shanghai"
|
||||
|
||||
# CORS 配置 - 生产环境应限制为具体域名
|
||||
CORS_ORIGINS: list[str] = ["http://localhost:5173", "http://127.0.0.1:5173"]
|
||||
|
||||
# API 配置
|
||||
API_V1_PREFIX: str = "/api/v1"
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_production_config(self) -> 'Settings':
|
||||
"""验证生产环境配置安全性"""
|
||||
if self.APP_ENV == "production":
|
||||
# 生产环境必须设置安全的 SECRET_KEY
|
||||
if not self.SECRET_KEY or self.SECRET_KEY == "your-secret-key-change-in-production":
|
||||
raise ValueError("生产环境必须设置 SECRET_KEY 环境变量")
|
||||
|
||||
# 生产环境 CORS 不能使用 "*"
|
||||
if "*" in self.CORS_ORIGINS:
|
||||
warnings.warn("生产环境 CORS_ORIGINS 不应使用 '*',请设置具体的允许域名")
|
||||
|
||||
# 生产环境不应开启 DEBUG
|
||||
if self.DEBUG:
|
||||
warnings.warn("生产环境建议关闭 DEBUG 模式")
|
||||
|
||||
# 如果没有设置 SECRET_KEY,开发环境自动生成一个
|
||||
if not self.SECRET_KEY:
|
||||
self.SECRET_KEY = secrets.token_urlsafe(32)
|
||||
|
||||
return self
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""获取配置单例"""
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
84
后端服务/app/database.py
Normal file
84
后端服务/app/database.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""数据库连接模块
|
||||
|
||||
使用 SQLAlchemy 异步引擎
|
||||
遵循瑞小美系统技术栈标准:MySQL 8.0, utf8mb4, utf8mb4_unicode_ci
|
||||
"""
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
# 命名约定,便于数据库迁移
|
||||
convention = {
|
||||
"ix": "ix_%(column_0_label)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s"
|
||||
}
|
||||
|
||||
metadata = MetaData(naming_convention=convention)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""SQLAlchemy 模型基类"""
|
||||
metadata = metadata
|
||||
|
||||
|
||||
# 创建异步引擎
|
||||
# SQLite 不支持连接池参数,需要区分处理
|
||||
_engine_kwargs = {
|
||||
"echo": settings.DEBUG,
|
||||
}
|
||||
|
||||
# 仅在非 SQLite 环境下添加连接池参数
|
||||
if not settings.DATABASE_URL.startswith("sqlite"):
|
||||
_engine_kwargs.update({
|
||||
"pool_size": settings.DB_POOL_SIZE,
|
||||
"max_overflow": settings.DB_MAX_OVERFLOW,
|
||||
"pool_recycle": settings.DB_POOL_RECYCLE,
|
||||
"pool_pre_ping": True,
|
||||
})
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
**_engine_kwargs,
|
||||
)
|
||||
|
||||
# 创建异步会话工厂
|
||||
async_session_maker = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话依赖"""
|
||||
async with async_session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""初始化数据库表"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""关闭数据库连接"""
|
||||
await engine.dispose()
|
||||
127
后端服务/app/main.py
Normal file
127
后端服务/app/main.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""FastAPI 应用入口
|
||||
|
||||
智能项目定价模型后端服务
|
||||
遵循瑞小美系统技术栈标准
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.database import init_db, close_db
|
||||
from app.routers import health, categories, materials, equipments, staff_levels, fixed_costs, projects, market, pricing, profit, dashboard
|
||||
from app.middleware import (
|
||||
PerformanceMiddleware,
|
||||
ResponseCacheMiddleware,
|
||||
RateLimitMiddleware,
|
||||
SecurityHeadersMiddleware,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
# 启动时初始化数据库
|
||||
await init_db()
|
||||
yield
|
||||
# 关闭时清理资源
|
||||
await close_db()
|
||||
|
||||
|
||||
# 创建 FastAPI 应用
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
version=settings.APP_VERSION,
|
||||
description="智能项目定价模型 - 帮助机构精准核算成本、分析市场、智能定价",
|
||||
docs_url="/docs" if settings.DEBUG else None,
|
||||
redoc_url="/redoc" if settings.DEBUG else None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# 配置 CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 性能优化中间件
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000) # 压缩大于 1KB 的响应
|
||||
app.add_middleware(PerformanceMiddleware) # 性能监控
|
||||
app.add_middleware(ResponseCacheMiddleware) # 响应缓存
|
||||
|
||||
# 安全中间件
|
||||
app.add_middleware(RateLimitMiddleware, enabled=not settings.DEBUG) # 速率限制(生产环境)
|
||||
app.add_middleware(SecurityHeadersMiddleware) # 安全响应头
|
||||
|
||||
# 注册路由
|
||||
app.include_router(health.router, tags=["健康检查"])
|
||||
app.include_router(
|
||||
categories.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}/categories",
|
||||
tags=["项目分类"]
|
||||
)
|
||||
app.include_router(
|
||||
materials.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}/materials",
|
||||
tags=["耗材管理"]
|
||||
)
|
||||
app.include_router(
|
||||
equipments.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}/equipments",
|
||||
tags=["设备管理"]
|
||||
)
|
||||
app.include_router(
|
||||
staff_levels.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}/staff-levels",
|
||||
tags=["人员级别"]
|
||||
)
|
||||
app.include_router(
|
||||
fixed_costs.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}/fixed-costs",
|
||||
tags=["固定成本"]
|
||||
)
|
||||
app.include_router(
|
||||
projects.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}/projects",
|
||||
tags=["服务项目"]
|
||||
)
|
||||
app.include_router(
|
||||
market.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}",
|
||||
tags=["市场行情"]
|
||||
)
|
||||
app.include_router(
|
||||
pricing.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}",
|
||||
tags=["智能定价"]
|
||||
)
|
||||
app.include_router(
|
||||
profit.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}",
|
||||
tags=["利润模拟"]
|
||||
)
|
||||
app.include_router(
|
||||
dashboard.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}",
|
||||
tags=["仪表盘"]
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"name": settings.APP_NAME,
|
||||
"version": settings.APP_VERSION,
|
||||
"docs": "/docs" if settings.DEBUG else None
|
||||
}
|
||||
}
|
||||
20
后端服务/app/middleware/__init__.py
Normal file
20
后端服务/app/middleware/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""中间件模块"""
|
||||
from .performance import PerformanceMiddleware
|
||||
from .cache import ResponseCacheMiddleware
|
||||
from .security import (
|
||||
RateLimitMiddleware,
|
||||
SecurityHeadersMiddleware,
|
||||
InputSanitizer,
|
||||
validate_request_body,
|
||||
AuditLogger,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PerformanceMiddleware",
|
||||
"ResponseCacheMiddleware",
|
||||
"RateLimitMiddleware",
|
||||
"SecurityHeadersMiddleware",
|
||||
"InputSanitizer",
|
||||
"validate_request_body",
|
||||
"AuditLogger",
|
||||
]
|
||||
121
后端服务/app/middleware/cache.py
Normal file
121
后端服务/app/middleware/cache.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""响应缓存中间件
|
||||
|
||||
对 GET 请求的响应进行缓存,减少数据库查询
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Callable, Optional, Set
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.services.cache_service import get_cache, CacheNamespace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponseCacheMiddleware(BaseHTTPMiddleware):
|
||||
"""响应缓存中间件
|
||||
|
||||
对符合条件的 GET 请求进行响应缓存
|
||||
"""
|
||||
|
||||
# 需要缓存的路径前缀和 TTL 配置
|
||||
CACHE_CONFIG = {
|
||||
"/api/v1/categories": {"ttl": 300, "namespace": CacheNamespace.CATEGORIES},
|
||||
"/api/v1/materials": {"ttl": 300, "namespace": CacheNamespace.MATERIALS},
|
||||
"/api/v1/equipments": {"ttl": 300, "namespace": CacheNamespace.EQUIPMENTS},
|
||||
"/api/v1/staff-levels": {"ttl": 300, "namespace": CacheNamespace.STAFF_LEVELS},
|
||||
}
|
||||
|
||||
# 不缓存的路径(精确匹配)
|
||||
NO_CACHE_PATHS: Set[str] = {
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
}
|
||||
|
||||
def _should_cache(self, request: Request) -> Optional[dict]:
|
||||
"""判断是否应该缓存"""
|
||||
# 只缓存 GET 请求
|
||||
if request.method != "GET":
|
||||
return None
|
||||
|
||||
path = request.url.path
|
||||
|
||||
# 排除不缓存的路径
|
||||
if path in self.NO_CACHE_PATHS:
|
||||
return None
|
||||
|
||||
# 检查是否在缓存配置中
|
||||
for prefix, config in self.CACHE_CONFIG.items():
|
||||
if path.startswith(prefix):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
def _generate_cache_key(self, request: Request) -> str:
|
||||
"""生成缓存键"""
|
||||
# 包含路径和查询参数
|
||||
key_parts = [
|
||||
request.method,
|
||||
request.url.path,
|
||||
str(sorted(request.query_params.items())),
|
||||
]
|
||||
key_str = "|".join(key_parts)
|
||||
return hashlib.md5(key_str.encode()).hexdigest()
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
cache_config = self._should_cache(request)
|
||||
|
||||
if not cache_config:
|
||||
return await call_next(request)
|
||||
|
||||
# 生成缓存键
|
||||
cache_key = self._generate_cache_key(request)
|
||||
cache = get_cache(cache_config["namespace"])
|
||||
|
||||
# 尝试从缓存获取
|
||||
cached_data = cache.get(cache_key)
|
||||
if cached_data is not None:
|
||||
logger.debug(f"Cache hit: {request.url.path}")
|
||||
response = Response(
|
||||
content=cached_data["content"],
|
||||
status_code=cached_data["status_code"],
|
||||
headers=dict(cached_data["headers"]),
|
||||
media_type="application/json",
|
||||
)
|
||||
response.headers["X-Cache"] = "HIT"
|
||||
return response
|
||||
|
||||
# 执行请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 只缓存成功的响应
|
||||
if response.status_code == 200:
|
||||
# 读取响应体
|
||||
body = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body += chunk
|
||||
|
||||
# 保存到缓存
|
||||
cache_data = {
|
||||
"content": body,
|
||||
"status_code": response.status_code,
|
||||
"headers": dict(response.headers),
|
||||
}
|
||||
cache.set(cache_key, cache_data, cache_config["ttl"])
|
||||
|
||||
# 重新构建响应
|
||||
response = Response(
|
||||
content=body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type="application/json",
|
||||
)
|
||||
response.headers["X-Cache"] = "MISS"
|
||||
|
||||
return response
|
||||
50
后端服务/app/middleware/performance.py
Normal file
50
后端服务/app/middleware/performance.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""性能监控中间件
|
||||
|
||||
记录请求响应时间,用于性能分析和优化
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerformanceMiddleware(BaseHTTPMiddleware):
|
||||
"""性能监控中间件
|
||||
|
||||
记录每个请求的响应时间,并在响应头中添加 X-Response-Time
|
||||
"""
|
||||
|
||||
# 慢请求阈值(毫秒)
|
||||
SLOW_REQUEST_THRESHOLD = 1000
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
start_time = time.time()
|
||||
|
||||
# 执行请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 计算响应时间
|
||||
process_time = (time.time() - start_time) * 1000
|
||||
|
||||
# 添加响应头
|
||||
response.headers["X-Response-Time"] = f"{process_time:.2f}ms"
|
||||
|
||||
# 记录慢请求
|
||||
if process_time > self.SLOW_REQUEST_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Slow request: {request.method} {request.url.path} "
|
||||
f"took {process_time:.2f}ms"
|
||||
)
|
||||
|
||||
# 记录请求日志(开发环境)
|
||||
logger.debug(
|
||||
f"{request.method} {request.url.path} - "
|
||||
f"{response.status_code} - {process_time:.2f}ms"
|
||||
)
|
||||
|
||||
return response
|
||||
267
后端服务/app/middleware/security.py
Normal file
267
后端服务/app/middleware/security.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""安全中间件
|
||||
|
||||
实现安全相关功能:
|
||||
- 请求验证
|
||||
- 速率限制
|
||||
- 安全头设置
|
||||
- 敏感数据保护
|
||||
|
||||
遵循瑞小美系统技术栈标准
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Optional
|
||||
import re
|
||||
|
||||
from fastapi import Request, Response, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""速率限制中间件
|
||||
|
||||
防止 API 滥用,保护服务稳定性
|
||||
"""
|
||||
|
||||
# 速率限制配置
|
||||
RATE_LIMITS = {
|
||||
"default": {"requests": 100, "window": 60}, # 默认:100 次/分钟
|
||||
"/api/v1/projects/*/generate-pricing": {"requests": 10, "window": 60}, # AI 接口:10 次/分钟
|
||||
"/api/v1/projects/*/market-analysis": {"requests": 20, "window": 60}, # 分析接口:20 次/分钟
|
||||
}
|
||||
|
||||
def __init__(self, app, enabled: bool = True):
|
||||
super().__init__(app)
|
||||
self.enabled = enabled
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""获取客户端标识"""
|
||||
# 优先使用 X-Forwarded-For(反向代理场景)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
# 使用客户端 IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
def _get_rate_limit(self, path: str) -> Dict:
|
||||
"""获取路径的速率限制配置"""
|
||||
for pattern, limit in self.RATE_LIMITS.items():
|
||||
if pattern == "default":
|
||||
continue
|
||||
# 简单的路径匹配(* 匹配任意字符)
|
||||
regex = pattern.replace("*", "[^/]+")
|
||||
if re.match(regex, path):
|
||||
return limit
|
||||
|
||||
return self.RATE_LIMITS["default"]
|
||||
|
||||
def _is_rate_limited(self, client_id: str, path: str) -> bool:
|
||||
"""检查是否超过速率限制"""
|
||||
limit_config = self._get_rate_limit(path)
|
||||
requests = limit_config["requests"]
|
||||
window = limit_config["window"]
|
||||
|
||||
now = time.time()
|
||||
key = f"{client_id}:{path}"
|
||||
|
||||
# 清理过期记录
|
||||
self._requests[key] = [t for t in self._requests[key] if now - t < window]
|
||||
|
||||
# 检查是否超限
|
||||
if len(self._requests[key]) >= requests:
|
||||
return True
|
||||
|
||||
# 记录请求
|
||||
self._requests[key].append(now)
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
client_id = self._get_client_id(request)
|
||||
path = request.url.path
|
||||
|
||||
if self._is_rate_limited(client_id, path):
|
||||
logger.warning(f"Rate limit exceeded: {client_id} -> {path}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"code": 40001,
|
||||
"message": "请求过于频繁,请稍后再试",
|
||||
"data": None
|
||||
}
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""安全响应头中间件
|
||||
|
||||
添加安全相关的 HTTP 响应头
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# 防止 XSS 攻击
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
|
||||
# 防止 MIME 类型嗅探
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
# 点击劫持保护
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
|
||||
# 内容安全策略(基础版)
|
||||
response.headers["Content-Security-Policy"] = "default-src 'self'"
|
||||
|
||||
# 严格传输安全(仅在 HTTPS 环境)
|
||||
# response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class InputSanitizer:
|
||||
"""输入清理工具
|
||||
|
||||
防止 SQL 注入、XSS 等攻击
|
||||
"""
|
||||
|
||||
# SQL 注入关键字
|
||||
SQL_KEYWORDS = [
|
||||
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "UNION",
|
||||
"OR", "AND", "--", "/*", "*/", "EXEC", "EXECUTE"
|
||||
]
|
||||
|
||||
# XSS 危险模式
|
||||
XSS_PATTERNS = [
|
||||
r"<script.*?>",
|
||||
r"javascript:",
|
||||
r"on\w+\s*=",
|
||||
r"eval\s*\(",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def check_sql_injection(cls, value: str) -> bool:
|
||||
"""检查是否包含 SQL 注入特征"""
|
||||
upper_value = value.upper()
|
||||
for keyword in cls.SQL_KEYWORDS:
|
||||
if keyword in upper_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_xss(cls, value: str) -> bool:
|
||||
"""检查是否包含 XSS 特征"""
|
||||
for pattern in cls.XSS_PATTERNS:
|
||||
if re.search(pattern, value, re.IGNORECASE):
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def sanitize(cls, value: str) -> str:
|
||||
"""清理输入值
|
||||
|
||||
移除潜在危险字符
|
||||
"""
|
||||
# 移除 HTML 标签
|
||||
value = re.sub(r'<[^>]+>', '', value)
|
||||
|
||||
# 转义特殊字符
|
||||
value = value.replace("&", "&")
|
||||
value = value.replace("<", "<")
|
||||
value = value.replace(">", ">")
|
||||
value = value.replace('"', """)
|
||||
value = value.replace("'", "'")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def validate_request_body(data: dict, max_depth: int = 10) -> None:
|
||||
"""验证请求体
|
||||
|
||||
检查数据深度和内容安全
|
||||
|
||||
Args:
|
||||
data: 请求数据
|
||||
max_depth: 最大嵌套深度
|
||||
|
||||
Raises:
|
||||
HTTPException: 验证失败
|
||||
"""
|
||||
def check_depth(obj, depth=0):
|
||||
if depth > max_depth:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": 10001, "message": "请求数据嵌套层级过深"}
|
||||
)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
for value in obj.values():
|
||||
check_depth(value, depth + 1)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
check_depth(item, depth + 1)
|
||||
elif isinstance(obj, str):
|
||||
if InputSanitizer.check_sql_injection(obj):
|
||||
logger.warning(f"Potential SQL injection detected: {obj[:100]}")
|
||||
if InputSanitizer.check_xss(obj):
|
||||
logger.warning(f"Potential XSS detected: {obj[:100]}")
|
||||
|
||||
check_depth(data)
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""审计日志记录器
|
||||
|
||||
记录敏感操作的审计日志
|
||||
"""
|
||||
|
||||
# 需要审计的操作
|
||||
AUDIT_OPERATIONS = {
|
||||
("POST", "/api/v1/pricing-plans"): "创建定价方案",
|
||||
("PUT", "/api/v1/pricing-plans/*"): "更新定价方案",
|
||||
("DELETE", "/api/v1/pricing-plans/*"): "删除定价方案",
|
||||
("POST", "/api/v1/projects/*/generate-pricing"): "生成 AI 定价建议",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def should_audit(cls, method: str, path: str) -> Optional[str]:
|
||||
"""检查是否需要审计"""
|
||||
for (m, p), desc in cls.AUDIT_OPERATIONS.items():
|
||||
if m != method:
|
||||
continue
|
||||
regex = p.replace("*", "[^/]+")
|
||||
if re.match(regex, path):
|
||||
return desc
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def log(
|
||||
cls,
|
||||
operation: str,
|
||||
user_id: Optional[int],
|
||||
request: Request,
|
||||
response_code: int,
|
||||
details: Optional[Dict] = None
|
||||
):
|
||||
"""记录审计日志"""
|
||||
log_data = {
|
||||
"operation": operation,
|
||||
"user_id": user_id,
|
||||
"method": request.method,
|
||||
"path": str(request.url.path),
|
||||
"client_ip": request.client.host if request.client else "unknown",
|
||||
"response_code": response_code,
|
||||
"details": details or {},
|
||||
}
|
||||
logger.info(f"AUDIT: {log_data}")
|
||||
44
后端服务/app/models/__init__.py
Normal file
44
后端服务/app/models/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""SQLAlchemy 数据模型"""
|
||||
|
||||
from app.models.base import BaseModel, TimestampMixin
|
||||
from app.models.category import Category
|
||||
from app.models.material import Material
|
||||
from app.models.equipment import Equipment
|
||||
from app.models.staff_level import StaffLevel
|
||||
from app.models.fixed_cost import FixedCost
|
||||
from app.models.project import Project
|
||||
from app.models.project_cost_item import ProjectCostItem
|
||||
from app.models.project_labor_cost import ProjectLaborCost
|
||||
from app.models.project_cost_summary import ProjectCostSummary
|
||||
from app.models.competitor import Competitor
|
||||
from app.models.competitor_price import CompetitorPrice
|
||||
from app.models.benchmark_price import BenchmarkPrice
|
||||
from app.models.market_analysis_result import MarketAnalysisResult
|
||||
from app.models.pricing_plan import PricingPlan
|
||||
from app.models.profit_simulation import ProfitSimulation
|
||||
from app.models.sensitivity_analysis import SensitivityAnalysis
|
||||
from app.models.user import User
|
||||
from app.models.operation_log import OperationLog
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"TimestampMixin",
|
||||
"Category",
|
||||
"Material",
|
||||
"Equipment",
|
||||
"StaffLevel",
|
||||
"FixedCost",
|
||||
"Project",
|
||||
"ProjectCostItem",
|
||||
"ProjectLaborCost",
|
||||
"ProjectCostSummary",
|
||||
"Competitor",
|
||||
"CompetitorPrice",
|
||||
"BenchmarkPrice",
|
||||
"MarketAnalysisResult",
|
||||
"PricingPlan",
|
||||
"ProfitSimulation",
|
||||
"SensitivityAnalysis",
|
||||
"User",
|
||||
"OperationLog",
|
||||
]
|
||||
51
后端服务/app/models/base.py
Normal file
51
后端服务/app/models/base.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""模型基类
|
||||
|
||||
包含时间戳 Mixin 和通用基类
|
||||
遵循瑞小美数据库设计规范
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Integer, DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
from app.config import settings
|
||||
|
||||
|
||||
# 主键类型:SQLite 使用 Integer(支持自增),MySQL 使用 BigInteger
|
||||
# SQLite 的 AUTOINCREMENT 只在 INTEGER PRIMARY KEY 时生效
|
||||
_PrimaryKeyType = Integer if settings.DATABASE_URL.startswith("sqlite") else BigInteger
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""时间戳 Mixin"""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=func.now(),
|
||||
server_default=func.now(),
|
||||
comment="创建时间"
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=func.now(),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
comment="更新时间"
|
||||
)
|
||||
|
||||
|
||||
class BaseModel(Base, TimestampMixin):
|
||||
"""模型基类"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
_PrimaryKeyType,
|
||||
primary_key=True,
|
||||
autoincrement=True,
|
||||
comment="主键ID"
|
||||
)
|
||||
72
后端服务/app/models/benchmark_price.py
Normal file
72
后端服务/app/models/benchmark_price.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""标杆价格模型
|
||||
|
||||
维护行业标杆机构的价格参考
|
||||
"""
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy import BigInteger, String, Date, DECIMAL, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.category import Category
|
||||
|
||||
|
||||
class BenchmarkPrice(BaseModel):
|
||||
"""标杆价格表"""
|
||||
|
||||
__tablename__ = "benchmark_prices"
|
||||
|
||||
benchmark_name: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
comment="标杆机构名称"
|
||||
)
|
||||
category_id: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("categories.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="项目分类ID"
|
||||
)
|
||||
min_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="最低价"
|
||||
)
|
||||
max_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="最高价"
|
||||
)
|
||||
avg_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="均价"
|
||||
)
|
||||
price_tier: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
default="medium",
|
||||
comment="价格带:low-低端, medium-中端, high-高端, premium-奢华"
|
||||
)
|
||||
effective_date: Mapped[date] = mapped_column(
|
||||
Date,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="生效日期"
|
||||
)
|
||||
remark: Mapped[Optional[str]] = mapped_column(
|
||||
String(200),
|
||||
nullable=True,
|
||||
comment="备注"
|
||||
)
|
||||
|
||||
# 关系
|
||||
category: Mapped[Optional["Category"]] = relationship(
|
||||
"Category"
|
||||
)
|
||||
60
后端服务/app/models/category.py
Normal file
60
后端服务/app/models/category.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""项目分类模型
|
||||
|
||||
支持树形分类结构
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from sqlalchemy import BigInteger, String, Integer, Boolean, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class Category(BaseModel):
|
||||
"""项目分类表"""
|
||||
|
||||
__tablename__ = "categories"
|
||||
|
||||
category_name: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
comment="分类名称"
|
||||
)
|
||||
parent_id: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("categories.id"),
|
||||
nullable=True,
|
||||
comment="父分类ID"
|
||||
)
|
||||
sort_order: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
comment="排序"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
comment="是否启用"
|
||||
)
|
||||
|
||||
# 关系
|
||||
parent: Mapped[Optional["Category"]] = relationship(
|
||||
"Category",
|
||||
remote_side="Category.id",
|
||||
back_populates="children"
|
||||
)
|
||||
children: Mapped[List["Category"]] = relationship(
|
||||
"Category",
|
||||
back_populates="parent"
|
||||
)
|
||||
projects: Mapped[List["Project"]] = relationship(
|
||||
"Project",
|
||||
back_populates="category"
|
||||
)
|
||||
|
||||
|
||||
# 避免循环导入
|
||||
from app.models.project import Project
|
||||
69
后端服务/app/models/competitor.py
Normal file
69
后端服务/app/models/competitor.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""竞品机构模型
|
||||
|
||||
管理周边竞品医美机构信息
|
||||
"""
|
||||
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import String, Boolean, DECIMAL
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.competitor_price import CompetitorPrice
|
||||
|
||||
|
||||
class Competitor(BaseModel):
|
||||
"""竞品机构表"""
|
||||
|
||||
__tablename__ = "competitors"
|
||||
|
||||
competitor_name: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
comment="机构名称"
|
||||
)
|
||||
address: Mapped[Optional[str]] = mapped_column(
|
||||
String(200),
|
||||
nullable=True,
|
||||
comment="地址"
|
||||
)
|
||||
distance_km: Mapped[Optional[Decimal]] = mapped_column(
|
||||
DECIMAL(5, 2),
|
||||
nullable=True,
|
||||
comment="距离(公里)"
|
||||
)
|
||||
positioning: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
default="medium",
|
||||
index=True,
|
||||
comment="定位:high-高端, medium-中端, budget-大众"
|
||||
)
|
||||
contact: Mapped[Optional[str]] = mapped_column(
|
||||
String(50),
|
||||
nullable=True,
|
||||
comment="联系方式"
|
||||
)
|
||||
is_key_competitor: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
index=True,
|
||||
comment="是否重点关注"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
comment="是否启用"
|
||||
)
|
||||
|
||||
# 关系
|
||||
prices: Mapped[List["CompetitorPrice"]] = relationship(
|
||||
"CompetitorPrice",
|
||||
back_populates="competitor",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
83
后端服务/app/models/competitor_price.py
Normal file
83
后端服务/app/models/competitor_price.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""竞品价格模型
|
||||
|
||||
记录竞品机构的项目价格信息
|
||||
"""
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy import BigInteger, String, Date, DECIMAL, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.competitor import Competitor
|
||||
from app.models.project import Project
|
||||
|
||||
|
||||
class CompetitorPrice(BaseModel):
|
||||
"""竞品价格表"""
|
||||
|
||||
__tablename__ = "competitor_prices"
|
||||
|
||||
competitor_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("competitors.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="竞品机构ID"
|
||||
)
|
||||
project_id: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("projects.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="关联本店项目ID"
|
||||
)
|
||||
project_name: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
comment="竞品项目名称"
|
||||
)
|
||||
original_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="原价"
|
||||
)
|
||||
promo_price: Mapped[Optional[Decimal]] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=True,
|
||||
comment="促销价"
|
||||
)
|
||||
member_price: Mapped[Optional[Decimal]] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=True,
|
||||
comment="会员价"
|
||||
)
|
||||
price_source: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
comment="来源:official-官网, meituan-美团, dianping-大众点评, survey-实地调研"
|
||||
)
|
||||
collected_at: Mapped[date] = mapped_column(
|
||||
Date,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="采集日期"
|
||||
)
|
||||
remark: Mapped[Optional[str]] = mapped_column(
|
||||
String(200),
|
||||
nullable=True,
|
||||
comment="备注"
|
||||
)
|
||||
|
||||
# 关系
|
||||
competitor: Mapped["Competitor"] = relationship(
|
||||
"Competitor",
|
||||
back_populates="prices"
|
||||
)
|
||||
project: Mapped[Optional["Project"]] = relationship(
|
||||
"Project"
|
||||
)
|
||||
75
后端服务/app/models/equipment.py
Normal file
75
后端服务/app/models/equipment.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""设备模型
|
||||
|
||||
管理设备基础信息和折旧计算
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy import String, Boolean, Integer, Date, DECIMAL
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class Equipment(BaseModel):
|
||||
"""设备表"""
|
||||
|
||||
__tablename__ = "equipments"
|
||||
|
||||
equipment_code: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="设备编码"
|
||||
)
|
||||
equipment_name: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
comment="设备名称"
|
||||
)
|
||||
original_value: Mapped[float] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="设备原值"
|
||||
)
|
||||
residual_rate: Mapped[float] = mapped_column(
|
||||
DECIMAL(5, 2),
|
||||
nullable=False,
|
||||
default=5.00,
|
||||
comment="残值率(%)"
|
||||
)
|
||||
service_years: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
comment="预计使用年限"
|
||||
)
|
||||
estimated_uses: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
comment="预计使用次数"
|
||||
)
|
||||
depreciation_per_use: Mapped[float] = mapped_column(
|
||||
DECIMAL(12, 4),
|
||||
nullable=False,
|
||||
comment="单次折旧成本 = (原值 - 残值) / 总次数"
|
||||
)
|
||||
purchase_date: Mapped[Optional[date]] = mapped_column(
|
||||
Date,
|
||||
nullable=True,
|
||||
comment="购入日期"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
comment="是否启用"
|
||||
)
|
||||
|
||||
def calculate_depreciation(self) -> float:
|
||||
"""计算单次折旧成本"""
|
||||
if self.estimated_uses <= 0:
|
||||
return 0
|
||||
residual_value = float(self.original_value) * float(self.residual_rate) / 100
|
||||
return (float(self.original_value) - residual_value) / self.estimated_uses
|
||||
49
后端服务/app/models/fixed_cost.py
Normal file
49
后端服务/app/models/fixed_cost.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""固定成本模型
|
||||
|
||||
管理月度固定成本(房租、水电等)及分摊方式
|
||||
"""
|
||||
|
||||
from sqlalchemy import String, Boolean, DECIMAL
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class FixedCost(BaseModel):
|
||||
"""固定成本表"""
|
||||
|
||||
__tablename__ = "fixed_costs"
|
||||
|
||||
cost_name: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
comment="成本名称"
|
||||
)
|
||||
cost_type: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
comment="类型:rent-房租, utilities-水电, property-物业, other-其他"
|
||||
)
|
||||
monthly_amount: Mapped[float] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="月度金额"
|
||||
)
|
||||
year_month: Mapped[str] = mapped_column(
|
||||
String(7),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="年月:2026-01"
|
||||
)
|
||||
allocation_method: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
default="count",
|
||||
comment="分摊方式:count-按项目数, revenue-按营收, duration-按时长"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
comment="是否启用"
|
||||
)
|
||||
76
后端服务/app/models/market_analysis_result.py
Normal file
76
后端服务/app/models/market_analysis_result.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""市场分析结果模型
|
||||
|
||||
存储项目的市场价格分析结果
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy import BigInteger, Integer, Date, DECIMAL, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.project import Project
|
||||
|
||||
|
||||
class MarketAnalysisResult(BaseModel):
|
||||
"""市场分析结果表"""
|
||||
|
||||
__tablename__ = "market_analysis_results"
|
||||
|
||||
project_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="项目ID"
|
||||
)
|
||||
analysis_date: Mapped[date] = mapped_column(
|
||||
Date,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="分析日期"
|
||||
)
|
||||
competitor_count: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
comment="样本竞品数量"
|
||||
)
|
||||
market_min_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="市场最低价"
|
||||
)
|
||||
market_max_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="市场最高价"
|
||||
)
|
||||
market_avg_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="市场均价"
|
||||
)
|
||||
market_median_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="市场中位价"
|
||||
)
|
||||
suggested_range_min: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="建议区间下限"
|
||||
)
|
||||
suggested_range_max: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="建议区间上限"
|
||||
)
|
||||
|
||||
# 关系
|
||||
project: Mapped["Project"] = relationship(
|
||||
"Project"
|
||||
)
|
||||
57
后端服务/app/models/material.py
Normal file
57
后端服务/app/models/material.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""耗材模型
|
||||
|
||||
管理耗材基础信息:名称、单位、单价、供应商等
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import String, Boolean, DECIMAL
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class Material(BaseModel):
|
||||
"""耗材表"""
|
||||
|
||||
__tablename__ = "materials"
|
||||
|
||||
material_code: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="耗材编码"
|
||||
)
|
||||
material_name: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
comment="耗材名称"
|
||||
)
|
||||
unit: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
comment="单位(支/ml/个)"
|
||||
)
|
||||
unit_price: Mapped[float] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="单价"
|
||||
)
|
||||
supplier: Mapped[Optional[str]] = mapped_column(
|
||||
String(100),
|
||||
nullable=True,
|
||||
comment="供应商"
|
||||
)
|
||||
material_type: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="类型:consumable-耗材, injectable-针剂, product-产品"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
comment="是否启用"
|
||||
)
|
||||
81
后端服务/app/models/operation_log.py
Normal file
81
后端服务/app/models/operation_log.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""操作日志模型
|
||||
|
||||
记录用户操作审计日志
|
||||
"""
|
||||
|
||||
from typing import Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, String, DateTime, JSON, ForeignKey, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class OperationLog(Base):
|
||||
"""操作日志表"""
|
||||
|
||||
__tablename__ = "operation_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
primary_key=True,
|
||||
autoincrement=True,
|
||||
comment="主键ID"
|
||||
)
|
||||
user_id: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="用户ID"
|
||||
)
|
||||
module: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="模块:cost/market/pricing/profit"
|
||||
)
|
||||
action: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
comment="操作:create/update/delete/export"
|
||||
)
|
||||
target_type: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
comment="对象类型"
|
||||
)
|
||||
target_id: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=True,
|
||||
comment="对象ID"
|
||||
)
|
||||
detail: Mapped[Optional[dict]] = mapped_column(
|
||||
JSON,
|
||||
nullable=True,
|
||||
comment="详情"
|
||||
)
|
||||
ip_address: Mapped[Optional[str]] = mapped_column(
|
||||
String(45),
|
||||
nullable=True,
|
||||
comment="IP地址"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=func.now(),
|
||||
server_default=func.now(),
|
||||
index=True,
|
||||
comment="操作时间"
|
||||
)
|
||||
|
||||
# 关系
|
||||
user: Mapped[Optional["User"]] = relationship(
|
||||
"User",
|
||||
back_populates="operation_logs"
|
||||
)
|
||||
|
||||
|
||||
# 避免循环导入
|
||||
from app.models.user import User
|
||||
94
后端服务/app/models/pricing_plan.py
Normal file
94
后端服务/app/models/pricing_plan.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""定价方案模型
|
||||
|
||||
管理项目定价方案,支持多种定价策略
|
||||
"""
|
||||
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import BigInteger, String, Boolean, Text, ForeignKey, DECIMAL
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.project import Project
|
||||
from app.models.user import User
|
||||
from app.models.profit_simulation import ProfitSimulation
|
||||
|
||||
|
||||
class PricingPlan(BaseModel):
|
||||
"""定价方案表"""
|
||||
|
||||
__tablename__ = "pricing_plans"
|
||||
|
||||
project_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("projects.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="项目ID"
|
||||
)
|
||||
plan_name: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
comment="方案名称"
|
||||
)
|
||||
strategy_type: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="策略类型:traffic-引流款, profit-利润款, premium-高端款"
|
||||
)
|
||||
base_cost: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="基础成本"
|
||||
)
|
||||
target_margin: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(5, 2),
|
||||
nullable=False,
|
||||
comment="目标毛利率(%)"
|
||||
)
|
||||
suggested_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="建议价格"
|
||||
)
|
||||
final_price: Mapped[Optional[Decimal]] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=True,
|
||||
comment="最终定价"
|
||||
)
|
||||
ai_advice: Mapped[Optional[str]] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="AI建议内容"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
comment="是否启用"
|
||||
)
|
||||
created_by: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
comment="创建人ID"
|
||||
)
|
||||
|
||||
# 关系
|
||||
project: Mapped["Project"] = relationship(
|
||||
"Project",
|
||||
back_populates="pricing_plans"
|
||||
)
|
||||
creator: Mapped[Optional["User"]] = relationship(
|
||||
"User",
|
||||
back_populates="created_pricing_plans"
|
||||
)
|
||||
profit_simulations: Mapped[List["ProfitSimulation"]] = relationship(
|
||||
"ProfitSimulation",
|
||||
back_populates="pricing_plan",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
97
后端服务/app/models/profit_simulation.py
Normal file
97
后端服务/app/models/profit_simulation.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""利润模拟模型
|
||||
|
||||
管理利润模拟测算记录
|
||||
"""
|
||||
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import BigInteger, String, Integer, ForeignKey, DECIMAL
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.pricing_plan import PricingPlan
|
||||
from app.models.user import User
|
||||
from app.models.sensitivity_analysis import SensitivityAnalysis
|
||||
|
||||
|
||||
class ProfitSimulation(BaseModel):
|
||||
"""利润模拟表"""
|
||||
|
||||
__tablename__ = "profit_simulations"
|
||||
|
||||
pricing_plan_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("pricing_plans.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="定价方案ID"
|
||||
)
|
||||
simulation_name: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
comment="模拟名称"
|
||||
)
|
||||
price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="模拟价格"
|
||||
)
|
||||
estimated_volume: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
comment="预估客量"
|
||||
)
|
||||
period_type: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
comment="周期类型:daily-日, weekly-周, monthly-月"
|
||||
)
|
||||
estimated_revenue: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(14, 2),
|
||||
nullable=False,
|
||||
comment="预估收入"
|
||||
)
|
||||
estimated_cost: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(14, 2),
|
||||
nullable=False,
|
||||
comment="预估成本"
|
||||
)
|
||||
estimated_profit: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(14, 2),
|
||||
nullable=False,
|
||||
comment="预估利润"
|
||||
)
|
||||
profit_margin: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(5, 2),
|
||||
nullable=False,
|
||||
comment="利润率(%)"
|
||||
)
|
||||
breakeven_volume: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
comment="盈亏平衡客量"
|
||||
)
|
||||
created_by: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
comment="创建人ID"
|
||||
)
|
||||
|
||||
# 关系
|
||||
pricing_plan: Mapped["PricingPlan"] = relationship(
|
||||
"PricingPlan",
|
||||
back_populates="profit_simulations"
|
||||
)
|
||||
creator: Mapped[Optional["User"]] = relationship(
|
||||
"User",
|
||||
back_populates="created_profit_simulations"
|
||||
)
|
||||
sensitivity_analyses: Mapped[List["SensitivityAnalysis"]] = relationship(
|
||||
"SensitivityAnalysis",
|
||||
back_populates="simulation",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
102
后端服务/app/models/project.py
Normal file
102
后端服务/app/models/project.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""服务项目模型
|
||||
|
||||
管理医美服务项目基础信息
|
||||
"""
|
||||
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import BigInteger, String, Boolean, Integer, Text, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.category import Category
|
||||
from app.models.user import User
|
||||
from app.models.project_cost_item import ProjectCostItem
|
||||
from app.models.project_labor_cost import ProjectLaborCost
|
||||
from app.models.project_cost_summary import ProjectCostSummary
|
||||
from app.models.pricing_plan import PricingPlan
|
||||
|
||||
|
||||
class Project(BaseModel):
|
||||
"""服务项目表"""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
project_code: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="项目编码"
|
||||
)
|
||||
project_name: Mapped[str] = mapped_column(
|
||||
String(100),
|
||||
nullable=False,
|
||||
comment="项目名称"
|
||||
)
|
||||
category_id: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("categories.id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="项目分类ID"
|
||||
)
|
||||
description: Mapped[Optional[str]] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="项目描述"
|
||||
)
|
||||
duration_minutes: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
comment="操作时长(分钟)"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
index=True,
|
||||
comment="是否启用"
|
||||
)
|
||||
created_by: Mapped[Optional[int]] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("users.id"),
|
||||
nullable=True,
|
||||
comment="创建人ID"
|
||||
)
|
||||
|
||||
# 关系
|
||||
category: Mapped[Optional["Category"]] = relationship(
|
||||
"Category",
|
||||
back_populates="projects"
|
||||
)
|
||||
creator: Mapped[Optional["User"]] = relationship(
|
||||
"User",
|
||||
back_populates="created_projects"
|
||||
)
|
||||
|
||||
# 成本相关关系
|
||||
cost_items: Mapped[List["ProjectCostItem"]] = relationship(
|
||||
"ProjectCostItem",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
labor_costs: Mapped[List["ProjectLaborCost"]] = relationship(
|
||||
"ProjectLaborCost",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
cost_summary: Mapped[Optional["ProjectCostSummary"]] = relationship(
|
||||
"ProjectCostSummary",
|
||||
back_populates="project",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
pricing_plans: Mapped[List["PricingPlan"]] = relationship(
|
||||
"PricingPlan",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
66
后端服务/app/models/project_cost_item.py
Normal file
66
后端服务/app/models/project_cost_item.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""项目成本明细模型
|
||||
|
||||
管理项目的耗材成本和设备折旧成本
|
||||
"""
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import BigInteger, String, DECIMAL, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.project import Project
|
||||
|
||||
|
||||
class ProjectCostItem(BaseModel):
|
||||
"""项目成本明细表(耗材/设备)"""
|
||||
|
||||
__tablename__ = "project_cost_items"
|
||||
|
||||
project_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="项目ID"
|
||||
)
|
||||
item_type: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="类型:material-耗材, equipment-设备"
|
||||
)
|
||||
item_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
comment="耗材/设备ID"
|
||||
)
|
||||
quantity: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(10, 4),
|
||||
nullable=False,
|
||||
comment="用量"
|
||||
)
|
||||
unit_cost: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 4),
|
||||
nullable=False,
|
||||
comment="单位成本"
|
||||
)
|
||||
total_cost: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="总成本 = quantity * unit_cost"
|
||||
)
|
||||
remark: Mapped[Optional[str]] = mapped_column(
|
||||
String(200),
|
||||
nullable=True,
|
||||
comment="备注"
|
||||
)
|
||||
|
||||
# 关系
|
||||
project: Mapped["Project"] = relationship(
|
||||
"Project",
|
||||
back_populates="cost_items"
|
||||
)
|
||||
72
后端服务/app/models/project_cost_summary.py
Normal file
72
后端服务/app/models/project_cost_summary.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""项目成本汇总模型
|
||||
|
||||
存储项目的成本计算结果
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, DECIMAL, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.project import Project
|
||||
|
||||
|
||||
class ProjectCostSummary(BaseModel):
|
||||
"""项目成本汇总表"""
|
||||
|
||||
__tablename__ = "project_cost_summaries"
|
||||
|
||||
project_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True,
|
||||
comment="项目ID"
|
||||
)
|
||||
material_cost: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
default=0,
|
||||
comment="耗材成本"
|
||||
)
|
||||
equipment_cost: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
default=0,
|
||||
comment="设备折旧成本"
|
||||
)
|
||||
labor_cost: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
default=0,
|
||||
comment="人工成本"
|
||||
)
|
||||
fixed_cost_allocation: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
default=0,
|
||||
comment="固定成本分摊"
|
||||
)
|
||||
total_cost: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
default=0,
|
||||
comment="总成本(最低成本线)"
|
||||
)
|
||||
calculated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
comment="计算时间"
|
||||
)
|
||||
|
||||
# 关系
|
||||
project: Mapped["Project"] = relationship(
|
||||
"Project",
|
||||
back_populates="cost_summary"
|
||||
)
|
||||
67
后端服务/app/models/project_labor_cost.py
Normal file
67
后端服务/app/models/project_labor_cost.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""项目人工成本模型
|
||||
|
||||
管理项目的人工成本配置
|
||||
"""
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import BigInteger, Integer, String, DECIMAL, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.project import Project
|
||||
from app.models.staff_level import StaffLevel
|
||||
|
||||
|
||||
class ProjectLaborCost(BaseModel):
|
||||
"""项目人工成本表"""
|
||||
|
||||
__tablename__ = "project_labor_costs"
|
||||
|
||||
project_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="项目ID"
|
||||
)
|
||||
staff_level_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("staff_levels.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="人员级别ID"
|
||||
)
|
||||
duration_minutes: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
comment="操作时长(分钟)"
|
||||
)
|
||||
hourly_rate: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(10, 2),
|
||||
nullable=False,
|
||||
comment="时薪(记录时的快照)"
|
||||
)
|
||||
labor_cost: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="人工成本 = duration/60 * hourly_rate"
|
||||
)
|
||||
remark: Mapped[Optional[str]] = mapped_column(
|
||||
String(200),
|
||||
nullable=True,
|
||||
comment="备注"
|
||||
)
|
||||
|
||||
# 关系
|
||||
project: Mapped["Project"] = relationship(
|
||||
"Project",
|
||||
back_populates="labor_costs"
|
||||
)
|
||||
staff_level: Mapped["StaffLevel"] = relationship(
|
||||
"StaffLevel",
|
||||
back_populates="project_labor_costs"
|
||||
)
|
||||
55
后端服务/app/models/sensitivity_analysis.py
Normal file
55
后端服务/app/models/sensitivity_analysis.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""敏感性分析模型
|
||||
|
||||
记录价格变动对利润的敏感性分析结果
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from decimal import Decimal
|
||||
|
||||
from sqlalchemy import BigInteger, ForeignKey, DECIMAL
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.profit_simulation import ProfitSimulation
|
||||
|
||||
|
||||
class SensitivityAnalysis(BaseModel):
|
||||
"""敏感性分析表"""
|
||||
|
||||
__tablename__ = "sensitivity_analyses"
|
||||
|
||||
simulation_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
ForeignKey("profit_simulations.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="模拟ID"
|
||||
)
|
||||
price_change_rate: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(5, 2),
|
||||
nullable=False,
|
||||
comment="价格变动率(%):如 -20, -10, 0, 10, 20"
|
||||
)
|
||||
adjusted_price: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(12, 2),
|
||||
nullable=False,
|
||||
comment="调整后价格"
|
||||
)
|
||||
adjusted_profit: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(14, 2),
|
||||
nullable=False,
|
||||
comment="调整后利润"
|
||||
)
|
||||
profit_change_rate: Mapped[Decimal] = mapped_column(
|
||||
DECIMAL(5, 2),
|
||||
nullable=False,
|
||||
comment="利润变动率(%)"
|
||||
)
|
||||
|
||||
# 关系
|
||||
simulation: Mapped["ProfitSimulation"] = relationship(
|
||||
"ProfitSimulation",
|
||||
back_populates="sensitivity_analyses"
|
||||
)
|
||||
49
后端服务/app/models/staff_level.py
Normal file
49
后端服务/app/models/staff_level.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""人员级别模型
|
||||
|
||||
管理不同岗位/级别的时薪标准
|
||||
"""
|
||||
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import String, Boolean, DECIMAL
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.project_labor_cost import ProjectLaborCost
|
||||
|
||||
|
||||
class StaffLevel(BaseModel):
|
||||
"""人员级别表"""
|
||||
|
||||
__tablename__ = "staff_levels"
|
||||
|
||||
level_code: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
comment="级别编码"
|
||||
)
|
||||
level_name: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
comment="级别名称"
|
||||
)
|
||||
hourly_rate: Mapped[float] = mapped_column(
|
||||
DECIMAL(10, 2),
|
||||
nullable=False,
|
||||
comment="时薪(元/小时)"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
comment="是否启用"
|
||||
)
|
||||
|
||||
# 关系
|
||||
project_labor_costs: Mapped[List["ProjectLaborCost"]] = relationship(
|
||||
"ProjectLaborCost",
|
||||
back_populates="staff_level"
|
||||
)
|
||||
66
后端服务/app/models/user.py
Normal file
66
后端服务/app/models/user.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""用户模型
|
||||
|
||||
与门户系统关联的用户信息
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import BigInteger, String, Boolean
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""用户表"""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
portal_user_id: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="门户用户ID"
|
||||
)
|
||||
username: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
nullable=False,
|
||||
comment="用户名"
|
||||
)
|
||||
role: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
comment="角色:admin-管理员, manager-经理, operator-操作员"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
comment="是否启用"
|
||||
)
|
||||
|
||||
# 关系
|
||||
created_projects: Mapped[List["Project"]] = relationship(
|
||||
"Project",
|
||||
back_populates="creator"
|
||||
)
|
||||
operation_logs: Mapped[List["OperationLog"]] = relationship(
|
||||
"OperationLog",
|
||||
back_populates="user"
|
||||
)
|
||||
created_pricing_plans: Mapped[List["PricingPlan"]] = relationship(
|
||||
"PricingPlan",
|
||||
back_populates="creator"
|
||||
)
|
||||
created_profit_simulations: Mapped[List["ProfitSimulation"]] = relationship(
|
||||
"ProfitSimulation",
|
||||
back_populates="creator"
|
||||
)
|
||||
|
||||
|
||||
# 避免循环导入
|
||||
from app.models.project import Project
|
||||
from app.models.operation_log import OperationLog
|
||||
from app.models.pricing_plan import PricingPlan
|
||||
from app.models.profit_simulation import ProfitSimulation
|
||||
1
后端服务/app/repositories/__init__.py
Normal file
1
后端服务/app/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""数据访问层"""
|
||||
29
后端服务/app/routers/__init__.py
Normal file
29
后端服务/app/routers/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""API 路由模块"""
|
||||
|
||||
from app.routers import (
|
||||
health,
|
||||
categories,
|
||||
materials,
|
||||
equipments,
|
||||
staff_levels,
|
||||
fixed_costs,
|
||||
projects,
|
||||
market,
|
||||
pricing,
|
||||
profit,
|
||||
dashboard,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"health",
|
||||
"categories",
|
||||
"materials",
|
||||
"equipments",
|
||||
"staff_levels",
|
||||
"fixed_costs",
|
||||
"projects",
|
||||
"market",
|
||||
"pricing",
|
||||
"profit",
|
||||
"dashboard",
|
||||
]
|
||||
210
后端服务/app/routers/categories.py
Normal file
210
后端服务/app/routers/categories.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""项目分类路由
|
||||
|
||||
实现分类的 CRUD 操作,支持树形结构
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.category import Category
|
||||
from app.schemas.common import ResponseModel, PaginatedResponse, PaginatedData, ErrorCode
|
||||
from app.schemas.category import (
|
||||
CategoryCreate,
|
||||
CategoryUpdate,
|
||||
CategoryResponse,
|
||||
CategoryTreeResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[CategoryResponse])
|
||||
async def get_categories(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
parent_id: Optional[int] = Query(None, description="父分类ID筛选"),
|
||||
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取项目分类列表"""
|
||||
# 构建查询
|
||||
query = select(Category)
|
||||
|
||||
if parent_id is not None:
|
||||
query = query.where(Category.parent_id == parent_id)
|
||||
if is_active is not None:
|
||||
query = query.where(Category.is_active == is_active)
|
||||
|
||||
query = query.order_by(Category.sort_order, Category.id)
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
categories = result.scalars().all()
|
||||
|
||||
# 统计总数
|
||||
count_query = select(func.count(Category.id))
|
||||
if parent_id is not None:
|
||||
count_query = count_query.where(Category.parent_id == parent_id)
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Category.is_active == is_active)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
return PaginatedResponse(
|
||||
data=PaginatedData(
|
||||
items=[CategoryResponse.model_validate(c) for c in categories],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tree", response_model=ResponseModel[List[CategoryTreeResponse]])
|
||||
async def get_category_tree(
|
||||
is_active: Optional[bool] = Query(True, description="是否只返回启用的分类"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取分类树形结构"""
|
||||
query = select(Category).options(selectinload(Category.children))
|
||||
|
||||
if is_active is not None:
|
||||
query = query.where(Category.is_active == is_active)
|
||||
|
||||
# 只获取顶级分类
|
||||
query = query.where(Category.parent_id.is_(None))
|
||||
query = query.order_by(Category.sort_order, Category.id)
|
||||
|
||||
result = await db.execute(query)
|
||||
categories = result.scalars().all()
|
||||
|
||||
return ResponseModel(data=[CategoryTreeResponse.model_validate(c) for c in categories])
|
||||
|
||||
|
||||
@router.get("/{category_id}", response_model=ResponseModel[CategoryResponse])
|
||||
async def get_category(
|
||||
category_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个分类详情"""
|
||||
result = await db.execute(select(Category).where(Category.id == category_id))
|
||||
category = result.scalar_one_or_none()
|
||||
|
||||
if not category:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
|
||||
)
|
||||
|
||||
return ResponseModel(data=CategoryResponse.model_validate(category))
|
||||
|
||||
|
||||
@router.post("", response_model=ResponseModel[CategoryResponse])
|
||||
async def create_category(
|
||||
data: CategoryCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建项目分类"""
|
||||
# 检查父分类是否存在
|
||||
if data.parent_id:
|
||||
parent_result = await db.execute(
|
||||
select(Category).where(Category.id == data.parent_id)
|
||||
)
|
||||
if not parent_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.PARAM_ERROR, "message": "父分类不存在"}
|
||||
)
|
||||
|
||||
# 创建分类
|
||||
category = Category(**data.model_dump())
|
||||
db.add(category)
|
||||
await db.flush()
|
||||
await db.refresh(category)
|
||||
|
||||
return ResponseModel(message="创建成功", data=CategoryResponse.model_validate(category))
|
||||
|
||||
|
||||
@router.put("/{category_id}", response_model=ResponseModel[CategoryResponse])
|
||||
async def update_category(
|
||||
category_id: int,
|
||||
data: CategoryUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新项目分类"""
|
||||
result = await db.execute(select(Category).where(Category.id == category_id))
|
||||
category = result.scalar_one_or_none()
|
||||
|
||||
if not category:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
|
||||
)
|
||||
|
||||
# 检查父分类
|
||||
if data.parent_id is not None and data.parent_id != category.parent_id:
|
||||
if data.parent_id == category_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.PARAM_ERROR, "message": "不能将自己设为父分类"}
|
||||
)
|
||||
parent_result = await db.execute(
|
||||
select(Category).where(Category.id == data.parent_id)
|
||||
)
|
||||
if not parent_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.PARAM_ERROR, "message": "父分类不存在"}
|
||||
)
|
||||
|
||||
# 更新字段
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(category, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(category)
|
||||
|
||||
return ResponseModel(message="更新成功", data=CategoryResponse.model_validate(category))
|
||||
|
||||
|
||||
@router.delete("/{category_id}", response_model=ResponseModel)
|
||||
async def delete_category(
|
||||
category_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除项目分类"""
|
||||
result = await db.execute(select(Category).where(Category.id == category_id))
|
||||
category = result.scalar_one_or_none()
|
||||
|
||||
if not category:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
|
||||
)
|
||||
|
||||
# 检查是否有子分类
|
||||
children_result = await db.execute(
|
||||
select(func.count(Category.id)).where(Category.parent_id == category_id)
|
||||
)
|
||||
children_count = children_result.scalar() or 0
|
||||
|
||||
if children_count > 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.NOT_ALLOWED, "message": "该分类下有子分类,无法删除"}
|
||||
)
|
||||
|
||||
await db.delete(category)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
305
后端服务/app/routers/dashboard.py
Normal file
305
后端服务/app/routers/dashboard.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""仪表盘路由
|
||||
|
||||
仪表盘数据相关的 API 接口
|
||||
"""
|
||||
|
||||
from datetime import datetime, date, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import select, func, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import (
|
||||
Project,
|
||||
ProjectCostSummary,
|
||||
Competitor,
|
||||
CompetitorPrice,
|
||||
PricingPlan,
|
||||
ProfitSimulation,
|
||||
OperationLog,
|
||||
)
|
||||
from app.schemas.common import ResponseModel
|
||||
from app.schemas.dashboard import (
|
||||
DashboardSummaryResponse,
|
||||
ProjectOverview,
|
||||
CostOverview,
|
||||
CostProjectInfo,
|
||||
MarketOverview,
|
||||
PricingOverview,
|
||||
StrategiesDistribution,
|
||||
AIUsageOverview,
|
||||
RecentActivity,
|
||||
CostTrendResponse,
|
||||
MarketTrendResponse,
|
||||
TrendDataPoint,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/dashboard/summary", response_model=ResponseModel[DashboardSummaryResponse])
|
||||
async def get_dashboard_summary(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取仪表盘概览数据"""
|
||||
|
||||
# 项目概览
|
||||
total_projects_result = await db.execute(
|
||||
select(func.count(Project.id))
|
||||
)
|
||||
total_projects = total_projects_result.scalar() or 0
|
||||
|
||||
active_projects_result = await db.execute(
|
||||
select(func.count(Project.id)).where(Project.is_active == True)
|
||||
)
|
||||
active_projects = active_projects_result.scalar() or 0
|
||||
|
||||
projects_with_pricing_result = await db.execute(
|
||||
select(func.count(func.distinct(PricingPlan.project_id)))
|
||||
)
|
||||
projects_with_pricing = projects_with_pricing_result.scalar() or 0
|
||||
|
||||
project_overview = ProjectOverview(
|
||||
total_projects=total_projects,
|
||||
active_projects=active_projects,
|
||||
projects_with_pricing=projects_with_pricing,
|
||||
)
|
||||
|
||||
# 成本概览
|
||||
avg_cost_result = await db.execute(
|
||||
select(func.avg(ProjectCostSummary.total_cost))
|
||||
)
|
||||
avg_project_cost = float(avg_cost_result.scalar() or 0)
|
||||
|
||||
# 最高成本项目
|
||||
highest_cost_result = await db.execute(
|
||||
select(ProjectCostSummary).options(
|
||||
).order_by(ProjectCostSummary.total_cost.desc()).limit(1)
|
||||
)
|
||||
highest_cost_summary = highest_cost_result.scalar_one_or_none()
|
||||
highest_cost_project = None
|
||||
if highest_cost_summary:
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == highest_cost_summary.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if project:
|
||||
highest_cost_project = CostProjectInfo(
|
||||
id=project.id,
|
||||
name=project.project_name,
|
||||
cost=float(highest_cost_summary.total_cost),
|
||||
)
|
||||
|
||||
# 最低成本项目
|
||||
lowest_cost_result = await db.execute(
|
||||
select(ProjectCostSummary).where(
|
||||
ProjectCostSummary.total_cost > 0
|
||||
).order_by(ProjectCostSummary.total_cost.asc()).limit(1)
|
||||
)
|
||||
lowest_cost_summary = lowest_cost_result.scalar_one_or_none()
|
||||
lowest_cost_project = None
|
||||
if lowest_cost_summary:
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == lowest_cost_summary.project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
if project:
|
||||
lowest_cost_project = CostProjectInfo(
|
||||
id=project.id,
|
||||
name=project.project_name,
|
||||
cost=float(lowest_cost_summary.total_cost),
|
||||
)
|
||||
|
||||
cost_overview = CostOverview(
|
||||
avg_project_cost=round(avg_project_cost, 2),
|
||||
highest_cost_project=highest_cost_project,
|
||||
lowest_cost_project=lowest_cost_project,
|
||||
)
|
||||
|
||||
# 市场概览
|
||||
competitors_result = await db.execute(
|
||||
select(func.count(Competitor.id)).where(Competitor.is_active == True)
|
||||
)
|
||||
competitors_tracked = competitors_result.scalar() or 0
|
||||
|
||||
# 本月价格记录数
|
||||
this_month_start = date.today().replace(day=1)
|
||||
price_records_result = await db.execute(
|
||||
select(func.count(CompetitorPrice.id)).where(
|
||||
CompetitorPrice.collected_at >= this_month_start
|
||||
)
|
||||
)
|
||||
price_records_this_month = price_records_result.scalar() or 0
|
||||
|
||||
# 市场平均价
|
||||
avg_market_price_result = await db.execute(
|
||||
select(func.avg(CompetitorPrice.original_price))
|
||||
)
|
||||
avg_market_price = avg_market_price_result.scalar()
|
||||
|
||||
market_overview = MarketOverview(
|
||||
competitors_tracked=competitors_tracked,
|
||||
price_records_this_month=price_records_this_month,
|
||||
avg_market_price=float(avg_market_price) if avg_market_price else None,
|
||||
)
|
||||
|
||||
# 定价概览
|
||||
pricing_plans_result = await db.execute(
|
||||
select(func.count(PricingPlan.id))
|
||||
)
|
||||
pricing_plans_count = pricing_plans_result.scalar() or 0
|
||||
|
||||
avg_margin_result = await db.execute(
|
||||
select(func.avg(PricingPlan.target_margin))
|
||||
)
|
||||
avg_target_margin = avg_margin_result.scalar()
|
||||
|
||||
# 策略分布
|
||||
traffic_count_result = await db.execute(
|
||||
select(func.count(PricingPlan.id)).where(PricingPlan.strategy_type == "traffic")
|
||||
)
|
||||
profit_count_result = await db.execute(
|
||||
select(func.count(PricingPlan.id)).where(PricingPlan.strategy_type == "profit")
|
||||
)
|
||||
premium_count_result = await db.execute(
|
||||
select(func.count(PricingPlan.id)).where(PricingPlan.strategy_type == "premium")
|
||||
)
|
||||
|
||||
pricing_overview = PricingOverview(
|
||||
pricing_plans_count=pricing_plans_count,
|
||||
avg_target_margin=float(avg_target_margin) if avg_target_margin else None,
|
||||
strategies_distribution=StrategiesDistribution(
|
||||
traffic=traffic_count_result.scalar() or 0,
|
||||
profit=profit_count_result.scalar() or 0,
|
||||
premium=premium_count_result.scalar() or 0,
|
||||
),
|
||||
)
|
||||
|
||||
# 最近活动(从操作日志获取)
|
||||
recent_logs_result = await db.execute(
|
||||
select(OperationLog).order_by(
|
||||
OperationLog.created_at.desc()
|
||||
).limit(10)
|
||||
)
|
||||
recent_logs = recent_logs_result.scalars().all()
|
||||
|
||||
recent_activities = []
|
||||
for log in recent_logs:
|
||||
recent_activities.append(RecentActivity(
|
||||
type=f"{log.module}_{log.action}",
|
||||
project_name=log.target_type,
|
||||
user=None, # 简化处理
|
||||
time=log.created_at,
|
||||
))
|
||||
|
||||
return ResponseModel(data=DashboardSummaryResponse(
|
||||
project_overview=project_overview,
|
||||
cost_overview=cost_overview,
|
||||
market_overview=market_overview,
|
||||
pricing_overview=pricing_overview,
|
||||
ai_usage_this_month=None, # AI 使用统计需要从 ai_call_logs 表获取
|
||||
recent_activities=recent_activities,
|
||||
))
|
||||
|
||||
|
||||
@router.get("/dashboard/cost-trend", response_model=ResponseModel[CostTrendResponse])
|
||||
async def get_cost_trend(
|
||||
period: str = Query("month", description="统计周期:week/month/quarter"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取成本趋势数据"""
|
||||
# 根据周期确定时间范围
|
||||
today = date.today()
|
||||
if period == "week":
|
||||
start_date = today - timedelta(days=7)
|
||||
elif period == "quarter":
|
||||
start_date = today - timedelta(days=90)
|
||||
else: # month
|
||||
start_date = today - timedelta(days=30)
|
||||
|
||||
# 按日期分组统计平均成本
|
||||
# 简化实现:返回最近的成本汇总数据
|
||||
result = await db.execute(
|
||||
select(ProjectCostSummary).order_by(
|
||||
ProjectCostSummary.calculated_at.desc()
|
||||
).limit(30)
|
||||
)
|
||||
summaries = result.scalars().all()
|
||||
|
||||
# 按日期聚合
|
||||
date_costs = {}
|
||||
for summary in summaries:
|
||||
# 检查 calculated_at 是否为 None
|
||||
if summary.calculated_at is None:
|
||||
continue
|
||||
day = summary.calculated_at.strftime("%Y-%m-%d")
|
||||
if day not in date_costs:
|
||||
date_costs[day] = []
|
||||
date_costs[day].append(float(summary.total_cost))
|
||||
|
||||
data = []
|
||||
total_cost = 0
|
||||
for day in sorted(date_costs.keys()):
|
||||
avg = sum(date_costs[day]) / len(date_costs[day])
|
||||
data.append(TrendDataPoint(date=day, value=round(avg, 2)))
|
||||
total_cost += avg
|
||||
|
||||
avg_cost = total_cost / len(data) if data else 0
|
||||
|
||||
return ResponseModel(data=CostTrendResponse(
|
||||
period=period,
|
||||
data=data,
|
||||
avg_cost=round(avg_cost, 2),
|
||||
))
|
||||
|
||||
|
||||
@router.get("/dashboard/market-trend", response_model=ResponseModel[MarketTrendResponse])
|
||||
async def get_market_trend(
|
||||
period: str = Query("month", description="统计周期:week/month/quarter"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取市场价格趋势数据"""
|
||||
# 根据周期确定时间范围
|
||||
today = date.today()
|
||||
if period == "week":
|
||||
start_date = today - timedelta(days=7)
|
||||
elif period == "quarter":
|
||||
start_date = today - timedelta(days=90)
|
||||
else: # month
|
||||
start_date = today - timedelta(days=30)
|
||||
|
||||
# 获取价格记录
|
||||
result = await db.execute(
|
||||
select(CompetitorPrice).where(
|
||||
CompetitorPrice.collected_at >= start_date
|
||||
).order_by(CompetitorPrice.collected_at.desc())
|
||||
)
|
||||
prices = result.scalars().all()
|
||||
|
||||
# 按日期聚合
|
||||
date_prices = {}
|
||||
for price in prices:
|
||||
# 检查 collected_at 是否为 None
|
||||
if price.collected_at is None:
|
||||
continue
|
||||
day = price.collected_at.strftime("%Y-%m-%d")
|
||||
if day not in date_prices:
|
||||
date_prices[day] = []
|
||||
date_prices[day].append(float(price.original_price))
|
||||
|
||||
data = []
|
||||
total_price = 0
|
||||
for day in sorted(date_prices.keys()):
|
||||
avg = sum(date_prices[day]) / len(date_prices[day])
|
||||
data.append(TrendDataPoint(date=day, value=round(avg, 2)))
|
||||
total_price += avg
|
||||
|
||||
avg_price = total_price / len(data) if data else 0
|
||||
|
||||
return ResponseModel(data=MarketTrendResponse(
|
||||
period=period,
|
||||
data=data,
|
||||
avg_price=round(avg_price, 2),
|
||||
))
|
||||
188
后端服务/app/routers/equipments.py
Normal file
188
后端服务/app/routers/equipments.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""设备管理路由
|
||||
|
||||
实现设备的 CRUD 操作,包含折旧计算
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.equipment import Equipment
|
||||
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
|
||||
from app.schemas.equipment import (
|
||||
EquipmentCreate,
|
||||
EquipmentUpdate,
|
||||
EquipmentResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel[PaginatedData[EquipmentResponse]])
|
||||
async def get_equipments(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
keyword: Optional[str] = Query(None, description="关键词搜索"),
|
||||
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取设备列表"""
|
||||
query = select(Equipment)
|
||||
|
||||
if keyword:
|
||||
query = query.where(
|
||||
or_(
|
||||
Equipment.equipment_code.contains(keyword),
|
||||
Equipment.equipment_name.contains(keyword),
|
||||
)
|
||||
)
|
||||
if is_active is not None:
|
||||
query = query.where(Equipment.is_active == is_active)
|
||||
|
||||
query = query.order_by(Equipment.id.desc())
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
equipments = result.scalars().all()
|
||||
|
||||
# 统计总数
|
||||
count_query = select(func.count(Equipment.id))
|
||||
if keyword:
|
||||
count_query = count_query.where(
|
||||
or_(
|
||||
Equipment.equipment_code.contains(keyword),
|
||||
Equipment.equipment_name.contains(keyword),
|
||||
)
|
||||
)
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Equipment.is_active == is_active)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
return ResponseModel(
|
||||
data=PaginatedData(
|
||||
items=[EquipmentResponse.model_validate(e) for e in equipments],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{equipment_id}", response_model=ResponseModel[EquipmentResponse])
|
||||
async def get_equipment(
|
||||
equipment_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个设备详情"""
|
||||
result = await db.execute(select(Equipment).where(Equipment.id == equipment_id))
|
||||
equipment = result.scalar_one_or_none()
|
||||
|
||||
if not equipment:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "设备不存在"}
|
||||
)
|
||||
|
||||
return ResponseModel(data=EquipmentResponse.model_validate(equipment))
|
||||
|
||||
|
||||
@router.post("", response_model=ResponseModel[EquipmentResponse])
|
||||
async def create_equipment(
|
||||
data: EquipmentCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建设备"""
|
||||
# 检查编码是否已存在
|
||||
existing = await db.execute(
|
||||
select(Equipment).where(Equipment.equipment_code == data.equipment_code)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "设备编码已存在"}
|
||||
)
|
||||
|
||||
# 计算单次折旧成本
|
||||
residual_value = data.original_value * data.residual_rate / 100
|
||||
depreciation_per_use = (data.original_value - residual_value) / data.estimated_uses
|
||||
|
||||
equipment = Equipment(
|
||||
**data.model_dump(),
|
||||
depreciation_per_use=depreciation_per_use,
|
||||
)
|
||||
db.add(equipment)
|
||||
await db.flush()
|
||||
await db.refresh(equipment)
|
||||
|
||||
return ResponseModel(message="创建成功", data=EquipmentResponse.model_validate(equipment))
|
||||
|
||||
|
||||
@router.put("/{equipment_id}", response_model=ResponseModel[EquipmentResponse])
|
||||
async def update_equipment(
|
||||
equipment_id: int,
|
||||
data: EquipmentUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新设备"""
|
||||
result = await db.execute(select(Equipment).where(Equipment.id == equipment_id))
|
||||
equipment = result.scalar_one_or_none()
|
||||
|
||||
if not equipment:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "设备不存在"}
|
||||
)
|
||||
|
||||
# 检查编码是否重复
|
||||
if data.equipment_code and data.equipment_code != equipment.equipment_code:
|
||||
existing = await db.execute(
|
||||
select(Equipment).where(Equipment.equipment_code == data.equipment_code)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "设备编码已存在"}
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(equipment, field, value)
|
||||
|
||||
# 重新计算折旧(如果相关字段更新了)
|
||||
if any(f in update_data for f in ['original_value', 'residual_rate', 'estimated_uses']):
|
||||
residual_value = float(equipment.original_value) * float(equipment.residual_rate) / 100
|
||||
equipment.depreciation_per_use = (float(equipment.original_value) - residual_value) / equipment.estimated_uses
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(equipment)
|
||||
|
||||
return ResponseModel(message="更新成功", data=EquipmentResponse.model_validate(equipment))
|
||||
|
||||
|
||||
@router.delete("/{equipment_id}", response_model=ResponseModel)
|
||||
async def delete_equipment(
|
||||
equipment_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除设备"""
|
||||
result = await db.execute(select(Equipment).where(Equipment.id == equipment_id))
|
||||
equipment = result.scalar_one_or_none()
|
||||
|
||||
if not equipment:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "设备不存在"}
|
||||
)
|
||||
|
||||
await db.delete(equipment)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
187
后端服务/app/routers/fixed_costs.py
Normal file
187
后端服务/app/routers/fixed_costs.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""固定成本路由
|
||||
|
||||
实现固定成本的 CRUD 操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.fixed_cost import FixedCost
|
||||
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
|
||||
from app.schemas.fixed_cost import (
|
||||
FixedCostCreate,
|
||||
FixedCostUpdate,
|
||||
FixedCostResponse,
|
||||
CostType,
|
||||
AllocationMethod,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel[PaginatedData[FixedCostResponse]])
|
||||
async def get_fixed_costs(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
year_month: Optional[str] = Query(None, description="年月筛选"),
|
||||
cost_type: Optional[CostType] = Query(None, description="类型筛选"),
|
||||
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取固定成本列表"""
|
||||
query = select(FixedCost)
|
||||
|
||||
if year_month:
|
||||
query = query.where(FixedCost.year_month == year_month)
|
||||
if cost_type:
|
||||
query = query.where(FixedCost.cost_type == cost_type.value)
|
||||
if is_active is not None:
|
||||
query = query.where(FixedCost.is_active == is_active)
|
||||
|
||||
query = query.order_by(FixedCost.year_month.desc(), FixedCost.id)
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
fixed_costs = result.scalars().all()
|
||||
|
||||
# 统计总数
|
||||
count_query = select(func.count(FixedCost.id))
|
||||
if year_month:
|
||||
count_query = count_query.where(FixedCost.year_month == year_month)
|
||||
if cost_type:
|
||||
count_query = count_query.where(FixedCost.cost_type == cost_type.value)
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(FixedCost.is_active == is_active)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
return ResponseModel(
|
||||
data=PaginatedData(
|
||||
items=[FixedCostResponse.model_validate(f) for f in fixed_costs],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/summary", response_model=ResponseModel)
|
||||
async def get_fixed_costs_summary(
|
||||
year_month: str = Query(..., description="年月"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取指定月份的固定成本汇总"""
|
||||
query = select(FixedCost).where(
|
||||
FixedCost.year_month == year_month,
|
||||
FixedCost.is_active == True,
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
fixed_costs = result.scalars().all()
|
||||
|
||||
# 按类型汇总
|
||||
summary_by_type = {}
|
||||
total_amount = 0
|
||||
|
||||
for cost in fixed_costs:
|
||||
cost_type = cost.cost_type
|
||||
if cost_type not in summary_by_type:
|
||||
summary_by_type[cost_type] = 0
|
||||
summary_by_type[cost_type] += float(cost.monthly_amount)
|
||||
total_amount += float(cost.monthly_amount)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"year_month": year_month,
|
||||
"total_amount": total_amount,
|
||||
"by_type": summary_by_type,
|
||||
"count": len(fixed_costs),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{fixed_cost_id}", response_model=ResponseModel[FixedCostResponse])
|
||||
async def get_fixed_cost(
|
||||
fixed_cost_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个固定成本详情"""
|
||||
result = await db.execute(select(FixedCost).where(FixedCost.id == fixed_cost_id))
|
||||
fixed_cost = result.scalar_one_or_none()
|
||||
|
||||
if not fixed_cost:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "固定成本不存在"}
|
||||
)
|
||||
|
||||
return ResponseModel(data=FixedCostResponse.model_validate(fixed_cost))
|
||||
|
||||
|
||||
@router.post("", response_model=ResponseModel[FixedCostResponse])
|
||||
async def create_fixed_cost(
|
||||
data: FixedCostCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建固定成本"""
|
||||
fixed_cost = FixedCost(**data.model_dump())
|
||||
db.add(fixed_cost)
|
||||
await db.flush()
|
||||
await db.refresh(fixed_cost)
|
||||
|
||||
return ResponseModel(message="创建成功", data=FixedCostResponse.model_validate(fixed_cost))
|
||||
|
||||
|
||||
@router.put("/{fixed_cost_id}", response_model=ResponseModel[FixedCostResponse])
|
||||
async def update_fixed_cost(
|
||||
fixed_cost_id: int,
|
||||
data: FixedCostUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新固定成本"""
|
||||
result = await db.execute(select(FixedCost).where(FixedCost.id == fixed_cost_id))
|
||||
fixed_cost = result.scalar_one_or_none()
|
||||
|
||||
if not fixed_cost:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "固定成本不存在"}
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(fixed_cost, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(fixed_cost)
|
||||
|
||||
return ResponseModel(message="更新成功", data=FixedCostResponse.model_validate(fixed_cost))
|
||||
|
||||
|
||||
@router.delete("/{fixed_cost_id}", response_model=ResponseModel)
|
||||
async def delete_fixed_cost(
|
||||
fixed_cost_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除固定成本"""
|
||||
result = await db.execute(select(FixedCost).where(FixedCost.id == fixed_cost_id))
|
||||
fixed_cost = result.scalar_one_or_none()
|
||||
|
||||
if not fixed_cost:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "固定成本不存在"}
|
||||
)
|
||||
|
||||
await db.delete(fixed_cost)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
44
后端服务/app/routers/health.py
Normal file
44
后端服务/app/routers/health.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""健康检查路由
|
||||
|
||||
提供 /health 端点用于 Docker 健康检查
|
||||
遵循瑞小美部署规范:30s interval, 10s timeout, 3 retries
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check(db: AsyncSession = Depends(get_db)):
|
||||
"""健康检查端点
|
||||
|
||||
检查内容:
|
||||
- 应用运行状态
|
||||
- 数据库连接状态
|
||||
- 当前时间戳
|
||||
"""
|
||||
# 检查数据库连接
|
||||
db_status = "connected"
|
||||
try:
|
||||
await db.execute(text("SELECT 1"))
|
||||
except Exception as e:
|
||||
db_status = f"error: {str(e)}"
|
||||
|
||||
return {
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"status": "healthy" if db_status == "connected" else "unhealthy",
|
||||
"version": settings.APP_VERSION,
|
||||
"database": db_status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
733
后端服务/app/routers/market.py
Normal file
733
后端服务/app/routers/market.py
Normal file
@@ -0,0 +1,733 @@
|
||||
"""市场行情管理路由
|
||||
|
||||
实现竞品机构、竞品价格、标杆价格和市场分析
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import (
|
||||
Project,
|
||||
Competitor,
|
||||
CompetitorPrice,
|
||||
BenchmarkPrice,
|
||||
MarketAnalysisResult,
|
||||
Category,
|
||||
)
|
||||
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
|
||||
from app.schemas.competitor import (
|
||||
CompetitorCreate,
|
||||
CompetitorUpdate,
|
||||
CompetitorResponse,
|
||||
CompetitorPriceCreate,
|
||||
CompetitorPriceUpdate,
|
||||
CompetitorPriceResponse,
|
||||
Positioning,
|
||||
)
|
||||
from app.schemas.market import (
|
||||
BenchmarkPriceCreate,
|
||||
BenchmarkPriceUpdate,
|
||||
BenchmarkPriceResponse,
|
||||
MarketAnalysisRequest,
|
||||
MarketAnalysisResult as MarketAnalysisResultSchema,
|
||||
MarketAnalysisResponse,
|
||||
)
|
||||
from app.services.market_service import MarketService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============ 竞品机构 CRUD ============
|
||||
|
||||
@router.get("/competitors", response_model=ResponseModel[PaginatedData[CompetitorResponse]])
|
||||
async def get_competitors(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
positioning: Optional[Positioning] = Query(None, description="定位筛选"),
|
||||
is_key_competitor: Optional[bool] = Query(None, description="是否重点关注"),
|
||||
keyword: Optional[str] = Query(None, description="关键词搜索"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取竞品机构列表"""
|
||||
query = select(Competitor).options(selectinload(Competitor.prices))
|
||||
|
||||
if positioning:
|
||||
query = query.where(Competitor.positioning == positioning.value)
|
||||
if is_key_competitor is not None:
|
||||
query = query.where(Competitor.is_key_competitor == is_key_competitor)
|
||||
if keyword:
|
||||
query = query.where(
|
||||
or_(
|
||||
Competitor.competitor_name.contains(keyword),
|
||||
Competitor.address.contains(keyword),
|
||||
)
|
||||
)
|
||||
|
||||
query = query.order_by(Competitor.is_key_competitor.desc(), Competitor.id.desc())
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
competitors = result.scalars().all()
|
||||
|
||||
# 统计总数
|
||||
count_query = select(func.count(Competitor.id))
|
||||
if positioning:
|
||||
count_query = count_query.where(Competitor.positioning == positioning.value)
|
||||
if is_key_competitor is not None:
|
||||
count_query = count_query.where(Competitor.is_key_competitor == is_key_competitor)
|
||||
if keyword:
|
||||
count_query = count_query.where(
|
||||
or_(
|
||||
Competitor.competitor_name.contains(keyword),
|
||||
Competitor.address.contains(keyword),
|
||||
)
|
||||
)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for c in competitors:
|
||||
# 获取最新价格更新日期
|
||||
last_price_update = None
|
||||
if c.prices:
|
||||
last_price_update = max(p.collected_at for p in c.prices)
|
||||
|
||||
items.append(CompetitorResponse(
|
||||
id=c.id,
|
||||
competitor_name=c.competitor_name,
|
||||
address=c.address,
|
||||
distance_km=float(c.distance_km) if c.distance_km else None,
|
||||
positioning=Positioning(c.positioning),
|
||||
contact=c.contact,
|
||||
is_key_competitor=c.is_key_competitor,
|
||||
is_active=c.is_active,
|
||||
price_count=len(c.prices),
|
||||
last_price_update=last_price_update,
|
||||
created_at=c.created_at,
|
||||
updated_at=c.updated_at,
|
||||
))
|
||||
|
||||
return ResponseModel(
|
||||
data=PaginatedData(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/competitors/{competitor_id}", response_model=ResponseModel[CompetitorResponse])
|
||||
async def get_competitor(
|
||||
competitor_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取竞品机构详情"""
|
||||
result = await db.execute(
|
||||
select(Competitor).options(
|
||||
selectinload(Competitor.prices)
|
||||
).where(Competitor.id == competitor_id)
|
||||
)
|
||||
competitor = result.scalar_one_or_none()
|
||||
|
||||
if not competitor:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
|
||||
)
|
||||
|
||||
last_price_update = None
|
||||
if competitor.prices:
|
||||
last_price_update = max(p.collected_at for p in competitor.prices)
|
||||
|
||||
return ResponseModel(
|
||||
data=CompetitorResponse(
|
||||
id=competitor.id,
|
||||
competitor_name=competitor.competitor_name,
|
||||
address=competitor.address,
|
||||
distance_km=float(competitor.distance_km) if competitor.distance_km else None,
|
||||
positioning=Positioning(competitor.positioning),
|
||||
contact=competitor.contact,
|
||||
is_key_competitor=competitor.is_key_competitor,
|
||||
is_active=competitor.is_active,
|
||||
price_count=len(competitor.prices),
|
||||
last_price_update=last_price_update,
|
||||
created_at=competitor.created_at,
|
||||
updated_at=competitor.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/competitors", response_model=ResponseModel[CompetitorResponse])
|
||||
async def create_competitor(
|
||||
data: CompetitorCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建竞品机构"""
|
||||
competitor = Competitor(
|
||||
competitor_name=data.competitor_name,
|
||||
address=data.address,
|
||||
distance_km=data.distance_km,
|
||||
positioning=data.positioning.value,
|
||||
contact=data.contact,
|
||||
is_key_competitor=data.is_key_competitor,
|
||||
is_active=data.is_active,
|
||||
)
|
||||
db.add(competitor)
|
||||
await db.flush()
|
||||
await db.refresh(competitor)
|
||||
|
||||
return ResponseModel(
|
||||
message="创建成功",
|
||||
data=CompetitorResponse(
|
||||
id=competitor.id,
|
||||
competitor_name=competitor.competitor_name,
|
||||
address=competitor.address,
|
||||
distance_km=float(competitor.distance_km) if competitor.distance_km else None,
|
||||
positioning=Positioning(competitor.positioning),
|
||||
contact=competitor.contact,
|
||||
is_key_competitor=competitor.is_key_competitor,
|
||||
is_active=competitor.is_active,
|
||||
price_count=0,
|
||||
last_price_update=None,
|
||||
created_at=competitor.created_at,
|
||||
updated_at=competitor.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/competitors/{competitor_id}", response_model=ResponseModel[CompetitorResponse])
|
||||
async def update_competitor(
|
||||
competitor_id: int,
|
||||
data: CompetitorUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新竞品机构"""
|
||||
result = await db.execute(
|
||||
select(Competitor).options(
|
||||
selectinload(Competitor.prices)
|
||||
).where(Competitor.id == competitor_id)
|
||||
)
|
||||
competitor = result.scalar_one_or_none()
|
||||
|
||||
if not competitor:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field == "positioning" and value:
|
||||
value = value.value
|
||||
setattr(competitor, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(competitor)
|
||||
|
||||
last_price_update = None
|
||||
if competitor.prices:
|
||||
last_price_update = max(p.collected_at for p in competitor.prices)
|
||||
|
||||
return ResponseModel(
|
||||
message="更新成功",
|
||||
data=CompetitorResponse(
|
||||
id=competitor.id,
|
||||
competitor_name=competitor.competitor_name,
|
||||
address=competitor.address,
|
||||
distance_km=float(competitor.distance_km) if competitor.distance_km else None,
|
||||
positioning=Positioning(competitor.positioning),
|
||||
contact=competitor.contact,
|
||||
is_key_competitor=competitor.is_key_competitor,
|
||||
is_active=competitor.is_active,
|
||||
price_count=len(competitor.prices),
|
||||
last_price_update=last_price_update,
|
||||
created_at=competitor.created_at,
|
||||
updated_at=competitor.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/competitors/{competitor_id}", response_model=ResponseModel)
|
||||
async def delete_competitor(
|
||||
competitor_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除竞品机构"""
|
||||
result = await db.execute(
|
||||
select(Competitor).where(Competitor.id == competitor_id)
|
||||
)
|
||||
competitor = result.scalar_one_or_none()
|
||||
|
||||
if not competitor:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
|
||||
)
|
||||
|
||||
await db.delete(competitor)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
# ============ 竞品价格管理 ============
|
||||
|
||||
@router.get("/competitors/{competitor_id}/prices", response_model=ResponseModel[list[CompetitorPriceResponse]])
|
||||
async def get_competitor_prices(
|
||||
competitor_id: int,
|
||||
project_id: Optional[int] = Query(None, description="项目筛选"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取竞品价格列表"""
|
||||
# 检查竞品机构是否存在
|
||||
competitor_result = await db.execute(
|
||||
select(Competitor).where(Competitor.id == competitor_id)
|
||||
)
|
||||
competitor = competitor_result.scalar_one_or_none()
|
||||
|
||||
if not competitor:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
|
||||
)
|
||||
|
||||
query = select(CompetitorPrice).where(
|
||||
CompetitorPrice.competitor_id == competitor_id
|
||||
)
|
||||
|
||||
if project_id:
|
||||
query = query.where(CompetitorPrice.project_id == project_id)
|
||||
|
||||
query = query.order_by(CompetitorPrice.collected_at.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
prices = result.scalars().all()
|
||||
|
||||
response_items = []
|
||||
for p in prices:
|
||||
response_items.append(CompetitorPriceResponse(
|
||||
id=p.id,
|
||||
competitor_id=p.competitor_id,
|
||||
competitor_name=competitor.competitor_name,
|
||||
project_id=p.project_id,
|
||||
project_name=p.project_name,
|
||||
original_price=float(p.original_price),
|
||||
promo_price=float(p.promo_price) if p.promo_price else None,
|
||||
member_price=float(p.member_price) if p.member_price else None,
|
||||
price_source=p.price_source,
|
||||
collected_at=p.collected_at,
|
||||
remark=p.remark,
|
||||
created_at=p.created_at,
|
||||
updated_at=p.updated_at,
|
||||
))
|
||||
|
||||
return ResponseModel(data=response_items)
|
||||
|
||||
|
||||
@router.post("/competitors/{competitor_id}/prices", response_model=ResponseModel[CompetitorPriceResponse])
|
||||
async def create_competitor_price(
|
||||
competitor_id: int,
|
||||
data: CompetitorPriceCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""添加竞品价格"""
|
||||
# 检查竞品机构是否存在
|
||||
competitor_result = await db.execute(
|
||||
select(Competitor).where(Competitor.id == competitor_id)
|
||||
)
|
||||
competitor = competitor_result.scalar_one_or_none()
|
||||
|
||||
if not competitor:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
|
||||
)
|
||||
|
||||
# 检查关联项目是否存在
|
||||
if data.project_id:
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == data.project_id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "关联项目不存在"}
|
||||
)
|
||||
|
||||
price = CompetitorPrice(
|
||||
competitor_id=competitor_id,
|
||||
project_id=data.project_id,
|
||||
project_name=data.project_name,
|
||||
original_price=data.original_price,
|
||||
promo_price=data.promo_price,
|
||||
member_price=data.member_price,
|
||||
price_source=data.price_source.value,
|
||||
collected_at=data.collected_at,
|
||||
remark=data.remark,
|
||||
)
|
||||
db.add(price)
|
||||
await db.flush()
|
||||
await db.refresh(price)
|
||||
|
||||
return ResponseModel(
|
||||
message="添加成功",
|
||||
data=CompetitorPriceResponse(
|
||||
id=price.id,
|
||||
competitor_id=price.competitor_id,
|
||||
competitor_name=competitor.competitor_name,
|
||||
project_id=price.project_id,
|
||||
project_name=price.project_name,
|
||||
original_price=float(price.original_price),
|
||||
promo_price=float(price.promo_price) if price.promo_price else None,
|
||||
member_price=float(price.member_price) if price.member_price else None,
|
||||
price_source=price.price_source,
|
||||
collected_at=price.collected_at,
|
||||
remark=price.remark,
|
||||
created_at=price.created_at,
|
||||
updated_at=price.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/competitor-prices/{price_id}", response_model=ResponseModel[CompetitorPriceResponse])
|
||||
async def update_competitor_price(
|
||||
price_id: int,
|
||||
data: CompetitorPriceUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新竞品价格"""
|
||||
result = await db.execute(
|
||||
select(CompetitorPrice).options(
|
||||
selectinload(CompetitorPrice.competitor)
|
||||
).where(CompetitorPrice.id == price_id)
|
||||
)
|
||||
price = result.scalar_one_or_none()
|
||||
|
||||
if not price:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品价格不存在"}
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field == "price_source" and value:
|
||||
value = value.value
|
||||
setattr(price, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(price)
|
||||
|
||||
return ResponseModel(
|
||||
message="更新成功",
|
||||
data=CompetitorPriceResponse(
|
||||
id=price.id,
|
||||
competitor_id=price.competitor_id,
|
||||
competitor_name=price.competitor.competitor_name if price.competitor else None,
|
||||
project_id=price.project_id,
|
||||
project_name=price.project_name,
|
||||
original_price=float(price.original_price),
|
||||
promo_price=float(price.promo_price) if price.promo_price else None,
|
||||
member_price=float(price.member_price) if price.member_price else None,
|
||||
price_source=price.price_source,
|
||||
collected_at=price.collected_at,
|
||||
remark=price.remark,
|
||||
created_at=price.created_at,
|
||||
updated_at=price.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/competitor-prices/{price_id}", response_model=ResponseModel)
|
||||
async def delete_competitor_price(
|
||||
price_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除竞品价格"""
|
||||
result = await db.execute(
|
||||
select(CompetitorPrice).where(CompetitorPrice.id == price_id)
|
||||
)
|
||||
price = result.scalar_one_or_none()
|
||||
|
||||
if not price:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品价格不存在"}
|
||||
)
|
||||
|
||||
await db.delete(price)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
# ============ 标杆价格管理 ============
|
||||
|
||||
@router.get("/benchmark-prices", response_model=ResponseModel[PaginatedData[BenchmarkPriceResponse]])
|
||||
async def get_benchmark_prices(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
category_id: Optional[int] = Query(None, description="分类筛选"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取标杆价格列表"""
|
||||
query = select(BenchmarkPrice).options(selectinload(BenchmarkPrice.category))
|
||||
|
||||
if category_id:
|
||||
query = query.where(BenchmarkPrice.category_id == category_id)
|
||||
|
||||
query = query.order_by(BenchmarkPrice.effective_date.desc())
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
benchmarks = result.scalars().all()
|
||||
|
||||
# 统计总数
|
||||
count_query = select(func.count(BenchmarkPrice.id))
|
||||
if category_id:
|
||||
count_query = count_query.where(BenchmarkPrice.category_id == category_id)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
items = []
|
||||
for b in benchmarks:
|
||||
items.append(BenchmarkPriceResponse(
|
||||
id=b.id,
|
||||
benchmark_name=b.benchmark_name,
|
||||
category_id=b.category_id,
|
||||
category_name=b.category.category_name if b.category else None,
|
||||
min_price=float(b.min_price),
|
||||
max_price=float(b.max_price),
|
||||
avg_price=float(b.avg_price),
|
||||
price_tier=b.price_tier,
|
||||
effective_date=b.effective_date,
|
||||
remark=b.remark,
|
||||
created_at=b.created_at,
|
||||
updated_at=b.updated_at,
|
||||
))
|
||||
|
||||
return ResponseModel(
|
||||
data=PaginatedData(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/benchmark-prices", response_model=ResponseModel[BenchmarkPriceResponse])
|
||||
async def create_benchmark_price(
|
||||
data: BenchmarkPriceCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建标杆价格"""
|
||||
# 检查分类是否存在
|
||||
category_name = None
|
||||
if data.category_id:
|
||||
category_result = await db.execute(
|
||||
select(Category).where(Category.id == data.category_id)
|
||||
)
|
||||
category = category_result.scalar_one_or_none()
|
||||
if not category:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
|
||||
)
|
||||
category_name = category.category_name
|
||||
|
||||
benchmark = BenchmarkPrice(
|
||||
benchmark_name=data.benchmark_name,
|
||||
category_id=data.category_id,
|
||||
min_price=data.min_price,
|
||||
max_price=data.max_price,
|
||||
avg_price=data.avg_price,
|
||||
price_tier=data.price_tier.value,
|
||||
effective_date=data.effective_date,
|
||||
remark=data.remark,
|
||||
)
|
||||
db.add(benchmark)
|
||||
await db.flush()
|
||||
await db.refresh(benchmark)
|
||||
|
||||
return ResponseModel(
|
||||
message="创建成功",
|
||||
data=BenchmarkPriceResponse(
|
||||
id=benchmark.id,
|
||||
benchmark_name=benchmark.benchmark_name,
|
||||
category_id=benchmark.category_id,
|
||||
category_name=category_name,
|
||||
min_price=float(benchmark.min_price),
|
||||
max_price=float(benchmark.max_price),
|
||||
avg_price=float(benchmark.avg_price),
|
||||
price_tier=benchmark.price_tier,
|
||||
effective_date=benchmark.effective_date,
|
||||
remark=benchmark.remark,
|
||||
created_at=benchmark.created_at,
|
||||
updated_at=benchmark.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/benchmark-prices/{benchmark_id}", response_model=ResponseModel[BenchmarkPriceResponse])
|
||||
async def update_benchmark_price(
|
||||
benchmark_id: int,
|
||||
data: BenchmarkPriceUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新标杆价格"""
|
||||
result = await db.execute(
|
||||
select(BenchmarkPrice).options(
|
||||
selectinload(BenchmarkPrice.category)
|
||||
).where(BenchmarkPrice.id == benchmark_id)
|
||||
)
|
||||
benchmark = result.scalar_one_or_none()
|
||||
|
||||
if not benchmark:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "标杆价格不存在"}
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
if field == "price_tier" and value:
|
||||
value = value.value
|
||||
setattr(benchmark, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(benchmark)
|
||||
|
||||
# 获取分类名称
|
||||
category_name = None
|
||||
if benchmark.category_id:
|
||||
cat_result = await db.execute(
|
||||
select(Category).where(Category.id == benchmark.category_id)
|
||||
)
|
||||
category = cat_result.scalar_one_or_none()
|
||||
if category:
|
||||
category_name = category.category_name
|
||||
|
||||
return ResponseModel(
|
||||
message="更新成功",
|
||||
data=BenchmarkPriceResponse(
|
||||
id=benchmark.id,
|
||||
benchmark_name=benchmark.benchmark_name,
|
||||
category_id=benchmark.category_id,
|
||||
category_name=category_name,
|
||||
min_price=float(benchmark.min_price),
|
||||
max_price=float(benchmark.max_price),
|
||||
avg_price=float(benchmark.avg_price),
|
||||
price_tier=benchmark.price_tier,
|
||||
effective_date=benchmark.effective_date,
|
||||
remark=benchmark.remark,
|
||||
created_at=benchmark.created_at,
|
||||
updated_at=benchmark.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/benchmark-prices/{benchmark_id}", response_model=ResponseModel)
|
||||
async def delete_benchmark_price(
|
||||
benchmark_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除标杆价格"""
|
||||
result = await db.execute(
|
||||
select(BenchmarkPrice).where(BenchmarkPrice.id == benchmark_id)
|
||||
)
|
||||
benchmark = result.scalar_one_or_none()
|
||||
|
||||
if not benchmark:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "标杆价格不存在"}
|
||||
)
|
||||
|
||||
await db.delete(benchmark)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
# ============ 市场分析 ============
|
||||
|
||||
@router.post("/projects/{project_id}/market-analysis", response_model=ResponseModel[MarketAnalysisResultSchema])
|
||||
async def analyze_market(
|
||||
project_id: int,
|
||||
data: MarketAnalysisRequest = MarketAnalysisRequest(),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""执行市场价格分析"""
|
||||
market_service = MarketService(db)
|
||||
|
||||
try:
|
||||
result = await market_service.analyze_market(
|
||||
project_id=project_id,
|
||||
competitor_ids=data.competitor_ids,
|
||||
include_benchmark=data.include_benchmark,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
|
||||
)
|
||||
|
||||
return ResponseModel(message="分析完成", data=result)
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/market-analysis", response_model=ResponseModel[MarketAnalysisResponse])
|
||||
async def get_market_analysis(
|
||||
project_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取最新市场分析结果"""
|
||||
# 检查项目是否存在
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
market_service = MarketService(db)
|
||||
result = await market_service.get_latest_analysis(project_id)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "暂无分析结果,请先执行市场分析"}
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data=MarketAnalysisResponse(
|
||||
id=result.id,
|
||||
project_id=result.project_id,
|
||||
analysis_date=result.analysis_date,
|
||||
competitor_count=result.competitor_count,
|
||||
market_min_price=float(result.market_min_price),
|
||||
market_max_price=float(result.market_max_price),
|
||||
market_avg_price=float(result.market_avg_price),
|
||||
market_median_price=float(result.market_median_price),
|
||||
suggested_range_min=float(result.suggested_range_min),
|
||||
suggested_range_max=float(result.suggested_range_max),
|
||||
created_at=result.created_at,
|
||||
)
|
||||
)
|
||||
272
后端服务/app/routers/materials.py
Normal file
272
后端服务/app/routers/materials.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""耗材管理路由
|
||||
|
||||
实现耗材的 CRUD 操作和批量导入
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.material import Material
|
||||
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
|
||||
from app.schemas.material import (
|
||||
MaterialCreate,
|
||||
MaterialUpdate,
|
||||
MaterialResponse,
|
||||
MaterialImportResult,
|
||||
MaterialType,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel[PaginatedData[MaterialResponse]])
|
||||
async def get_materials(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
keyword: Optional[str] = Query(None, description="关键词搜索"),
|
||||
material_type: Optional[MaterialType] = Query(None, description="类型筛选"),
|
||||
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取耗材列表"""
|
||||
query = select(Material)
|
||||
|
||||
if keyword:
|
||||
query = query.where(
|
||||
or_(
|
||||
Material.material_code.contains(keyword),
|
||||
Material.material_name.contains(keyword),
|
||||
)
|
||||
)
|
||||
if material_type:
|
||||
query = query.where(Material.material_type == material_type.value)
|
||||
if is_active is not None:
|
||||
query = query.where(Material.is_active == is_active)
|
||||
|
||||
query = query.order_by(Material.id.desc())
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
materials = result.scalars().all()
|
||||
|
||||
# 统计总数
|
||||
count_query = select(func.count(Material.id))
|
||||
if keyword:
|
||||
count_query = count_query.where(
|
||||
or_(
|
||||
Material.material_code.contains(keyword),
|
||||
Material.material_name.contains(keyword),
|
||||
)
|
||||
)
|
||||
if material_type:
|
||||
count_query = count_query.where(Material.material_type == material_type.value)
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Material.is_active == is_active)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
return ResponseModel(
|
||||
data=PaginatedData(
|
||||
items=[MaterialResponse.model_validate(m) for m in materials],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{material_id}", response_model=ResponseModel[MaterialResponse])
|
||||
async def get_material(
|
||||
material_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个耗材详情"""
|
||||
result = await db.execute(select(Material).where(Material.id == material_id))
|
||||
material = result.scalar_one_or_none()
|
||||
|
||||
if not material:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "耗材不存在"}
|
||||
)
|
||||
|
||||
return ResponseModel(data=MaterialResponse.model_validate(material))
|
||||
|
||||
|
||||
@router.post("", response_model=ResponseModel[MaterialResponse])
|
||||
async def create_material(
|
||||
data: MaterialCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建耗材"""
|
||||
# 检查编码是否已存在
|
||||
existing = await db.execute(
|
||||
select(Material).where(Material.material_code == data.material_code)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "耗材编码已存在"}
|
||||
)
|
||||
|
||||
material = Material(**data.model_dump())
|
||||
db.add(material)
|
||||
await db.flush()
|
||||
await db.refresh(material)
|
||||
|
||||
return ResponseModel(message="创建成功", data=MaterialResponse.model_validate(material))
|
||||
|
||||
|
||||
@router.put("/{material_id}", response_model=ResponseModel[MaterialResponse])
|
||||
async def update_material(
|
||||
material_id: int,
|
||||
data: MaterialUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新耗材"""
|
||||
result = await db.execute(select(Material).where(Material.id == material_id))
|
||||
material = result.scalar_one_or_none()
|
||||
|
||||
if not material:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "耗材不存在"}
|
||||
)
|
||||
|
||||
# 检查编码是否重复
|
||||
if data.material_code and data.material_code != material.material_code:
|
||||
existing = await db.execute(
|
||||
select(Material).where(Material.material_code == data.material_code)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "耗材编码已存在"}
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(material, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(material)
|
||||
|
||||
return ResponseModel(message="更新成功", data=MaterialResponse.model_validate(material))
|
||||
|
||||
|
||||
@router.delete("/{material_id}", response_model=ResponseModel)
|
||||
async def delete_material(
|
||||
material_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除耗材"""
|
||||
result = await db.execute(select(Material).where(Material.id == material_id))
|
||||
material = result.scalar_one_or_none()
|
||||
|
||||
if not material:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "耗材不存在"}
|
||||
)
|
||||
|
||||
await db.delete(material)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
@router.post("/import", response_model=ResponseModel[MaterialImportResult])
|
||||
async def import_materials(
|
||||
file: UploadFile = File(..., description="Excel 文件"),
|
||||
update_existing: bool = Query(False, description="是否更新已存在的数据"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""批量导入耗材
|
||||
|
||||
Excel 格式:耗材编码 | 耗材名称 | 单位 | 单价 | 供应商 | 类型
|
||||
"""
|
||||
import openpyxl
|
||||
from io import BytesIO
|
||||
|
||||
if not file.filename.endswith(('.xlsx', '.xls')):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.PARAM_ERROR, "message": "请上传 Excel 文件"}
|
||||
)
|
||||
|
||||
content = await file.read()
|
||||
wb = openpyxl.load_workbook(BytesIO(content))
|
||||
ws = wb.active
|
||||
|
||||
total = 0
|
||||
success = 0
|
||||
errors = []
|
||||
|
||||
for row_num, row in enumerate(ws.iter_rows(min_row=2, values_only=True), start=2):
|
||||
if not row[0]: # 跳过空行
|
||||
continue
|
||||
|
||||
total += 1
|
||||
|
||||
try:
|
||||
material_code = str(row[0]).strip()
|
||||
material_name = str(row[1]).strip() if row[1] else ""
|
||||
unit = str(row[2]).strip() if row[2] else ""
|
||||
unit_price = float(row[3]) if row[3] else 0
|
||||
supplier = str(row[4]).strip() if row[4] else None
|
||||
material_type = str(row[5]).strip() if row[5] else "consumable"
|
||||
|
||||
if not material_name or not unit:
|
||||
errors.append({"row": row_num, "error": "名称或单位不能为空"})
|
||||
continue
|
||||
|
||||
# 检查是否已存在
|
||||
existing = await db.execute(
|
||||
select(Material).where(Material.material_code == material_code)
|
||||
)
|
||||
existing_material = existing.scalar_one_or_none()
|
||||
|
||||
if existing_material:
|
||||
if update_existing:
|
||||
existing_material.material_name = material_name
|
||||
existing_material.unit = unit
|
||||
existing_material.unit_price = unit_price
|
||||
existing_material.supplier = supplier
|
||||
existing_material.material_type = material_type
|
||||
success += 1
|
||||
else:
|
||||
errors.append({"row": row_num, "error": f"耗材编码 {material_code} 已存在"})
|
||||
else:
|
||||
material = Material(
|
||||
material_code=material_code,
|
||||
material_name=material_name,
|
||||
unit=unit,
|
||||
unit_price=unit_price,
|
||||
supplier=supplier,
|
||||
material_type=material_type,
|
||||
)
|
||||
db.add(material)
|
||||
success += 1
|
||||
|
||||
except Exception as e:
|
||||
errors.append({"row": row_num, "error": str(e)})
|
||||
|
||||
await db.flush()
|
||||
|
||||
return ResponseModel(
|
||||
message="导入完成",
|
||||
data=MaterialImportResult(
|
||||
total=total,
|
||||
success=success,
|
||||
failed=len(errors),
|
||||
errors=errors,
|
||||
)
|
||||
)
|
||||
327
后端服务/app/routers/pricing.py
Normal file
327
后端服务/app/routers/pricing.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""智能定价路由
|
||||
|
||||
智能定价建议相关的 API 接口
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import PricingPlan, Project
|
||||
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
|
||||
from app.schemas.pricing import (
|
||||
StrategyType,
|
||||
PricingPlanCreate,
|
||||
PricingPlanUpdate,
|
||||
PricingPlanResponse,
|
||||
PricingPlanListResponse,
|
||||
PricingPlanQuery,
|
||||
GeneratePricingRequest,
|
||||
GeneratePricingResponse,
|
||||
SimulateStrategyRequest,
|
||||
SimulateStrategyResponse,
|
||||
)
|
||||
from app.services.pricing_service import PricingService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 定价方案 CRUD
|
||||
|
||||
# 定价方案允许的排序字段白名单
|
||||
PRICING_PLAN_SORT_FIELDS = {"created_at", "updated_at", "plan_name", "base_cost", "target_margin", "suggested_price"}
|
||||
|
||||
|
||||
@router.get("/pricing-plans", response_model=ResponseModel[PaginatedData[PricingPlanListResponse]])
|
||||
async def list_pricing_plans(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
project_id: Optional[int] = None,
|
||||
strategy_type: Optional[StrategyType] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取定价方案列表"""
|
||||
query = select(PricingPlan).options(
|
||||
selectinload(PricingPlan.project),
|
||||
selectinload(PricingPlan.creator),
|
||||
)
|
||||
|
||||
if project_id:
|
||||
query = query.where(PricingPlan.project_id == project_id)
|
||||
if strategy_type:
|
||||
query = query.where(PricingPlan.strategy_type == strategy_type.value)
|
||||
if is_active is not None:
|
||||
query = query.where(PricingPlan.is_active == is_active)
|
||||
|
||||
# 计算总数
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# 排序 - 使用白名单验证防止注入
|
||||
if sort_by not in PRICING_PLAN_SORT_FIELDS:
|
||||
sort_by = "created_at"
|
||||
sort_column = getattr(PricingPlan, sort_by, PricingPlan.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# 分页
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
plans = result.scalars().all()
|
||||
|
||||
items = []
|
||||
for plan in plans:
|
||||
items.append(PricingPlanListResponse(
|
||||
id=plan.id,
|
||||
project_id=plan.project_id,
|
||||
project_name=plan.project.project_name if plan.project else None,
|
||||
plan_name=plan.plan_name,
|
||||
strategy_type=plan.strategy_type,
|
||||
base_cost=float(plan.base_cost),
|
||||
target_margin=float(plan.target_margin),
|
||||
suggested_price=float(plan.suggested_price),
|
||||
final_price=float(plan.final_price) if plan.final_price else None,
|
||||
is_active=plan.is_active,
|
||||
created_at=plan.created_at,
|
||||
created_by_name=plan.creator.username if plan.creator else None,
|
||||
))
|
||||
|
||||
return ResponseModel(data=PaginatedData(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
))
|
||||
|
||||
|
||||
@router.post("/pricing-plans", response_model=ResponseModel[PricingPlanResponse])
|
||||
async def create_pricing_plan(
|
||||
data: PricingPlanCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建定价方案"""
|
||||
service = PricingService(db)
|
||||
|
||||
try:
|
||||
plan = await service.create_pricing_plan(
|
||||
project_id=data.project_id,
|
||||
plan_name=data.plan_name,
|
||||
strategy_type=data.strategy_type,
|
||||
target_margin=data.target_margin,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# 重新加载关系
|
||||
await db.refresh(plan, ["project", "creator"])
|
||||
|
||||
return ResponseModel(
|
||||
message="创建成功",
|
||||
data=PricingPlanResponse(
|
||||
id=plan.id,
|
||||
project_id=plan.project_id,
|
||||
project_name=plan.project.project_name if plan.project else None,
|
||||
plan_name=plan.plan_name,
|
||||
strategy_type=plan.strategy_type,
|
||||
base_cost=float(plan.base_cost),
|
||||
target_margin=float(plan.target_margin),
|
||||
suggested_price=float(plan.suggested_price),
|
||||
final_price=float(plan.final_price) if plan.final_price else None,
|
||||
ai_advice=plan.ai_advice,
|
||||
is_active=plan.is_active,
|
||||
created_at=plan.created_at,
|
||||
updated_at=plan.updated_at,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/pricing-plans/{plan_id}", response_model=ResponseModel[PricingPlanResponse])
|
||||
async def get_pricing_plan(
|
||||
plan_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取定价方案详情"""
|
||||
result = await db.execute(
|
||||
select(PricingPlan).options(
|
||||
selectinload(PricingPlan.project),
|
||||
selectinload(PricingPlan.creator),
|
||||
).where(PricingPlan.id == plan_id)
|
||||
)
|
||||
plan = result.scalar_one_or_none()
|
||||
|
||||
if not plan:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "定价方案不存在"}
|
||||
)
|
||||
|
||||
return ResponseModel(data=PricingPlanResponse(
|
||||
id=plan.id,
|
||||
project_id=plan.project_id,
|
||||
project_name=plan.project.project_name if plan.project else None,
|
||||
plan_name=plan.plan_name,
|
||||
strategy_type=plan.strategy_type,
|
||||
base_cost=float(plan.base_cost),
|
||||
target_margin=float(plan.target_margin),
|
||||
suggested_price=float(plan.suggested_price),
|
||||
final_price=float(plan.final_price) if plan.final_price else None,
|
||||
ai_advice=plan.ai_advice,
|
||||
is_active=plan.is_active,
|
||||
created_at=plan.created_at,
|
||||
updated_at=plan.updated_at,
|
||||
created_by_name=plan.creator.username if plan.creator else None,
|
||||
))
|
||||
|
||||
|
||||
@router.put("/pricing-plans/{plan_id}", response_model=ResponseModel[PricingPlanResponse])
|
||||
async def update_pricing_plan(
|
||||
plan_id: int,
|
||||
data: PricingPlanUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新定价方案"""
|
||||
service = PricingService(db)
|
||||
|
||||
try:
|
||||
plan = await service.update_pricing_plan(
|
||||
plan_id=plan_id,
|
||||
plan_name=data.plan_name,
|
||||
strategy_type=data.strategy_type.value if data.strategy_type else None,
|
||||
target_margin=data.target_margin,
|
||||
final_price=data.final_price,
|
||||
is_active=data.is_active,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
await db.refresh(plan, ["project", "creator"])
|
||||
|
||||
return ResponseModel(
|
||||
message="更新成功",
|
||||
data=PricingPlanResponse(
|
||||
id=plan.id,
|
||||
project_id=plan.project_id,
|
||||
project_name=plan.project.project_name if plan.project else None,
|
||||
plan_name=plan.plan_name,
|
||||
strategy_type=plan.strategy_type,
|
||||
base_cost=float(plan.base_cost),
|
||||
target_margin=float(plan.target_margin),
|
||||
suggested_price=float(plan.suggested_price),
|
||||
final_price=float(plan.final_price) if plan.final_price else None,
|
||||
ai_advice=plan.ai_advice,
|
||||
is_active=plan.is_active,
|
||||
created_at=plan.created_at,
|
||||
updated_at=plan.updated_at,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/pricing-plans/{plan_id}", response_model=ResponseModel)
|
||||
async def delete_pricing_plan(
|
||||
plan_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除定价方案"""
|
||||
result = await db.execute(
|
||||
select(PricingPlan).where(PricingPlan.id == plan_id)
|
||||
)
|
||||
plan = result.scalar_one_or_none()
|
||||
|
||||
if not plan:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "定价方案不存在"}
|
||||
)
|
||||
|
||||
await db.delete(plan)
|
||||
await db.commit()
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
# AI 定价建议
|
||||
|
||||
@router.post("/projects/{project_id}/generate-pricing", response_model=ResponseModel[GeneratePricingResponse])
|
||||
async def generate_pricing(
|
||||
project_id: int,
|
||||
request: GeneratePricingRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""AI 生成定价建议
|
||||
|
||||
支持流式和非流式两种模式
|
||||
"""
|
||||
service = PricingService(db)
|
||||
|
||||
# 检查项目是否存在
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
if not result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# 流式返回
|
||||
return StreamingResponse(
|
||||
service.generate_pricing_advice_stream(
|
||||
project_id=project_id,
|
||||
target_margin=request.target_margin,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式返回
|
||||
try:
|
||||
response = await service.generate_pricing_advice(
|
||||
project_id=project_id,
|
||||
target_margin=request.target_margin,
|
||||
strategies=request.strategies,
|
||||
)
|
||||
return ResponseModel(data=response)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"code": ErrorCode.AI_SERVICE_ERROR, "message": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/simulate-strategy", response_model=ResponseModel[SimulateStrategyResponse])
|
||||
async def simulate_strategy(
|
||||
project_id: int,
|
||||
request: SimulateStrategyRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""模拟定价策略"""
|
||||
service = PricingService(db)
|
||||
|
||||
try:
|
||||
response = await service.simulate_strategies(
|
||||
project_id=project_id,
|
||||
strategies=request.strategies,
|
||||
target_margin=request.target_margin,
|
||||
)
|
||||
return ResponseModel(data=response)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
273
后端服务/app/routers/profit.py
Normal file
273
后端服务/app/routers/profit.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""利润模拟路由
|
||||
|
||||
利润模拟测算相关的 API 接口
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import ProfitSimulation, PricingPlan
|
||||
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
|
||||
from app.schemas.profit import (
|
||||
PeriodType,
|
||||
ProfitSimulationResponse,
|
||||
ProfitSimulationListResponse,
|
||||
SimulateProfitRequest,
|
||||
SimulateProfitResponse,
|
||||
SensitivityAnalysisRequest,
|
||||
SensitivityAnalysisResponse,
|
||||
BreakevenRequest,
|
||||
BreakevenResponse,
|
||||
)
|
||||
from app.services.profit_service import ProfitService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 利润模拟 CRUD
|
||||
|
||||
@router.get("/profit-simulations", response_model=ResponseModel[PaginatedData[ProfitSimulationListResponse]])
|
||||
async def list_profit_simulations(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
pricing_plan_id: Optional[int] = None,
|
||||
period_type: Optional[PeriodType] = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取利润模拟列表"""
|
||||
service = ProfitService(db)
|
||||
|
||||
simulations, total = await service.get_simulation_list(
|
||||
pricing_plan_id=pricing_plan_id,
|
||||
period_type=period_type,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
items = []
|
||||
for sim in simulations:
|
||||
items.append(ProfitSimulationListResponse(
|
||||
id=sim.id,
|
||||
pricing_plan_id=sim.pricing_plan_id,
|
||||
plan_name=sim.pricing_plan.plan_name if sim.pricing_plan else None,
|
||||
project_name=sim.pricing_plan.project.project_name if sim.pricing_plan and sim.pricing_plan.project else None,
|
||||
simulation_name=sim.simulation_name,
|
||||
price=float(sim.price),
|
||||
estimated_volume=sim.estimated_volume,
|
||||
period_type=sim.period_type,
|
||||
estimated_profit=float(sim.estimated_profit),
|
||||
profit_margin=float(sim.profit_margin),
|
||||
created_at=sim.created_at,
|
||||
))
|
||||
|
||||
return ResponseModel(data=PaginatedData(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
))
|
||||
|
||||
|
||||
@router.get("/profit-simulations/{simulation_id}", response_model=ResponseModel[ProfitSimulationResponse])
|
||||
async def get_profit_simulation(
|
||||
simulation_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取模拟详情"""
|
||||
result = await db.execute(
|
||||
select(ProfitSimulation).options(
|
||||
selectinload(ProfitSimulation.pricing_plan).selectinload(PricingPlan.project),
|
||||
selectinload(ProfitSimulation.creator),
|
||||
).where(ProfitSimulation.id == simulation_id)
|
||||
)
|
||||
sim = result.scalar_one_or_none()
|
||||
|
||||
if not sim:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "模拟记录不存在"}
|
||||
)
|
||||
|
||||
return ResponseModel(data=ProfitSimulationResponse(
|
||||
id=sim.id,
|
||||
pricing_plan_id=sim.pricing_plan_id,
|
||||
plan_name=sim.pricing_plan.plan_name if sim.pricing_plan else None,
|
||||
project_name=sim.pricing_plan.project.project_name if sim.pricing_plan and sim.pricing_plan.project else None,
|
||||
simulation_name=sim.simulation_name,
|
||||
price=float(sim.price),
|
||||
estimated_volume=sim.estimated_volume,
|
||||
period_type=sim.period_type,
|
||||
estimated_revenue=float(sim.estimated_revenue),
|
||||
estimated_cost=float(sim.estimated_cost),
|
||||
estimated_profit=float(sim.estimated_profit),
|
||||
profit_margin=float(sim.profit_margin),
|
||||
breakeven_volume=sim.breakeven_volume,
|
||||
created_at=sim.created_at,
|
||||
created_by_name=sim.creator.username if sim.creator else None,
|
||||
))
|
||||
|
||||
|
||||
@router.delete("/profit-simulations/{simulation_id}", response_model=ResponseModel)
|
||||
async def delete_profit_simulation(
|
||||
simulation_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除模拟记录"""
|
||||
service = ProfitService(db)
|
||||
|
||||
deleted = await service.delete_simulation(simulation_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "模拟记录不存在"}
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
# 执行模拟
|
||||
|
||||
@router.post("/pricing-plans/{plan_id}/simulate-profit", response_model=ResponseModel[SimulateProfitResponse])
|
||||
async def simulate_profit(
|
||||
plan_id: int,
|
||||
request: SimulateProfitRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""执行利润模拟"""
|
||||
service = ProfitService(db)
|
||||
|
||||
try:
|
||||
response = await service.simulate_profit(
|
||||
pricing_plan_id=plan_id,
|
||||
price=request.price,
|
||||
estimated_volume=request.estimated_volume,
|
||||
period_type=request.period_type,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return ResponseModel(data=response)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# 敏感性分析
|
||||
|
||||
@router.post("/profit-simulations/{simulation_id}/sensitivity", response_model=ResponseModel[SensitivityAnalysisResponse])
|
||||
async def create_sensitivity_analysis(
|
||||
simulation_id: int,
|
||||
request: SensitivityAnalysisRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""执行敏感性分析"""
|
||||
service = ProfitService(db)
|
||||
|
||||
try:
|
||||
response = await service.sensitivity_analysis(
|
||||
simulation_id=simulation_id,
|
||||
price_change_rates=request.price_change_rates,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return ResponseModel(data=response)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/profit-simulations/{simulation_id}/sensitivity", response_model=ResponseModel[SensitivityAnalysisResponse])
|
||||
async def get_sensitivity_analysis(
|
||||
simulation_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取敏感性分析结果"""
|
||||
result = await db.execute(
|
||||
select(ProfitSimulation).options(
|
||||
selectinload(ProfitSimulation.sensitivity_analyses),
|
||||
selectinload(ProfitSimulation.pricing_plan),
|
||||
).where(ProfitSimulation.id == simulation_id)
|
||||
)
|
||||
sim = result.scalar_one_or_none()
|
||||
|
||||
if not sim:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "模拟记录不存在"}
|
||||
)
|
||||
|
||||
if not sim.sensitivity_analyses:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "尚未执行敏感性分析"}
|
||||
)
|
||||
|
||||
from app.schemas.profit import SensitivityResultItem
|
||||
|
||||
results = [
|
||||
SensitivityResultItem(
|
||||
price_change_rate=float(sa.price_change_rate),
|
||||
adjusted_price=float(sa.adjusted_price),
|
||||
adjusted_profit=float(sa.adjusted_profit),
|
||||
profit_change_rate=float(sa.profit_change_rate),
|
||||
)
|
||||
for sa in sorted(sim.sensitivity_analyses, key=lambda x: x.price_change_rate)
|
||||
]
|
||||
|
||||
return ResponseModel(data=SensitivityAnalysisResponse(
|
||||
simulation_id=simulation_id,
|
||||
base_price=float(sim.price),
|
||||
base_profit=float(sim.estimated_profit),
|
||||
sensitivity_results=results,
|
||||
))
|
||||
|
||||
|
||||
# 盈亏平衡分析
|
||||
|
||||
@router.get("/pricing-plans/{plan_id}/breakeven", response_model=ResponseModel[BreakevenResponse])
|
||||
async def get_breakeven_analysis(
|
||||
plan_id: int,
|
||||
target_profit: Optional[float] = Query(None, description="目标利润"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取盈亏平衡分析"""
|
||||
service = ProfitService(db)
|
||||
|
||||
try:
|
||||
response = await service.breakeven_analysis(
|
||||
pricing_plan_id=plan_id,
|
||||
target_profit=target_profit,
|
||||
)
|
||||
return ResponseModel(data=response)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# AI 利润预测
|
||||
|
||||
@router.post("/profit-simulations/{simulation_id}/forecast", response_model=ResponseModel[dict])
|
||||
async def generate_profit_forecast(
|
||||
simulation_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""AI 生成利润预测分析"""
|
||||
service = ProfitService(db)
|
||||
|
||||
try:
|
||||
content = await service.generate_profit_forecast(simulation_id)
|
||||
return ResponseModel(data={"content": content})
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"code": ErrorCode.AI_SERVICE_ERROR, "message": str(e)}
|
||||
)
|
||||
904
后端服务/app/routers/projects.py
Normal file
904
后端服务/app/routers/projects.py
Normal file
@@ -0,0 +1,904 @@
|
||||
"""服务项目管理路由
|
||||
|
||||
实现服务项目的 CRUD 操作和成本管理
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import (
|
||||
Project,
|
||||
ProjectCostItem,
|
||||
ProjectLaborCost,
|
||||
ProjectCostSummary,
|
||||
Category,
|
||||
Material,
|
||||
Equipment,
|
||||
StaffLevel,
|
||||
)
|
||||
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
|
||||
from app.schemas.project import (
|
||||
ProjectCreate,
|
||||
ProjectUpdate,
|
||||
ProjectResponse,
|
||||
ProjectListResponse,
|
||||
CostSummaryBrief,
|
||||
)
|
||||
from app.schemas.project_cost import (
|
||||
CostItemCreate,
|
||||
CostItemUpdate,
|
||||
CostItemResponse,
|
||||
LaborCostCreate,
|
||||
LaborCostUpdate,
|
||||
LaborCostResponse,
|
||||
CalculateCostRequest,
|
||||
CostCalculationResult,
|
||||
CostSummaryResponse,
|
||||
ProjectDetailResponse,
|
||||
CostItemType,
|
||||
AllocationMethod,
|
||||
)
|
||||
from app.services.cost_service import CostService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# 项目允许的排序字段白名单
|
||||
PROJECT_SORT_FIELDS = {"created_at", "updated_at", "project_code", "project_name", "duration_minutes"}
|
||||
|
||||
|
||||
# ============ 项目 CRUD ============
|
||||
|
||||
@router.get("", response_model=ResponseModel[PaginatedData[ProjectListResponse]])
|
||||
async def get_projects(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
category_id: Optional[int] = Query(None, description="分类筛选"),
|
||||
keyword: Optional[str] = Query(None, description="关键词搜索"),
|
||||
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
|
||||
sort_by: str = Query("created_at", description="排序字段"),
|
||||
sort_order: str = Query("desc", description="排序方向"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取服务项目列表"""
|
||||
query = select(Project).options(
|
||||
selectinload(Project.category),
|
||||
selectinload(Project.cost_summary),
|
||||
)
|
||||
|
||||
if keyword:
|
||||
query = query.where(
|
||||
or_(
|
||||
Project.project_code.contains(keyword),
|
||||
Project.project_name.contains(keyword),
|
||||
)
|
||||
)
|
||||
if category_id:
|
||||
query = query.where(Project.category_id == category_id)
|
||||
if is_active is not None:
|
||||
query = query.where(Project.is_active == is_active)
|
||||
|
||||
# 排序 - 使用白名单验证防止注入
|
||||
if sort_by not in PROJECT_SORT_FIELDS:
|
||||
sort_by = "created_at"
|
||||
sort_column = getattr(Project, sort_by, Project.created_at)
|
||||
if sort_order == "asc":
|
||||
query = query.order_by(sort_column.asc())
|
||||
else:
|
||||
query = query.order_by(sort_column.desc())
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
projects = result.scalars().all()
|
||||
|
||||
# 统计总数
|
||||
count_query = select(func.count(Project.id))
|
||||
if keyword:
|
||||
count_query = count_query.where(
|
||||
or_(
|
||||
Project.project_code.contains(keyword),
|
||||
Project.project_name.contains(keyword),
|
||||
)
|
||||
)
|
||||
if category_id:
|
||||
count_query = count_query.where(Project.category_id == category_id)
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(Project.is_active == is_active)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for p in projects:
|
||||
item = {
|
||||
"id": p.id,
|
||||
"project_code": p.project_code,
|
||||
"project_name": p.project_name,
|
||||
"category_id": p.category_id,
|
||||
"category_name": p.category.category_name if p.category else None,
|
||||
"description": p.description,
|
||||
"duration_minutes": p.duration_minutes,
|
||||
"is_active": p.is_active,
|
||||
"cost_summary": None,
|
||||
"created_at": p.created_at,
|
||||
"updated_at": p.updated_at,
|
||||
}
|
||||
if p.cost_summary:
|
||||
item["cost_summary"] = CostSummaryBrief(
|
||||
total_cost=float(p.cost_summary.total_cost),
|
||||
material_cost=float(p.cost_summary.material_cost),
|
||||
equipment_cost=float(p.cost_summary.equipment_cost),
|
||||
labor_cost=float(p.cost_summary.labor_cost),
|
||||
fixed_cost_allocation=float(p.cost_summary.fixed_cost_allocation),
|
||||
)
|
||||
items.append(ProjectListResponse(**item))
|
||||
|
||||
return ResponseModel(
|
||||
data=PaginatedData(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=ResponseModel[ProjectDetailResponse])
|
||||
async def get_project(
|
||||
project_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取项目详情(含成本明细)"""
|
||||
result = await db.execute(
|
||||
select(Project).options(
|
||||
selectinload(Project.category),
|
||||
selectinload(Project.cost_items),
|
||||
selectinload(Project.labor_costs).selectinload(ProjectLaborCost.staff_level),
|
||||
selectinload(Project.cost_summary),
|
||||
).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
# 构建成本明细响应
|
||||
cost_items = []
|
||||
for item in project.cost_items:
|
||||
item_name = None
|
||||
unit = None
|
||||
if item.item_type == CostItemType.MATERIAL.value:
|
||||
material_result = await db.execute(
|
||||
select(Material).where(Material.id == item.item_id)
|
||||
)
|
||||
material = material_result.scalar_one_or_none()
|
||||
if material:
|
||||
item_name = material.material_name
|
||||
unit = material.unit
|
||||
else:
|
||||
equipment_result = await db.execute(
|
||||
select(Equipment).where(Equipment.id == item.item_id)
|
||||
)
|
||||
equipment = equipment_result.scalar_one_or_none()
|
||||
if equipment:
|
||||
item_name = equipment.equipment_name
|
||||
unit = "次"
|
||||
|
||||
cost_items.append(CostItemResponse(
|
||||
id=item.id,
|
||||
item_type=CostItemType(item.item_type),
|
||||
item_id=item.item_id,
|
||||
item_name=item_name,
|
||||
quantity=float(item.quantity),
|
||||
unit=unit,
|
||||
unit_cost=float(item.unit_cost),
|
||||
total_cost=float(item.total_cost),
|
||||
remark=item.remark,
|
||||
created_at=item.created_at,
|
||||
updated_at=item.updated_at,
|
||||
))
|
||||
|
||||
# 构建人工成本响应
|
||||
labor_costs = []
|
||||
for item in project.labor_costs:
|
||||
labor_costs.append(LaborCostResponse(
|
||||
id=item.id,
|
||||
staff_level_id=item.staff_level_id,
|
||||
level_name=item.staff_level.level_name if item.staff_level else None,
|
||||
duration_minutes=item.duration_minutes,
|
||||
hourly_rate=float(item.hourly_rate),
|
||||
labor_cost=float(item.labor_cost),
|
||||
remark=item.remark,
|
||||
created_at=item.created_at,
|
||||
updated_at=item.updated_at,
|
||||
))
|
||||
|
||||
# 构建成本汇总
|
||||
cost_summary = None
|
||||
if project.cost_summary:
|
||||
cost_summary = CostSummaryResponse(
|
||||
project_id=project.id,
|
||||
material_cost=float(project.cost_summary.material_cost),
|
||||
equipment_cost=float(project.cost_summary.equipment_cost),
|
||||
labor_cost=float(project.cost_summary.labor_cost),
|
||||
fixed_cost_allocation=float(project.cost_summary.fixed_cost_allocation),
|
||||
total_cost=float(project.cost_summary.total_cost),
|
||||
calculated_at=project.cost_summary.calculated_at,
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data=ProjectDetailResponse(
|
||||
id=project.id,
|
||||
project_code=project.project_code,
|
||||
project_name=project.project_name,
|
||||
category_id=project.category_id,
|
||||
category_name=project.category.category_name if project.category else None,
|
||||
description=project.description,
|
||||
duration_minutes=project.duration_minutes,
|
||||
is_active=project.is_active,
|
||||
cost_items=cost_items,
|
||||
labor_costs=labor_costs,
|
||||
cost_summary=cost_summary,
|
||||
created_at=project.created_at,
|
||||
updated_at=project.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ResponseModel[ProjectResponse])
|
||||
async def create_project(
|
||||
data: ProjectCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建服务项目"""
|
||||
# 检查编码是否已存在
|
||||
existing = await db.execute(
|
||||
select(Project).where(Project.project_code == data.project_code)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "项目编码已存在"}
|
||||
)
|
||||
|
||||
# 检查分类是否存在
|
||||
if data.category_id:
|
||||
category_result = await db.execute(
|
||||
select(Category).where(Category.id == data.category_id)
|
||||
)
|
||||
if not category_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
|
||||
)
|
||||
|
||||
project = Project(**data.model_dump())
|
||||
db.add(project)
|
||||
await db.flush()
|
||||
await db.refresh(project)
|
||||
|
||||
# 获取分类名称
|
||||
category_name = None
|
||||
if project.category_id:
|
||||
result = await db.execute(
|
||||
select(Category).where(Category.id == project.category_id)
|
||||
)
|
||||
category = result.scalar_one_or_none()
|
||||
if category:
|
||||
category_name = category.category_name
|
||||
|
||||
return ResponseModel(
|
||||
message="创建成功",
|
||||
data=ProjectResponse(
|
||||
id=project.id,
|
||||
project_code=project.project_code,
|
||||
project_name=project.project_name,
|
||||
category_id=project.category_id,
|
||||
category_name=category_name,
|
||||
description=project.description,
|
||||
duration_minutes=project.duration_minutes,
|
||||
is_active=project.is_active,
|
||||
cost_summary=None,
|
||||
created_at=project.created_at,
|
||||
updated_at=project.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{project_id}", response_model=ResponseModel[ProjectResponse])
|
||||
async def update_project(
|
||||
project_id: int,
|
||||
data: ProjectUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新服务项目"""
|
||||
result = await db.execute(
|
||||
select(Project).options(
|
||||
selectinload(Project.category),
|
||||
selectinload(Project.cost_summary),
|
||||
).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
# 检查编码是否重复
|
||||
if data.project_code and data.project_code != project.project_code:
|
||||
existing = await db.execute(
|
||||
select(Project).where(Project.project_code == data.project_code)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "项目编码已存在"}
|
||||
)
|
||||
|
||||
# 检查分类是否存在
|
||||
if data.category_id:
|
||||
category_result = await db.execute(
|
||||
select(Category).where(Category.id == data.category_id)
|
||||
)
|
||||
if not category_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(project, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(project)
|
||||
|
||||
# 获取分类名称
|
||||
category_name = None
|
||||
if project.category_id:
|
||||
cat_result = await db.execute(
|
||||
select(Category).where(Category.id == project.category_id)
|
||||
)
|
||||
category = cat_result.scalar_one_or_none()
|
||||
if category:
|
||||
category_name = category.category_name
|
||||
|
||||
# 构建成本汇总
|
||||
cost_summary = None
|
||||
if project.cost_summary:
|
||||
cost_summary = CostSummaryBrief(
|
||||
total_cost=float(project.cost_summary.total_cost),
|
||||
material_cost=float(project.cost_summary.material_cost),
|
||||
equipment_cost=float(project.cost_summary.equipment_cost),
|
||||
labor_cost=float(project.cost_summary.labor_cost),
|
||||
fixed_cost_allocation=float(project.cost_summary.fixed_cost_allocation),
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
message="更新成功",
|
||||
data=ProjectResponse(
|
||||
id=project.id,
|
||||
project_code=project.project_code,
|
||||
project_name=project.project_name,
|
||||
category_id=project.category_id,
|
||||
category_name=category_name,
|
||||
description=project.description,
|
||||
duration_minutes=project.duration_minutes,
|
||||
is_active=project.is_active,
|
||||
cost_summary=cost_summary,
|
||||
created_at=project.created_at,
|
||||
updated_at=project.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{project_id}", response_model=ResponseModel)
|
||||
async def delete_project(
|
||||
project_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除服务项目"""
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
await db.delete(project)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
# ============ 成本明细(耗材/设备)管理 ============
|
||||
|
||||
@router.get("/{project_id}/cost-items", response_model=ResponseModel[list[CostItemResponse]])
|
||||
async def get_cost_items(
|
||||
project_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取项目成本明细"""
|
||||
# 检查项目是否存在
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(ProjectCostItem).where(
|
||||
ProjectCostItem.project_id == project_id
|
||||
).order_by(ProjectCostItem.id)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
|
||||
response_items = []
|
||||
for item in items:
|
||||
item_name = None
|
||||
unit = None
|
||||
if item.item_type == CostItemType.MATERIAL.value:
|
||||
material_result = await db.execute(
|
||||
select(Material).where(Material.id == item.item_id)
|
||||
)
|
||||
material = material_result.scalar_one_or_none()
|
||||
if material:
|
||||
item_name = material.material_name
|
||||
unit = material.unit
|
||||
else:
|
||||
equipment_result = await db.execute(
|
||||
select(Equipment).where(Equipment.id == item.item_id)
|
||||
)
|
||||
equipment = equipment_result.scalar_one_or_none()
|
||||
if equipment:
|
||||
item_name = equipment.equipment_name
|
||||
unit = "次"
|
||||
|
||||
response_items.append(CostItemResponse(
|
||||
id=item.id,
|
||||
item_type=CostItemType(item.item_type),
|
||||
item_id=item.item_id,
|
||||
item_name=item_name,
|
||||
quantity=float(item.quantity),
|
||||
unit=unit,
|
||||
unit_cost=float(item.unit_cost),
|
||||
total_cost=float(item.total_cost),
|
||||
remark=item.remark,
|
||||
created_at=item.created_at,
|
||||
updated_at=item.updated_at,
|
||||
))
|
||||
|
||||
return ResponseModel(data=response_items)
|
||||
|
||||
|
||||
@router.post("/{project_id}/cost-items", response_model=ResponseModel[CostItemResponse])
|
||||
async def create_cost_item(
|
||||
project_id: int,
|
||||
data: CostItemCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""添加成本明细"""
|
||||
# 检查项目是否存在
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
cost_service = CostService(db)
|
||||
|
||||
try:
|
||||
cost_item = await cost_service.add_cost_item(
|
||||
project_id=project_id,
|
||||
item_type=data.item_type,
|
||||
item_id=data.item_id,
|
||||
quantity=data.quantity,
|
||||
remark=data.remark,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
|
||||
)
|
||||
|
||||
# 获取物品名称
|
||||
item_name = None
|
||||
unit = None
|
||||
if data.item_type == CostItemType.MATERIAL:
|
||||
material_result = await db.execute(
|
||||
select(Material).where(Material.id == data.item_id)
|
||||
)
|
||||
material = material_result.scalar_one_or_none()
|
||||
if material:
|
||||
item_name = material.material_name
|
||||
unit = material.unit
|
||||
else:
|
||||
equipment_result = await db.execute(
|
||||
select(Equipment).where(Equipment.id == data.item_id)
|
||||
)
|
||||
equipment = equipment_result.scalar_one_or_none()
|
||||
if equipment:
|
||||
item_name = equipment.equipment_name
|
||||
unit = "次"
|
||||
|
||||
return ResponseModel(
|
||||
message="添加成功",
|
||||
data=CostItemResponse(
|
||||
id=cost_item.id,
|
||||
item_type=CostItemType(cost_item.item_type),
|
||||
item_id=cost_item.item_id,
|
||||
item_name=item_name,
|
||||
quantity=float(cost_item.quantity),
|
||||
unit=unit,
|
||||
unit_cost=float(cost_item.unit_cost),
|
||||
total_cost=float(cost_item.total_cost),
|
||||
remark=cost_item.remark,
|
||||
created_at=cost_item.created_at,
|
||||
updated_at=cost_item.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{project_id}/cost-items/{item_id}", response_model=ResponseModel[CostItemResponse])
|
||||
async def update_cost_item(
|
||||
project_id: int,
|
||||
item_id: int,
|
||||
data: CostItemUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新成本明细"""
|
||||
result = await db.execute(
|
||||
select(ProjectCostItem).where(
|
||||
ProjectCostItem.id == item_id,
|
||||
ProjectCostItem.project_id == project_id,
|
||||
)
|
||||
)
|
||||
cost_item = result.scalar_one_or_none()
|
||||
|
||||
if not cost_item:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "成本明细不存在"}
|
||||
)
|
||||
|
||||
cost_service = CostService(db)
|
||||
cost_item = await cost_service.update_cost_item(
|
||||
cost_item=cost_item,
|
||||
quantity=data.quantity,
|
||||
remark=data.remark,
|
||||
)
|
||||
|
||||
# 获取物品名称
|
||||
item_name = None
|
||||
unit = None
|
||||
if cost_item.item_type == CostItemType.MATERIAL.value:
|
||||
material_result = await db.execute(
|
||||
select(Material).where(Material.id == cost_item.item_id)
|
||||
)
|
||||
material = material_result.scalar_one_or_none()
|
||||
if material:
|
||||
item_name = material.material_name
|
||||
unit = material.unit
|
||||
else:
|
||||
equipment_result = await db.execute(
|
||||
select(Equipment).where(Equipment.id == cost_item.item_id)
|
||||
)
|
||||
equipment = equipment_result.scalar_one_or_none()
|
||||
if equipment:
|
||||
item_name = equipment.equipment_name
|
||||
unit = "次"
|
||||
|
||||
return ResponseModel(
|
||||
message="更新成功",
|
||||
data=CostItemResponse(
|
||||
id=cost_item.id,
|
||||
item_type=CostItemType(cost_item.item_type),
|
||||
item_id=cost_item.item_id,
|
||||
item_name=item_name,
|
||||
quantity=float(cost_item.quantity),
|
||||
unit=unit,
|
||||
unit_cost=float(cost_item.unit_cost),
|
||||
total_cost=float(cost_item.total_cost),
|
||||
remark=cost_item.remark,
|
||||
created_at=cost_item.created_at,
|
||||
updated_at=cost_item.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{project_id}/cost-items/{item_id}", response_model=ResponseModel)
|
||||
async def delete_cost_item(
|
||||
project_id: int,
|
||||
item_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除成本明细"""
|
||||
result = await db.execute(
|
||||
select(ProjectCostItem).where(
|
||||
ProjectCostItem.id == item_id,
|
||||
ProjectCostItem.project_id == project_id,
|
||||
)
|
||||
)
|
||||
cost_item = result.scalar_one_or_none()
|
||||
|
||||
if not cost_item:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "成本明细不存在"}
|
||||
)
|
||||
|
||||
await db.delete(cost_item)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
# ============ 人工成本管理 ============
|
||||
|
||||
@router.get("/{project_id}/labor-costs", response_model=ResponseModel[list[LaborCostResponse]])
|
||||
async def get_labor_costs(
|
||||
project_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取项目人工成本"""
|
||||
# 检查项目是否存在
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(ProjectLaborCost).options(
|
||||
selectinload(ProjectLaborCost.staff_level)
|
||||
).where(
|
||||
ProjectLaborCost.project_id == project_id
|
||||
).order_by(ProjectLaborCost.id)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
|
||||
response_items = []
|
||||
for item in items:
|
||||
response_items.append(LaborCostResponse(
|
||||
id=item.id,
|
||||
staff_level_id=item.staff_level_id,
|
||||
level_name=item.staff_level.level_name if item.staff_level else None,
|
||||
duration_minutes=item.duration_minutes,
|
||||
hourly_rate=float(item.hourly_rate),
|
||||
labor_cost=float(item.labor_cost),
|
||||
remark=item.remark,
|
||||
created_at=item.created_at,
|
||||
updated_at=item.updated_at,
|
||||
))
|
||||
|
||||
return ResponseModel(data=response_items)
|
||||
|
||||
|
||||
@router.post("/{project_id}/labor-costs", response_model=ResponseModel[LaborCostResponse])
|
||||
async def create_labor_cost(
|
||||
project_id: int,
|
||||
data: LaborCostCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""添加人工成本"""
|
||||
# 检查项目是否存在
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
cost_service = CostService(db)
|
||||
|
||||
try:
|
||||
labor_cost = await cost_service.add_labor_cost(
|
||||
project_id=project_id,
|
||||
staff_level_id=data.staff_level_id,
|
||||
duration_minutes=data.duration_minutes,
|
||||
remark=data.remark,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
|
||||
)
|
||||
|
||||
# 获取级别名称
|
||||
level_result = await db.execute(
|
||||
select(StaffLevel).where(StaffLevel.id == data.staff_level_id)
|
||||
)
|
||||
staff_level = level_result.scalar_one_or_none()
|
||||
|
||||
return ResponseModel(
|
||||
message="添加成功",
|
||||
data=LaborCostResponse(
|
||||
id=labor_cost.id,
|
||||
staff_level_id=labor_cost.staff_level_id,
|
||||
level_name=staff_level.level_name if staff_level else None,
|
||||
duration_minutes=labor_cost.duration_minutes,
|
||||
hourly_rate=float(labor_cost.hourly_rate),
|
||||
labor_cost=float(labor_cost.labor_cost),
|
||||
remark=labor_cost.remark,
|
||||
created_at=labor_cost.created_at,
|
||||
updated_at=labor_cost.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{project_id}/labor-costs/{item_id}", response_model=ResponseModel[LaborCostResponse])
|
||||
async def update_labor_cost(
|
||||
project_id: int,
|
||||
item_id: int,
|
||||
data: LaborCostUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新人工成本"""
|
||||
result = await db.execute(
|
||||
select(ProjectLaborCost).where(
|
||||
ProjectLaborCost.id == item_id,
|
||||
ProjectLaborCost.project_id == project_id,
|
||||
)
|
||||
)
|
||||
labor_item = result.scalar_one_or_none()
|
||||
|
||||
if not labor_item:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "人工成本记录不存在"}
|
||||
)
|
||||
|
||||
cost_service = CostService(db)
|
||||
|
||||
try:
|
||||
labor_item = await cost_service.update_labor_cost(
|
||||
labor_item=labor_item,
|
||||
staff_level_id=data.staff_level_id,
|
||||
duration_minutes=data.duration_minutes,
|
||||
remark=data.remark,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
|
||||
)
|
||||
|
||||
# 获取级别名称
|
||||
level_result = await db.execute(
|
||||
select(StaffLevel).where(StaffLevel.id == labor_item.staff_level_id)
|
||||
)
|
||||
staff_level = level_result.scalar_one_or_none()
|
||||
|
||||
return ResponseModel(
|
||||
message="更新成功",
|
||||
data=LaborCostResponse(
|
||||
id=labor_item.id,
|
||||
staff_level_id=labor_item.staff_level_id,
|
||||
level_name=staff_level.level_name if staff_level else None,
|
||||
duration_minutes=labor_item.duration_minutes,
|
||||
hourly_rate=float(labor_item.hourly_rate),
|
||||
labor_cost=float(labor_item.labor_cost),
|
||||
remark=labor_item.remark,
|
||||
created_at=labor_item.created_at,
|
||||
updated_at=labor_item.updated_at,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{project_id}/labor-costs/{item_id}", response_model=ResponseModel)
|
||||
async def delete_labor_cost(
|
||||
project_id: int,
|
||||
item_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除人工成本"""
|
||||
result = await db.execute(
|
||||
select(ProjectLaborCost).where(
|
||||
ProjectLaborCost.id == item_id,
|
||||
ProjectLaborCost.project_id == project_id,
|
||||
)
|
||||
)
|
||||
labor_item = result.scalar_one_or_none()
|
||||
|
||||
if not labor_item:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "人工成本记录不存在"}
|
||||
)
|
||||
|
||||
await db.delete(labor_item)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
|
||||
|
||||
# ============ 成本计算 ============
|
||||
|
||||
@router.post("/{project_id}/calculate-cost", response_model=ResponseModel[CostCalculationResult])
|
||||
async def calculate_cost(
|
||||
project_id: int,
|
||||
data: CalculateCostRequest = CalculateCostRequest(),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""计算项目总成本"""
|
||||
cost_service = CostService(db)
|
||||
|
||||
try:
|
||||
result = await cost_service.calculate_project_cost(
|
||||
project_id=project_id,
|
||||
allocation_method=data.fixed_cost_allocation_method,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
|
||||
)
|
||||
|
||||
return ResponseModel(message="计算完成", data=result)
|
||||
|
||||
|
||||
@router.get("/{project_id}/cost-summary", response_model=ResponseModel[CostSummaryResponse])
|
||||
async def get_cost_summary(
|
||||
project_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取成本汇总"""
|
||||
# 检查项目是否存在
|
||||
project_result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
if not project_result.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(ProjectCostSummary).where(
|
||||
ProjectCostSummary.project_id == project_id
|
||||
)
|
||||
)
|
||||
summary = result.scalar_one_or_none()
|
||||
|
||||
if not summary:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "成本汇总不存在,请先计算成本"}
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data=CostSummaryResponse(
|
||||
project_id=summary.project_id,
|
||||
material_cost=float(summary.material_cost),
|
||||
equipment_cost=float(summary.equipment_cost),
|
||||
labor_cost=float(summary.labor_cost),
|
||||
fixed_cost_allocation=float(summary.fixed_cost_allocation),
|
||||
total_cost=float(summary.total_cost),
|
||||
calculated_at=summary.calculated_at,
|
||||
)
|
||||
)
|
||||
172
后端服务/app/routers/staff_levels.py
Normal file
172
后端服务/app/routers/staff_levels.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""人员级别路由
|
||||
|
||||
实现人员级别的 CRUD 操作
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.staff_level import StaffLevel
|
||||
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
|
||||
from app.schemas.staff_level import (
|
||||
StaffLevelCreate,
|
||||
StaffLevelUpdate,
|
||||
StaffLevelResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel[PaginatedData[StaffLevelResponse]])
|
||||
async def get_staff_levels(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
keyword: Optional[str] = Query(None, description="关键词搜索"),
|
||||
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取人员级别列表"""
|
||||
query = select(StaffLevel)
|
||||
|
||||
if keyword:
|
||||
query = query.where(
|
||||
StaffLevel.level_name.contains(keyword) |
|
||||
StaffLevel.level_code.contains(keyword)
|
||||
)
|
||||
if is_active is not None:
|
||||
query = query.where(StaffLevel.is_active == is_active)
|
||||
|
||||
query = query.order_by(StaffLevel.hourly_rate)
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
|
||||
result = await db.execute(query)
|
||||
staff_levels = result.scalars().all()
|
||||
|
||||
# 统计总数
|
||||
count_query = select(func.count(StaffLevel.id))
|
||||
if keyword:
|
||||
count_query = count_query.where(
|
||||
StaffLevel.level_name.contains(keyword) |
|
||||
StaffLevel.level_code.contains(keyword)
|
||||
)
|
||||
if is_active is not None:
|
||||
count_query = count_query.where(StaffLevel.is_active == is_active)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
return ResponseModel(
|
||||
data=PaginatedData(
|
||||
items=[StaffLevelResponse.model_validate(s) for s in staff_levels],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{staff_level_id}", response_model=ResponseModel[StaffLevelResponse])
|
||||
async def get_staff_level(
|
||||
staff_level_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取单个人员级别详情"""
|
||||
result = await db.execute(select(StaffLevel).where(StaffLevel.id == staff_level_id))
|
||||
staff_level = result.scalar_one_or_none()
|
||||
|
||||
if not staff_level:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "人员级别不存在"}
|
||||
)
|
||||
|
||||
return ResponseModel(data=StaffLevelResponse.model_validate(staff_level))
|
||||
|
||||
|
||||
@router.post("", response_model=ResponseModel[StaffLevelResponse])
|
||||
async def create_staff_level(
|
||||
data: StaffLevelCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""创建人员级别"""
|
||||
# 检查编码是否已存在
|
||||
existing = await db.execute(
|
||||
select(StaffLevel).where(StaffLevel.level_code == data.level_code)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "级别编码已存在"}
|
||||
)
|
||||
|
||||
staff_level = StaffLevel(**data.model_dump())
|
||||
db.add(staff_level)
|
||||
await db.flush()
|
||||
await db.refresh(staff_level)
|
||||
|
||||
return ResponseModel(message="创建成功", data=StaffLevelResponse.model_validate(staff_level))
|
||||
|
||||
|
||||
@router.put("/{staff_level_id}", response_model=ResponseModel[StaffLevelResponse])
|
||||
async def update_staff_level(
|
||||
staff_level_id: int,
|
||||
data: StaffLevelUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""更新人员级别"""
|
||||
result = await db.execute(select(StaffLevel).where(StaffLevel.id == staff_level_id))
|
||||
staff_level = result.scalar_one_or_none()
|
||||
|
||||
if not staff_level:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "人员级别不存在"}
|
||||
)
|
||||
|
||||
# 检查编码是否重复
|
||||
if data.level_code and data.level_code != staff_level.level_code:
|
||||
existing = await db.execute(
|
||||
select(StaffLevel).where(StaffLevel.level_code == data.level_code)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "级别编码已存在"}
|
||||
)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(staff_level, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(staff_level)
|
||||
|
||||
return ResponseModel(message="更新成功", data=StaffLevelResponse.model_validate(staff_level))
|
||||
|
||||
|
||||
@router.delete("/{staff_level_id}", response_model=ResponseModel)
|
||||
async def delete_staff_level(
|
||||
staff_level_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除人员级别"""
|
||||
result = await db.execute(select(StaffLevel).where(StaffLevel.id == staff_level_id))
|
||||
staff_level = result.scalar_one_or_none()
|
||||
|
||||
if not staff_level:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"code": ErrorCode.NOT_FOUND, "message": "人员级别不存在"}
|
||||
)
|
||||
|
||||
await db.delete(staff_level)
|
||||
|
||||
return ResponseModel(message="删除成功")
|
||||
131
后端服务/app/schemas/__init__.py
Normal file
131
后端服务/app/schemas/__init__.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Pydantic 数据模型"""
|
||||
|
||||
from app.schemas.common import ResponseModel, PaginatedResponse
|
||||
from app.schemas.category import CategoryCreate, CategoryUpdate, CategoryResponse
|
||||
from app.schemas.material import MaterialCreate, MaterialUpdate, MaterialResponse
|
||||
from app.schemas.equipment import EquipmentCreate, EquipmentUpdate, EquipmentResponse
|
||||
from app.schemas.staff_level import StaffLevelCreate, StaffLevelUpdate, StaffLevelResponse
|
||||
from app.schemas.fixed_cost import FixedCostCreate, FixedCostUpdate, FixedCostResponse
|
||||
from app.schemas.project import (
|
||||
ProjectCreate, ProjectUpdate, ProjectResponse,
|
||||
ProjectListResponse, ProjectQuery
|
||||
)
|
||||
from app.schemas.project_cost import (
|
||||
CostItemCreate, CostItemUpdate, CostItemResponse,
|
||||
LaborCostCreate, LaborCostUpdate, LaborCostResponse,
|
||||
CalculateCostRequest, CostCalculationResult, CostSummaryResponse,
|
||||
ProjectDetailResponse, CostItemType, AllocationMethod
|
||||
)
|
||||
from app.schemas.competitor import (
|
||||
CompetitorCreate, CompetitorUpdate, CompetitorResponse,
|
||||
CompetitorPriceCreate, CompetitorPriceUpdate, CompetitorPriceResponse,
|
||||
Positioning, PriceSource
|
||||
)
|
||||
from app.schemas.market import (
|
||||
BenchmarkPriceCreate, BenchmarkPriceUpdate, BenchmarkPriceResponse,
|
||||
MarketAnalysisRequest, MarketAnalysisResult, MarketAnalysisResponse,
|
||||
PriceTier
|
||||
)
|
||||
from app.schemas.pricing import (
|
||||
PricingPlanCreate, PricingPlanUpdate, PricingPlanResponse,
|
||||
PricingPlanListResponse, PricingPlanQuery,
|
||||
GeneratePricingRequest, GeneratePricingResponse,
|
||||
SimulateStrategyRequest, SimulateStrategyResponse,
|
||||
StrategyType
|
||||
)
|
||||
from app.schemas.profit import (
|
||||
ProfitSimulationCreate, ProfitSimulationResponse,
|
||||
ProfitSimulationListResponse, ProfitSimulationQuery,
|
||||
SimulateProfitRequest, SimulateProfitResponse,
|
||||
SensitivityAnalysisRequest, SensitivityAnalysisResponse,
|
||||
BreakevenRequest, BreakevenResponse,
|
||||
PeriodType
|
||||
)
|
||||
from app.schemas.dashboard import (
|
||||
DashboardSummaryResponse, CostTrendResponse, MarketTrendResponse,
|
||||
AIUsageStatsResponse
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ResponseModel",
|
||||
"PaginatedResponse",
|
||||
"CategoryCreate",
|
||||
"CategoryUpdate",
|
||||
"CategoryResponse",
|
||||
"MaterialCreate",
|
||||
"MaterialUpdate",
|
||||
"MaterialResponse",
|
||||
"EquipmentCreate",
|
||||
"EquipmentUpdate",
|
||||
"EquipmentResponse",
|
||||
"StaffLevelCreate",
|
||||
"StaffLevelUpdate",
|
||||
"StaffLevelResponse",
|
||||
"FixedCostCreate",
|
||||
"FixedCostUpdate",
|
||||
"FixedCostResponse",
|
||||
# Project schemas
|
||||
"ProjectCreate",
|
||||
"ProjectUpdate",
|
||||
"ProjectResponse",
|
||||
"ProjectListResponse",
|
||||
"ProjectQuery",
|
||||
# Project cost schemas
|
||||
"CostItemCreate",
|
||||
"CostItemUpdate",
|
||||
"CostItemResponse",
|
||||
"LaborCostCreate",
|
||||
"LaborCostUpdate",
|
||||
"LaborCostResponse",
|
||||
"CalculateCostRequest",
|
||||
"CostCalculationResult",
|
||||
"CostSummaryResponse",
|
||||
"ProjectDetailResponse",
|
||||
"CostItemType",
|
||||
"AllocationMethod",
|
||||
# Competitor schemas
|
||||
"CompetitorCreate",
|
||||
"CompetitorUpdate",
|
||||
"CompetitorResponse",
|
||||
"CompetitorPriceCreate",
|
||||
"CompetitorPriceUpdate",
|
||||
"CompetitorPriceResponse",
|
||||
"Positioning",
|
||||
"PriceSource",
|
||||
# Market schemas
|
||||
"BenchmarkPriceCreate",
|
||||
"BenchmarkPriceUpdate",
|
||||
"BenchmarkPriceResponse",
|
||||
"MarketAnalysisRequest",
|
||||
"MarketAnalysisResult",
|
||||
"MarketAnalysisResponse",
|
||||
"PriceTier",
|
||||
# Pricing schemas
|
||||
"PricingPlanCreate",
|
||||
"PricingPlanUpdate",
|
||||
"PricingPlanResponse",
|
||||
"PricingPlanListResponse",
|
||||
"PricingPlanQuery",
|
||||
"GeneratePricingRequest",
|
||||
"GeneratePricingResponse",
|
||||
"SimulateStrategyRequest",
|
||||
"SimulateStrategyResponse",
|
||||
"StrategyType",
|
||||
# Profit simulation schemas
|
||||
"ProfitSimulationCreate",
|
||||
"ProfitSimulationResponse",
|
||||
"ProfitSimulationListResponse",
|
||||
"ProfitSimulationQuery",
|
||||
"SimulateProfitRequest",
|
||||
"SimulateProfitResponse",
|
||||
"SensitivityAnalysisRequest",
|
||||
"SensitivityAnalysisResponse",
|
||||
"BreakevenRequest",
|
||||
"BreakevenResponse",
|
||||
"PeriodType",
|
||||
# Dashboard schemas
|
||||
"DashboardSummaryResponse",
|
||||
"CostTrendResponse",
|
||||
"MarketTrendResponse",
|
||||
"AIUsageStatsResponse",
|
||||
]
|
||||
53
后端服务/app/schemas/category.py
Normal file
53
后端服务/app/schemas/category.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""项目分类 Schema"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CategoryBase(BaseModel):
|
||||
"""分类基础字段"""
|
||||
|
||||
category_name: str = Field(..., min_length=1, max_length=50, description="分类名称")
|
||||
parent_id: Optional[int] = Field(None, description="父分类ID")
|
||||
sort_order: int = Field(0, ge=0, description="排序")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class CategoryCreate(CategoryBase):
|
||||
"""创建分类请求"""
|
||||
pass
|
||||
|
||||
|
||||
class CategoryUpdate(BaseModel):
|
||||
"""更新分类请求"""
|
||||
|
||||
category_name: Optional[str] = Field(None, min_length=1, max_length=50, description="分类名称")
|
||||
parent_id: Optional[int] = Field(None, description="父分类ID")
|
||||
sort_order: Optional[int] = Field(None, ge=0, description="排序")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class CategoryResponse(CategoryBase):
|
||||
"""分类响应"""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CategoryTreeResponse(CategoryResponse):
|
||||
"""分类树形响应"""
|
||||
|
||||
children: List["CategoryTreeResponse"] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# 解决循环引用
|
||||
CategoryTreeResponse.model_rebuild()
|
||||
60
后端服务/app/schemas/common.py
Normal file
60
后端服务/app/schemas/common.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""通用响应模型
|
||||
|
||||
遵循瑞小美 API 响应格式规范
|
||||
"""
|
||||
|
||||
from typing import Generic, TypeVar, Optional, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ResponseModel(BaseModel, Generic[T]):
|
||||
"""统一响应格式"""
|
||||
|
||||
code: int = 0
|
||||
message: str = "success"
|
||||
data: Optional[T] = None
|
||||
|
||||
|
||||
class PaginatedData(BaseModel, Generic[T]):
|
||||
"""分页数据"""
|
||||
|
||||
items: List[T]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
"""分页响应"""
|
||||
|
||||
code: int = 0
|
||||
message: str = "success"
|
||||
data: Optional[PaginatedData[T]] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""错误响应"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: None = None
|
||||
|
||||
|
||||
# 错误码定义
|
||||
class ErrorCode:
|
||||
SUCCESS = 0
|
||||
PARAM_ERROR = 10001
|
||||
NOT_FOUND = 10002
|
||||
ALREADY_EXISTS = 10003
|
||||
NOT_ALLOWED = 10004
|
||||
AUTH_FAILED = 20001
|
||||
PERMISSION_DENIED = 20002
|
||||
TOKEN_EXPIRED = 20003
|
||||
INTERNAL_ERROR = 30001
|
||||
SERVICE_UNAVAILABLE = 30002
|
||||
AI_SERVICE_ERROR = 40001
|
||||
AI_SERVICE_TIMEOUT = 40002
|
||||
114
后端服务/app/schemas/competitor.py
Normal file
114
后端服务/app/schemas/competitor.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""竞品机构 Schema"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime, date
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Positioning(str, Enum):
|
||||
"""机构定位枚举"""
|
||||
|
||||
HIGH = "high" # 高端
|
||||
MEDIUM = "medium" # 中端
|
||||
BUDGET = "budget" # 大众
|
||||
|
||||
|
||||
class PriceSource(str, Enum):
|
||||
"""价格来源枚举"""
|
||||
|
||||
OFFICIAL = "official" # 官网
|
||||
MEITUAN = "meituan" # 美团
|
||||
DIANPING = "dianping" # 大众点评
|
||||
SURVEY = "survey" # 实地调研
|
||||
|
||||
|
||||
# ============ 竞品机构 Schema ============
|
||||
|
||||
class CompetitorBase(BaseModel):
|
||||
"""竞品机构基础字段"""
|
||||
|
||||
competitor_name: str = Field(..., min_length=1, max_length=100, description="机构名称")
|
||||
address: Optional[str] = Field(None, max_length=200, description="地址")
|
||||
distance_km: Optional[float] = Field(None, ge=0, description="距离(公里)")
|
||||
positioning: Positioning = Field(Positioning.MEDIUM, description="定位")
|
||||
contact: Optional[str] = Field(None, max_length=50, description="联系方式")
|
||||
is_key_competitor: bool = Field(False, description="是否重点关注")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class CompetitorCreate(CompetitorBase):
|
||||
"""创建竞品机构请求"""
|
||||
pass
|
||||
|
||||
|
||||
class CompetitorUpdate(BaseModel):
|
||||
"""更新竞品机构请求"""
|
||||
|
||||
competitor_name: Optional[str] = Field(None, min_length=1, max_length=100, description="机构名称")
|
||||
address: Optional[str] = Field(None, max_length=200, description="地址")
|
||||
distance_km: Optional[float] = Field(None, ge=0, description="距离(公里)")
|
||||
positioning: Optional[Positioning] = Field(None, description="定位")
|
||||
contact: Optional[str] = Field(None, max_length=50, description="联系方式")
|
||||
is_key_competitor: Optional[bool] = Field(None, description="是否重点关注")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class CompetitorResponse(CompetitorBase):
|
||||
"""竞品机构响应"""
|
||||
|
||||
id: int
|
||||
price_count: int = Field(0, description="价格记录数")
|
||||
last_price_update: Optional[date] = Field(None, description="最后价格更新日期")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ============ 竞品价格 Schema ============
|
||||
|
||||
class CompetitorPriceBase(BaseModel):
|
||||
"""竞品价格基础字段"""
|
||||
|
||||
project_id: Optional[int] = Field(None, description="关联本店项目ID")
|
||||
project_name: str = Field(..., min_length=1, max_length=100, description="竞品项目名称")
|
||||
original_price: float = Field(..., gt=0, description="原价")
|
||||
promo_price: Optional[float] = Field(None, gt=0, description="促销价")
|
||||
member_price: Optional[float] = Field(None, gt=0, description="会员价")
|
||||
price_source: PriceSource = Field(..., description="价格来源")
|
||||
collected_at: date = Field(..., description="采集日期")
|
||||
remark: Optional[str] = Field(None, max_length=200, description="备注")
|
||||
|
||||
|
||||
class CompetitorPriceCreate(CompetitorPriceBase):
|
||||
"""创建竞品价格请求"""
|
||||
pass
|
||||
|
||||
|
||||
class CompetitorPriceUpdate(BaseModel):
|
||||
"""更新竞品价格请求"""
|
||||
|
||||
project_id: Optional[int] = Field(None, description="关联本店项目ID")
|
||||
project_name: Optional[str] = Field(None, min_length=1, max_length=100, description="竞品项目名称")
|
||||
original_price: Optional[float] = Field(None, gt=0, description="原价")
|
||||
promo_price: Optional[float] = Field(None, gt=0, description="促销价")
|
||||
member_price: Optional[float] = Field(None, gt=0, description="会员价")
|
||||
price_source: Optional[PriceSource] = Field(None, description="价格来源")
|
||||
collected_at: Optional[date] = Field(None, description="采集日期")
|
||||
remark: Optional[str] = Field(None, max_length=200, description="备注")
|
||||
|
||||
|
||||
class CompetitorPriceResponse(CompetitorPriceBase):
|
||||
"""竞品价格响应"""
|
||||
|
||||
id: int
|
||||
competitor_id: int
|
||||
competitor_name: Optional[str] = Field(None, description="竞品机构名称")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
142
后端服务/app/schemas/dashboard.py
Normal file
142
后端服务/app/schemas/dashboard.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""仪表盘 Schema
|
||||
|
||||
仪表盘数据相关的响应模型
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProjectOverview(BaseModel):
|
||||
"""项目概览"""
|
||||
|
||||
total_projects: int = Field(..., description="总项目数")
|
||||
active_projects: int = Field(..., description="启用项目数")
|
||||
projects_with_pricing: int = Field(..., description="已定价项目数")
|
||||
|
||||
|
||||
class CostProjectInfo(BaseModel):
|
||||
"""成本项目信息"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
cost: float
|
||||
|
||||
|
||||
class CostOverview(BaseModel):
|
||||
"""成本概览"""
|
||||
|
||||
avg_project_cost: float = Field(..., description="平均项目成本")
|
||||
highest_cost_project: Optional[CostProjectInfo] = Field(None, description="最高成本项目")
|
||||
lowest_cost_project: Optional[CostProjectInfo] = Field(None, description="最低成本项目")
|
||||
|
||||
|
||||
class MarketOverview(BaseModel):
|
||||
"""市场概览"""
|
||||
|
||||
competitors_tracked: int = Field(..., description="跟踪竞品数")
|
||||
price_records_this_month: int = Field(..., description="本月价格记录数")
|
||||
avg_market_price: Optional[float] = Field(None, description="市场平均价")
|
||||
|
||||
|
||||
class StrategiesDistribution(BaseModel):
|
||||
"""策略分布"""
|
||||
|
||||
traffic: int = Field(0, description="引流款数量")
|
||||
profit: int = Field(0, description="利润款数量")
|
||||
premium: int = Field(0, description="高端款数量")
|
||||
|
||||
|
||||
class PricingOverview(BaseModel):
|
||||
"""定价概览"""
|
||||
|
||||
pricing_plans_count: int = Field(..., description="定价方案总数")
|
||||
avg_target_margin: Optional[float] = Field(None, description="平均目标毛利率")
|
||||
strategies_distribution: StrategiesDistribution = Field(..., description="策略分布")
|
||||
|
||||
|
||||
class ProviderDistribution(BaseModel):
|
||||
"""服务商分布"""
|
||||
|
||||
primary: int = Field(0, alias="4sapi", description="4sapi 调用次数")
|
||||
fallback: int = Field(0, alias="openrouter", description="OpenRouter 调用次数")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class AIUsageOverview(BaseModel):
|
||||
"""AI 使用概览"""
|
||||
|
||||
total_calls: int = Field(..., description="总调用次数")
|
||||
total_tokens: int = Field(..., description="总 Token 消耗")
|
||||
total_cost_usd: float = Field(..., description="总费用(美元)")
|
||||
provider_distribution: Dict[str, int] = Field(..., description="服务商分布")
|
||||
|
||||
|
||||
class RecentActivity(BaseModel):
|
||||
"""最近活动"""
|
||||
|
||||
type: str = Field(..., description="活动类型")
|
||||
project_name: str = Field(..., description="项目名称")
|
||||
user: Optional[str] = Field(None, description="用户")
|
||||
time: datetime = Field(..., description="时间")
|
||||
|
||||
|
||||
class DashboardSummaryResponse(BaseModel):
|
||||
"""仪表盘概览响应"""
|
||||
|
||||
project_overview: ProjectOverview
|
||||
cost_overview: CostOverview
|
||||
market_overview: MarketOverview
|
||||
pricing_overview: PricingOverview
|
||||
ai_usage_this_month: Optional[AIUsageOverview] = None
|
||||
recent_activities: List[RecentActivity] = Field(default_factory=list)
|
||||
|
||||
|
||||
# 趋势数据
|
||||
|
||||
class TrendDataPoint(BaseModel):
|
||||
"""趋势数据点"""
|
||||
|
||||
date: str = Field(..., description="日期")
|
||||
value: float = Field(..., description="值")
|
||||
|
||||
|
||||
class CostTrendResponse(BaseModel):
|
||||
"""成本趋势响应"""
|
||||
|
||||
period: str = Field(..., description="统计周期")
|
||||
data: List[TrendDataPoint]
|
||||
avg_cost: float = Field(..., description="平均成本")
|
||||
|
||||
|
||||
class MarketTrendResponse(BaseModel):
|
||||
"""市场趋势响应"""
|
||||
|
||||
period: str = Field(..., description="统计周期")
|
||||
data: List[TrendDataPoint]
|
||||
avg_price: float = Field(..., description="平均价格")
|
||||
|
||||
|
||||
# AI 使用统计
|
||||
|
||||
class AIUsageStatItem(BaseModel):
|
||||
"""AI 使用统计项"""
|
||||
|
||||
date: str
|
||||
calls: int
|
||||
tokens: int
|
||||
cost: float
|
||||
|
||||
|
||||
class AIUsageStatsResponse(BaseModel):
|
||||
"""AI 使用统计响应"""
|
||||
|
||||
period: str = Field(..., description="统计周期")
|
||||
total_calls: int
|
||||
total_tokens: int
|
||||
total_cost: float
|
||||
daily_stats: List[AIUsageStatItem]
|
||||
54
后端服务/app/schemas/equipment.py
Normal file
54
后端服务/app/schemas/equipment.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""设备 Schema"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime, date
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class EquipmentBase(BaseModel):
|
||||
"""设备基础字段"""
|
||||
|
||||
equipment_code: str = Field(..., min_length=1, max_length=50, description="设备编码")
|
||||
equipment_name: str = Field(..., min_length=1, max_length=100, description="设备名称")
|
||||
original_value: float = Field(..., gt=0, description="设备原值")
|
||||
residual_rate: float = Field(5.00, ge=0, le=100, description="残值率(%)")
|
||||
service_years: int = Field(..., gt=0, description="预计使用年限")
|
||||
estimated_uses: int = Field(..., gt=0, description="预计使用次数")
|
||||
purchase_date: Optional[date] = Field(None, description="购入日期")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class EquipmentCreate(EquipmentBase):
|
||||
"""创建设备请求"""
|
||||
|
||||
@property
|
||||
def depreciation_per_use(self) -> float:
|
||||
"""计算单次折旧成本"""
|
||||
residual_value = self.original_value * self.residual_rate / 100
|
||||
return (self.original_value - residual_value) / self.estimated_uses
|
||||
|
||||
|
||||
class EquipmentUpdate(BaseModel):
|
||||
"""更新设备请求"""
|
||||
|
||||
equipment_code: Optional[str] = Field(None, min_length=1, max_length=50, description="设备编码")
|
||||
equipment_name: Optional[str] = Field(None, min_length=1, max_length=100, description="设备名称")
|
||||
original_value: Optional[float] = Field(None, gt=0, description="设备原值")
|
||||
residual_rate: Optional[float] = Field(None, ge=0, le=100, description="残值率(%)")
|
||||
service_years: Optional[int] = Field(None, gt=0, description="预计使用年限")
|
||||
estimated_uses: Optional[int] = Field(None, gt=0, description="预计使用次数")
|
||||
purchase_date: Optional[date] = Field(None, description="购入日期")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class EquipmentResponse(EquipmentBase):
|
||||
"""设备响应"""
|
||||
|
||||
id: int
|
||||
depreciation_per_use: float = Field(..., description="单次折旧成本")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
79
后端服务/app/schemas/fixed_cost.py
Normal file
79
后端服务/app/schemas/fixed_cost.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""固定成本 Schema"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
import re
|
||||
|
||||
|
||||
class CostType(str, Enum):
|
||||
"""成本类型枚举"""
|
||||
|
||||
RENT = "rent" # 房租
|
||||
UTILITIES = "utilities" # 水电
|
||||
PROPERTY = "property" # 物业
|
||||
OTHER = "other" # 其他
|
||||
|
||||
|
||||
class AllocationMethod(str, Enum):
|
||||
"""分摊方式枚举"""
|
||||
|
||||
COUNT = "count" # 按项目数量
|
||||
REVENUE = "revenue" # 按营收占比
|
||||
DURATION = "duration" # 按时长占比
|
||||
|
||||
|
||||
class FixedCostBase(BaseModel):
|
||||
"""固定成本基础字段"""
|
||||
|
||||
cost_name: str = Field(..., min_length=1, max_length=100, description="成本名称")
|
||||
cost_type: CostType = Field(..., description="类型")
|
||||
monthly_amount: float = Field(..., gt=0, description="月度金额")
|
||||
year_month: str = Field(..., description="年月:2026-01")
|
||||
allocation_method: AllocationMethod = Field(AllocationMethod.COUNT, description="分摊方式")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
@field_validator("year_month")
|
||||
@classmethod
|
||||
def validate_year_month(cls, v: str) -> str:
|
||||
"""验证年月格式"""
|
||||
if not re.match(r"^\d{4}-\d{2}$", v):
|
||||
raise ValueError("年月格式必须为 YYYY-MM")
|
||||
return v
|
||||
|
||||
|
||||
class FixedCostCreate(FixedCostBase):
|
||||
"""创建固定成本请求"""
|
||||
pass
|
||||
|
||||
|
||||
class FixedCostUpdate(BaseModel):
|
||||
"""更新固定成本请求"""
|
||||
|
||||
cost_name: Optional[str] = Field(None, min_length=1, max_length=100, description="成本名称")
|
||||
cost_type: Optional[CostType] = Field(None, description="类型")
|
||||
monthly_amount: Optional[float] = Field(None, gt=0, description="月度金额")
|
||||
year_month: Optional[str] = Field(None, description="年月:2026-01")
|
||||
allocation_method: Optional[AllocationMethod] = Field(None, description="分摊方式")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
@field_validator("year_month")
|
||||
@classmethod
|
||||
def validate_year_month(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证年月格式"""
|
||||
if v is not None and not re.match(r"^\d{4}-\d{2}$", v):
|
||||
raise ValueError("年月格式必须为 YYYY-MM")
|
||||
return v
|
||||
|
||||
|
||||
class FixedCostResponse(FixedCostBase):
|
||||
"""固定成本响应"""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
156
后端服务/app/schemas/market.py
Normal file
156
后端服务/app/schemas/market.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""市场分析 Schema"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, date
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PriceTier(str, Enum):
|
||||
"""价格带枚举"""
|
||||
|
||||
LOW = "low" # 低端
|
||||
MEDIUM = "medium" # 中端
|
||||
HIGH = "high" # 高端
|
||||
PREMIUM = "premium" # 奢华
|
||||
|
||||
|
||||
# ============ 标杆价格 Schema ============
|
||||
|
||||
class BenchmarkPriceBase(BaseModel):
|
||||
"""标杆价格基础字段"""
|
||||
|
||||
benchmark_name: str = Field(..., min_length=1, max_length=100, description="标杆机构名称")
|
||||
category_id: Optional[int] = Field(None, description="项目分类ID")
|
||||
min_price: float = Field(..., gt=0, description="最低价")
|
||||
max_price: float = Field(..., gt=0, description="最高价")
|
||||
avg_price: float = Field(..., gt=0, description="均价")
|
||||
price_tier: PriceTier = Field(PriceTier.MEDIUM, description="价格带")
|
||||
effective_date: date = Field(..., description="生效日期")
|
||||
remark: Optional[str] = Field(None, max_length=200, description="备注")
|
||||
|
||||
|
||||
class BenchmarkPriceCreate(BenchmarkPriceBase):
|
||||
"""创建标杆价格请求"""
|
||||
pass
|
||||
|
||||
|
||||
class BenchmarkPriceUpdate(BaseModel):
|
||||
"""更新标杆价格请求"""
|
||||
|
||||
benchmark_name: Optional[str] = Field(None, min_length=1, max_length=100, description="标杆机构名称")
|
||||
category_id: Optional[int] = Field(None, description="项目分类ID")
|
||||
min_price: Optional[float] = Field(None, gt=0, description="最低价")
|
||||
max_price: Optional[float] = Field(None, gt=0, description="最高价")
|
||||
avg_price: Optional[float] = Field(None, gt=0, description="均价")
|
||||
price_tier: Optional[PriceTier] = Field(None, description="价格带")
|
||||
effective_date: Optional[date] = Field(None, description="生效日期")
|
||||
remark: Optional[str] = Field(None, max_length=200, description="备注")
|
||||
|
||||
|
||||
class BenchmarkPriceResponse(BenchmarkPriceBase):
|
||||
"""标杆价格响应"""
|
||||
|
||||
id: int
|
||||
category_name: Optional[str] = Field(None, description="分类名称")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ============ 市场分析 Schema ============
|
||||
|
||||
class MarketAnalysisRequest(BaseModel):
|
||||
"""市场分析请求"""
|
||||
|
||||
competitor_ids: Optional[List[int]] = Field(None, description="指定竞品机构ID列表")
|
||||
include_benchmark: bool = Field(True, description="是否包含标杆价格参考")
|
||||
|
||||
|
||||
class PriceStatistics(BaseModel):
|
||||
"""价格统计"""
|
||||
|
||||
min_price: float
|
||||
max_price: float
|
||||
avg_price: float
|
||||
median_price: float
|
||||
std_deviation: Optional[float] = None
|
||||
|
||||
|
||||
class PriceDistributionItem(BaseModel):
|
||||
"""价格分布项"""
|
||||
|
||||
range: str
|
||||
count: int
|
||||
percentage: float
|
||||
|
||||
|
||||
class PriceDistribution(BaseModel):
|
||||
"""价格分布"""
|
||||
|
||||
low: PriceDistributionItem
|
||||
medium: PriceDistributionItem
|
||||
high: PriceDistributionItem
|
||||
|
||||
|
||||
class CompetitorPriceSummary(BaseModel):
|
||||
"""竞品价格摘要"""
|
||||
|
||||
competitor_name: str
|
||||
positioning: str
|
||||
original_price: float
|
||||
promo_price: Optional[float] = None
|
||||
collected_at: date
|
||||
|
||||
|
||||
class BenchmarkReference(BaseModel):
|
||||
"""标杆参考"""
|
||||
|
||||
tier: str
|
||||
min_price: float
|
||||
max_price: float
|
||||
avg_price: float
|
||||
|
||||
|
||||
class SuggestedRange(BaseModel):
|
||||
"""建议定价区间"""
|
||||
|
||||
min: float
|
||||
max: float
|
||||
recommended: float
|
||||
|
||||
|
||||
class MarketAnalysisResult(BaseModel):
|
||||
"""市场分析结果"""
|
||||
|
||||
project_id: int
|
||||
project_name: str
|
||||
analysis_date: date
|
||||
competitor_count: int
|
||||
price_statistics: PriceStatistics
|
||||
price_distribution: Optional[PriceDistribution] = None
|
||||
competitor_prices: List[CompetitorPriceSummary] = []
|
||||
benchmark_reference: Optional[BenchmarkReference] = None
|
||||
suggested_range: SuggestedRange
|
||||
|
||||
|
||||
class MarketAnalysisResponse(BaseModel):
|
||||
"""市场分析响应(数据库记录)"""
|
||||
|
||||
id: int
|
||||
project_id: int
|
||||
analysis_date: date
|
||||
competitor_count: int
|
||||
market_min_price: float
|
||||
market_max_price: float
|
||||
market_avg_price: float
|
||||
market_median_price: float
|
||||
suggested_range_min: float
|
||||
suggested_range_max: float
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
64
后端服务/app/schemas/material.py
Normal file
64
后端服务/app/schemas/material.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""耗材 Schema"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MaterialType(str, Enum):
|
||||
"""耗材类型枚举"""
|
||||
|
||||
CONSUMABLE = "consumable" # 一般耗材
|
||||
INJECTABLE = "injectable" # 针剂
|
||||
PRODUCT = "product" # 产品
|
||||
|
||||
|
||||
class MaterialBase(BaseModel):
|
||||
"""耗材基础字段"""
|
||||
|
||||
material_code: str = Field(..., min_length=1, max_length=50, description="耗材编码")
|
||||
material_name: str = Field(..., min_length=1, max_length=100, description="耗材名称")
|
||||
unit: str = Field(..., min_length=1, max_length=20, description="单位")
|
||||
unit_price: float = Field(..., ge=0, description="单价")
|
||||
supplier: Optional[str] = Field(None, max_length=100, description="供应商")
|
||||
material_type: MaterialType = Field(..., description="类型")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class MaterialCreate(MaterialBase):
|
||||
"""创建耗材请求"""
|
||||
pass
|
||||
|
||||
|
||||
class MaterialUpdate(BaseModel):
|
||||
"""更新耗材请求"""
|
||||
|
||||
material_code: Optional[str] = Field(None, min_length=1, max_length=50, description="耗材编码")
|
||||
material_name: Optional[str] = Field(None, min_length=1, max_length=100, description="耗材名称")
|
||||
unit: Optional[str] = Field(None, min_length=1, max_length=20, description="单位")
|
||||
unit_price: Optional[float] = Field(None, ge=0, description="单价")
|
||||
supplier: Optional[str] = Field(None, max_length=100, description="供应商")
|
||||
material_type: Optional[MaterialType] = Field(None, description="类型")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class MaterialResponse(MaterialBase):
|
||||
"""耗材响应"""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MaterialImportResult(BaseModel):
|
||||
"""批量导入结果"""
|
||||
|
||||
total: int = Field(..., description="总数")
|
||||
success: int = Field(..., description="成功数")
|
||||
failed: int = Field(..., description="失败数")
|
||||
errors: list[dict] = Field(default_factory=list, description="错误详情")
|
||||
194
后端服务/app/schemas/pricing.py
Normal file
194
后端服务/app/schemas/pricing.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""定价方案 Schema
|
||||
|
||||
智能定价建议相关的请求和响应模型
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class StrategyType(str, Enum):
|
||||
"""定价策略类型"""
|
||||
TRAFFIC = "traffic" # 引流款
|
||||
PROFIT = "profit" # 利润款
|
||||
PREMIUM = "premium" # 高端款
|
||||
|
||||
|
||||
class PricingPlanBase(BaseModel):
|
||||
"""定价方案基础字段"""
|
||||
|
||||
project_id: int = Field(..., description="项目ID")
|
||||
plan_name: str = Field(..., min_length=1, max_length=100, description="方案名称")
|
||||
strategy_type: StrategyType = Field(..., description="策略类型")
|
||||
target_margin: float = Field(..., ge=0, le=100, description="目标毛利率(%)")
|
||||
|
||||
|
||||
class PricingPlanCreate(PricingPlanBase):
|
||||
"""创建定价方案请求"""
|
||||
pass
|
||||
|
||||
|
||||
class PricingPlanUpdate(BaseModel):
|
||||
"""更新定价方案请求"""
|
||||
|
||||
plan_name: Optional[str] = Field(None, min_length=1, max_length=100, description="方案名称")
|
||||
strategy_type: Optional[StrategyType] = Field(None, description="策略类型")
|
||||
target_margin: Optional[float] = Field(None, ge=0, le=100, description="目标毛利率(%)")
|
||||
final_price: Optional[float] = Field(None, ge=0, description="最终定价")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class PricingPlanResponse(BaseModel):
|
||||
"""定价方案响应"""
|
||||
|
||||
id: int
|
||||
project_id: int
|
||||
project_name: Optional[str] = Field(None, description="项目名称")
|
||||
plan_name: str
|
||||
strategy_type: str
|
||||
base_cost: float
|
||||
target_margin: float
|
||||
suggested_price: float
|
||||
final_price: Optional[float] = None
|
||||
ai_advice: Optional[str] = None
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
created_by_name: Optional[str] = Field(None, description="创建人姓名")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PricingPlanListResponse(BaseModel):
|
||||
"""定价方案列表响应"""
|
||||
|
||||
id: int
|
||||
project_id: int
|
||||
project_name: Optional[str] = None
|
||||
plan_name: str
|
||||
strategy_type: str
|
||||
base_cost: float
|
||||
target_margin: float
|
||||
suggested_price: float
|
||||
final_price: Optional[float] = None
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
created_by_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PricingPlanQuery(BaseModel):
|
||||
"""定价方案查询参数"""
|
||||
|
||||
page: int = Field(1, ge=1, description="页码")
|
||||
page_size: int = Field(20, ge=1, le=100, description="每页数量")
|
||||
project_id: Optional[int] = Field(None, description="项目筛选")
|
||||
strategy_type: Optional[StrategyType] = Field(None, description="策略类型筛选")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
sort_by: str = Field("created_at", description="排序字段")
|
||||
sort_order: str = Field("desc", description="排序方向")
|
||||
|
||||
|
||||
# AI 定价建议相关
|
||||
|
||||
class GeneratePricingRequest(BaseModel):
|
||||
"""生成定价建议请求"""
|
||||
|
||||
target_margin: float = Field(50, ge=0, le=100, description="目标毛利率(%),默认50")
|
||||
strategies: Optional[List[StrategyType]] = Field(None, description="策略类型列表,默认全部")
|
||||
stream: bool = Field(False, description="是否流式返回")
|
||||
|
||||
|
||||
class StrategySuggestion(BaseModel):
|
||||
"""单个策略的定价建议"""
|
||||
|
||||
strategy: str = Field(..., description="策略名称")
|
||||
suggested_price: float = Field(..., description="建议价格")
|
||||
margin: float = Field(..., description="毛利率(%)")
|
||||
description: str = Field(..., description="策略说明")
|
||||
|
||||
|
||||
class AIAdvice(BaseModel):
|
||||
"""AI 建议内容"""
|
||||
|
||||
summary: str = Field(..., description="综合建议摘要")
|
||||
cost_analysis: str = Field(..., description="成本分析")
|
||||
market_analysis: str = Field(..., description="市场分析")
|
||||
risk_notes: str = Field(..., description="风险提示")
|
||||
recommendations: List[str] = Field(..., description="具体建议列表")
|
||||
|
||||
|
||||
class AIUsage(BaseModel):
|
||||
"""AI 调用使用情况"""
|
||||
|
||||
provider: str = Field(..., description="服务商")
|
||||
model: str = Field(..., description="模型")
|
||||
tokens: int = Field(..., description="Token 消耗")
|
||||
latency_ms: int = Field(..., description="延迟(毫秒)")
|
||||
|
||||
|
||||
class MarketReference(BaseModel):
|
||||
"""市场参考数据"""
|
||||
|
||||
min: float
|
||||
max: float
|
||||
avg: float
|
||||
|
||||
|
||||
class PricingSuggestions(BaseModel):
|
||||
"""定价建议集合"""
|
||||
|
||||
traffic: Optional[StrategySuggestion] = None
|
||||
profit: Optional[StrategySuggestion] = None
|
||||
premium: Optional[StrategySuggestion] = None
|
||||
|
||||
|
||||
class GeneratePricingResponse(BaseModel):
|
||||
"""生成定价建议响应"""
|
||||
|
||||
project_id: int
|
||||
project_name: str
|
||||
cost_base: float = Field(..., description="基础成本")
|
||||
market_reference: Optional[MarketReference] = Field(None, description="市场参考")
|
||||
pricing_suggestions: PricingSuggestions = Field(..., description="各策略定价建议")
|
||||
ai_advice: Optional[AIAdvice] = Field(None, description="AI 建议详情")
|
||||
ai_usage: Optional[AIUsage] = Field(None, description="AI 使用统计")
|
||||
|
||||
|
||||
class SimulateStrategyRequest(BaseModel):
|
||||
"""模拟定价策略请求"""
|
||||
|
||||
strategies: List[StrategyType] = Field(..., description="要模拟的策略类型")
|
||||
target_margin: float = Field(50, ge=0, le=100, description="目标毛利率(%)")
|
||||
|
||||
|
||||
class StrategySimulationResult(BaseModel):
|
||||
"""策略模拟结果"""
|
||||
|
||||
strategy_type: str
|
||||
strategy_name: str
|
||||
suggested_price: float
|
||||
margin: float
|
||||
profit_per_unit: float
|
||||
market_position: str = Field(..., description="市场位置描述")
|
||||
|
||||
|
||||
class SimulateStrategyResponse(BaseModel):
|
||||
"""模拟定价策略响应"""
|
||||
|
||||
project_id: int
|
||||
project_name: str
|
||||
base_cost: float
|
||||
results: List[StrategySimulationResult]
|
||||
|
||||
|
||||
class ExportReportRequest(BaseModel):
|
||||
"""导出报告请求"""
|
||||
|
||||
format: str = Field("pdf", description="导出格式:pdf/excel")
|
||||
194
后端服务/app/schemas/profit.py
Normal file
194
后端服务/app/schemas/profit.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""利润模拟 Schema
|
||||
|
||||
利润模拟测算相关的请求和响应模型
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PeriodType(str, Enum):
|
||||
"""周期类型"""
|
||||
DAILY = "daily" # 日
|
||||
WEEKLY = "weekly" # 周
|
||||
MONTHLY = "monthly" # 月
|
||||
|
||||
|
||||
# 利润模拟相关
|
||||
|
||||
class ProfitSimulationBase(BaseModel):
|
||||
"""利润模拟基础字段"""
|
||||
|
||||
pricing_plan_id: int = Field(..., description="定价方案ID")
|
||||
simulation_name: str = Field(..., min_length=1, max_length=100, description="模拟名称")
|
||||
price: float = Field(..., gt=0, description="模拟价格")
|
||||
estimated_volume: int = Field(..., gt=0, description="预估客量")
|
||||
period_type: PeriodType = Field(..., description="周期类型")
|
||||
|
||||
|
||||
class ProfitSimulationCreate(ProfitSimulationBase):
|
||||
"""创建利润模拟请求"""
|
||||
pass
|
||||
|
||||
|
||||
class SimulateProfitRequest(BaseModel):
|
||||
"""执行利润模拟请求"""
|
||||
|
||||
price: float = Field(..., gt=0, description="模拟价格")
|
||||
estimated_volume: int = Field(..., gt=0, description="预估客量")
|
||||
period_type: PeriodType = Field(PeriodType.MONTHLY, description="周期类型")
|
||||
|
||||
|
||||
class SimulationInput(BaseModel):
|
||||
"""模拟输入参数"""
|
||||
|
||||
price: float
|
||||
cost_per_unit: float
|
||||
estimated_volume: int
|
||||
period_type: str
|
||||
|
||||
|
||||
class SimulationResult(BaseModel):
|
||||
"""模拟计算结果"""
|
||||
|
||||
estimated_revenue: float = Field(..., description="预估收入")
|
||||
estimated_cost: float = Field(..., description="预估成本")
|
||||
estimated_profit: float = Field(..., description="预估利润")
|
||||
profit_margin: float = Field(..., description="利润率(%)")
|
||||
profit_per_unit: float = Field(..., description="单位利润")
|
||||
|
||||
|
||||
class BreakevenAnalysis(BaseModel):
|
||||
"""盈亏平衡分析"""
|
||||
|
||||
breakeven_volume: int = Field(..., description="盈亏平衡客量")
|
||||
current_volume: int = Field(..., description="当前预估客量")
|
||||
safety_margin: int = Field(..., description="安全边际(客量)")
|
||||
safety_margin_percentage: float = Field(..., description="安全边际率(%)")
|
||||
|
||||
|
||||
class SimulateProfitResponse(BaseModel):
|
||||
"""执行利润模拟响应"""
|
||||
|
||||
simulation_id: int
|
||||
pricing_plan_id: int
|
||||
project_name: str
|
||||
input: SimulationInput
|
||||
result: SimulationResult
|
||||
breakeven_analysis: BreakevenAnalysis
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ProfitSimulationResponse(BaseModel):
|
||||
"""利润模拟响应"""
|
||||
|
||||
id: int
|
||||
pricing_plan_id: int
|
||||
plan_name: Optional[str] = None
|
||||
project_name: Optional[str] = None
|
||||
simulation_name: str
|
||||
price: float
|
||||
estimated_volume: int
|
||||
period_type: str
|
||||
estimated_revenue: float
|
||||
estimated_cost: float
|
||||
estimated_profit: float
|
||||
profit_margin: float
|
||||
breakeven_volume: int
|
||||
created_at: datetime
|
||||
created_by_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProfitSimulationListResponse(BaseModel):
|
||||
"""利润模拟列表响应"""
|
||||
|
||||
id: int
|
||||
pricing_plan_id: int
|
||||
plan_name: Optional[str] = None
|
||||
project_name: Optional[str] = None
|
||||
simulation_name: str
|
||||
price: float
|
||||
estimated_volume: int
|
||||
period_type: str
|
||||
estimated_profit: float
|
||||
profit_margin: float
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProfitSimulationQuery(BaseModel):
|
||||
"""利润模拟查询参数"""
|
||||
|
||||
page: int = Field(1, ge=1, description="页码")
|
||||
page_size: int = Field(20, ge=1, le=100, description="每页数量")
|
||||
pricing_plan_id: Optional[int] = Field(None, description="定价方案筛选")
|
||||
period_type: Optional[PeriodType] = Field(None, description="周期类型筛选")
|
||||
sort_by: str = Field("created_at", description="排序字段")
|
||||
sort_order: str = Field("desc", description="排序方向")
|
||||
|
||||
|
||||
# 敏感性分析相关
|
||||
|
||||
class SensitivityAnalysisRequest(BaseModel):
|
||||
"""敏感性分析请求"""
|
||||
|
||||
price_change_rates: List[float] = Field(
|
||||
default=[-20, -15, -10, -5, 0, 5, 10, 15, 20],
|
||||
description="价格变动率列表(%)"
|
||||
)
|
||||
|
||||
|
||||
class SensitivityResultItem(BaseModel):
|
||||
"""敏感性分析单项结果"""
|
||||
|
||||
price_change_rate: float = Field(..., description="价格变动率(%)")
|
||||
adjusted_price: float = Field(..., description="调整后价格")
|
||||
adjusted_profit: float = Field(..., description="调整后利润")
|
||||
profit_change_rate: float = Field(..., description="利润变动率(%)")
|
||||
|
||||
|
||||
class SensitivityInsights(BaseModel):
|
||||
"""敏感性分析洞察"""
|
||||
|
||||
price_elasticity: str = Field(..., description="价格弹性描述")
|
||||
risk_level: str = Field(..., description="风险等级")
|
||||
recommendation: str = Field(..., description="建议")
|
||||
|
||||
|
||||
class SensitivityAnalysisResponse(BaseModel):
|
||||
"""敏感性分析响应"""
|
||||
|
||||
simulation_id: int
|
||||
base_price: float
|
||||
base_profit: float
|
||||
sensitivity_results: List[SensitivityResultItem]
|
||||
insights: Optional[SensitivityInsights] = None
|
||||
|
||||
|
||||
# 盈亏平衡分析
|
||||
|
||||
class BreakevenRequest(BaseModel):
|
||||
"""盈亏平衡分析请求"""
|
||||
|
||||
target_profit: Optional[float] = Field(None, description="目标利润(可选)")
|
||||
|
||||
|
||||
class BreakevenResponse(BaseModel):
|
||||
"""盈亏平衡分析响应"""
|
||||
|
||||
pricing_plan_id: int
|
||||
project_name: str
|
||||
price: float
|
||||
unit_cost: float
|
||||
fixed_cost_monthly: float
|
||||
breakeven_volume: int = Field(..., description="盈亏平衡客量")
|
||||
current_margin: float = Field(..., description="当前边际贡献")
|
||||
target_profit_volume: Optional[int] = Field(None, description="达到目标利润所需客量")
|
||||
84
后端服务/app/schemas/project.py
Normal file
84
后端服务/app/schemas/project.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""服务项目 Schema"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProjectBase(BaseModel):
|
||||
"""项目基础字段"""
|
||||
|
||||
project_code: str = Field(..., min_length=1, max_length=50, description="项目编码")
|
||||
project_name: str = Field(..., min_length=1, max_length=100, description="项目名称")
|
||||
category_id: Optional[int] = Field(None, description="项目分类ID")
|
||||
description: Optional[str] = Field(None, description="项目描述")
|
||||
duration_minutes: int = Field(0, ge=0, description="操作时长(分钟)")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
"""创建项目请求"""
|
||||
pass
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
"""更新项目请求"""
|
||||
|
||||
project_code: Optional[str] = Field(None, min_length=1, max_length=50, description="项目编码")
|
||||
project_name: Optional[str] = Field(None, min_length=1, max_length=100, description="项目名称")
|
||||
category_id: Optional[int] = Field(None, description="项目分类ID")
|
||||
description: Optional[str] = Field(None, description="项目描述")
|
||||
duration_minutes: Optional[int] = Field(None, ge=0, description="操作时长(分钟)")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class CostSummaryBrief(BaseModel):
|
||||
"""成本汇总简要信息"""
|
||||
|
||||
total_cost: float = Field(..., description="总成本(最低成本线)")
|
||||
material_cost: float = Field(..., description="耗材成本")
|
||||
equipment_cost: float = Field(..., description="设备折旧成本")
|
||||
labor_cost: float = Field(..., description="人工成本")
|
||||
fixed_cost_allocation: float = Field(..., description="固定成本分摊")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProjectResponse(ProjectBase):
|
||||
"""项目响应"""
|
||||
|
||||
id: int
|
||||
category_name: Optional[str] = Field(None, description="分类名称")
|
||||
cost_summary: Optional[CostSummaryBrief] = Field(None, description="成本汇总")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProjectListResponse(ProjectBase):
|
||||
"""项目列表响应"""
|
||||
|
||||
id: int
|
||||
category_name: Optional[str] = Field(None, description="分类名称")
|
||||
cost_summary: Optional[CostSummaryBrief] = Field(None, description="成本汇总")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ProjectQuery(BaseModel):
|
||||
"""项目查询参数"""
|
||||
|
||||
page: int = Field(1, ge=1, description="页码")
|
||||
page_size: int = Field(20, ge=1, le=100, description="每页数量")
|
||||
category_id: Optional[int] = Field(None, description="分类筛选")
|
||||
keyword: Optional[str] = Field(None, description="关键词搜索")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
sort_by: str = Field("created_at", description="排序字段")
|
||||
sort_order: str = Field("desc", description="排序方向")
|
||||
195
后端服务/app/schemas/project_cost.py
Normal file
195
后端服务/app/schemas/project_cost.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""项目成本相关 Schema"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CostItemType(str, Enum):
|
||||
"""成本明细类型枚举"""
|
||||
|
||||
MATERIAL = "material" # 耗材
|
||||
EQUIPMENT = "equipment" # 设备
|
||||
|
||||
|
||||
class AllocationMethod(str, Enum):
|
||||
"""固定成本分摊方式枚举"""
|
||||
|
||||
COUNT = "count" # 按项目数量平均分摊
|
||||
REVENUE = "revenue" # 按项目营收占比分摊
|
||||
DURATION = "duration" # 按项目时长占比分摊
|
||||
|
||||
|
||||
# ============ 项目成本明细(耗材/设备)Schema ============
|
||||
|
||||
class CostItemBase(BaseModel):
|
||||
"""成本明细基础字段"""
|
||||
|
||||
item_type: CostItemType = Field(..., description="类型:material/equipment")
|
||||
item_id: int = Field(..., description="耗材/设备ID")
|
||||
quantity: float = Field(..., gt=0, description="用量")
|
||||
remark: Optional[str] = Field(None, max_length=200, description="备注")
|
||||
|
||||
|
||||
class CostItemCreate(CostItemBase):
|
||||
"""创建成本明细请求"""
|
||||
pass
|
||||
|
||||
|
||||
class CostItemUpdate(BaseModel):
|
||||
"""更新成本明细请求"""
|
||||
|
||||
quantity: Optional[float] = Field(None, gt=0, description="用量")
|
||||
remark: Optional[str] = Field(None, max_length=200, description="备注")
|
||||
|
||||
|
||||
class CostItemResponse(BaseModel):
|
||||
"""成本明细响应"""
|
||||
|
||||
id: int
|
||||
item_type: CostItemType
|
||||
item_id: int
|
||||
item_name: Optional[str] = Field(None, description="耗材/设备名称")
|
||||
quantity: float
|
||||
unit: Optional[str] = Field(None, description="单位")
|
||||
unit_cost: float
|
||||
total_cost: float
|
||||
remark: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ============ 项目人工成本 Schema ============
|
||||
|
||||
class LaborCostBase(BaseModel):
|
||||
"""人工成本基础字段"""
|
||||
|
||||
staff_level_id: int = Field(..., description="人员级别ID")
|
||||
duration_minutes: int = Field(..., gt=0, description="操作时长(分钟)")
|
||||
remark: Optional[str] = Field(None, max_length=200, description="备注")
|
||||
|
||||
|
||||
class LaborCostCreate(LaborCostBase):
|
||||
"""创建人工成本请求"""
|
||||
pass
|
||||
|
||||
|
||||
class LaborCostUpdate(BaseModel):
|
||||
"""更新人工成本请求"""
|
||||
|
||||
staff_level_id: Optional[int] = Field(None, description="人员级别ID")
|
||||
duration_minutes: Optional[int] = Field(None, gt=0, description="操作时长(分钟)")
|
||||
remark: Optional[str] = Field(None, max_length=200, description="备注")
|
||||
|
||||
|
||||
class LaborCostResponse(BaseModel):
|
||||
"""人工成本响应"""
|
||||
|
||||
id: int
|
||||
staff_level_id: int
|
||||
level_name: Optional[str] = Field(None, description="级别名称")
|
||||
duration_minutes: int
|
||||
hourly_rate: float
|
||||
labor_cost: float
|
||||
remark: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ============ 成本计算相关 Schema ============
|
||||
|
||||
class CalculateCostRequest(BaseModel):
|
||||
"""计算成本请求"""
|
||||
|
||||
fixed_cost_allocation_method: AllocationMethod = Field(
|
||||
AllocationMethod.COUNT,
|
||||
description="固定成本分摊方式"
|
||||
)
|
||||
|
||||
|
||||
class CostBreakdownItem(BaseModel):
|
||||
"""成本明细项"""
|
||||
|
||||
name: str
|
||||
quantity: Optional[float] = None
|
||||
unit: Optional[str] = None
|
||||
unit_cost: Optional[float] = None
|
||||
depreciation_per_use: Optional[float] = None
|
||||
duration_minutes: Optional[int] = None
|
||||
hourly_rate: Optional[float] = None
|
||||
total: float
|
||||
|
||||
|
||||
class CostBreakdown(BaseModel):
|
||||
"""成本分项"""
|
||||
|
||||
items: List[CostBreakdownItem]
|
||||
subtotal: float
|
||||
|
||||
|
||||
class FixedCostAllocationDetail(BaseModel):
|
||||
"""固定成本分摊详情"""
|
||||
|
||||
method: str
|
||||
total_fixed_cost: float
|
||||
project_count: Optional[int] = None
|
||||
total_revenue: Optional[float] = None
|
||||
total_duration: Optional[int] = None
|
||||
allocation: float
|
||||
|
||||
|
||||
class CostCalculationResult(BaseModel):
|
||||
"""成本计算结果"""
|
||||
|
||||
project_id: int
|
||||
project_name: str
|
||||
cost_breakdown: dict = Field(..., description="成本分项明细")
|
||||
total_cost: float = Field(..., description="总成本")
|
||||
min_price_suggestion: float = Field(..., description="建议最低售价(等于总成本)")
|
||||
calculated_at: datetime
|
||||
|
||||
|
||||
class CostSummaryResponse(BaseModel):
|
||||
"""成本汇总响应"""
|
||||
|
||||
project_id: int
|
||||
material_cost: float
|
||||
equipment_cost: float
|
||||
labor_cost: float
|
||||
fixed_cost_allocation: float
|
||||
total_cost: float
|
||||
calculated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ============ 项目详情(含成本)Schema ============
|
||||
|
||||
class ProjectDetailResponse(BaseModel):
|
||||
"""项目详情响应(含成本明细)"""
|
||||
|
||||
id: int
|
||||
project_code: str
|
||||
project_name: str
|
||||
category_id: Optional[int]
|
||||
category_name: Optional[str]
|
||||
description: Optional[str]
|
||||
duration_minutes: int
|
||||
is_active: bool
|
||||
cost_items: List[CostItemResponse] = []
|
||||
labor_costs: List[LaborCostResponse] = []
|
||||
cost_summary: Optional[CostSummaryResponse] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
40
后端服务/app/schemas/staff_level.py
Normal file
40
后端服务/app/schemas/staff_level.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""人员级别 Schema"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class StaffLevelBase(BaseModel):
|
||||
"""人员级别基础字段"""
|
||||
|
||||
level_code: str = Field(..., min_length=1, max_length=20, description="级别编码")
|
||||
level_name: str = Field(..., min_length=1, max_length=50, description="级别名称")
|
||||
hourly_rate: float = Field(..., gt=0, description="时薪(元/小时)")
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class StaffLevelCreate(StaffLevelBase):
|
||||
"""创建人员级别请求"""
|
||||
pass
|
||||
|
||||
|
||||
class StaffLevelUpdate(BaseModel):
|
||||
"""更新人员级别请求"""
|
||||
|
||||
level_code: Optional[str] = Field(None, min_length=1, max_length=20, description="级别编码")
|
||||
level_name: Optional[str] = Field(None, min_length=1, max_length=50, description="级别名称")
|
||||
hourly_rate: Optional[float] = Field(None, gt=0, description="时薪(元/小时)")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class StaffLevelResponse(StaffLevelBase):
|
||||
"""人员级别响应"""
|
||||
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
13
后端服务/app/services/__init__.py
Normal file
13
后端服务/app/services/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""业务逻辑服务"""
|
||||
|
||||
from app.services.cost_service import CostService
|
||||
from app.services.market_service import MarketService
|
||||
from app.services.pricing_service import PricingService
|
||||
from app.services.profit_service import ProfitService
|
||||
|
||||
__all__ = [
|
||||
"CostService",
|
||||
"MarketService",
|
||||
"PricingService",
|
||||
"ProfitService",
|
||||
]
|
||||
257
后端服务/app/services/ai_service_wrapper.py
Normal file
257
后端服务/app/services/ai_service_wrapper.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""AI 服务封装
|
||||
|
||||
遵循瑞小美 AI 接入规范:
|
||||
- 通过 shared_backend.AIService 调用
|
||||
- 初始化时传入 db_session(用于日志记录)
|
||||
- 调用时传入 prompt_name(用于统计)
|
||||
|
||||
性能优化:
|
||||
- 相同输入的响应缓存(减少 API 调用成本)
|
||||
- 缓存键基于消息内容哈希
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Optional, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.services.cache_service import get_cache, CacheNamespace
|
||||
|
||||
|
||||
class AIServiceWrapper:
|
||||
"""AI 服务封装类
|
||||
|
||||
封装 shared_backend.AIService 的调用
|
||||
提供统一的接口供业务层使用
|
||||
支持响应缓存以减少 API 调用
|
||||
"""
|
||||
|
||||
# 默认缓存 TTL(1 小时)- AI 响应通常不会频繁变化
|
||||
DEFAULT_CACHE_TTL = 3600
|
||||
|
||||
def __init__(self, db_session: AsyncSession, enable_cache: bool = True):
|
||||
"""初始化 AI 服务
|
||||
|
||||
Args:
|
||||
db_session: 数据库会话,用于记录 AI 调用日志
|
||||
enable_cache: 是否启用响应缓存
|
||||
"""
|
||||
self.db_session = db_session
|
||||
self.module_code = settings.AI_MODULE_CODE
|
||||
self.enable_cache = enable_cache
|
||||
self._ai_service = None
|
||||
self._cache = get_cache(CacheNamespace.AI_RESPONSES, maxsize=100, ttl=self.DEFAULT_CACHE_TTL)
|
||||
|
||||
def _generate_cache_key(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
prompt_name: str,
|
||||
model: Optional[str] = None,
|
||||
) -> str:
|
||||
"""生成缓存键
|
||||
|
||||
基于消息内容和参数生成唯一的缓存键
|
||||
"""
|
||||
key_data = {
|
||||
"messages": messages,
|
||||
"prompt_name": prompt_name,
|
||||
"model": model or "default",
|
||||
}
|
||||
key_str = json.dumps(key_data, sort_keys=True, ensure_ascii=False)
|
||||
return f"ai:{hashlib.sha256(key_str.encode()).hexdigest()[:16]}"
|
||||
|
||||
async def _get_service(self):
|
||||
"""获取 AIService 实例(延迟加载)"""
|
||||
if self._ai_service is None:
|
||||
try:
|
||||
from shared_backend.services.ai_service import AIService
|
||||
self._ai_service = AIService(
|
||||
module_code=self.module_code,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
except ImportError:
|
||||
# 开发环境可能没有 shared_backend
|
||||
# 使用 Mock 实现
|
||||
self._ai_service = MockAIService(self.module_code)
|
||||
return self._ai_service
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
prompt_name: str,
|
||||
model: Optional[str] = None,
|
||||
use_cache: bool = True,
|
||||
cache_ttl: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""调用 AI 聊天接口(带缓存支持)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
prompt_name: 提示词名称(必填,用于统计)
|
||||
model: 模型名称,默认使用配置的模型
|
||||
use_cache: 是否使用缓存
|
||||
cache_ttl: 缓存 TTL(秒),默认 3600
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
AIResponse 对象
|
||||
"""
|
||||
# 检查缓存
|
||||
if self.enable_cache and use_cache:
|
||||
cache_key = self._generate_cache_key(messages, prompt_name, model)
|
||||
cached_response = self._cache.get(cache_key)
|
||||
if cached_response is not None:
|
||||
# 返回缓存的响应(添加标记)
|
||||
cached_response.from_cache = True
|
||||
return cached_response
|
||||
|
||||
# 调用 AI 服务
|
||||
service = await self._get_service()
|
||||
response = await service.chat(
|
||||
messages=messages,
|
||||
prompt_name=prompt_name,
|
||||
model=model,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# 存入缓存
|
||||
if self.enable_cache and use_cache and response is not None:
|
||||
cache_key = self._generate_cache_key(messages, prompt_name, model)
|
||||
response.from_cache = False
|
||||
self._cache.set(cache_key, response, cache_ttl or self.DEFAULT_CACHE_TTL)
|
||||
|
||||
return response
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
prompt_name: str,
|
||||
model: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""调用 AI 聊天流式接口
|
||||
|
||||
注意:流式接口不使用缓存
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
prompt_name: 提示词名称(必填,用于统计)
|
||||
model: 模型名称
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
响应片段
|
||||
"""
|
||||
service = await self._get_service()
|
||||
async for chunk in service.chat_stream(
|
||||
messages=messages,
|
||||
prompt_name=prompt_name,
|
||||
model=model,
|
||||
**kwargs,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
def clear_cache(self, prompt_name: Optional[str] = None):
|
||||
"""清除 AI 响应缓存
|
||||
|
||||
Args:
|
||||
prompt_name: 指定提示词名称清除,None 则清除全部
|
||||
"""
|
||||
# 简单实现:清除整个缓存
|
||||
# 更精细的实现可以按 prompt_name 过滤
|
||||
self._cache.clear()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
return self._cache.stats()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockAIResponse:
|
||||
"""Mock AI 响应"""
|
||||
content: str = "这是一个 Mock 响应,用于开发测试。实际部署时会使用真实的 AI 服务。"
|
||||
model: str = "mock-model"
|
||||
provider: str = "mock"
|
||||
input_tokens: int = 100
|
||||
output_tokens: int = 50
|
||||
total_tokens: int = 150
|
||||
cost: float = 0.0
|
||||
latency_ms: int = 100
|
||||
raw_response: dict = None
|
||||
images: list = None
|
||||
annotations: dict = None
|
||||
from_cache: bool = False
|
||||
|
||||
|
||||
class MockAIService:
|
||||
"""Mock AI 服务(开发环境使用)"""
|
||||
|
||||
def __init__(self, module_code: str):
|
||||
self.module_code = module_code
|
||||
|
||||
async def chat(self, messages, prompt_name, **kwargs):
|
||||
"""Mock 聊天接口"""
|
||||
# 生成一个基于输入的简单响应
|
||||
user_message = ""
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
user_message = msg.get("content", "")[:100]
|
||||
break
|
||||
|
||||
response = MockAIResponse(
|
||||
content=f"""## Mock AI 分析报告
|
||||
|
||||
根据您提供的数据,以下是分析结果:
|
||||
|
||||
### 定价建议
|
||||
- **推荐价格**: 根据成本和市场分析,建议定价在合理区间内
|
||||
- **引流款策略**: 适合新客引流,建议价格较低
|
||||
- **利润款策略**: 适合日常经营,建议价格适中
|
||||
- **高端款策略**: 适合高端客群,可考虑较高定价
|
||||
|
||||
### 风险提示
|
||||
- 请密切关注市场动态
|
||||
- 建议定期复核定价策略
|
||||
|
||||
*注意:这是开发测试环境的 Mock 响应,实际部署时会使用真实的 AI 服务。*
|
||||
""",
|
||||
model="mock-model",
|
||||
provider="mock",
|
||||
input_tokens=len(str(messages)),
|
||||
output_tokens=200,
|
||||
total_tokens=len(str(messages)) + 200,
|
||||
cost=0.0,
|
||||
latency_ms=50,
|
||||
)
|
||||
return response
|
||||
|
||||
async def chat_stream(self, messages, prompt_name, **kwargs):
|
||||
"""Mock 流式接口"""
|
||||
chunks = [
|
||||
"## Mock AI 分析报告\n\n",
|
||||
"根据您提供的数据,以下是分析结果:\n\n",
|
||||
"### 定价建议\n",
|
||||
"- 推荐价格:合理区间内\n",
|
||||
"- 引流款策略:适合新客引流\n",
|
||||
"- 利润款策略:适合日常经营\n\n",
|
||||
"*这是 Mock 响应,实际部署时会使用真实的 AI 服务。*",
|
||||
]
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def get_ai_service(db_session: AsyncSession, enable_cache: bool = True) -> AIServiceWrapper:
|
||||
"""获取 AI 服务实例(依赖注入)
|
||||
|
||||
Args:
|
||||
db_session: 数据库会话
|
||||
enable_cache: 是否启用缓存
|
||||
|
||||
Returns:
|
||||
AIServiceWrapper 实例
|
||||
"""
|
||||
return AIServiceWrapper(db_session, enable_cache=enable_cache)
|
||||
233
后端服务/app/services/cache_service.py
Normal file
233
后端服务/app/services/cache_service.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""缓存服务
|
||||
|
||||
实现简单的内存缓存和 LRU 缓存策略
|
||||
用于优化频繁查询的数据
|
||||
|
||||
遵循瑞小美系统技术栈标准
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Any, Callable, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
import threading
|
||||
|
||||
|
||||
class TTLCache:
|
||||
"""带过期时间的缓存
|
||||
|
||||
线程安全的 TTL 缓存实现
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int = 1000, ttl: int = 300):
|
||||
"""
|
||||
Args:
|
||||
maxsize: 最大缓存条目数
|
||||
ttl: 默认过期时间(秒)
|
||||
"""
|
||||
self.maxsize = maxsize
|
||||
self.ttl = ttl
|
||||
self._cache: OrderedDict = OrderedDict()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _is_expired(self, expire_at: datetime) -> bool:
|
||||
"""检查是否过期"""
|
||||
return datetime.now() > expire_at
|
||||
|
||||
def get(self, key: str) -> Optional[Any]:
|
||||
"""获取缓存值"""
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
|
||||
value, expire_at = self._cache[key]
|
||||
|
||||
if self._is_expired(expire_at):
|
||||
del self._cache[key]
|
||||
return None
|
||||
|
||||
# 移动到末尾(LRU)
|
||||
self._cache.move_to_end(key)
|
||||
return value
|
||||
|
||||
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
||||
"""设置缓存值"""
|
||||
with self._lock:
|
||||
expire_at = datetime.now() + timedelta(seconds=ttl or self.ttl)
|
||||
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
|
||||
self._cache[key] = (value, expire_at)
|
||||
|
||||
# 超过最大容量时删除最旧的
|
||||
while len(self._cache) > self.maxsize:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
"""删除缓存"""
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空缓存"""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
def cleanup(self) -> int:
|
||||
"""清理过期缓存,返回清理数量"""
|
||||
with self._lock:
|
||||
expired_keys = [
|
||||
key for key, (_, expire_at) in self._cache.items()
|
||||
if self._is_expired(expire_at)
|
||||
]
|
||||
for key in expired_keys:
|
||||
del self._cache[key]
|
||||
return len(expired_keys)
|
||||
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计"""
|
||||
with self._lock:
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"maxsize": self.maxsize,
|
||||
"ttl": self.ttl,
|
||||
}
|
||||
|
||||
|
||||
# 全局缓存实例
|
||||
_cache_instances: Dict[str, TTLCache] = {}
|
||||
|
||||
|
||||
def get_cache(namespace: str = "default", maxsize: int = 1000, ttl: int = 300) -> TTLCache:
|
||||
"""获取缓存实例
|
||||
|
||||
Args:
|
||||
namespace: 缓存命名空间
|
||||
maxsize: 最大缓存条目数
|
||||
ttl: 默认过期时间(秒)
|
||||
|
||||
Returns:
|
||||
缓存实例
|
||||
"""
|
||||
if namespace not in _cache_instances:
|
||||
_cache_instances[namespace] = TTLCache(maxsize=maxsize, ttl=ttl)
|
||||
return _cache_instances[namespace]
|
||||
|
||||
|
||||
def cache_key(*args, **kwargs) -> str:
|
||||
"""生成缓存键"""
|
||||
key_parts = [str(arg) for arg in args]
|
||||
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
|
||||
key_str = "|".join(key_parts)
|
||||
return hashlib.md5(key_str.encode()).hexdigest()
|
||||
|
||||
|
||||
def cached(
|
||||
namespace: str = "default",
|
||||
ttl: int = 300,
|
||||
key_prefix: str = "",
|
||||
):
|
||||
"""缓存装饰器
|
||||
|
||||
Args:
|
||||
namespace: 缓存命名空间
|
||||
ttl: 过期时间(秒)
|
||||
key_prefix: 键前缀
|
||||
|
||||
Example:
|
||||
@cached(namespace="projects", ttl=60, key_prefix="project_detail")
|
||||
async def get_project(project_id: int):
|
||||
...
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
cache = get_cache(namespace, ttl=ttl)
|
||||
|
||||
# 生成缓存键
|
||||
key = f"{key_prefix}:{cache_key(*args, **kwargs)}"
|
||||
|
||||
# 尝试获取缓存
|
||||
cached_value = cache.get(key)
|
||||
if cached_value is not None:
|
||||
return cached_value
|
||||
|
||||
# 执行函数
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# 设置缓存
|
||||
if result is not None:
|
||||
cache.set(key, result, ttl)
|
||||
|
||||
return result
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
cache = get_cache(namespace, ttl=ttl)
|
||||
key = f"{key_prefix}:{cache_key(*args, **kwargs)}"
|
||||
|
||||
cached_value = cache.get(key)
|
||||
if cached_value is not None:
|
||||
return cached_value
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if result is not None:
|
||||
cache.set(key, result, ttl)
|
||||
|
||||
return result
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def invalidate_cache(namespace: str, key_prefix: str = "") -> None:
|
||||
"""使缓存失效
|
||||
|
||||
注意:这会清除整个命名空间的缓存
|
||||
"""
|
||||
cache = get_cache(namespace)
|
||||
cache.clear()
|
||||
|
||||
|
||||
# 预定义缓存命名空间
|
||||
class CacheNamespace:
|
||||
"""缓存命名空间常量"""
|
||||
CATEGORIES = "categories"
|
||||
MATERIALS = "materials"
|
||||
EQUIPMENTS = "equipments"
|
||||
STAFF_LEVELS = "staff_levels"
|
||||
PROJECTS = "projects"
|
||||
MARKET_ANALYSIS = "market_analysis"
|
||||
AI_RESPONSES = "ai_responses"
|
||||
|
||||
|
||||
# 预初始化常用缓存
|
||||
def init_caches():
|
||||
"""初始化缓存实例"""
|
||||
# 基础数据缓存(较长 TTL)
|
||||
get_cache(CacheNamespace.CATEGORIES, maxsize=100, ttl=600)
|
||||
get_cache(CacheNamespace.MATERIALS, maxsize=500, ttl=600)
|
||||
get_cache(CacheNamespace.EQUIPMENTS, maxsize=200, ttl=600)
|
||||
get_cache(CacheNamespace.STAFF_LEVELS, maxsize=50, ttl=600)
|
||||
|
||||
# 业务数据缓存(较短 TTL)
|
||||
get_cache(CacheNamespace.PROJECTS, maxsize=500, ttl=300)
|
||||
get_cache(CacheNamespace.MARKET_ANALYSIS, maxsize=200, ttl=180)
|
||||
|
||||
# AI 响应缓存(较长 TTL,减少 API 调用)
|
||||
get_cache(CacheNamespace.AI_RESPONSES, maxsize=100, ttl=3600)
|
||||
|
||||
|
||||
# 应用启动时初始化
|
||||
init_caches()
|
||||
478
后端服务/app/services/cost_service.py
Normal file
478
后端服务/app/services/cost_service.py
Normal file
@@ -0,0 +1,478 @@
|
||||
"""成本计算服务
|
||||
|
||||
实现项目成本计算核心业务逻辑,包含:
|
||||
- 耗材成本计算
|
||||
- 设备折旧成本计算
|
||||
- 人工成本计算
|
||||
- 固定成本分摊(三种分摊方式)
|
||||
- 成本汇总
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models import (
|
||||
Project,
|
||||
ProjectCostItem,
|
||||
ProjectLaborCost,
|
||||
ProjectCostSummary,
|
||||
Material,
|
||||
Equipment,
|
||||
StaffLevel,
|
||||
FixedCost,
|
||||
)
|
||||
from app.schemas.project_cost import (
|
||||
AllocationMethod,
|
||||
CostItemType,
|
||||
CostBreakdownItem,
|
||||
CostCalculationResult,
|
||||
)
|
||||
|
||||
|
||||
class CostService:
|
||||
"""成本计算服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_material_info(self, material_id: int) -> Optional[Material]:
|
||||
"""获取耗材信息"""
|
||||
result = await self.db.execute(
|
||||
select(Material).where(Material.id == material_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_equipment_info(self, equipment_id: int) -> Optional[Equipment]:
|
||||
"""获取设备信息"""
|
||||
result = await self.db.execute(
|
||||
select(Equipment).where(Equipment.id == equipment_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_staff_level_info(self, staff_level_id: int) -> Optional[StaffLevel]:
|
||||
"""获取人员级别信息"""
|
||||
result = await self.db.execute(
|
||||
select(StaffLevel).where(StaffLevel.id == staff_level_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def calculate_material_cost(self, project_id: int) -> tuple[Decimal, List[Dict[str, Any]]]:
|
||||
"""计算耗材成本
|
||||
|
||||
Returns:
|
||||
(总耗材成本, 耗材明细列表)
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(ProjectCostItem).where(
|
||||
ProjectCostItem.project_id == project_id,
|
||||
ProjectCostItem.item_type == CostItemType.MATERIAL.value
|
||||
)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
|
||||
total = Decimal("0")
|
||||
breakdown = []
|
||||
|
||||
for item in items:
|
||||
material = await self.get_material_info(item.item_id)
|
||||
item_detail = {
|
||||
"name": material.material_name if material else f"耗材#{item.item_id}",
|
||||
"quantity": float(item.quantity),
|
||||
"unit": material.unit if material else "",
|
||||
"unit_cost": float(item.unit_cost),
|
||||
"total": float(item.total_cost),
|
||||
}
|
||||
breakdown.append(item_detail)
|
||||
total += item.total_cost
|
||||
|
||||
return total, breakdown
|
||||
|
||||
async def calculate_equipment_cost(self, project_id: int) -> tuple[Decimal, List[Dict[str, Any]]]:
|
||||
"""计算设备折旧成本
|
||||
|
||||
Returns:
|
||||
(总设备折旧成本, 设备明细列表)
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(ProjectCostItem).where(
|
||||
ProjectCostItem.project_id == project_id,
|
||||
ProjectCostItem.item_type == CostItemType.EQUIPMENT.value
|
||||
)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
|
||||
total = Decimal("0")
|
||||
breakdown = []
|
||||
|
||||
for item in items:
|
||||
equipment = await self.get_equipment_info(item.item_id)
|
||||
item_detail = {
|
||||
"name": equipment.equipment_name if equipment else f"设备#{item.item_id}",
|
||||
"depreciation_per_use": float(item.unit_cost),
|
||||
"total": float(item.total_cost),
|
||||
}
|
||||
breakdown.append(item_detail)
|
||||
total += item.total_cost
|
||||
|
||||
return total, breakdown
|
||||
|
||||
async def calculate_labor_cost(self, project_id: int) -> tuple[Decimal, List[Dict[str, Any]]]:
|
||||
"""计算人工成本
|
||||
|
||||
Returns:
|
||||
(总人工成本, 人工明细列表)
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(ProjectLaborCost).where(
|
||||
ProjectLaborCost.project_id == project_id
|
||||
)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
|
||||
total = Decimal("0")
|
||||
breakdown = []
|
||||
|
||||
for item in items:
|
||||
staff_level = await self.get_staff_level_info(item.staff_level_id)
|
||||
item_detail = {
|
||||
"name": staff_level.level_name if staff_level else f"级别#{item.staff_level_id}",
|
||||
"duration_minutes": item.duration_minutes,
|
||||
"hourly_rate": float(item.hourly_rate),
|
||||
"total": float(item.labor_cost),
|
||||
}
|
||||
breakdown.append(item_detail)
|
||||
total += item.labor_cost
|
||||
|
||||
return total, breakdown
|
||||
|
||||
async def calculate_fixed_cost_allocation(
|
||||
self,
|
||||
project_id: int,
|
||||
method: AllocationMethod = AllocationMethod.COUNT,
|
||||
year_month: Optional[str] = None
|
||||
) -> tuple[Decimal, Dict[str, Any]]:
|
||||
"""计算固定成本分摊
|
||||
|
||||
三种分摊方式:
|
||||
- COUNT: 按项目数量平均分摊
|
||||
- REVENUE: 按项目营收占比分摊(当前简化为平均分摊)
|
||||
- DURATION: 按项目时长占比分摊
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
method: 分摊方式
|
||||
year_month: 年月,默认当前月份
|
||||
|
||||
Returns:
|
||||
(分摊金额, 分摊详情)
|
||||
"""
|
||||
if not year_month:
|
||||
year_month = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# 获取当月固定成本总额
|
||||
result = await self.db.execute(
|
||||
select(func.sum(FixedCost.monthly_amount)).where(
|
||||
FixedCost.year_month == year_month,
|
||||
FixedCost.is_active == True
|
||||
)
|
||||
)
|
||||
total_fixed_cost = result.scalar() or Decimal("0")
|
||||
|
||||
if total_fixed_cost == 0:
|
||||
return Decimal("0"), {
|
||||
"method": method.value,
|
||||
"total_fixed_cost": 0,
|
||||
"allocation": 0,
|
||||
}
|
||||
|
||||
allocation = Decimal("0")
|
||||
detail = {
|
||||
"method": method.value,
|
||||
"total_fixed_cost": float(total_fixed_cost),
|
||||
}
|
||||
|
||||
if method == AllocationMethod.COUNT:
|
||||
# 按项目数量平均分摊
|
||||
count_result = await self.db.execute(
|
||||
select(func.count(Project.id)).where(Project.is_active == True)
|
||||
)
|
||||
project_count = count_result.scalar() or 1
|
||||
allocation = total_fixed_cost / Decimal(str(project_count))
|
||||
detail["project_count"] = project_count
|
||||
|
||||
elif method == AllocationMethod.DURATION:
|
||||
# 按项目时长占比分摊
|
||||
# 获取当前项目时长
|
||||
project_result = await self.db.execute(
|
||||
select(Project.duration_minutes).where(Project.id == project_id)
|
||||
)
|
||||
project_duration = project_result.scalar() or 0
|
||||
|
||||
# 获取所有活跃项目总时长
|
||||
total_duration_result = await self.db.execute(
|
||||
select(func.sum(Project.duration_minutes)).where(Project.is_active == True)
|
||||
)
|
||||
total_duration = total_duration_result.scalar() or 1
|
||||
|
||||
if total_duration > 0:
|
||||
ratio = Decimal(str(project_duration)) / Decimal(str(total_duration))
|
||||
allocation = total_fixed_cost * ratio
|
||||
|
||||
detail["project_duration"] = project_duration
|
||||
detail["total_duration"] = total_duration
|
||||
|
||||
elif method == AllocationMethod.REVENUE:
|
||||
# 按营收占比分摊(暂简化为平均分摊,后续可接入实际营收数据)
|
||||
count_result = await self.db.execute(
|
||||
select(func.count(Project.id)).where(Project.is_active == True)
|
||||
)
|
||||
project_count = count_result.scalar() or 1
|
||||
allocation = total_fixed_cost / Decimal(str(project_count))
|
||||
detail["project_count"] = project_count
|
||||
detail["note"] = "暂按项目数量平均分摊,后续可接入实际营收数据"
|
||||
|
||||
detail["allocation"] = float(allocation)
|
||||
|
||||
return allocation, detail
|
||||
|
||||
async def calculate_project_cost(
|
||||
self,
|
||||
project_id: int,
|
||||
allocation_method: AllocationMethod = AllocationMethod.COUNT
|
||||
) -> CostCalculationResult:
|
||||
"""计算项目总成本
|
||||
|
||||
总成本 = 耗材成本 + 设备折旧成本 + 人工成本 + 固定成本分摊
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
allocation_method: 固定成本分摊方式
|
||||
|
||||
Returns:
|
||||
成本计算结果
|
||||
"""
|
||||
# 获取项目信息
|
||||
project_result = await self.db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise ValueError(f"项目不存在: {project_id}")
|
||||
|
||||
# 计算各项成本
|
||||
material_cost, material_breakdown = await self.calculate_material_cost(project_id)
|
||||
equipment_cost, equipment_breakdown = await self.calculate_equipment_cost(project_id)
|
||||
labor_cost, labor_breakdown = await self.calculate_labor_cost(project_id)
|
||||
fixed_allocation, fixed_detail = await self.calculate_fixed_cost_allocation(
|
||||
project_id, allocation_method
|
||||
)
|
||||
|
||||
# 计算总成本
|
||||
total_cost = material_cost + equipment_cost + labor_cost + fixed_allocation
|
||||
|
||||
# 构建成本分项
|
||||
cost_breakdown = {
|
||||
"material_cost": {
|
||||
"items": material_breakdown,
|
||||
"subtotal": float(material_cost),
|
||||
},
|
||||
"equipment_cost": {
|
||||
"items": equipment_breakdown,
|
||||
"subtotal": float(equipment_cost),
|
||||
},
|
||||
"labor_cost": {
|
||||
"items": labor_breakdown,
|
||||
"subtotal": float(labor_cost),
|
||||
},
|
||||
"fixed_cost_allocation": fixed_detail,
|
||||
}
|
||||
|
||||
calculated_at = datetime.now()
|
||||
|
||||
# 更新或创建成本汇总记录
|
||||
await self._save_cost_summary(
|
||||
project_id=project_id,
|
||||
material_cost=material_cost,
|
||||
equipment_cost=equipment_cost,
|
||||
labor_cost=labor_cost,
|
||||
fixed_cost_allocation=fixed_allocation,
|
||||
total_cost=total_cost,
|
||||
calculated_at=calculated_at,
|
||||
)
|
||||
|
||||
return CostCalculationResult(
|
||||
project_id=project_id,
|
||||
project_name=project.project_name,
|
||||
cost_breakdown=cost_breakdown,
|
||||
total_cost=float(total_cost),
|
||||
min_price_suggestion=float(total_cost),
|
||||
calculated_at=calculated_at,
|
||||
)
|
||||
|
||||
async def _save_cost_summary(
|
||||
self,
|
||||
project_id: int,
|
||||
material_cost: Decimal,
|
||||
equipment_cost: Decimal,
|
||||
labor_cost: Decimal,
|
||||
fixed_cost_allocation: Decimal,
|
||||
total_cost: Decimal,
|
||||
calculated_at: datetime,
|
||||
):
|
||||
"""保存或更新成本汇总"""
|
||||
result = await self.db.execute(
|
||||
select(ProjectCostSummary).where(
|
||||
ProjectCostSummary.project_id == project_id
|
||||
)
|
||||
)
|
||||
summary = result.scalar_one_or_none()
|
||||
|
||||
if summary:
|
||||
summary.material_cost = material_cost
|
||||
summary.equipment_cost = equipment_cost
|
||||
summary.labor_cost = labor_cost
|
||||
summary.fixed_cost_allocation = fixed_cost_allocation
|
||||
summary.total_cost = total_cost
|
||||
summary.calculated_at = calculated_at
|
||||
else:
|
||||
summary = ProjectCostSummary(
|
||||
project_id=project_id,
|
||||
material_cost=material_cost,
|
||||
equipment_cost=equipment_cost,
|
||||
labor_cost=labor_cost,
|
||||
fixed_cost_allocation=fixed_cost_allocation,
|
||||
total_cost=total_cost,
|
||||
calculated_at=calculated_at,
|
||||
)
|
||||
self.db.add(summary)
|
||||
|
||||
await self.db.flush()
|
||||
|
||||
async def add_cost_item(
|
||||
self,
|
||||
project_id: int,
|
||||
item_type: CostItemType,
|
||||
item_id: int,
|
||||
quantity: float,
|
||||
remark: Optional[str] = None,
|
||||
) -> ProjectCostItem:
|
||||
"""添加成本明细项
|
||||
|
||||
自动计算 unit_cost 和 total_cost
|
||||
"""
|
||||
# 根据类型获取单位成本
|
||||
if item_type == CostItemType.MATERIAL:
|
||||
material = await self.get_material_info(item_id)
|
||||
if not material:
|
||||
raise ValueError(f"耗材不存在: {item_id}")
|
||||
unit_cost = Decimal(str(material.unit_price))
|
||||
else:
|
||||
equipment = await self.get_equipment_info(item_id)
|
||||
if not equipment:
|
||||
raise ValueError(f"设备不存在: {item_id}")
|
||||
unit_cost = equipment.depreciation_per_use
|
||||
|
||||
total_cost = unit_cost * Decimal(str(quantity))
|
||||
|
||||
cost_item = ProjectCostItem(
|
||||
project_id=project_id,
|
||||
item_type=item_type.value,
|
||||
item_id=item_id,
|
||||
quantity=Decimal(str(quantity)),
|
||||
unit_cost=unit_cost,
|
||||
total_cost=total_cost,
|
||||
remark=remark,
|
||||
)
|
||||
self.db.add(cost_item)
|
||||
await self.db.flush()
|
||||
await self.db.refresh(cost_item)
|
||||
|
||||
return cost_item
|
||||
|
||||
async def update_cost_item(
|
||||
self,
|
||||
cost_item: ProjectCostItem,
|
||||
quantity: Optional[float] = None,
|
||||
remark: Optional[str] = None,
|
||||
) -> ProjectCostItem:
|
||||
"""更新成本明细项"""
|
||||
if quantity is not None:
|
||||
cost_item.quantity = Decimal(str(quantity))
|
||||
cost_item.total_cost = cost_item.unit_cost * cost_item.quantity
|
||||
|
||||
if remark is not None:
|
||||
cost_item.remark = remark
|
||||
|
||||
await self.db.flush()
|
||||
await self.db.refresh(cost_item)
|
||||
|
||||
return cost_item
|
||||
|
||||
async def add_labor_cost(
|
||||
self,
|
||||
project_id: int,
|
||||
staff_level_id: int,
|
||||
duration_minutes: int,
|
||||
remark: Optional[str] = None,
|
||||
) -> ProjectLaborCost:
|
||||
"""添加人工成本
|
||||
|
||||
自动获取时薪并计算人工成本
|
||||
"""
|
||||
staff_level = await self.get_staff_level_info(staff_level_id)
|
||||
if not staff_level:
|
||||
raise ValueError(f"人员级别不存在: {staff_level_id}")
|
||||
|
||||
hourly_rate = Decimal(str(staff_level.hourly_rate))
|
||||
labor_cost = hourly_rate * Decimal(str(duration_minutes)) / Decimal("60")
|
||||
|
||||
labor_item = ProjectLaborCost(
|
||||
project_id=project_id,
|
||||
staff_level_id=staff_level_id,
|
||||
duration_minutes=duration_minutes,
|
||||
hourly_rate=hourly_rate,
|
||||
labor_cost=labor_cost,
|
||||
remark=remark,
|
||||
)
|
||||
self.db.add(labor_item)
|
||||
await self.db.flush()
|
||||
await self.db.refresh(labor_item)
|
||||
|
||||
return labor_item
|
||||
|
||||
async def update_labor_cost(
|
||||
self,
|
||||
labor_item: ProjectLaborCost,
|
||||
staff_level_id: Optional[int] = None,
|
||||
duration_minutes: Optional[int] = None,
|
||||
remark: Optional[str] = None,
|
||||
) -> ProjectLaborCost:
|
||||
"""更新人工成本"""
|
||||
if staff_level_id is not None:
|
||||
staff_level = await self.get_staff_level_info(staff_level_id)
|
||||
if not staff_level:
|
||||
raise ValueError(f"人员级别不存在: {staff_level_id}")
|
||||
labor_item.staff_level_id = staff_level_id
|
||||
labor_item.hourly_rate = Decimal(str(staff_level.hourly_rate))
|
||||
|
||||
if duration_minutes is not None:
|
||||
labor_item.duration_minutes = duration_minutes
|
||||
|
||||
# 重新计算人工成本
|
||||
labor_item.labor_cost = (
|
||||
labor_item.hourly_rate * Decimal(str(labor_item.duration_minutes)) / Decimal("60")
|
||||
)
|
||||
|
||||
if remark is not None:
|
||||
labor_item.remark = remark
|
||||
|
||||
await self.db.flush()
|
||||
await self.db.refresh(labor_item)
|
||||
|
||||
return labor_item
|
||||
367
后端服务/app/services/market_service.py
Normal file
367
后端服务/app/services/market_service.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""市场分析服务
|
||||
|
||||
实现市场价格分析核心业务逻辑,包含:
|
||||
- 竞品价格统计分析
|
||||
- 价格分布计算
|
||||
- 建议定价区间生成
|
||||
"""
|
||||
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List
|
||||
from statistics import mean, median, stdev
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models import (
|
||||
Project,
|
||||
Competitor,
|
||||
CompetitorPrice,
|
||||
BenchmarkPrice,
|
||||
MarketAnalysisResult,
|
||||
Category,
|
||||
)
|
||||
from app.schemas.market import (
|
||||
MarketAnalysisResult as MarketAnalysisResultSchema,
|
||||
PriceStatistics,
|
||||
PriceDistribution,
|
||||
PriceDistributionItem,
|
||||
CompetitorPriceSummary,
|
||||
BenchmarkReference,
|
||||
SuggestedRange,
|
||||
)
|
||||
|
||||
|
||||
class MarketService:
|
||||
"""市场分析服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_competitor_prices_for_project(
|
||||
self,
|
||||
project_id: int,
|
||||
competitor_ids: Optional[List[int]] = None
|
||||
) -> List[CompetitorPrice]:
|
||||
"""获取项目的竞品价格数据
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
competitor_ids: 指定竞品机构ID列表
|
||||
|
||||
Returns:
|
||||
竞品价格列表
|
||||
"""
|
||||
query = select(CompetitorPrice).options(
|
||||
selectinload(CompetitorPrice.competitor)
|
||||
).where(
|
||||
CompetitorPrice.project_id == project_id
|
||||
)
|
||||
|
||||
if competitor_ids:
|
||||
query = query.where(CompetitorPrice.competitor_id.in_(competitor_ids))
|
||||
|
||||
result = await self.db.execute(query.order_by(CompetitorPrice.collected_at.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_benchmark_prices_for_category(
|
||||
self,
|
||||
category_id: Optional[int]
|
||||
) -> List[BenchmarkPrice]:
|
||||
"""获取分类的标杆价格
|
||||
|
||||
Args:
|
||||
category_id: 分类ID
|
||||
|
||||
Returns:
|
||||
标杆价格列表
|
||||
"""
|
||||
if not category_id:
|
||||
return []
|
||||
|
||||
query = select(BenchmarkPrice).where(
|
||||
BenchmarkPrice.category_id == category_id
|
||||
).order_by(BenchmarkPrice.effective_date.desc())
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
def calculate_price_statistics(self, prices: List[float]) -> PriceStatistics:
|
||||
"""计算价格统计数据
|
||||
|
||||
Args:
|
||||
prices: 价格列表
|
||||
|
||||
Returns:
|
||||
价格统计
|
||||
"""
|
||||
if not prices:
|
||||
return PriceStatistics(
|
||||
min_price=0,
|
||||
max_price=0,
|
||||
avg_price=0,
|
||||
median_price=0,
|
||||
std_deviation=0
|
||||
)
|
||||
|
||||
std_dev = None
|
||||
if len(prices) > 1:
|
||||
try:
|
||||
std_dev = round(stdev(prices), 2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return PriceStatistics(
|
||||
min_price=round(min(prices), 2),
|
||||
max_price=round(max(prices), 2),
|
||||
avg_price=round(mean(prices), 2),
|
||||
median_price=round(median(prices), 2),
|
||||
std_deviation=std_dev
|
||||
)
|
||||
|
||||
def calculate_price_distribution(
|
||||
self,
|
||||
prices: List[float],
|
||||
min_price: float,
|
||||
max_price: float
|
||||
) -> PriceDistribution:
|
||||
"""计算价格分布
|
||||
|
||||
将价格分为低/中/高三个区间
|
||||
|
||||
Args:
|
||||
prices: 价格列表
|
||||
min_price: 最低价
|
||||
max_price: 最高价
|
||||
|
||||
Returns:
|
||||
价格分布
|
||||
"""
|
||||
if not prices or min_price >= max_price:
|
||||
return PriceDistribution(
|
||||
low=PriceDistributionItem(range="N/A", count=0, percentage=0),
|
||||
medium=PriceDistributionItem(range="N/A", count=0, percentage=0),
|
||||
high=PriceDistributionItem(range="N/A", count=0, percentage=0),
|
||||
)
|
||||
|
||||
# 计算三个区间的边界
|
||||
range_size = (max_price - min_price) / 3
|
||||
low_upper = min_price + range_size
|
||||
mid_upper = min_price + range_size * 2
|
||||
|
||||
# 统计各区间数量
|
||||
low_count = sum(1 for p in prices if p < low_upper)
|
||||
mid_count = sum(1 for p in prices if low_upper <= p < mid_upper)
|
||||
high_count = sum(1 for p in prices if p >= mid_upper)
|
||||
|
||||
total = len(prices)
|
||||
|
||||
return PriceDistribution(
|
||||
low=PriceDistributionItem(
|
||||
range=f"{int(min_price)}-{int(low_upper)}",
|
||||
count=low_count,
|
||||
percentage=round(low_count / total * 100, 1) if total > 0 else 0
|
||||
),
|
||||
medium=PriceDistributionItem(
|
||||
range=f"{int(low_upper)}-{int(mid_upper)}",
|
||||
count=mid_count,
|
||||
percentage=round(mid_count / total * 100, 1) if total > 0 else 0
|
||||
),
|
||||
high=PriceDistributionItem(
|
||||
range=f"{int(mid_upper)}-{int(max_price)}",
|
||||
count=high_count,
|
||||
percentage=round(high_count / total * 100, 1) if total > 0 else 0
|
||||
),
|
||||
)
|
||||
|
||||
def calculate_suggested_range(
|
||||
self,
|
||||
avg_price: float,
|
||||
min_price: float,
|
||||
max_price: float,
|
||||
benchmark_avg: Optional[float] = None
|
||||
) -> SuggestedRange:
|
||||
"""计算建议定价区间
|
||||
|
||||
Args:
|
||||
avg_price: 市场均价
|
||||
min_price: 市场最低价
|
||||
max_price: 市场最高价
|
||||
benchmark_avg: 标杆均价(可选)
|
||||
|
||||
Returns:
|
||||
建议定价区间
|
||||
"""
|
||||
if avg_price == 0:
|
||||
return SuggestedRange(min=0, max=0, recommended=0)
|
||||
|
||||
# 建议区间:以均价为中心,±20%
|
||||
range_factor = 0.2
|
||||
suggested_min = round(avg_price * (1 - range_factor), 2)
|
||||
suggested_max = round(avg_price * (1 + range_factor), 2)
|
||||
|
||||
# 确保不低于市场最低价的80%,不高于市场最高价的120%
|
||||
suggested_min = max(suggested_min, round(min_price * 0.8, 2))
|
||||
suggested_max = min(suggested_max, round(max_price * 1.2, 2))
|
||||
|
||||
# 推荐价格:如果有标杆价格,取市场均价和标杆均价的加权平均
|
||||
if benchmark_avg:
|
||||
recommended = round((avg_price * 0.6 + benchmark_avg * 0.4), 2)
|
||||
else:
|
||||
recommended = avg_price
|
||||
|
||||
return SuggestedRange(
|
||||
min=suggested_min,
|
||||
max=suggested_max,
|
||||
recommended=recommended
|
||||
)
|
||||
|
||||
async def analyze_market(
|
||||
self,
|
||||
project_id: int,
|
||||
competitor_ids: Optional[List[int]] = None,
|
||||
include_benchmark: bool = True
|
||||
) -> MarketAnalysisResultSchema:
|
||||
"""执行市场价格分析
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
competitor_ids: 指定竞品机构ID列表
|
||||
include_benchmark: 是否包含标杆参考
|
||||
|
||||
Returns:
|
||||
市场分析结果
|
||||
"""
|
||||
# 获取项目信息
|
||||
project_result = await self.db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = project_result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise ValueError(f"项目不存在: {project_id}")
|
||||
|
||||
# 获取竞品价格
|
||||
competitor_prices = await self.get_competitor_prices_for_project(
|
||||
project_id, competitor_ids
|
||||
)
|
||||
|
||||
# 提取价格列表(使用原价)
|
||||
prices = [float(cp.original_price) for cp in competitor_prices]
|
||||
|
||||
# 计算统计数据
|
||||
price_stats = self.calculate_price_statistics(prices)
|
||||
|
||||
# 计算价格分布
|
||||
price_distribution = None
|
||||
if len(prices) >= 3:
|
||||
price_distribution = self.calculate_price_distribution(
|
||||
prices, price_stats.min_price, price_stats.max_price
|
||||
)
|
||||
|
||||
# 构建竞品价格摘要
|
||||
competitor_summaries = []
|
||||
for cp in competitor_prices[:10]: # 最多返回10条
|
||||
competitor_summaries.append(CompetitorPriceSummary(
|
||||
competitor_name=cp.competitor.competitor_name if cp.competitor else "未知",
|
||||
positioning=cp.competitor.positioning if cp.competitor else "medium",
|
||||
original_price=float(cp.original_price),
|
||||
promo_price=float(cp.promo_price) if cp.promo_price else None,
|
||||
collected_at=cp.collected_at,
|
||||
))
|
||||
|
||||
# 获取标杆参考
|
||||
benchmark_ref = None
|
||||
benchmark_avg = None
|
||||
if include_benchmark and project.category_id:
|
||||
benchmarks = await self.get_benchmark_prices_for_category(project.category_id)
|
||||
if benchmarks:
|
||||
latest_benchmark = benchmarks[0]
|
||||
benchmark_avg = float(latest_benchmark.avg_price)
|
||||
benchmark_ref = BenchmarkReference(
|
||||
tier=latest_benchmark.price_tier,
|
||||
min_price=float(latest_benchmark.min_price),
|
||||
max_price=float(latest_benchmark.max_price),
|
||||
avg_price=benchmark_avg,
|
||||
)
|
||||
|
||||
# 计算建议区间
|
||||
suggested_range = self.calculate_suggested_range(
|
||||
price_stats.avg_price,
|
||||
price_stats.min_price,
|
||||
price_stats.max_price,
|
||||
benchmark_avg
|
||||
)
|
||||
|
||||
analysis_date = date.today()
|
||||
|
||||
# 保存分析结果到数据库
|
||||
await self._save_analysis_result(
|
||||
project_id=project_id,
|
||||
analysis_date=analysis_date,
|
||||
competitor_count=len(competitor_prices),
|
||||
min_price=price_stats.min_price,
|
||||
max_price=price_stats.max_price,
|
||||
avg_price=price_stats.avg_price,
|
||||
median_price=price_stats.median_price,
|
||||
suggested_min=suggested_range.min,
|
||||
suggested_max=suggested_range.max,
|
||||
)
|
||||
|
||||
return MarketAnalysisResultSchema(
|
||||
project_id=project_id,
|
||||
project_name=project.project_name,
|
||||
analysis_date=analysis_date,
|
||||
competitor_count=len(competitor_prices),
|
||||
price_statistics=price_stats,
|
||||
price_distribution=price_distribution,
|
||||
competitor_prices=competitor_summaries,
|
||||
benchmark_reference=benchmark_ref,
|
||||
suggested_range=suggested_range,
|
||||
)
|
||||
|
||||
async def _save_analysis_result(
|
||||
self,
|
||||
project_id: int,
|
||||
analysis_date: date,
|
||||
competitor_count: int,
|
||||
min_price: float,
|
||||
max_price: float,
|
||||
avg_price: float,
|
||||
median_price: float,
|
||||
suggested_min: float,
|
||||
suggested_max: float,
|
||||
):
|
||||
"""保存分析结果"""
|
||||
# 创建新记录(保留历史)
|
||||
result = MarketAnalysisResult(
|
||||
project_id=project_id,
|
||||
analysis_date=analysis_date,
|
||||
competitor_count=competitor_count,
|
||||
market_min_price=Decimal(str(min_price)),
|
||||
market_max_price=Decimal(str(max_price)),
|
||||
market_avg_price=Decimal(str(avg_price)),
|
||||
market_median_price=Decimal(str(median_price)),
|
||||
suggested_range_min=Decimal(str(suggested_min)),
|
||||
suggested_range_max=Decimal(str(suggested_max)),
|
||||
)
|
||||
self.db.add(result)
|
||||
await self.db.flush()
|
||||
|
||||
async def get_latest_analysis(self, project_id: int) -> Optional[MarketAnalysisResult]:
|
||||
"""获取最新的市场分析结果
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
|
||||
Returns:
|
||||
最新分析结果
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(MarketAnalysisResult).where(
|
||||
MarketAnalysisResult.project_id == project_id
|
||||
).order_by(MarketAnalysisResult.analysis_date.desc()).limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
574
后端服务/app/services/pricing_service.py
Normal file
574
后端服务/app/services/pricing_service.py
Normal file
@@ -0,0 +1,574 @@
|
||||
"""智能定价服务
|
||||
|
||||
实现智能定价核心业务逻辑,包含:
|
||||
- 综合定价计算
|
||||
- AI 定价建议生成(遵循瑞小美 AI 接入规范)
|
||||
- 定价策略模拟
|
||||
- 定价报告导出
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List, Dict, Any, AsyncIterator
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models import (
|
||||
Project,
|
||||
ProjectCostSummary,
|
||||
MarketAnalysisResult,
|
||||
PricingPlan,
|
||||
)
|
||||
from app.schemas.pricing import (
|
||||
StrategyType,
|
||||
PricingSuggestions,
|
||||
StrategySuggestion,
|
||||
MarketReference,
|
||||
AIAdvice,
|
||||
AIUsage,
|
||||
GeneratePricingResponse,
|
||||
StrategySimulationResult,
|
||||
SimulateStrategyResponse,
|
||||
)
|
||||
from app.services.ai_service_wrapper import AIServiceWrapper
|
||||
from app.services.cost_service import CostService
|
||||
from app.services.market_service import MarketService
|
||||
|
||||
# 导入提示词
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../prompts'))
|
||||
from prompts.pricing_advice_prompts import SYSTEM_PROMPT, USER_PROMPT, PROMPT_META
|
||||
|
||||
|
||||
class PricingService:
|
||||
"""智能定价服务"""
|
||||
|
||||
# 各策略的利润率范围
|
||||
STRATEGY_MARGINS = {
|
||||
StrategyType.TRAFFIC: (0.10, 0.20), # 引流款:10%-20%
|
||||
StrategyType.PROFIT: (0.40, 0.60), # 利润款:40%-60%
|
||||
StrategyType.PREMIUM: (0.60, 0.80), # 高端款:60%-80%
|
||||
}
|
||||
|
||||
STRATEGY_NAMES = {
|
||||
StrategyType.TRAFFIC: "引流款",
|
||||
StrategyType.PROFIT: "利润款",
|
||||
StrategyType.PREMIUM: "高端款",
|
||||
}
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.cost_service = CostService(db)
|
||||
self.market_service = MarketService(db)
|
||||
|
||||
async def get_project_with_cost(self, project_id: int) -> tuple[Project, Optional[ProjectCostSummary]]:
|
||||
"""获取项目及其成本汇总"""
|
||||
result = await self.db.execute(
|
||||
select(Project).options(
|
||||
selectinload(Project.cost_summary)
|
||||
).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
raise ValueError(f"项目不存在: {project_id}")
|
||||
|
||||
return project, project.cost_summary
|
||||
|
||||
async def get_market_reference(self, project_id: int) -> Optional[MarketReference]:
|
||||
"""获取市场参考数据"""
|
||||
result = await self.db.execute(
|
||||
select(MarketAnalysisResult).where(
|
||||
MarketAnalysisResult.project_id == project_id
|
||||
).order_by(MarketAnalysisResult.analysis_date.desc()).limit(1)
|
||||
)
|
||||
market_result = result.scalar_one_or_none()
|
||||
|
||||
if market_result:
|
||||
return MarketReference(
|
||||
min=float(market_result.market_min_price),
|
||||
max=float(market_result.market_max_price),
|
||||
avg=float(market_result.market_avg_price),
|
||||
)
|
||||
return None
|
||||
|
||||
def calculate_strategy_price(
|
||||
self,
|
||||
base_cost: float,
|
||||
strategy: StrategyType,
|
||||
target_margin: Optional[float] = None,
|
||||
market_ref: Optional[MarketReference] = None
|
||||
) -> StrategySuggestion:
|
||||
"""计算单个策略的定价建议
|
||||
|
||||
Args:
|
||||
base_cost: 基础成本
|
||||
strategy: 策略类型
|
||||
target_margin: 自定义目标毛利率(可选)
|
||||
market_ref: 市场参考(可选)
|
||||
|
||||
Returns:
|
||||
策略定价建议
|
||||
"""
|
||||
min_margin, max_margin = self.STRATEGY_MARGINS[strategy]
|
||||
|
||||
# 使用策略默认的中间利润率,或自定义目标
|
||||
if target_margin is not None:
|
||||
margin = target_margin / 100
|
||||
else:
|
||||
margin = (min_margin + max_margin) / 2
|
||||
|
||||
# 成本加成定价法:价格 = 成本 / (1 - 毛利率)
|
||||
suggested_price = base_cost / (1 - margin)
|
||||
|
||||
# 如果有市场参考,调整价格
|
||||
if market_ref:
|
||||
if strategy == StrategyType.TRAFFIC:
|
||||
# 引流款:取市场最低价和成本定价的较低者
|
||||
market_low = market_ref.min * 0.9
|
||||
suggested_price = min(suggested_price, market_low)
|
||||
elif strategy == StrategyType.PREMIUM:
|
||||
# 高端款:取市场高位
|
||||
market_high = market_ref.max * 1.1
|
||||
suggested_price = max(suggested_price, market_high * 0.9)
|
||||
|
||||
# 确保价格不低于成本(边界处理:零成本时设置最低价格)
|
||||
if base_cost == 0:
|
||||
# 零成本时,使用市场参考或设置一个最低保护价格
|
||||
if market_ref:
|
||||
suggested_price = market_ref.avg * 0.8
|
||||
else:
|
||||
suggested_price = 1.0 # 最低保护价格
|
||||
else:
|
||||
suggested_price = max(suggested_price, base_cost * 1.05)
|
||||
|
||||
# 计算实际毛利率(防止除零)
|
||||
if suggested_price > 0:
|
||||
actual_margin = (suggested_price - base_cost) / suggested_price * 100
|
||||
else:
|
||||
actual_margin = 0
|
||||
|
||||
descriptions = {
|
||||
StrategyType.TRAFFIC: "低于市场均价,适合引流获客、新店开业、淡季促销",
|
||||
StrategyType.PROFIT: "接近市场均价,平衡利润与竞争力,适合日常经营",
|
||||
StrategyType.PREMIUM: "定位高端,高利润空间,需配套优质服务和品牌溢价",
|
||||
}
|
||||
|
||||
return StrategySuggestion(
|
||||
strategy=self.STRATEGY_NAMES[strategy],
|
||||
suggested_price=round(suggested_price, 2),
|
||||
margin=round(actual_margin, 1),
|
||||
description=descriptions[strategy],
|
||||
)
|
||||
|
||||
def calculate_all_strategies(
|
||||
self,
|
||||
base_cost: float,
|
||||
target_margin: float,
|
||||
market_ref: Optional[MarketReference] = None,
|
||||
strategies: Optional[List[StrategyType]] = None
|
||||
) -> PricingSuggestions:
|
||||
"""计算所有策略的定价建议"""
|
||||
if strategies is None:
|
||||
strategies = list(StrategyType)
|
||||
|
||||
suggestions = PricingSuggestions()
|
||||
|
||||
for strategy in strategies:
|
||||
suggestion = self.calculate_strategy_price(
|
||||
base_cost=base_cost,
|
||||
strategy=strategy,
|
||||
target_margin=target_margin if strategy == StrategyType.PROFIT else None,
|
||||
market_ref=market_ref,
|
||||
)
|
||||
|
||||
if strategy == StrategyType.TRAFFIC:
|
||||
suggestions.traffic = suggestion
|
||||
elif strategy == StrategyType.PROFIT:
|
||||
suggestions.profit = suggestion
|
||||
elif strategy == StrategyType.PREMIUM:
|
||||
suggestions.premium = suggestion
|
||||
|
||||
return suggestions
|
||||
|
||||
async def generate_pricing_advice(
|
||||
self,
|
||||
project_id: int,
|
||||
target_margin: float = 50,
|
||||
strategies: Optional[List[StrategyType]] = None,
|
||||
stream: bool = False,
|
||||
) -> GeneratePricingResponse:
|
||||
"""生成智能定价建议
|
||||
|
||||
遵循瑞小美 AI 接入规范:
|
||||
- 通过 AIServiceWrapper 调用
|
||||
- 必须传入 prompt_name(用于统计)
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
target_margin: 目标毛利率
|
||||
strategies: 要计算的策略列表
|
||||
stream: 是否流式输出(此方法返回完整结果,流式由路由处理)
|
||||
|
||||
Returns:
|
||||
定价建议响应
|
||||
"""
|
||||
# 获取项目和成本数据
|
||||
project, cost_summary = await self.get_project_with_cost(project_id)
|
||||
|
||||
if not cost_summary:
|
||||
# 如果没有成本汇总,先计算
|
||||
cost_result = await self.cost_service.calculate_project_cost(project_id)
|
||||
base_cost = cost_result.total_cost
|
||||
else:
|
||||
base_cost = float(cost_summary.total_cost)
|
||||
|
||||
# 获取市场参考
|
||||
market_ref = await self.get_market_reference(project_id)
|
||||
|
||||
# 计算各策略价格
|
||||
pricing_suggestions = self.calculate_all_strategies(
|
||||
base_cost=base_cost,
|
||||
target_margin=target_margin,
|
||||
market_ref=market_ref,
|
||||
strategies=strategies,
|
||||
)
|
||||
|
||||
# 构建 AI 输入数据
|
||||
cost_data = self._format_cost_data(cost_summary)
|
||||
market_data = self._format_market_data(market_ref)
|
||||
|
||||
# 调用 AI 生成建议(遵循规范)
|
||||
ai_service = AIServiceWrapper(db_session=self.db)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": USER_PROMPT.format(
|
||||
project_name=project.project_name,
|
||||
cost_data=cost_data,
|
||||
market_data=market_data,
|
||||
target_margin=target_margin,
|
||||
)},
|
||||
]
|
||||
|
||||
ai_advice = None
|
||||
ai_usage = None
|
||||
|
||||
try:
|
||||
response = await ai_service.chat(
|
||||
messages=messages,
|
||||
prompt_name=PROMPT_META["name"], # 必填!用于调用统计
|
||||
)
|
||||
|
||||
ai_advice = AIAdvice(
|
||||
summary=self._extract_section(response.content, "推荐方案") or response.content[:200],
|
||||
cost_analysis=self._extract_section(response.content, "成本") or "",
|
||||
market_analysis=self._extract_section(response.content, "市场") or "",
|
||||
risk_notes=self._extract_section(response.content, "风险") or "",
|
||||
recommendations=self._extract_recommendations(response.content),
|
||||
)
|
||||
|
||||
ai_usage = AIUsage(
|
||||
provider=response.provider,
|
||||
model=response.model,
|
||||
tokens=response.total_tokens,
|
||||
latency_ms=response.latency_ms,
|
||||
)
|
||||
except Exception as e:
|
||||
# AI 调用失败不影响基本定价计算
|
||||
print(f"AI 调用失败: {e}")
|
||||
|
||||
return GeneratePricingResponse(
|
||||
project_id=project_id,
|
||||
project_name=project.project_name,
|
||||
cost_base=base_cost,
|
||||
market_reference=market_ref,
|
||||
pricing_suggestions=pricing_suggestions,
|
||||
ai_advice=ai_advice,
|
||||
ai_usage=ai_usage,
|
||||
)
|
||||
|
||||
async def generate_pricing_advice_stream(
|
||||
self,
|
||||
project_id: int,
|
||||
target_margin: float = 50,
|
||||
) -> AsyncIterator[str]:
|
||||
"""流式生成定价建议
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
target_margin: 目标毛利率
|
||||
|
||||
Yields:
|
||||
SSE 格式的响应片段
|
||||
"""
|
||||
# 获取基础数据
|
||||
project, cost_summary = await self.get_project_with_cost(project_id)
|
||||
|
||||
if not cost_summary:
|
||||
cost_result = await self.cost_service.calculate_project_cost(project_id)
|
||||
base_cost = cost_result.total_cost
|
||||
else:
|
||||
base_cost = float(cost_summary.total_cost)
|
||||
|
||||
market_ref = await self.get_market_reference(project_id)
|
||||
|
||||
# 先返回基础定价计算结果
|
||||
pricing_suggestions = self.calculate_all_strategies(
|
||||
base_cost=base_cost,
|
||||
target_margin=target_margin,
|
||||
market_ref=market_ref,
|
||||
)
|
||||
|
||||
# 发送初始数据
|
||||
initial_data = {
|
||||
"type": "init",
|
||||
"project_name": project.project_name,
|
||||
"cost_base": base_cost,
|
||||
"market_reference": market_ref.model_dump() if market_ref else None,
|
||||
"pricing_suggestions": pricing_suggestions.model_dump(),
|
||||
}
|
||||
yield f"data: {json.dumps(initial_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 构建 AI 输入
|
||||
cost_data = self._format_cost_data(cost_summary)
|
||||
market_data = self._format_market_data(market_ref)
|
||||
|
||||
ai_service = AIServiceWrapper(db_session=self.db)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": USER_PROMPT.format(
|
||||
project_name=project.project_name,
|
||||
cost_data=cost_data,
|
||||
market_data=market_data,
|
||||
target_margin=target_margin,
|
||||
)},
|
||||
]
|
||||
|
||||
# 流式返回 AI 建议
|
||||
try:
|
||||
async for chunk in ai_service.chat_stream(
|
||||
messages=messages,
|
||||
prompt_name=PROMPT_META["name"],
|
||||
):
|
||||
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送完成信号
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def simulate_strategies(
|
||||
self,
|
||||
project_id: int,
|
||||
strategies: List[StrategyType],
|
||||
target_margin: float = 50,
|
||||
) -> SimulateStrategyResponse:
|
||||
"""模拟定价策略
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
strategies: 要模拟的策略列表
|
||||
target_margin: 目标毛利率
|
||||
|
||||
Returns:
|
||||
策略模拟结果
|
||||
"""
|
||||
project, cost_summary = await self.get_project_with_cost(project_id)
|
||||
|
||||
if not cost_summary:
|
||||
cost_result = await self.cost_service.calculate_project_cost(project_id)
|
||||
base_cost = cost_result.total_cost
|
||||
else:
|
||||
base_cost = float(cost_summary.total_cost)
|
||||
|
||||
market_ref = await self.get_market_reference(project_id)
|
||||
|
||||
results = []
|
||||
for strategy in strategies:
|
||||
suggestion = self.calculate_strategy_price(
|
||||
base_cost=base_cost,
|
||||
strategy=strategy,
|
||||
target_margin=target_margin if strategy == StrategyType.PROFIT else None,
|
||||
market_ref=market_ref,
|
||||
)
|
||||
|
||||
# 确定市场位置
|
||||
market_position = "中等"
|
||||
if market_ref:
|
||||
if suggestion.suggested_price < market_ref.avg * 0.8:
|
||||
market_position = "低于市场均价"
|
||||
elif suggestion.suggested_price > market_ref.avg * 1.2:
|
||||
market_position = "高于市场均价"
|
||||
else:
|
||||
market_position = "接近市场均价"
|
||||
|
||||
results.append(StrategySimulationResult(
|
||||
strategy_type=strategy.value,
|
||||
strategy_name=self.STRATEGY_NAMES[strategy],
|
||||
suggested_price=suggestion.suggested_price,
|
||||
margin=suggestion.margin,
|
||||
profit_per_unit=round(suggestion.suggested_price - base_cost, 2),
|
||||
market_position=market_position,
|
||||
))
|
||||
|
||||
return SimulateStrategyResponse(
|
||||
project_id=project_id,
|
||||
project_name=project.project_name,
|
||||
base_cost=base_cost,
|
||||
results=results,
|
||||
)
|
||||
|
||||
async def create_pricing_plan(
|
||||
self,
|
||||
project_id: int,
|
||||
plan_name: str,
|
||||
strategy_type: StrategyType,
|
||||
target_margin: float,
|
||||
created_by: Optional[int] = None,
|
||||
) -> PricingPlan:
|
||||
"""创建定价方案
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
plan_name: 方案名称
|
||||
strategy_type: 策略类型
|
||||
target_margin: 目标毛利率
|
||||
created_by: 创建人ID
|
||||
|
||||
Returns:
|
||||
创建的定价方案
|
||||
"""
|
||||
# 获取成本数据
|
||||
project, cost_summary = await self.get_project_with_cost(project_id)
|
||||
|
||||
if not cost_summary:
|
||||
cost_result = await self.cost_service.calculate_project_cost(project_id)
|
||||
base_cost = Decimal(str(cost_result.total_cost))
|
||||
else:
|
||||
base_cost = cost_summary.total_cost
|
||||
|
||||
# 获取市场参考
|
||||
market_ref = await self.get_market_reference(project_id)
|
||||
|
||||
# 计算建议价格
|
||||
suggestion = self.calculate_strategy_price(
|
||||
base_cost=float(base_cost),
|
||||
strategy=strategy_type,
|
||||
target_margin=target_margin if strategy_type == StrategyType.PROFIT else None,
|
||||
market_ref=market_ref,
|
||||
)
|
||||
|
||||
# 创建定价方案
|
||||
pricing_plan = PricingPlan(
|
||||
project_id=project_id,
|
||||
plan_name=plan_name,
|
||||
strategy_type=strategy_type.value,
|
||||
base_cost=base_cost,
|
||||
target_margin=Decimal(str(target_margin)),
|
||||
suggested_price=Decimal(str(suggestion.suggested_price)),
|
||||
is_active=True,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
self.db.add(pricing_plan)
|
||||
await self.db.flush()
|
||||
await self.db.refresh(pricing_plan)
|
||||
|
||||
return pricing_plan
|
||||
|
||||
async def update_pricing_plan(
|
||||
self,
|
||||
plan_id: int,
|
||||
**kwargs
|
||||
) -> PricingPlan:
|
||||
"""更新定价方案"""
|
||||
result = await self.db.execute(
|
||||
select(PricingPlan).where(PricingPlan.id == plan_id)
|
||||
)
|
||||
plan = result.scalar_one_or_none()
|
||||
|
||||
if not plan:
|
||||
raise ValueError(f"定价方案不存在: {plan_id}")
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if value is not None and hasattr(plan, key):
|
||||
if key in ['target_margin', 'final_price', 'base_cost', 'suggested_price']:
|
||||
setattr(plan, key, Decimal(str(value)))
|
||||
else:
|
||||
setattr(plan, key, value)
|
||||
|
||||
await self.db.flush()
|
||||
await self.db.refresh(plan)
|
||||
|
||||
return plan
|
||||
|
||||
async def save_ai_advice(self, plan_id: int, advice: str) -> None:
|
||||
"""保存 AI 建议到定价方案"""
|
||||
result = await self.db.execute(
|
||||
select(PricingPlan).where(PricingPlan.id == plan_id)
|
||||
)
|
||||
plan = result.scalar_one_or_none()
|
||||
|
||||
if plan:
|
||||
plan.ai_advice = advice
|
||||
await self.db.flush()
|
||||
|
||||
def _format_cost_data(self, cost_summary: Optional[ProjectCostSummary]) -> str:
|
||||
"""格式化成本数据用于 AI 输入"""
|
||||
if not cost_summary:
|
||||
return "暂无成本数据"
|
||||
|
||||
return f"""- 耗材成本:{float(cost_summary.material_cost):.2f} 元
|
||||
- 设备折旧:{float(cost_summary.equipment_cost):.2f} 元
|
||||
- 人工成本:{float(cost_summary.labor_cost):.2f} 元
|
||||
- 固定成本分摊:{float(cost_summary.fixed_cost_allocation):.2f} 元
|
||||
- **总成本(最低成本线):{float(cost_summary.total_cost):.2f} 元**"""
|
||||
|
||||
def _format_market_data(self, market_ref: Optional[MarketReference]) -> str:
|
||||
"""格式化市场数据用于 AI 输入"""
|
||||
if not market_ref:
|
||||
return "暂无市场行情数据"
|
||||
|
||||
return f"""- 市场最低价:{market_ref.min:.2f} 元
|
||||
- 市场最高价:{market_ref.max:.2f} 元
|
||||
- 市场均价:{market_ref.avg:.2f} 元"""
|
||||
|
||||
def _extract_section(self, content: str, keyword: str) -> Optional[str]:
|
||||
"""从 AI 响应中提取特定部分"""
|
||||
lines = content.split('\n')
|
||||
in_section = False
|
||||
section_lines = []
|
||||
|
||||
for line in lines:
|
||||
if keyword in line and ('#' in line or '**' in line):
|
||||
in_section = True
|
||||
continue
|
||||
elif in_section:
|
||||
if line.startswith('#') or (line.startswith('**') and line.endswith('**')):
|
||||
break
|
||||
section_lines.append(line)
|
||||
|
||||
return '\n'.join(section_lines).strip() if section_lines else None
|
||||
|
||||
def _extract_recommendations(self, content: str) -> List[str]:
|
||||
"""从 AI 响应中提取建议列表"""
|
||||
recommendations = []
|
||||
lines = content.split('\n')
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith('- ') or line.startswith('* ') or line.startswith('• '):
|
||||
recommendations.append(line[2:].strip())
|
||||
elif line and line[0].isdigit() and '.' in line:
|
||||
# 处理 "1. xxx" 格式
|
||||
parts = line.split('.', 1)
|
||||
if len(parts) > 1:
|
||||
recommendations.append(parts[1].strip())
|
||||
|
||||
return recommendations[:5] # 最多返回5条
|
||||
513
后端服务/app/services/profit_service.py
Normal file
513
后端服务/app/services/profit_service.py
Normal file
@@ -0,0 +1,513 @@
|
||||
"""利润模拟服务
|
||||
|
||||
实现利润模拟测算核心业务逻辑,包含:
|
||||
- 利润模拟计算
|
||||
- 敏感性分析
|
||||
- 盈亏平衡分析
|
||||
- AI 利润预测分析
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models import (
|
||||
Project,
|
||||
PricingPlan,
|
||||
ProfitSimulation,
|
||||
SensitivityAnalysis,
|
||||
FixedCost,
|
||||
)
|
||||
from app.schemas.profit import (
|
||||
PeriodType,
|
||||
SimulationInput,
|
||||
SimulationResult,
|
||||
BreakevenAnalysis,
|
||||
SimulateProfitResponse,
|
||||
SensitivityResultItem,
|
||||
SensitivityInsights,
|
||||
SensitivityAnalysisResponse,
|
||||
BreakevenResponse,
|
||||
)
|
||||
from app.services.ai_service_wrapper import AIServiceWrapper
|
||||
|
||||
# 导入提示词
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../prompts'))
|
||||
from prompts.profit_forecast_prompts import SYSTEM_PROMPT, USER_PROMPT, PROMPT_META
|
||||
|
||||
|
||||
class ProfitService:
|
||||
"""利润模拟服务"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_pricing_plan(self, plan_id: int) -> PricingPlan:
|
||||
"""获取定价方案"""
|
||||
result = await self.db.execute(
|
||||
select(PricingPlan).options(
|
||||
selectinload(PricingPlan.project)
|
||||
).where(PricingPlan.id == plan_id)
|
||||
)
|
||||
plan = result.scalar_one_or_none()
|
||||
|
||||
if not plan:
|
||||
raise ValueError(f"定价方案不存在: {plan_id}")
|
||||
|
||||
return plan
|
||||
|
||||
async def get_monthly_fixed_cost(self, year_month: Optional[str] = None) -> Decimal:
|
||||
"""获取月度固定成本总额"""
|
||||
if not year_month:
|
||||
year_month = datetime.now().strftime("%Y-%m")
|
||||
|
||||
result = await self.db.execute(
|
||||
select(func.sum(FixedCost.monthly_amount)).where(
|
||||
FixedCost.year_month == year_month,
|
||||
FixedCost.is_active == True
|
||||
)
|
||||
)
|
||||
return result.scalar() or Decimal("0")
|
||||
|
||||
def calculate_profit(
|
||||
self,
|
||||
price: float,
|
||||
cost_per_unit: float,
|
||||
volume: int,
|
||||
) -> tuple[float, float, float, float]:
|
||||
"""计算利润
|
||||
|
||||
Returns:
|
||||
(收入, 成本, 利润, 利润率)
|
||||
"""
|
||||
revenue = price * volume
|
||||
total_cost = cost_per_unit * volume
|
||||
profit = revenue - total_cost
|
||||
margin = (profit / revenue * 100) if revenue > 0 else 0
|
||||
|
||||
return revenue, total_cost, profit, margin
|
||||
|
||||
def calculate_breakeven(
|
||||
self,
|
||||
price: float,
|
||||
variable_cost: float,
|
||||
fixed_cost: float = 0,
|
||||
) -> int:
|
||||
"""计算盈亏平衡点
|
||||
|
||||
盈亏平衡客量 = 固定成本 / (单价 - 单位变动成本)
|
||||
|
||||
Args:
|
||||
price: 单价
|
||||
variable_cost: 单位变动成本
|
||||
fixed_cost: 固定成本(可选)
|
||||
|
||||
Returns:
|
||||
盈亏平衡客量
|
||||
"""
|
||||
contribution_margin = price - variable_cost
|
||||
|
||||
if contribution_margin <= 0:
|
||||
# 边际贡献为负,无法盈利
|
||||
return 999999
|
||||
|
||||
if fixed_cost > 0:
|
||||
breakeven = int(fixed_cost / contribution_margin) + 1
|
||||
else:
|
||||
# 无固定成本时,只要有销量就盈利
|
||||
breakeven = 1
|
||||
|
||||
return breakeven
|
||||
|
||||
async def simulate_profit(
|
||||
self,
|
||||
pricing_plan_id: int,
|
||||
price: float,
|
||||
estimated_volume: int,
|
||||
period_type: PeriodType,
|
||||
created_by: Optional[int] = None,
|
||||
) -> SimulateProfitResponse:
|
||||
"""执行利润模拟
|
||||
|
||||
Args:
|
||||
pricing_plan_id: 定价方案ID
|
||||
price: 模拟价格
|
||||
estimated_volume: 预估客量
|
||||
period_type: 周期类型
|
||||
created_by: 创建人ID
|
||||
|
||||
Returns:
|
||||
利润模拟结果
|
||||
"""
|
||||
# 获取定价方案
|
||||
plan = await self.get_pricing_plan(pricing_plan_id)
|
||||
cost_per_unit = float(plan.base_cost)
|
||||
|
||||
# 计算利润
|
||||
revenue, total_cost, profit, margin = self.calculate_profit(
|
||||
price=price,
|
||||
cost_per_unit=cost_per_unit,
|
||||
volume=estimated_volume,
|
||||
)
|
||||
|
||||
# 计算盈亏平衡点
|
||||
breakeven_volume = self.calculate_breakeven(
|
||||
price=price,
|
||||
variable_cost=cost_per_unit,
|
||||
)
|
||||
|
||||
# 计算安全边际
|
||||
safety_margin = estimated_volume - breakeven_volume
|
||||
safety_margin_pct = (safety_margin / estimated_volume * 100) if estimated_volume > 0 else 0
|
||||
|
||||
# 限制利润率范围以避免数据库溢出 (DECIMAL(5,2) 范围 -999.99 ~ 999.99)
|
||||
clamped_margin = max(-999.99, min(999.99, margin))
|
||||
|
||||
# 创建模拟记录
|
||||
simulation = ProfitSimulation(
|
||||
pricing_plan_id=pricing_plan_id,
|
||||
simulation_name=f"{plan.project.project_name}-{period_type.value}模拟",
|
||||
price=Decimal(str(price)),
|
||||
estimated_volume=estimated_volume,
|
||||
period_type=period_type.value,
|
||||
estimated_revenue=Decimal(str(revenue)),
|
||||
estimated_cost=Decimal(str(total_cost)),
|
||||
estimated_profit=Decimal(str(profit)),
|
||||
profit_margin=Decimal(str(round(clamped_margin, 2))),
|
||||
breakeven_volume=breakeven_volume,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
self.db.add(simulation)
|
||||
await self.db.flush()
|
||||
await self.db.refresh(simulation)
|
||||
|
||||
return SimulateProfitResponse(
|
||||
simulation_id=simulation.id,
|
||||
pricing_plan_id=pricing_plan_id,
|
||||
project_name=plan.project.project_name,
|
||||
input=SimulationInput(
|
||||
price=price,
|
||||
cost_per_unit=cost_per_unit,
|
||||
estimated_volume=estimated_volume,
|
||||
period_type=period_type.value,
|
||||
),
|
||||
result=SimulationResult(
|
||||
estimated_revenue=round(revenue, 2),
|
||||
estimated_cost=round(total_cost, 2),
|
||||
estimated_profit=round(profit, 2),
|
||||
profit_margin=round(margin, 2),
|
||||
profit_per_unit=round(price - cost_per_unit, 2),
|
||||
),
|
||||
breakeven_analysis=BreakevenAnalysis(
|
||||
breakeven_volume=breakeven_volume,
|
||||
current_volume=estimated_volume,
|
||||
safety_margin=safety_margin,
|
||||
safety_margin_percentage=round(safety_margin_pct, 1),
|
||||
),
|
||||
created_at=simulation.created_at,
|
||||
)
|
||||
|
||||
async def sensitivity_analysis(
|
||||
self,
|
||||
simulation_id: int,
|
||||
price_change_rates: List[float],
|
||||
) -> SensitivityAnalysisResponse:
|
||||
"""执行敏感性分析
|
||||
|
||||
分析价格变动对利润的影响
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
price_change_rates: 价格变动率列表
|
||||
|
||||
Returns:
|
||||
敏感性分析结果
|
||||
"""
|
||||
# 获取模拟记录
|
||||
result = await self.db.execute(
|
||||
select(ProfitSimulation).options(
|
||||
selectinload(ProfitSimulation.pricing_plan)
|
||||
).where(ProfitSimulation.id == simulation_id)
|
||||
)
|
||||
simulation = result.scalar_one_or_none()
|
||||
|
||||
if not simulation:
|
||||
raise ValueError(f"模拟记录不存在: {simulation_id}")
|
||||
|
||||
base_price = float(simulation.price)
|
||||
base_profit = float(simulation.estimated_profit)
|
||||
cost_per_unit = float(simulation.pricing_plan.base_cost)
|
||||
volume = simulation.estimated_volume
|
||||
|
||||
sensitivity_results = []
|
||||
|
||||
for rate in sorted(price_change_rates):
|
||||
# 计算调整后价格
|
||||
adjusted_price = base_price * (1 + rate / 100)
|
||||
|
||||
# 计算调整后利润
|
||||
_, _, adjusted_profit, _ = self.calculate_profit(
|
||||
price=adjusted_price,
|
||||
cost_per_unit=cost_per_unit,
|
||||
volume=volume,
|
||||
)
|
||||
|
||||
# 计算利润变动率
|
||||
profit_change_rate = 0
|
||||
if base_profit != 0:
|
||||
profit_change_rate = (adjusted_profit - base_profit) / abs(base_profit) * 100
|
||||
|
||||
# 限制变动率范围以避免数据库溢出
|
||||
clamped_profit_change_rate = max(-999.99, min(999.99, profit_change_rate))
|
||||
|
||||
item = SensitivityResultItem(
|
||||
price_change_rate=rate,
|
||||
adjusted_price=round(adjusted_price, 2),
|
||||
adjusted_profit=round(adjusted_profit, 2),
|
||||
profit_change_rate=round(clamped_profit_change_rate, 2),
|
||||
)
|
||||
sensitivity_results.append(item)
|
||||
|
||||
# 保存到数据库
|
||||
analysis = SensitivityAnalysis(
|
||||
simulation_id=simulation_id,
|
||||
price_change_rate=Decimal(str(rate)),
|
||||
adjusted_price=Decimal(str(adjusted_price)),
|
||||
adjusted_profit=Decimal(str(adjusted_profit)),
|
||||
profit_change_rate=Decimal(str(round(clamped_profit_change_rate, 2))),
|
||||
)
|
||||
self.db.add(analysis)
|
||||
|
||||
await self.db.flush()
|
||||
|
||||
# 生成洞察
|
||||
insights = self._generate_sensitivity_insights(
|
||||
base_price=base_price,
|
||||
base_profit=base_profit,
|
||||
results=sensitivity_results,
|
||||
)
|
||||
|
||||
return SensitivityAnalysisResponse(
|
||||
simulation_id=simulation_id,
|
||||
base_price=base_price,
|
||||
base_profit=base_profit,
|
||||
sensitivity_results=sensitivity_results,
|
||||
insights=insights,
|
||||
)
|
||||
|
||||
def _generate_sensitivity_insights(
|
||||
self,
|
||||
base_price: float,
|
||||
base_profit: float,
|
||||
results: List[SensitivityResultItem],
|
||||
) -> SensitivityInsights:
|
||||
"""生成敏感性分析洞察"""
|
||||
# 计算价格弹性(使用 ±10% 的数据点)
|
||||
elasticity = 0
|
||||
for r in results:
|
||||
if r.price_change_rate == 10:
|
||||
elasticity = r.profit_change_rate / 10
|
||||
break
|
||||
|
||||
# 判断风险等级
|
||||
risk_level = "低"
|
||||
min_profit = min(r.adjusted_profit for r in results)
|
||||
|
||||
if min_profit < 0:
|
||||
risk_level = "高"
|
||||
elif min_profit < base_profit * 0.5:
|
||||
risk_level = "中等"
|
||||
|
||||
# 生成建议
|
||||
recommendation = "价格调整空间较大,经营风险可控。"
|
||||
if risk_level == "高":
|
||||
recommendation = f"价格下降超过某阈值会导致亏损,建议密切关注市场动态。"
|
||||
elif risk_level == "中等":
|
||||
recommendation = f"价格变动对利润影响较大,建议谨慎调价。"
|
||||
|
||||
return SensitivityInsights(
|
||||
price_elasticity=f"价格每变动1%,利润变动约{abs(elasticity):.2f}%",
|
||||
risk_level=risk_level,
|
||||
recommendation=recommendation,
|
||||
)
|
||||
|
||||
async def breakeven_analysis(
|
||||
self,
|
||||
pricing_plan_id: int,
|
||||
target_profit: Optional[float] = None,
|
||||
) -> BreakevenResponse:
|
||||
"""盈亏平衡分析
|
||||
|
||||
Args:
|
||||
pricing_plan_id: 定价方案ID
|
||||
target_profit: 目标利润(可选)
|
||||
|
||||
Returns:
|
||||
盈亏平衡分析结果
|
||||
"""
|
||||
plan = await self.get_pricing_plan(pricing_plan_id)
|
||||
|
||||
price = float(plan.final_price or plan.suggested_price)
|
||||
unit_cost = float(plan.base_cost)
|
||||
|
||||
# 获取月度固定成本
|
||||
monthly_fixed_cost = float(await self.get_monthly_fixed_cost())
|
||||
|
||||
# 计算盈亏平衡点
|
||||
contribution_margin = price - unit_cost
|
||||
breakeven_volume = self.calculate_breakeven(
|
||||
price=price,
|
||||
variable_cost=unit_cost,
|
||||
fixed_cost=monthly_fixed_cost,
|
||||
)
|
||||
|
||||
# 计算达到目标利润的客量
|
||||
target_volume = None
|
||||
if target_profit is not None and contribution_margin > 0:
|
||||
target_volume = int((monthly_fixed_cost + target_profit) / contribution_margin) + 1
|
||||
|
||||
return BreakevenResponse(
|
||||
pricing_plan_id=pricing_plan_id,
|
||||
project_name=plan.project.project_name,
|
||||
price=price,
|
||||
unit_cost=unit_cost,
|
||||
fixed_cost_monthly=monthly_fixed_cost,
|
||||
breakeven_volume=breakeven_volume,
|
||||
current_margin=round(contribution_margin, 2),
|
||||
target_profit_volume=target_volume,
|
||||
)
|
||||
|
||||
async def generate_profit_forecast(
|
||||
self,
|
||||
simulation_id: int,
|
||||
) -> str:
|
||||
"""AI 生成利润预测分析
|
||||
|
||||
遵循瑞小美 AI 接入规范
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
|
||||
Returns:
|
||||
AI 分析内容
|
||||
"""
|
||||
# 获取模拟数据
|
||||
result = await self.db.execute(
|
||||
select(ProfitSimulation).options(
|
||||
selectinload(ProfitSimulation.pricing_plan).selectinload(PricingPlan.project),
|
||||
selectinload(ProfitSimulation.sensitivity_analyses),
|
||||
).where(ProfitSimulation.id == simulation_id)
|
||||
)
|
||||
simulation = result.scalar_one_or_none()
|
||||
|
||||
if not simulation:
|
||||
raise ValueError(f"模拟记录不存在: {simulation_id}")
|
||||
|
||||
# 格式化数据
|
||||
pricing_data = f"""- 定价方案:{simulation.pricing_plan.plan_name}
|
||||
- 策略类型:{simulation.pricing_plan.strategy_type}
|
||||
- 基础成本:{float(simulation.pricing_plan.base_cost):.2f} 元
|
||||
- 建议价格:{float(simulation.pricing_plan.suggested_price):.2f} 元
|
||||
- 目标毛利率:{float(simulation.pricing_plan.target_margin):.1f}%"""
|
||||
|
||||
simulation_data = f"""- 模拟价格:{float(simulation.price):.2f} 元
|
||||
- 预估客量:{simulation.estimated_volume} ({simulation.period_type})
|
||||
- 预估收入:{float(simulation.estimated_revenue):.2f} 元
|
||||
- 预估成本:{float(simulation.estimated_cost):.2f} 元
|
||||
- 预估利润:{float(simulation.estimated_profit):.2f} 元
|
||||
- 利润率:{float(simulation.profit_margin):.1f}%
|
||||
- 盈亏平衡客量:{simulation.breakeven_volume}"""
|
||||
|
||||
# 格式化敏感性数据
|
||||
sensitivity_data = "暂无敏感性分析数据"
|
||||
if simulation.sensitivity_analyses:
|
||||
lines = []
|
||||
for sa in simulation.sensitivity_analyses:
|
||||
lines.append(
|
||||
f" - 价格{float(sa.price_change_rate):+.0f}%: "
|
||||
f"价格{float(sa.adjusted_price):.0f}元, "
|
||||
f"利润{float(sa.adjusted_profit):.0f}元 "
|
||||
f"({float(sa.profit_change_rate):+.1f}%)"
|
||||
)
|
||||
sensitivity_data = "\n".join(lines)
|
||||
|
||||
# 调用 AI
|
||||
ai_service = AIServiceWrapper(db_session=self.db)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": USER_PROMPT.format(
|
||||
project_name=simulation.pricing_plan.project.project_name,
|
||||
pricing_data=pricing_data,
|
||||
simulation_data=simulation_data,
|
||||
sensitivity_data=sensitivity_data,
|
||||
)},
|
||||
]
|
||||
|
||||
try:
|
||||
response = await ai_service.chat(
|
||||
messages=messages,
|
||||
prompt_name=PROMPT_META["name"], # 必填!
|
||||
)
|
||||
return response.content
|
||||
except Exception as e:
|
||||
return f"AI 分析暂不可用: {str(e)}"
|
||||
|
||||
async def get_simulation_list(
|
||||
self,
|
||||
pricing_plan_id: Optional[int] = None,
|
||||
period_type: Optional[PeriodType] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[List[ProfitSimulation], int]:
|
||||
"""获取模拟列表
|
||||
|
||||
Returns:
|
||||
(模拟列表, 总数)
|
||||
"""
|
||||
query = select(ProfitSimulation).options(
|
||||
selectinload(ProfitSimulation.pricing_plan).selectinload(PricingPlan.project)
|
||||
)
|
||||
|
||||
if pricing_plan_id:
|
||||
query = query.where(ProfitSimulation.pricing_plan_id == pricing_plan_id)
|
||||
|
||||
if period_type:
|
||||
query = query.where(ProfitSimulation.period_type == period_type.value)
|
||||
|
||||
# 计算总数
|
||||
count_query = select(func.count()).select_from(
|
||||
query.subquery()
|
||||
)
|
||||
total_result = await self.db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
# 分页
|
||||
query = query.order_by(ProfitSimulation.created_at.desc())
|
||||
query = query.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
simulations = result.scalars().all()
|
||||
|
||||
return simulations, total
|
||||
|
||||
async def delete_simulation(self, simulation_id: int) -> bool:
|
||||
"""删除模拟记录"""
|
||||
result = await self.db.execute(
|
||||
select(ProfitSimulation).where(ProfitSimulation.id == simulation_id)
|
||||
)
|
||||
simulation = result.scalar_one_or_none()
|
||||
|
||||
if simulation:
|
||||
await self.db.delete(simulation)
|
||||
await self.db.flush()
|
||||
return True
|
||||
|
||||
return False
|
||||
29
后端服务/prompts/__init__.py
Normal file
29
后端服务/prompts/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""AI 提示词模块
|
||||
|
||||
遵循瑞小美 AI 接入规范:
|
||||
- 文件位置:{模块}/后端服务/prompts/{功能名}_prompts.py
|
||||
- 每个文件必须包含:PROMPT_META, SYSTEM_PROMPT, USER_PROMPT
|
||||
"""
|
||||
|
||||
from prompts.pricing_advice_prompts import (
|
||||
PROMPT_META as PRICING_ADVICE_META,
|
||||
SYSTEM_PROMPT as PRICING_ADVICE_SYSTEM,
|
||||
USER_PROMPT as PRICING_ADVICE_USER,
|
||||
)
|
||||
from prompts.market_analysis_prompts import (
|
||||
PROMPT_META as MARKET_ANALYSIS_META,
|
||||
SYSTEM_PROMPT as MARKET_ANALYSIS_SYSTEM,
|
||||
USER_PROMPT as MARKET_ANALYSIS_USER,
|
||||
)
|
||||
from prompts.profit_forecast_prompts import (
|
||||
PROMPT_META as PROFIT_FORECAST_META,
|
||||
SYSTEM_PROMPT as PROFIT_FORECAST_SYSTEM,
|
||||
USER_PROMPT as PROFIT_FORECAST_USER,
|
||||
)
|
||||
|
||||
# 导出所有提示词元数据
|
||||
ALL_PROMPT_METAS = [
|
||||
PRICING_ADVICE_META,
|
||||
MARKET_ANALYSIS_META,
|
||||
PROFIT_FORECAST_META,
|
||||
]
|
||||
63
后端服务/prompts/market_analysis_prompts.py
Normal file
63
后端服务/prompts/market_analysis_prompts.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""市场分析报告提示词
|
||||
|
||||
分析市场数据,生成市场分析报告
|
||||
"""
|
||||
|
||||
# 提示词元数据(必须包含)
|
||||
PROMPT_META = {
|
||||
"name": "market_analysis",
|
||||
"display_name": "市场分析报告",
|
||||
"description": "分析竞品价格和市场行情,生成市场分析报告",
|
||||
"module": "pricing_model",
|
||||
"variables": ["project_name", "competitor_data", "benchmark_data"],
|
||||
}
|
||||
|
||||
# 系统提示词(必须包含)
|
||||
SYSTEM_PROMPT = """你是一位专业的医美行业市场分析师,擅长分析竞争格局和价格趋势。
|
||||
|
||||
你需要根据提供的竞品价格数据和标杆机构数据,分析市场定价情况,给出市场洞察。
|
||||
|
||||
分析维度:
|
||||
1. **价格分布**:分析市场价格的分布特征(高、中、低端)
|
||||
2. **竞争格局**:识别主要竞争对手的定价策略
|
||||
3. **标杆对比**:与行业标杆进行对比分析
|
||||
4. **趋势判断**:分析价格变化趋势
|
||||
5. **机会识别**:发现市场空白和定价机会
|
||||
|
||||
输出要求:
|
||||
- 使用数据支撑分析结论
|
||||
- 提供可视化友好的结构
|
||||
- 给出具体的市场建议"""
|
||||
|
||||
# 用户提示词模板(必须包含)
|
||||
USER_PROMPT = """请分析以下医美项目的市场行情:
|
||||
|
||||
## 项目名称
|
||||
{project_name}
|
||||
|
||||
## 竞品价格数据
|
||||
{competitor_data}
|
||||
|
||||
## 标杆机构参考
|
||||
{benchmark_data}
|
||||
|
||||
---
|
||||
|
||||
请给出以下分析内容:
|
||||
|
||||
### 1. 市场价格概览
|
||||
- 价格区间
|
||||
- 均价和中位价
|
||||
- 价格分布特征
|
||||
|
||||
### 2. 竞争格局分析
|
||||
- 主要竞争对手定位
|
||||
- 定价策略特点
|
||||
|
||||
### 3. 市场定位建议
|
||||
- 建议的定价区间
|
||||
- 定价策略建议
|
||||
|
||||
### 4. 市场机会与风险
|
||||
- 潜在机会
|
||||
- 需要关注的风险"""
|
||||
66
后端服务/prompts/pricing_advice_prompts.py
Normal file
66
后端服务/prompts/pricing_advice_prompts.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""定价建议生成提示词
|
||||
|
||||
综合成本、市场、目标利润率,生成项目定价建议
|
||||
"""
|
||||
|
||||
# 提示词元数据(必须包含)
|
||||
PROMPT_META = {
|
||||
"name": "pricing_advice",
|
||||
"display_name": "智能定价建议",
|
||||
"description": "综合成本、市场、目标利润率,生成项目定价建议",
|
||||
"module": "pricing_model",
|
||||
"variables": ["project_name", "cost_data", "market_data", "target_margin"],
|
||||
}
|
||||
|
||||
# 系统提示词(必须包含)
|
||||
SYSTEM_PROMPT = """你是一位专业的医美行业定价分析师,拥有丰富的市场分析和定价策略经验。
|
||||
|
||||
你需要根据提供的成本数据、市场行情数据,结合目标利润率,给出专业的定价建议。
|
||||
|
||||
分析时请考虑以下维度:
|
||||
1. **成本结构分析**:评估成本构成的合理性,识别成本优化空间
|
||||
2. **市场竞争态势**:分析竞品定价分布,确定市场位置
|
||||
3. **目标客群定位**:根据定价策略匹配目标客群
|
||||
4. **风险评估**:识别定价可能面临的风险和挑战
|
||||
|
||||
输出要求:
|
||||
- 使用清晰的结构化格式
|
||||
- 提供具体的数字建议
|
||||
- 包含不同策略的对比分析
|
||||
- 给出风险提示和注意事项
|
||||
|
||||
请用专业但易懂的语言回复,避免过于学术化的表达。"""
|
||||
|
||||
# 用户提示词模板(必须包含)
|
||||
USER_PROMPT = """请为以下医美项目生成定价建议:
|
||||
|
||||
## 项目信息
|
||||
**项目名称**:{project_name}
|
||||
|
||||
## 成本数据
|
||||
{cost_data}
|
||||
|
||||
## 市场行情
|
||||
{market_data}
|
||||
|
||||
## 目标毛利率
|
||||
{target_margin}%
|
||||
|
||||
---
|
||||
|
||||
请给出以下内容:
|
||||
|
||||
### 1. 定价建议区间
|
||||
分析成本和市场数据,给出合理的定价区间。
|
||||
|
||||
### 2. 策略定价建议
|
||||
针对三种定价策略给出具体价格:
|
||||
- **引流款**:低价引流,适合获客
|
||||
- **利润款**:平衡利润与竞争力
|
||||
- **高端款**:高端定位,高利润
|
||||
|
||||
### 3. 推荐方案
|
||||
给出最推荐的定价方案及理由。
|
||||
|
||||
### 4. 风险提示
|
||||
指出定价时需要注意的风险和问题。"""
|
||||
69
后端服务/prompts/profit_forecast_prompts.py
Normal file
69
后端服务/prompts/profit_forecast_prompts.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""利润预测分析提示词
|
||||
|
||||
分析利润趋势与风险,提供经营建议
|
||||
"""
|
||||
|
||||
# 提示词元数据(必须包含)
|
||||
PROMPT_META = {
|
||||
"name": "profit_forecast",
|
||||
"display_name": "利润预测分析",
|
||||
"description": "分析利润趋势与风险,提供经营建议",
|
||||
"module": "pricing_model",
|
||||
"variables": ["project_name", "pricing_data", "simulation_data", "sensitivity_data"],
|
||||
}
|
||||
|
||||
# 系统提示词(必须包含)
|
||||
SYSTEM_PROMPT = """你是一位专业的医美行业财务分析师,擅长利润分析和经营预测。
|
||||
|
||||
你需要根据提供的定价数据、利润模拟数据和敏感性分析数据,评估项目的盈利能力和风险。
|
||||
|
||||
分析维度:
|
||||
1. **盈利能力评估**:评估项目的利润水平和利润率
|
||||
2. **盈亏平衡分析**:分析达到盈亏平衡所需的客量
|
||||
3. **敏感性分析**:分析价格变动对利润的影响
|
||||
4. **风险评估**:识别可能影响利润的风险因素
|
||||
5. **优化建议**:提供提升利润的具体建议
|
||||
|
||||
输出要求:
|
||||
- 使用清晰的数据对比
|
||||
- 提供具体的经营建议
|
||||
- 给出风险预警信息"""
|
||||
|
||||
# 用户提示词模板(必须包含)
|
||||
USER_PROMPT = """请分析以下医美项目的利润预测:
|
||||
|
||||
## 项目名称
|
||||
{project_name}
|
||||
|
||||
## 定价数据
|
||||
{pricing_data}
|
||||
|
||||
## 利润模拟结果
|
||||
{simulation_data}
|
||||
|
||||
## 敏感性分析
|
||||
{sensitivity_data}
|
||||
|
||||
---
|
||||
|
||||
请给出以下分析内容:
|
||||
|
||||
### 1. 盈利能力评估
|
||||
- 预估利润水平
|
||||
- 利润率分析
|
||||
- 与行业对比
|
||||
|
||||
### 2. 盈亏平衡分析
|
||||
- 盈亏平衡点客量
|
||||
- 安全边际评估
|
||||
- 达标难度分析
|
||||
|
||||
### 3. 敏感性分析解读
|
||||
- 价格弹性分析
|
||||
- 关键影响因素
|
||||
- 临界点提醒
|
||||
|
||||
### 4. 经营建议
|
||||
- 定价优化建议
|
||||
- 成本控制建议
|
||||
- 风险防范建议"""
|
||||
36
后端服务/pytest.ini
Normal file
36
后端服务/pytest.ini
Normal file
@@ -0,0 +1,36 @@
|
||||
[pytest]
|
||||
# pytest 配置
|
||||
# 遵循瑞小美系统技术栈标准
|
||||
|
||||
# 测试路径
|
||||
testpaths = tests
|
||||
|
||||
# 异步测试配置(pytest-asyncio 0.23+)
|
||||
asyncio_mode = auto
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
|
||||
# 测试输出配置
|
||||
addopts =
|
||||
-v
|
||||
--strict-markers
|
||||
--tb=short
|
||||
-p no:warnings
|
||||
|
||||
# 测试标记
|
||||
markers =
|
||||
unit: 单元测试
|
||||
integration: 集成测试
|
||||
slow: 慢速测试
|
||||
api: API 测试
|
||||
|
||||
# 日志配置
|
||||
log_cli = true
|
||||
log_cli_level = INFO
|
||||
log_cli_format = %(asctime)s [%(levelname)s] %(message)s
|
||||
log_cli_date_format = %Y-%m-%d %H:%M:%S
|
||||
|
||||
# 最小版本
|
||||
minversion = 7.4
|
||||
|
||||
# Python 路径
|
||||
pythonpath = .
|
||||
41
后端服务/requirements.txt
Normal file
41
后端服务/requirements.txt
Normal file
@@ -0,0 +1,41 @@
|
||||
# FastAPI 框架
|
||||
fastapi==0.109.2
|
||||
uvicorn[standard]==0.27.1
|
||||
python-multipart==0.0.9
|
||||
|
||||
# 数据库
|
||||
sqlalchemy[asyncio]==2.0.25
|
||||
aiomysql==0.2.0
|
||||
pymysql==1.1.0
|
||||
|
||||
# 配置管理
|
||||
pydantic==2.6.1
|
||||
pydantic-settings==2.1.0
|
||||
|
||||
# 数据处理
|
||||
python-dateutil==2.8.2
|
||||
openpyxl==3.1.2
|
||||
|
||||
# HTTP 客户端
|
||||
httpx==0.26.0
|
||||
aiohttp==3.9.3
|
||||
|
||||
# 工具库
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
|
||||
# 日志
|
||||
structlog==24.1.0
|
||||
|
||||
# 测试
|
||||
pytest==7.4.4
|
||||
pytest-asyncio==0.23.4
|
||||
pytest-cov==4.1.0
|
||||
httpx==0.26.0
|
||||
aiosqlite==0.19.0
|
||||
faker==22.5.1
|
||||
|
||||
# 代码质量
|
||||
black==24.1.1
|
||||
isort==5.13.2
|
||||
flake8==7.0.0
|
||||
40
后端服务/run_tests.sh
Executable file
40
后端服务/run_tests.sh
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
# 测试运行脚本
|
||||
# 遵循瑞小美系统技术栈标准
|
||||
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo "智能项目定价模型 - 测试套件"
|
||||
echo "=========================================="
|
||||
|
||||
# 检查虚拟环境
|
||||
if [ -z "$VIRTUAL_ENV" ]; then
|
||||
echo "建议在虚拟环境中运行测试"
|
||||
fi
|
||||
|
||||
# 安装测试依赖
|
||||
echo ""
|
||||
echo "[1/4] 安装测试依赖..."
|
||||
pip install -q pytest pytest-asyncio pytest-cov aiosqlite faker httpx
|
||||
|
||||
# 运行单元测试
|
||||
echo ""
|
||||
echo "[2/4] 运行单元测试..."
|
||||
pytest tests/test_services/ -v --tb=short -m "unit" || true
|
||||
|
||||
# 运行 API 集成测试
|
||||
echo ""
|
||||
echo "[3/4] 运行 API 集成测试..."
|
||||
pytest tests/test_api/ -v --tb=short -m "api" || true
|
||||
|
||||
# 生成覆盖率报告
|
||||
echo ""
|
||||
echo "[4/4] 生成覆盖率报告..."
|
||||
pytest tests/ --cov=app --cov-report=term-missing --cov-report=html:coverage_report || true
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "测试完成!"
|
||||
echo "覆盖率报告: coverage_report/index.html"
|
||||
echo "=========================================="
|
||||
5
后端服务/tests/__init__.py
Normal file
5
后端服务/tests/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""智能项目定价模型 - 测试模块
|
||||
|
||||
遵循瑞小美系统技术栈标准
|
||||
测试框架: pytest + pytest-asyncio
|
||||
"""
|
||||
349
后端服务/tests/conftest.py
Normal file
349
后端服务/tests/conftest.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""pytest 测试配置
|
||||
|
||||
提供测试所需的 fixtures:
|
||||
- 测试数据库会话
|
||||
- 测试客户端
|
||||
- 测试数据工厂
|
||||
|
||||
遵循瑞小美系统技术栈标准
|
||||
"""
|
||||
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# 设置测试环境变量
|
||||
os.environ["APP_ENV"] = "test"
|
||||
os.environ["DEBUG"] = "true"
|
||||
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
from app.main import app
|
||||
from app.database import Base, get_db
|
||||
from app.models import (
|
||||
Category, Material, Equipment, StaffLevel, FixedCost,
|
||||
Project, ProjectCostItem, ProjectLaborCost, ProjectCostSummary,
|
||||
Competitor, CompetitorPrice, BenchmarkPrice, MarketAnalysisResult,
|
||||
PricingPlan, ProfitSimulation, SensitivityAnalysis
|
||||
)
|
||||
|
||||
|
||||
# 测试数据库引擎(使用 SQLite 内存数据库)
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
test_engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
echo=False,
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
test_session_maker = async_sessionmaker(
|
||||
test_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取测试数据库会话
|
||||
|
||||
每个测试函数使用独立的数据库会话,测试后自动回滚
|
||||
"""
|
||||
# 创建表
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async with test_session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.rollback()
|
||||
await session.close()
|
||||
|
||||
# 清理表
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""获取测试客户端
|
||||
|
||||
使用测试数据库会话替换应用的数据库依赖
|
||||
"""
|
||||
async def override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ============ 测试数据工厂 ============
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_category(db_session: AsyncSession) -> Category:
|
||||
"""创建示例项目分类"""
|
||||
category = Category(
|
||||
category_name="光电类",
|
||||
parent_id=None,
|
||||
sort_order=1,
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(category)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(category)
|
||||
return category
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_material(db_session: AsyncSession) -> Material:
|
||||
"""创建示例耗材"""
|
||||
material = Material(
|
||||
material_code="MAT001",
|
||||
material_name="冷凝胶",
|
||||
unit="ml",
|
||||
unit_price=Decimal("2.00"),
|
||||
supplier="供应商A",
|
||||
material_type="consumable",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(material)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(material)
|
||||
return material
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_equipment(db_session: AsyncSession) -> Equipment:
|
||||
"""创建示例设备"""
|
||||
equipment = Equipment(
|
||||
equipment_code="EQP001",
|
||||
equipment_name="光子仪",
|
||||
original_value=Decimal("100000.00"),
|
||||
residual_rate=Decimal("5.00"),
|
||||
service_years=5,
|
||||
estimated_uses=2000,
|
||||
depreciation_per_use=Decimal("47.50"), # (100000 - 5000) / 2000
|
||||
purchase_date=datetime(2025, 1, 1).date(),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(equipment)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(equipment)
|
||||
return equipment
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_staff_level(db_session: AsyncSession) -> StaffLevel:
|
||||
"""创建示例人员级别"""
|
||||
staff_level = StaffLevel(
|
||||
level_code="L2",
|
||||
level_name="中级美容师",
|
||||
hourly_rate=Decimal("50.00"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(staff_level)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(staff_level)
|
||||
return staff_level
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_fixed_cost(db_session: AsyncSession) -> FixedCost:
|
||||
"""创建示例固定成本"""
|
||||
fixed_cost = FixedCost(
|
||||
cost_name="房租",
|
||||
cost_type="rent",
|
||||
monthly_amount=Decimal("30000.00"),
|
||||
year_month=datetime.now().strftime("%Y-%m"),
|
||||
allocation_method="count",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(fixed_cost)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(fixed_cost)
|
||||
return fixed_cost
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_project(
|
||||
db_session: AsyncSession,
|
||||
sample_category: Category
|
||||
) -> Project:
|
||||
"""创建示例服务项目"""
|
||||
project = Project(
|
||||
project_code="PRJ001",
|
||||
project_name="光子嫩肤",
|
||||
category_id=sample_category.id,
|
||||
description="IPL光子嫩肤项目",
|
||||
duration_minutes=60,
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(project)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_project_with_costs(
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_material: Material,
|
||||
sample_equipment: Equipment,
|
||||
sample_staff_level: StaffLevel,
|
||||
sample_fixed_cost: FixedCost,
|
||||
) -> Project:
|
||||
"""创建带成本明细的示例项目"""
|
||||
# 添加耗材成本
|
||||
cost_item = ProjectCostItem(
|
||||
project_id=sample_project.id,
|
||||
item_type="material",
|
||||
item_id=sample_material.id,
|
||||
quantity=Decimal("20"),
|
||||
unit_cost=Decimal("2.00"),
|
||||
total_cost=Decimal("40.00"),
|
||||
)
|
||||
db_session.add(cost_item)
|
||||
|
||||
# 添加设备折旧成本
|
||||
equip_cost = ProjectCostItem(
|
||||
project_id=sample_project.id,
|
||||
item_type="equipment",
|
||||
item_id=sample_equipment.id,
|
||||
quantity=Decimal("1"),
|
||||
unit_cost=Decimal("47.50"),
|
||||
total_cost=Decimal("47.50"),
|
||||
)
|
||||
db_session.add(equip_cost)
|
||||
|
||||
# 添加人工成本
|
||||
labor_cost = ProjectLaborCost(
|
||||
project_id=sample_project.id,
|
||||
staff_level_id=sample_staff_level.id,
|
||||
duration_minutes=60,
|
||||
hourly_rate=Decimal("50.00"),
|
||||
labor_cost=Decimal("50.00"), # 60分钟 / 60 * 50
|
||||
)
|
||||
db_session.add(labor_cost)
|
||||
|
||||
await db_session.commit()
|
||||
await db_session.refresh(sample_project)
|
||||
return sample_project
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_competitor(db_session: AsyncSession) -> Competitor:
|
||||
"""创建示例竞品机构"""
|
||||
competitor = Competitor(
|
||||
competitor_name="美丽人生医美",
|
||||
address="XX市XX路100号",
|
||||
distance_km=Decimal("2.5"),
|
||||
positioning="medium",
|
||||
contact="13800138000",
|
||||
is_key_competitor=True,
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(competitor)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(competitor)
|
||||
return competitor
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_competitor_price(
|
||||
db_session: AsyncSession,
|
||||
sample_competitor: Competitor,
|
||||
sample_project: Project,
|
||||
) -> CompetitorPrice:
|
||||
"""创建示例竞品价格"""
|
||||
price = CompetitorPrice(
|
||||
competitor_id=sample_competitor.id,
|
||||
project_id=sample_project.id,
|
||||
project_name="光子嫩肤",
|
||||
original_price=Decimal("680.00"),
|
||||
promo_price=Decimal("480.00"),
|
||||
member_price=Decimal("580.00"),
|
||||
price_source="meituan",
|
||||
collected_at=datetime.now().date(),
|
||||
)
|
||||
db_session.add(price)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(price)
|
||||
return price
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_pricing_plan(
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
) -> PricingPlan:
|
||||
"""创建示例定价方案"""
|
||||
plan = PricingPlan(
|
||||
project_id=sample_project.id,
|
||||
plan_name="2026年Q1定价",
|
||||
strategy_type="profit",
|
||||
base_cost=Decimal("280.50"),
|
||||
target_margin=Decimal("50.00"),
|
||||
suggested_price=Decimal("561.00"),
|
||||
final_price=Decimal("580.00"),
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(plan)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(plan)
|
||||
return plan
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sample_cost_summary(
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
) -> ProjectCostSummary:
|
||||
"""创建示例成本汇总"""
|
||||
cost_summary = ProjectCostSummary(
|
||||
project_id=sample_project.id,
|
||||
material_cost=Decimal("100.00"),
|
||||
equipment_cost=Decimal("50.00"),
|
||||
labor_cost=Decimal("50.00"),
|
||||
fixed_cost_allocation=Decimal("30.00"),
|
||||
total_cost=Decimal("230.00"),
|
||||
calculated_at=datetime.now(), # 确保设置计算时间
|
||||
)
|
||||
db_session.add(cost_summary)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(cost_summary)
|
||||
return cost_summary
|
||||
|
||||
|
||||
# ============ 辅助函数 ============
|
||||
|
||||
def assert_response_success(response, expected_code=0):
|
||||
"""断言响应成功"""
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == expected_code
|
||||
assert "data" in data
|
||||
return data["data"]
|
||||
|
||||
|
||||
def assert_response_error(response, expected_code):
|
||||
"""断言响应错误"""
|
||||
data = response.json()
|
||||
assert data["code"] == expected_code
|
||||
return data
|
||||
1
后端服务/tests/test_api/__init__.py
Normal file
1
后端服务/tests/test_api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API 层集成测试"""
|
||||
107
后端服务/tests/test_api/test_categories.py
Normal file
107
后端服务/tests/test_api/test_categories.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""项目分类接口测试"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models import Category
|
||||
from tests.conftest import assert_response_success
|
||||
|
||||
|
||||
class TestCategoriesAPI:
|
||||
"""项目分类 API 测试"""
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_category(self, client: AsyncClient):
|
||||
"""测试创建分类"""
|
||||
response = await client.post(
|
||||
"/api/v1/categories",
|
||||
json={
|
||||
"category_name": "测试分类",
|
||||
"sort_order": 1,
|
||||
"is_active": True
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["category_name"] == "测试分类"
|
||||
assert data["sort_order"] == 1
|
||||
assert data["is_active"] is True
|
||||
assert "id" in data
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_categories_list(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_category: Category
|
||||
):
|
||||
"""测试获取分类列表"""
|
||||
response = await client.get("/api/v1/categories")
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert "items" in data
|
||||
assert data["total"] >= 1
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_category_by_id(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_category: Category
|
||||
):
|
||||
"""测试获取单个分类"""
|
||||
response = await client.get(f"/api/v1/categories/{sample_category.id}")
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["id"] == sample_category.id
|
||||
assert data["category_name"] == sample_category.category_name
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_category(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_category: Category
|
||||
):
|
||||
"""测试更新分类"""
|
||||
response = await client.put(
|
||||
f"/api/v1/categories/{sample_category.id}",
|
||||
json={
|
||||
"category_name": "更新后分类",
|
||||
"sort_order": 99
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["category_name"] == "更新后分类"
|
||||
assert data["sort_order"] == 99
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_category(self, client: AsyncClient):
|
||||
"""测试删除分类"""
|
||||
# 先创建
|
||||
create_response = await client.post(
|
||||
"/api/v1/categories",
|
||||
json={"category_name": "待删除分类"}
|
||||
)
|
||||
created = assert_response_success(create_response)
|
||||
|
||||
# 删除
|
||||
delete_response = await client.delete(
|
||||
f"/api/v1/categories/{created['id']}"
|
||||
)
|
||||
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_category(self, client: AsyncClient):
|
||||
"""测试获取不存在的分类"""
|
||||
response = await client.get("/api/v1/categories/99999")
|
||||
|
||||
# API 返回 HTTP 404 + 错误详情
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["detail"]["code"] == 10002 # 数据不存在
|
||||
31
后端服务/tests/test_api/test_health.py
Normal file
31
后端服务/tests/test_api/test_health.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""健康检查接口测试"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
class TestHealthAPI:
|
||||
"""健康检查 API 测试"""
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, client: AsyncClient):
|
||||
"""测试健康检查端点"""
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["status"] == "healthy"
|
||||
assert "version" in data["data"]
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_root_endpoint(self, client: AsyncClient):
|
||||
"""测试根路径端点"""
|
||||
response = await client.get("/")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert "智能项目定价模型" in data["data"]["name"]
|
||||
113
后端服务/tests/test_api/test_market.py
Normal file
113
后端服务/tests/test_api/test_market.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""市场行情接口测试"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models import Project, Competitor, CompetitorPrice
|
||||
from tests.conftest import assert_response_success
|
||||
|
||||
|
||||
class TestMarketAPI:
|
||||
"""市场行情 API 测试"""
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_competitor(self, client: AsyncClient):
|
||||
"""测试创建竞品机构"""
|
||||
response = await client.post(
|
||||
"/api/v1/competitors",
|
||||
json={
|
||||
"competitor_name": "测试竞品",
|
||||
"address": "测试地址",
|
||||
"distance_km": 3.5,
|
||||
"positioning": "medium",
|
||||
"is_key_competitor": True
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["competitor_name"] == "测试竞品"
|
||||
assert float(data["distance_km"]) == 3.5
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_competitors_list(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_competitor: Competitor
|
||||
):
|
||||
"""测试获取竞品列表"""
|
||||
response = await client.get("/api/v1/competitors")
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert "items" in data
|
||||
assert data["total"] >= 1
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_competitor_price(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_competitor: Competitor,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试添加竞品价格"""
|
||||
response = await client.post(
|
||||
f"/api/v1/competitors/{sample_competitor.id}/prices",
|
||||
json={
|
||||
"project_id": sample_project.id,
|
||||
"project_name": "光子嫩肤",
|
||||
"original_price": 800.00,
|
||||
"promo_price": 600.00,
|
||||
"price_source": "meituan",
|
||||
"collected_at": "2026-01-20"
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert float(data["original_price"]) == 800.00
|
||||
assert float(data["promo_price"]) == 600.00
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_market_analysis(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_project: Project,
|
||||
sample_competitor_price: CompetitorPrice
|
||||
):
|
||||
"""测试市场分析"""
|
||||
response = await client.post(
|
||||
f"/api/v1/projects/{sample_project.id}/market-analysis",
|
||||
json={
|
||||
"include_benchmark": False
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["project_id"] == sample_project.id
|
||||
assert "price_statistics" in data
|
||||
assert "suggested_range" in data
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_market_analysis(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_project: Project,
|
||||
sample_competitor_price: CompetitorPrice
|
||||
):
|
||||
"""测试获取市场分析结果"""
|
||||
# 先执行分析
|
||||
await client.post(
|
||||
f"/api/v1/projects/{sample_project.id}/market-analysis",
|
||||
json={"include_benchmark": False}
|
||||
)
|
||||
|
||||
# 获取结果
|
||||
response = await client.get(
|
||||
f"/api/v1/projects/{sample_project.id}/market-analysis"
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["project_id"] == sample_project.id
|
||||
106
后端服务/tests/test_api/test_materials.py
Normal file
106
后端服务/tests/test_api/test_materials.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""耗材管理接口测试"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models import Material
|
||||
from tests.conftest import assert_response_success
|
||||
|
||||
|
||||
class TestMaterialsAPI:
|
||||
"""耗材管理 API 测试"""
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_material(self, client: AsyncClient):
|
||||
"""测试创建耗材"""
|
||||
response = await client.post(
|
||||
"/api/v1/materials",
|
||||
json={
|
||||
"material_code": "MAT_TEST001",
|
||||
"material_name": "测试耗材",
|
||||
"unit": "个",
|
||||
"unit_price": 10.50,
|
||||
"supplier": "测试供应商",
|
||||
"material_type": "consumable"
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["material_code"] == "MAT_TEST001"
|
||||
assert data["material_name"] == "测试耗材"
|
||||
assert float(data["unit_price"]) == 10.50
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_materials_list(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_material: Material
|
||||
):
|
||||
"""测试获取耗材列表"""
|
||||
response = await client.get("/api/v1/materials")
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert "items" in data
|
||||
assert data["total"] >= 1
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_materials_with_filter(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_material: Material
|
||||
):
|
||||
"""测试带筛选的耗材列表"""
|
||||
response = await client.get(
|
||||
"/api/v1/materials",
|
||||
params={"material_type": "consumable"}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["total"] >= 1
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_material(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_material: Material
|
||||
):
|
||||
"""测试更新耗材"""
|
||||
response = await client.put(
|
||||
f"/api/v1/materials/{sample_material.id}",
|
||||
json={
|
||||
"unit_price": 3.00,
|
||||
"supplier": "新供应商"
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert float(data["unit_price"]) == 3.00
|
||||
assert data["supplier"] == "新供应商"
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_material_code(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_material: Material
|
||||
):
|
||||
"""测试创建重复编码的耗材"""
|
||||
response = await client.post(
|
||||
"/api/v1/materials",
|
||||
json={
|
||||
"material_code": sample_material.material_code,
|
||||
"material_name": "重复编码耗材",
|
||||
"unit": "个",
|
||||
"unit_price": 10.00,
|
||||
"material_type": "consumable"
|
||||
}
|
||||
)
|
||||
|
||||
# API 返回 HTTP 400 + 错误详情
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["detail"]["code"] == 10003 # 数据已存在
|
||||
134
后端服务/tests/test_api/test_pricing.py
Normal file
134
后端服务/tests/test_api/test_pricing.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""智能定价接口测试"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
|
||||
from app.models import Project, ProjectCostSummary, PricingPlan
|
||||
from tests.conftest import assert_response_success
|
||||
|
||||
|
||||
class TestPricingAPI:
|
||||
"""智能定价 API 测试"""
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pricing_plans_list(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试获取定价方案列表"""
|
||||
response = await client.get("/api/v1/pricing-plans")
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert "items" in data
|
||||
assert data["total"] >= 1
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pricing_plan_detail(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试获取定价方案详情"""
|
||||
response = await client.get(
|
||||
f"/api/v1/pricing-plans/{sample_pricing_plan.id}"
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["id"] == sample_pricing_plan.id
|
||||
assert data["plan_name"] == sample_pricing_plan.plan_name
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pricing_plan(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
db_session,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试创建定价方案"""
|
||||
# 先添加成本汇总
|
||||
cost_summary = ProjectCostSummary(
|
||||
project_id=sample_project.id,
|
||||
material_cost=Decimal("100"),
|
||||
equipment_cost=Decimal("50"),
|
||||
labor_cost=Decimal("50"),
|
||||
fixed_cost_allocation=Decimal("0"),
|
||||
total_cost=Decimal("200"),
|
||||
calculated_at=datetime.now()
|
||||
)
|
||||
db_session.add(cost_summary)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/pricing-plans",
|
||||
json={
|
||||
"project_id": sample_project.id,
|
||||
"plan_name": "测试定价方案",
|
||||
"strategy_type": "profit",
|
||||
"target_margin": 50
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["project_id"] == sample_project.id
|
||||
assert data["plan_name"] == "测试定价方案"
|
||||
assert data["strategy_type"] == "profit"
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pricing_plan(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试更新定价方案"""
|
||||
response = await client.put(
|
||||
f"/api/v1/pricing-plans/{sample_pricing_plan.id}",
|
||||
json={
|
||||
"final_price": 599.00,
|
||||
"plan_name": "更新后方案"
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert float(data["final_price"]) == 599.00
|
||||
assert data["plan_name"] == "更新后方案"
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_strategy(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
db_session,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试策略模拟"""
|
||||
# 添加成本汇总
|
||||
cost_summary = ProjectCostSummary(
|
||||
project_id=sample_project.id,
|
||||
total_cost=Decimal("200"),
|
||||
material_cost=Decimal("100"),
|
||||
equipment_cost=Decimal("50"),
|
||||
labor_cost=Decimal("50"),
|
||||
fixed_cost_allocation=Decimal("0"),
|
||||
calculated_at=datetime.now()
|
||||
)
|
||||
db_session.add(cost_summary)
|
||||
await db_session.commit()
|
||||
|
||||
response = await client.post(
|
||||
f"/api/v1/projects/{sample_project.id}/simulate-strategy",
|
||||
json={
|
||||
"strategies": ["traffic", "profit", "premium"],
|
||||
"target_margin": 50
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["project_id"] == sample_project.id
|
||||
assert len(data["results"]) == 3
|
||||
133
后端服务/tests/test_api/test_profit.py
Normal file
133
后端服务/tests/test_api/test_profit.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""利润模拟接口测试"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models import PricingPlan
|
||||
from tests.conftest import assert_response_success
|
||||
|
||||
|
||||
class TestProfitAPI:
|
||||
"""利润模拟 API 测试"""
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_profit(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试利润模拟"""
|
||||
response = await client.post(
|
||||
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/simulate-profit",
|
||||
json={
|
||||
"price": 580.00,
|
||||
"estimated_volume": 100,
|
||||
"period_type": "monthly"
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["pricing_plan_id"] == sample_pricing_plan.id
|
||||
assert "input" in data
|
||||
assert "result" in data
|
||||
assert "breakeven_analysis" in data
|
||||
assert data["input"]["price"] == 580.00
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_profit_simulations_list(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试获取模拟列表"""
|
||||
# 先创建一个模拟
|
||||
await client.post(
|
||||
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/simulate-profit",
|
||||
json={
|
||||
"price": 580.00,
|
||||
"estimated_volume": 100,
|
||||
"period_type": "monthly"
|
||||
}
|
||||
)
|
||||
|
||||
response = await client.get("/api/v1/profit-simulations")
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert "items" in data
|
||||
assert data["total"] >= 1
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_sensitivity_analysis(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试敏感性分析"""
|
||||
# 先创建模拟
|
||||
sim_response = await client.post(
|
||||
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/simulate-profit",
|
||||
json={
|
||||
"price": 580.00,
|
||||
"estimated_volume": 100,
|
||||
"period_type": "monthly"
|
||||
}
|
||||
)
|
||||
sim_data = assert_response_success(sim_response)
|
||||
|
||||
# 执行敏感性分析
|
||||
response = await client.post(
|
||||
f"/api/v1/profit-simulations/{sim_data['simulation_id']}/sensitivity",
|
||||
json={
|
||||
"price_change_rates": [-20, -10, 0, 10, 20]
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert len(data["sensitivity_results"]) == 5
|
||||
assert "insights" in data
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_breakeven_analysis(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试盈亏平衡分析"""
|
||||
response = await client.get(
|
||||
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/breakeven"
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["pricing_plan_id"] == sample_pricing_plan.id
|
||||
assert "breakeven_volume" in data
|
||||
assert "current_margin" in data
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_simulation(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试删除模拟"""
|
||||
# 创建模拟
|
||||
sim_response = await client.post(
|
||||
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/simulate-profit",
|
||||
json={
|
||||
"price": 580.00,
|
||||
"estimated_volume": 100,
|
||||
"period_type": "monthly"
|
||||
}
|
||||
)
|
||||
sim_data = assert_response_success(sim_response)
|
||||
|
||||
# 删除
|
||||
delete_response = await client.delete(
|
||||
f"/api/v1/profit-simulations/{sim_data['simulation_id']}"
|
||||
)
|
||||
|
||||
assert delete_response.status_code == 200
|
||||
152
后端服务/tests/test_api/test_projects.py
Normal file
152
后端服务/tests/test_api/test_projects.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""服务项目接口测试"""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models import Project, Category, Material, Equipment, StaffLevel
|
||||
from tests.conftest import assert_response_success
|
||||
|
||||
|
||||
class TestProjectsAPI:
|
||||
"""服务项目 API 测试"""
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_category: Category
|
||||
):
|
||||
"""测试创建项目"""
|
||||
response = await client.post(
|
||||
"/api/v1/projects",
|
||||
json={
|
||||
"project_code": "PRJ_TEST001",
|
||||
"project_name": "测试项目",
|
||||
"category_id": sample_category.id,
|
||||
"description": "测试项目描述",
|
||||
"duration_minutes": 45,
|
||||
"is_active": True
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["project_code"] == "PRJ_TEST001"
|
||||
assert data["project_name"] == "测试项目"
|
||||
assert data["duration_minutes"] == 45
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_projects_list(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试获取项目列表"""
|
||||
response = await client.get("/api/v1/projects")
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert "items" in data
|
||||
assert data["total"] >= 1
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_detail(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试获取项目详情"""
|
||||
response = await client.get(f"/api/v1/projects/{sample_project.id}")
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["id"] == sample_project.id
|
||||
assert data["project_name"] == sample_project.project_name
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_cost_item(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_project: Project,
|
||||
sample_material: Material
|
||||
):
|
||||
"""测试添加成本明细"""
|
||||
response = await client.post(
|
||||
f"/api/v1/projects/{sample_project.id}/cost-items",
|
||||
json={
|
||||
"item_type": "material",
|
||||
"item_id": sample_material.id,
|
||||
"quantity": 5,
|
||||
"remark": "测试备注"
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["item_type"] == "material"
|
||||
assert float(data["quantity"]) == 5.0
|
||||
assert float(data["total_cost"]) == 10.0 # 5 * 2.00
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_labor_cost(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_project: Project,
|
||||
sample_staff_level: StaffLevel
|
||||
):
|
||||
"""测试添加人工成本"""
|
||||
response = await client.post(
|
||||
f"/api/v1/projects/{sample_project.id}/labor-costs",
|
||||
json={
|
||||
"staff_level_id": sample_staff_level.id,
|
||||
"duration_minutes": 30
|
||||
}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["duration_minutes"] == 30
|
||||
assert float(data["labor_cost"]) == 25.0 # 30/60 * 50
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_project_cost(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_project_with_costs: Project
|
||||
):
|
||||
"""测试计算项目成本"""
|
||||
response = await client.post(
|
||||
f"/api/v1/projects/{sample_project_with_costs.id}/calculate-cost",
|
||||
json={"allocation_method": "count"}
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["project_id"] == sample_project_with_costs.id
|
||||
assert "cost_breakdown" in data
|
||||
assert "total_cost" in data
|
||||
assert data["total_cost"] > 0
|
||||
|
||||
@pytest.mark.api
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cost_summary(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
sample_project_with_costs: Project
|
||||
):
|
||||
"""测试获取成本汇总"""
|
||||
# 先计算成本
|
||||
await client.post(
|
||||
f"/api/v1/projects/{sample_project_with_costs.id}/calculate-cost",
|
||||
json={"allocation_method": "count"}
|
||||
)
|
||||
|
||||
# 获取汇总
|
||||
response = await client.get(
|
||||
f"/api/v1/projects/{sample_project_with_costs.id}/cost-summary"
|
||||
)
|
||||
|
||||
data = assert_response_success(response)
|
||||
assert data["project_id"] == sample_project_with_costs.id
|
||||
assert "material_cost" in data
|
||||
assert "total_cost" in data
|
||||
1
后端服务/tests/test_services/__init__.py
Normal file
1
后端服务/tests/test_services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""服务层单元测试"""
|
||||
415
后端服务/tests/test_services/test_cost_service.py
Normal file
415
后端服务/tests/test_services/test_cost_service.py
Normal file
@@ -0,0 +1,415 @@
|
||||
"""成本计算服务单元测试
|
||||
|
||||
测试 CostService 的核心业务逻辑
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.cost_service import CostService
|
||||
from app.schemas.project_cost import AllocationMethod, CostItemType
|
||||
from app.models import (
|
||||
Material, Equipment, StaffLevel, Project, FixedCost,
|
||||
ProjectCostItem, ProjectLaborCost, ProjectCostSummary
|
||||
)
|
||||
|
||||
|
||||
class TestCostService:
|
||||
"""成本服务测试类"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_material_info(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_material: Material
|
||||
):
|
||||
"""测试获取耗材信息"""
|
||||
service = CostService(db_session)
|
||||
|
||||
# 获取存在的耗材
|
||||
material = await service.get_material_info(sample_material.id)
|
||||
assert material is not None
|
||||
assert material.material_name == "冷凝胶"
|
||||
assert material.unit_price == Decimal("2.00")
|
||||
|
||||
# 获取不存在的耗材
|
||||
material = await service.get_material_info(99999)
|
||||
assert material is None
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_equipment_info(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_equipment: Equipment
|
||||
):
|
||||
"""测试获取设备信息"""
|
||||
service = CostService(db_session)
|
||||
|
||||
# 获取存在的设备
|
||||
equipment = await service.get_equipment_info(sample_equipment.id)
|
||||
assert equipment is not None
|
||||
assert equipment.equipment_name == "光子仪"
|
||||
assert equipment.depreciation_per_use == Decimal("47.50")
|
||||
|
||||
# 获取不存在的设备
|
||||
equipment = await service.get_equipment_info(99999)
|
||||
assert equipment is None
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_staff_level_info(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_staff_level: StaffLevel
|
||||
):
|
||||
"""测试获取人员级别信息"""
|
||||
service = CostService(db_session)
|
||||
|
||||
# 获取存在的级别
|
||||
level = await service.get_staff_level_info(sample_staff_level.id)
|
||||
assert level is not None
|
||||
assert level.level_name == "中级美容师"
|
||||
assert level.hourly_rate == Decimal("50.00")
|
||||
|
||||
# 获取不存在的级别
|
||||
level = await service.get_staff_level_info(99999)
|
||||
assert level is None
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_material_cost(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project_with_costs: Project
|
||||
):
|
||||
"""测试耗材成本计算"""
|
||||
service = CostService(db_session)
|
||||
|
||||
total, breakdown = await service.calculate_material_cost(
|
||||
sample_project_with_costs.id
|
||||
)
|
||||
|
||||
assert total == Decimal("40.00") # 20 * 2.00
|
||||
assert len(breakdown) == 1
|
||||
assert breakdown[0]["name"] == "冷凝胶"
|
||||
assert breakdown[0]["quantity"] == 20.0
|
||||
assert breakdown[0]["total"] == 40.0
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_equipment_cost(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project_with_costs: Project
|
||||
):
|
||||
"""测试设备折旧成本计算"""
|
||||
service = CostService(db_session)
|
||||
|
||||
total, breakdown = await service.calculate_equipment_cost(
|
||||
sample_project_with_costs.id
|
||||
)
|
||||
|
||||
assert total == Decimal("47.50") # 1 * 47.50
|
||||
assert len(breakdown) == 1
|
||||
assert breakdown[0]["name"] == "光子仪"
|
||||
assert breakdown[0]["depreciation_per_use"] == 47.5
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_labor_cost(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project_with_costs: Project
|
||||
):
|
||||
"""测试人工成本计算"""
|
||||
service = CostService(db_session)
|
||||
|
||||
total, breakdown = await service.calculate_labor_cost(
|
||||
sample_project_with_costs.id
|
||||
)
|
||||
|
||||
assert total == Decimal("50.00") # 60分钟 / 60 * 50
|
||||
assert len(breakdown) == 1
|
||||
assert breakdown[0]["name"] == "中级美容师"
|
||||
assert breakdown[0]["duration_minutes"] == 60
|
||||
assert breakdown[0]["hourly_rate"] == 50.0
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_fixed_cost_allocation_by_count(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_fixed_cost: FixedCost
|
||||
):
|
||||
"""测试固定成本按项目数量分摊"""
|
||||
service = CostService(db_session)
|
||||
|
||||
allocation, detail = await service.calculate_fixed_cost_allocation(
|
||||
sample_project.id,
|
||||
method=AllocationMethod.COUNT
|
||||
)
|
||||
|
||||
# 只有一个项目,分摊全部固定成本
|
||||
assert allocation == Decimal("30000.00")
|
||||
assert detail["method"] == "count"
|
||||
assert detail["project_count"] == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_fixed_cost_allocation_by_duration(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_fixed_cost: FixedCost
|
||||
):
|
||||
"""测试固定成本按时长分摊"""
|
||||
service = CostService(db_session)
|
||||
|
||||
allocation, detail = await service.calculate_fixed_cost_allocation(
|
||||
sample_project.id,
|
||||
method=AllocationMethod.DURATION
|
||||
)
|
||||
|
||||
# 只有一个项目,占比 100%
|
||||
assert allocation == Decimal("30000.00")
|
||||
assert detail["method"] == "duration"
|
||||
assert detail["project_duration"] == 60
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_project_cost(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project_with_costs: Project,
|
||||
sample_fixed_cost: FixedCost
|
||||
):
|
||||
"""测试项目总成本计算"""
|
||||
service = CostService(db_session)
|
||||
|
||||
result = await service.calculate_project_cost(
|
||||
sample_project_with_costs.id,
|
||||
allocation_method=AllocationMethod.COUNT
|
||||
)
|
||||
|
||||
assert result.project_id == sample_project_with_costs.id
|
||||
assert result.project_name == "光子嫩肤"
|
||||
|
||||
# 验证成本构成
|
||||
breakdown = result.cost_breakdown
|
||||
assert breakdown["material_cost"]["subtotal"] == 40.0
|
||||
assert breakdown["equipment_cost"]["subtotal"] == 47.5
|
||||
assert breakdown["labor_cost"]["subtotal"] == 50.0
|
||||
|
||||
# 总成本 = 耗材40 + 设备47.5 + 人工50 + 固定30000
|
||||
expected_total = 40 + 47.5 + 50 + 30000
|
||||
assert result.total_cost == expected_total
|
||||
assert result.min_price_suggestion == expected_total
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_project_cost_not_found(
|
||||
self,
|
||||
db_session: AsyncSession
|
||||
):
|
||||
"""测试项目不存在时的错误处理"""
|
||||
service = CostService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="项目不存在"):
|
||||
await service.calculate_project_cost(99999)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_cost_item_material(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_material: Material
|
||||
):
|
||||
"""测试添加耗材成本明细"""
|
||||
service = CostService(db_session)
|
||||
|
||||
cost_item = await service.add_cost_item(
|
||||
project_id=sample_project.id,
|
||||
item_type=CostItemType.MATERIAL,
|
||||
item_id=sample_material.id,
|
||||
quantity=10,
|
||||
remark="测试备注"
|
||||
)
|
||||
|
||||
assert cost_item.project_id == sample_project.id
|
||||
assert cost_item.item_type == "material"
|
||||
assert cost_item.quantity == Decimal("10")
|
||||
assert cost_item.unit_cost == Decimal("2.00")
|
||||
assert cost_item.total_cost == Decimal("20.00")
|
||||
assert cost_item.remark == "测试备注"
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_cost_item_equipment(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_equipment: Equipment
|
||||
):
|
||||
"""测试添加设备折旧成本明细"""
|
||||
service = CostService(db_session)
|
||||
|
||||
cost_item = await service.add_cost_item(
|
||||
project_id=sample_project.id,
|
||||
item_type=CostItemType.EQUIPMENT,
|
||||
item_id=sample_equipment.id,
|
||||
quantity=1,
|
||||
)
|
||||
|
||||
assert cost_item.item_type == "equipment"
|
||||
assert cost_item.unit_cost == Decimal("47.50")
|
||||
assert cost_item.total_cost == Decimal("47.50")
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_cost_item_not_found(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试添加不存在的耗材/设备时的错误处理"""
|
||||
service = CostService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="耗材不存在"):
|
||||
await service.add_cost_item(
|
||||
project_id=sample_project.id,
|
||||
item_type=CostItemType.MATERIAL,
|
||||
item_id=99999,
|
||||
quantity=1,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="设备不存在"):
|
||||
await service.add_cost_item(
|
||||
project_id=sample_project.id,
|
||||
item_type=CostItemType.EQUIPMENT,
|
||||
item_id=99999,
|
||||
quantity=1,
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_labor_cost(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_staff_level: StaffLevel
|
||||
):
|
||||
"""测试添加人工成本"""
|
||||
service = CostService(db_session)
|
||||
|
||||
labor_cost = await service.add_labor_cost(
|
||||
project_id=sample_project.id,
|
||||
staff_level_id=sample_staff_level.id,
|
||||
duration_minutes=30,
|
||||
remark="测试人工"
|
||||
)
|
||||
|
||||
assert labor_cost.project_id == sample_project.id
|
||||
assert labor_cost.duration_minutes == 30
|
||||
assert labor_cost.hourly_rate == Decimal("50.00")
|
||||
# 30分钟 / 60 * 50 = 25
|
||||
assert labor_cost.labor_cost == Decimal("25.00")
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_labor_cost_not_found(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试添加不存在的人员级别时的错误处理"""
|
||||
service = CostService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="人员级别不存在"):
|
||||
await service.add_labor_cost(
|
||||
project_id=sample_project.id,
|
||||
staff_level_id=99999,
|
||||
duration_minutes=30,
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_cost_item(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_material: Material
|
||||
):
|
||||
"""测试更新成本明细"""
|
||||
service = CostService(db_session)
|
||||
|
||||
# 先添加
|
||||
cost_item = await service.add_cost_item(
|
||||
project_id=sample_project.id,
|
||||
item_type=CostItemType.MATERIAL,
|
||||
item_id=sample_material.id,
|
||||
quantity=10,
|
||||
)
|
||||
|
||||
# 更新数量
|
||||
updated = await service.update_cost_item(
|
||||
cost_item=cost_item,
|
||||
quantity=20,
|
||||
remark="更新后备注"
|
||||
)
|
||||
|
||||
assert updated.quantity == Decimal("20")
|
||||
assert updated.total_cost == Decimal("40.00") # 20 * 2
|
||||
assert updated.remark == "更新后备注"
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_labor_cost(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_staff_level: StaffLevel
|
||||
):
|
||||
"""测试更新人工成本"""
|
||||
service = CostService(db_session)
|
||||
|
||||
# 先添加
|
||||
labor = await service.add_labor_cost(
|
||||
project_id=sample_project.id,
|
||||
staff_level_id=sample_staff_level.id,
|
||||
duration_minutes=30,
|
||||
)
|
||||
|
||||
# 更新时长
|
||||
updated = await service.update_labor_cost(
|
||||
labor_item=labor,
|
||||
duration_minutes=60,
|
||||
)
|
||||
|
||||
assert updated.duration_minutes == 60
|
||||
assert updated.labor_cost == Decimal("50.00") # 60/60 * 50
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_project_cost(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试没有成本明细的项目计算"""
|
||||
service = CostService(db_session)
|
||||
|
||||
# 计算空项目成本(无固定成本)
|
||||
total_material, _ = await service.calculate_material_cost(sample_project.id)
|
||||
total_equipment, _ = await service.calculate_equipment_cost(sample_project.id)
|
||||
total_labor, _ = await service.calculate_labor_cost(sample_project.id)
|
||||
|
||||
assert total_material == Decimal("0")
|
||||
assert total_equipment == Decimal("0")
|
||||
assert total_labor == Decimal("0")
|
||||
305
后端服务/tests/test_services/test_market_service.py
Normal file
305
后端服务/tests/test_services/test_market_service.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""市场分析服务单元测试
|
||||
|
||||
测试 MarketService 的核心业务逻辑
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.market_service import MarketService
|
||||
from app.models import (
|
||||
Project, Competitor, CompetitorPrice, BenchmarkPrice, Category
|
||||
)
|
||||
|
||||
|
||||
class TestMarketService:
|
||||
"""市场分析服务测试类"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_price_statistics_empty(self):
|
||||
"""测试空价格列表的统计"""
|
||||
service = MarketService(None) # 不需要 db
|
||||
|
||||
stats = service.calculate_price_statistics([])
|
||||
|
||||
assert stats.min_price == 0
|
||||
assert stats.max_price == 0
|
||||
assert stats.avg_price == 0
|
||||
assert stats.median_price == 0
|
||||
assert stats.std_deviation is None or stats.std_deviation == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_price_statistics_single(self):
|
||||
"""测试单个价格的统计"""
|
||||
service = MarketService(None)
|
||||
|
||||
stats = service.calculate_price_statistics([500.0])
|
||||
|
||||
assert stats.min_price == 500.0
|
||||
assert stats.max_price == 500.0
|
||||
assert stats.avg_price == 500.0
|
||||
assert stats.median_price == 500.0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_price_statistics_multiple(self):
|
||||
"""测试多个价格的统计"""
|
||||
service = MarketService(None)
|
||||
|
||||
prices = [300.0, 400.0, 500.0, 600.0, 700.0]
|
||||
stats = service.calculate_price_statistics(prices)
|
||||
|
||||
assert stats.min_price == 300.0
|
||||
assert stats.max_price == 700.0
|
||||
assert stats.avg_price == 500.0 # (300+400+500+600+700)/5
|
||||
assert stats.median_price == 500.0 # 中位数
|
||||
assert stats.std_deviation is not None
|
||||
assert stats.std_deviation > 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_price_distribution(self):
|
||||
"""测试价格分布计算"""
|
||||
service = MarketService(None)
|
||||
|
||||
# 价格范围 300-900,分为三个区间
|
||||
# 低: 300-500, 中: 500-700, 高: 700-900
|
||||
prices = [350.0, 450.0, 550.0, 650.0, 750.0, 850.0]
|
||||
|
||||
distribution = service.calculate_price_distribution(
|
||||
prices=prices,
|
||||
min_price=300.0,
|
||||
max_price=900.0
|
||||
)
|
||||
|
||||
# 验证分布
|
||||
assert distribution.low.count == 2 # 350, 450
|
||||
assert distribution.medium.count == 2 # 550, 650
|
||||
assert distribution.high.count == 2 # 750, 850
|
||||
|
||||
# 验证百分比
|
||||
assert distribution.low.percentage == pytest.approx(33.3, rel=0.1)
|
||||
assert distribution.medium.percentage == pytest.approx(33.3, rel=0.1)
|
||||
assert distribution.high.percentage == pytest.approx(33.3, rel=0.1)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_price_distribution_empty(self):
|
||||
"""测试空价格列表的分布"""
|
||||
service = MarketService(None)
|
||||
|
||||
distribution = service.calculate_price_distribution(
|
||||
prices=[],
|
||||
min_price=0,
|
||||
max_price=0
|
||||
)
|
||||
|
||||
assert distribution.low.count == 0
|
||||
assert distribution.medium.count == 0
|
||||
assert distribution.high.count == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_suggested_range(self):
|
||||
"""测试建议定价区间计算"""
|
||||
service = MarketService(None)
|
||||
|
||||
suggested = service.calculate_suggested_range(
|
||||
avg_price=500.0,
|
||||
min_price=300.0,
|
||||
max_price=700.0,
|
||||
benchmark_avg=None
|
||||
)
|
||||
|
||||
# 以均价为中心 ±20%
|
||||
assert suggested.min == pytest.approx(400.0, rel=0.01) # 500 * 0.8
|
||||
assert suggested.max == pytest.approx(600.0, rel=0.01) # 500 * 1.2
|
||||
assert suggested.recommended == 500.0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_suggested_range_with_benchmark(self):
|
||||
"""测试带标杆参考的建议定价区间"""
|
||||
service = MarketService(None)
|
||||
|
||||
suggested = service.calculate_suggested_range(
|
||||
avg_price=500.0,
|
||||
min_price=300.0,
|
||||
max_price=700.0,
|
||||
benchmark_avg=600.0
|
||||
)
|
||||
|
||||
# 推荐价格 = 市场均价 * 0.6 + 标杆均价 * 0.4
|
||||
expected_recommended = 500 * 0.6 + 600 * 0.4 # 540
|
||||
assert suggested.recommended == pytest.approx(expected_recommended, rel=0.01)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_suggested_range_zero_avg(self):
|
||||
"""测试均价为0时的处理"""
|
||||
service = MarketService(None)
|
||||
|
||||
suggested = service.calculate_suggested_range(
|
||||
avg_price=0,
|
||||
min_price=0,
|
||||
max_price=0,
|
||||
benchmark_avg=None
|
||||
)
|
||||
|
||||
assert suggested.min == 0
|
||||
assert suggested.max == 0
|
||||
assert suggested.recommended == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_competitor_prices_for_project(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_competitor_price: CompetitorPrice,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试获取项目的竞品价格"""
|
||||
service = MarketService(db_session)
|
||||
|
||||
prices = await service.get_competitor_prices_for_project(
|
||||
sample_project.id
|
||||
)
|
||||
|
||||
assert len(prices) == 1
|
||||
assert float(prices[0].original_price) == 680.0
|
||||
assert float(prices[0].promo_price) == 480.0
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_competitor_prices_filter_by_competitor(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_competitor_price: CompetitorPrice,
|
||||
sample_project: Project,
|
||||
sample_competitor: Competitor
|
||||
):
|
||||
"""测试按竞品机构筛选价格"""
|
||||
service = MarketService(db_session)
|
||||
|
||||
# 指定竞品ID
|
||||
prices = await service.get_competitor_prices_for_project(
|
||||
sample_project.id,
|
||||
competitor_ids=[sample_competitor.id]
|
||||
)
|
||||
|
||||
assert len(prices) == 1
|
||||
|
||||
# 指定不存在的竞品ID
|
||||
prices = await service.get_competitor_prices_for_project(
|
||||
sample_project.id,
|
||||
competitor_ids=[99999]
|
||||
)
|
||||
|
||||
assert len(prices) == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_market(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_competitor_price: CompetitorPrice
|
||||
):
|
||||
"""测试市场分析"""
|
||||
service = MarketService(db_session)
|
||||
|
||||
result = await service.analyze_market(
|
||||
project_id=sample_project.id,
|
||||
include_benchmark=False
|
||||
)
|
||||
|
||||
assert result.project_id == sample_project.id
|
||||
assert result.project_name == "光子嫩肤"
|
||||
assert result.competitor_count == 1
|
||||
assert result.price_statistics.min_price == 680.0
|
||||
assert result.price_statistics.max_price == 680.0
|
||||
assert result.price_statistics.avg_price == 680.0
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_market_not_found(
|
||||
self,
|
||||
db_session: AsyncSession
|
||||
):
|
||||
"""测试项目不存在时的错误处理"""
|
||||
service = MarketService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="项目不存在"):
|
||||
await service.analyze_market(project_id=99999)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_latest_analysis(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project,
|
||||
sample_competitor_price: CompetitorPrice
|
||||
):
|
||||
"""测试获取最新分析结果"""
|
||||
service = MarketService(db_session)
|
||||
|
||||
# 先执行分析
|
||||
await service.analyze_market(
|
||||
project_id=sample_project.id,
|
||||
include_benchmark=False
|
||||
)
|
||||
|
||||
# 获取最新结果
|
||||
latest = await service.get_latest_analysis(sample_project.id)
|
||||
|
||||
assert latest is not None
|
||||
assert latest.project_id == sample_project.id
|
||||
assert float(latest.market_avg_price) == 680.0
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_benchmark_prices_empty(
|
||||
self,
|
||||
db_session: AsyncSession
|
||||
):
|
||||
"""测试没有标杆价格时的处理"""
|
||||
service = MarketService(db_session)
|
||||
|
||||
benchmarks = await service.get_benchmark_prices_for_category(None)
|
||||
assert benchmarks == []
|
||||
|
||||
benchmarks = await service.get_benchmark_prices_for_category(99999)
|
||||
assert benchmarks == []
|
||||
|
||||
|
||||
class TestMarketServiceEdgeCases:
|
||||
"""市场分析服务边界情况测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_price_distribution_same_min_max(self):
|
||||
"""测试最小最大价相同时的分布"""
|
||||
service = MarketService(None)
|
||||
|
||||
distribution = service.calculate_price_distribution(
|
||||
prices=[500.0, 500.0],
|
||||
min_price=500.0,
|
||||
max_price=500.0
|
||||
)
|
||||
|
||||
# 应返回 N/A
|
||||
assert distribution.low.range == "N/A"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_statistics_with_outliers(self):
|
||||
"""测试包含极端值的统计"""
|
||||
service = MarketService(None)
|
||||
|
||||
# 包含一个极端高价
|
||||
prices = [300.0, 400.0, 500.0, 600.0, 5000.0]
|
||||
stats = service.calculate_price_statistics(prices)
|
||||
|
||||
assert stats.min_price == 300.0
|
||||
assert stats.max_price == 5000.0
|
||||
# 均值会被拉高
|
||||
assert stats.avg_price == 1360.0 # (300+400+500+600+5000)/5
|
||||
# 中位数不受极端值影响
|
||||
assert stats.median_price == 500.0
|
||||
# 标准差会很大
|
||||
assert stats.std_deviation > 1000
|
||||
369
后端服务/tests/test_services/test_pricing_service.py
Normal file
369
后端服务/tests/test_services/test_pricing_service.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""智能定价服务单元测试
|
||||
|
||||
测试 PricingService 的核心业务逻辑
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.pricing_service import PricingService
|
||||
from app.schemas.pricing import (
|
||||
StrategyType, MarketReference, StrategySuggestion, PricingSuggestions
|
||||
)
|
||||
from app.models import Project, ProjectCostSummary, PricingPlan
|
||||
|
||||
|
||||
class TestPricingService:
|
||||
"""智能定价服务测试类"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_strategy_price_traffic(self):
|
||||
"""测试引流款定价策略"""
|
||||
service = PricingService(None)
|
||||
|
||||
suggestion = service.calculate_strategy_price(
|
||||
base_cost=100.0,
|
||||
strategy=StrategyType.TRAFFIC,
|
||||
)
|
||||
|
||||
# 引流款利润率 10%-20%,使用中间值 15%
|
||||
# 价格 = 100 / (1 - 0.15) ≈ 117.65
|
||||
assert suggestion.strategy == "引流款"
|
||||
assert suggestion.suggested_price > 100 # 大于成本
|
||||
assert suggestion.suggested_price < 130 # 利润率适中
|
||||
assert suggestion.margin > 0
|
||||
assert "引流" in suggestion.description
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_strategy_price_profit(self):
|
||||
"""测试利润款定价策略"""
|
||||
service = PricingService(None)
|
||||
|
||||
suggestion = service.calculate_strategy_price(
|
||||
base_cost=100.0,
|
||||
strategy=StrategyType.PROFIT,
|
||||
target_margin=50, # 50% 目标毛利率
|
||||
)
|
||||
|
||||
# 价格 = 100 / (1 - 0.5) = 200
|
||||
assert suggestion.strategy == "利润款"
|
||||
assert suggestion.suggested_price >= 200
|
||||
assert suggestion.margin >= 45 # 接近目标
|
||||
assert "日常" in suggestion.description
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_strategy_price_premium(self):
|
||||
"""测试高端款定价策略"""
|
||||
service = PricingService(None)
|
||||
|
||||
suggestion = service.calculate_strategy_price(
|
||||
base_cost=100.0,
|
||||
strategy=StrategyType.PREMIUM,
|
||||
)
|
||||
|
||||
# 高端款利润率 60%-80%,使用中间值 70%
|
||||
# 价格 = 100 / (1 - 0.7) ≈ 333
|
||||
assert suggestion.strategy == "高端款"
|
||||
assert suggestion.suggested_price > 300
|
||||
assert suggestion.margin > 60
|
||||
assert "高端" in suggestion.description
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_strategy_price_with_market_reference(self):
|
||||
"""测试带市场参考的定价"""
|
||||
service = PricingService(None)
|
||||
|
||||
market_ref = MarketReference(min=80.0, max=150.0, avg=100.0)
|
||||
|
||||
# 引流款应该参考市场最低价
|
||||
suggestion = service.calculate_strategy_price(
|
||||
base_cost=50.0,
|
||||
strategy=StrategyType.TRAFFIC,
|
||||
market_ref=market_ref,
|
||||
)
|
||||
|
||||
# 应该取市场最低价的 90% 和成本定价的较低者
|
||||
assert suggestion.suggested_price <= 100 # 不会太高
|
||||
assert suggestion.suggested_price >= 50 * 1.05 # 不低于成本
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_strategy_price_ensures_profit(self):
|
||||
"""测试确保价格不低于成本"""
|
||||
service = PricingService(None)
|
||||
|
||||
market_ref = MarketReference(min=30.0, max=50.0, avg=40.0)
|
||||
|
||||
# 即使市场价很低,也不能低于成本
|
||||
suggestion = service.calculate_strategy_price(
|
||||
base_cost=100.0, # 成本高于市场价
|
||||
strategy=StrategyType.TRAFFIC,
|
||||
market_ref=market_ref,
|
||||
)
|
||||
|
||||
# 价格至少是成本的 1.05 倍
|
||||
assert suggestion.suggested_price >= 100 * 1.05
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_all_strategies(self):
|
||||
"""测试计算所有策略"""
|
||||
service = PricingService(None)
|
||||
|
||||
suggestions = service.calculate_all_strategies(
|
||||
base_cost=100.0,
|
||||
target_margin=50.0,
|
||||
)
|
||||
|
||||
assert suggestions.traffic is not None
|
||||
assert suggestions.profit is not None
|
||||
assert suggestions.premium is not None
|
||||
|
||||
# 价格应该递增:引流款 < 利润款 < 高端款
|
||||
assert suggestions.traffic.suggested_price < suggestions.profit.suggested_price
|
||||
assert suggestions.profit.suggested_price < suggestions.premium.suggested_price
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_all_strategies_selected(self):
|
||||
"""测试只计算选定的策略"""
|
||||
service = PricingService(None)
|
||||
|
||||
suggestions = service.calculate_all_strategies(
|
||||
base_cost=100.0,
|
||||
target_margin=50.0,
|
||||
strategies=[StrategyType.TRAFFIC, StrategyType.PROFIT],
|
||||
)
|
||||
|
||||
assert suggestions.traffic is not None
|
||||
assert suggestions.profit is not None
|
||||
assert suggestions.premium is None
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_with_cost(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project_with_costs: Project
|
||||
):
|
||||
"""测试获取项目及成本"""
|
||||
service = PricingService(db_session)
|
||||
|
||||
project, cost_summary = await service.get_project_with_cost(
|
||||
sample_project_with_costs.id
|
||||
)
|
||||
|
||||
assert project.id == sample_project_with_costs.id
|
||||
assert project.project_name == "光子嫩肤"
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_with_cost_not_found(
|
||||
self,
|
||||
db_session: AsyncSession
|
||||
):
|
||||
"""测试项目不存在时的错误处理"""
|
||||
service = PricingService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="项目不存在"):
|
||||
await service.get_project_with_cost(99999)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pricing_plan(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试创建定价方案"""
|
||||
# 先添加成本汇总
|
||||
cost_summary = ProjectCostSummary(
|
||||
project_id=sample_project.id,
|
||||
material_cost=Decimal("40.00"),
|
||||
equipment_cost=Decimal("50.00"),
|
||||
labor_cost=Decimal("60.00"),
|
||||
fixed_cost_allocation=Decimal("30.00"),
|
||||
total_cost=Decimal("180.00"),
|
||||
calculated_at=datetime.now(),
|
||||
)
|
||||
db_session.add(cost_summary)
|
||||
await db_session.commit()
|
||||
|
||||
service = PricingService(db_session)
|
||||
|
||||
plan = await service.create_pricing_plan(
|
||||
project_id=sample_project.id,
|
||||
plan_name="测试定价方案",
|
||||
strategy_type=StrategyType.PROFIT,
|
||||
target_margin=50.0,
|
||||
)
|
||||
|
||||
assert plan.project_id == sample_project.id
|
||||
assert plan.plan_name == "测试定价方案"
|
||||
assert plan.strategy_type == "profit"
|
||||
assert float(plan.target_margin) == 50.0
|
||||
assert float(plan.base_cost) == 180.0
|
||||
assert plan.suggested_price > plan.base_cost
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pricing_plan(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试更新定价方案"""
|
||||
service = PricingService(db_session)
|
||||
|
||||
updated = await service.update_pricing_plan(
|
||||
plan_id=sample_pricing_plan.id,
|
||||
final_price=599.00,
|
||||
plan_name="更新后方案名",
|
||||
)
|
||||
|
||||
assert float(updated.final_price) == 599.00
|
||||
assert updated.plan_name == "更新后方案名"
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pricing_plan_not_found(
|
||||
self,
|
||||
db_session: AsyncSession
|
||||
):
|
||||
"""测试更新不存在的方案"""
|
||||
service = PricingService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="定价方案不存在"):
|
||||
await service.update_pricing_plan(
|
||||
plan_id=99999,
|
||||
final_price=599.00,
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_strategies(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试策略模拟"""
|
||||
# 添加成本汇总
|
||||
cost_summary = ProjectCostSummary(
|
||||
project_id=sample_project.id,
|
||||
total_cost=Decimal("200.00"),
|
||||
material_cost=Decimal("100.00"),
|
||||
equipment_cost=Decimal("50.00"),
|
||||
labor_cost=Decimal("50.00"),
|
||||
fixed_cost_allocation=Decimal("0.00"),
|
||||
calculated_at=datetime.now(),
|
||||
)
|
||||
db_session.add(cost_summary)
|
||||
await db_session.commit()
|
||||
|
||||
service = PricingService(db_session)
|
||||
|
||||
response = await service.simulate_strategies(
|
||||
project_id=sample_project.id,
|
||||
strategies=[StrategyType.TRAFFIC, StrategyType.PROFIT, StrategyType.PREMIUM],
|
||||
target_margin=50.0,
|
||||
)
|
||||
|
||||
assert response.project_id == sample_project.id
|
||||
assert response.base_cost == 200.0
|
||||
assert len(response.results) == 3
|
||||
|
||||
# 验证结果排序
|
||||
prices = [r.suggested_price for r in response.results]
|
||||
assert prices == sorted(prices) # 应该是升序
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_cost_data(self):
|
||||
"""测试成本数据格式化"""
|
||||
service = PricingService(None)
|
||||
|
||||
# 测试空数据
|
||||
result = service._format_cost_data(None)
|
||||
assert "暂无成本数据" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_market_data(self):
|
||||
"""测试市场数据格式化"""
|
||||
service = PricingService(None)
|
||||
|
||||
# 测试空数据
|
||||
result = service._format_market_data(None)
|
||||
assert "暂无市场行情数据" in result
|
||||
|
||||
# 测试有数据
|
||||
market_ref = MarketReference(min=100.0, max=500.0, avg=300.0)
|
||||
result = service._format_market_data(market_ref)
|
||||
assert "100.00" in result
|
||||
assert "500.00" in result
|
||||
assert "300.00" in result
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extract_recommendations(self):
|
||||
"""测试提取 AI 建议列表"""
|
||||
service = PricingService(None)
|
||||
|
||||
content = """
|
||||
根据分析,建议如下:
|
||||
- 建议一:常规定价 580 元
|
||||
- 建议二:新客首单 388 元
|
||||
* 建议三:VIP 会员 520 元
|
||||
1. 定期促销活动
|
||||
2. 会员体系建设
|
||||
"""
|
||||
|
||||
recommendations = service._extract_recommendations(content)
|
||||
|
||||
assert len(recommendations) == 5
|
||||
assert "常规定价" in recommendations[0]
|
||||
|
||||
|
||||
class TestPricingServiceWithAI:
|
||||
"""需要 AI 服务的定价测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
@patch('app.services.pricing_service.AIServiceWrapper')
|
||||
async def test_generate_pricing_advice_ai_failure(
|
||||
self,
|
||||
mock_ai_wrapper,
|
||||
db_session: AsyncSession,
|
||||
sample_project: Project
|
||||
):
|
||||
"""测试 AI 调用失败时的降级处理"""
|
||||
# 添加成本汇总
|
||||
cost_summary = ProjectCostSummary(
|
||||
project_id=sample_project.id,
|
||||
total_cost=Decimal("200.00"),
|
||||
material_cost=Decimal("100.00"),
|
||||
equipment_cost=Decimal("50.00"),
|
||||
labor_cost=Decimal("50.00"),
|
||||
fixed_cost_allocation=Decimal("0.00"),
|
||||
calculated_at=datetime.now(),
|
||||
)
|
||||
db_session.add(cost_summary)
|
||||
await db_session.commit()
|
||||
|
||||
# 模拟 AI 调用失败
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.chat = AsyncMock(side_effect=Exception("AI 服务不可用"))
|
||||
mock_ai_wrapper.return_value = mock_instance
|
||||
|
||||
service = PricingService(db_session)
|
||||
|
||||
# 即使 AI 失败,基本定价计算应该仍然返回
|
||||
response = await service.generate_pricing_advice(
|
||||
project_id=sample_project.id,
|
||||
target_margin=50.0,
|
||||
)
|
||||
|
||||
# 验证基本定价仍然可用
|
||||
assert response.project_id == sample_project.id
|
||||
assert response.cost_base == 200.0
|
||||
assert response.pricing_suggestions is not None
|
||||
# AI 建议可能为空
|
||||
assert response.ai_advice is None or response.ai_usage is None
|
||||
211
后端服务/tests/test_services/test_profit_service.py
Normal file
211
后端服务/tests/test_services/test_profit_service.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""利润模拟服务单元测试
|
||||
|
||||
测试 ProfitService 的核心业务逻辑
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from decimal import Decimal
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.profit_service import ProfitService
|
||||
from app.schemas.profit import PeriodType
|
||||
from app.models import PricingPlan, FixedCost
|
||||
|
||||
|
||||
class TestProfitService:
|
||||
"""利润模拟服务测试类"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_profit_basic(self):
|
||||
"""测试基础利润计算"""
|
||||
service = ProfitService(None)
|
||||
|
||||
revenue, cost, profit, margin = service.calculate_profit(
|
||||
price=100.0,
|
||||
cost_per_unit=60.0,
|
||||
volume=100
|
||||
)
|
||||
|
||||
assert revenue == 10000.0
|
||||
assert cost == 6000.0
|
||||
assert profit == 4000.0
|
||||
assert margin == 40.0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_profit_zero_revenue(self):
|
||||
"""测试零收入时的处理"""
|
||||
service = ProfitService(None)
|
||||
|
||||
revenue, cost, profit, margin = service.calculate_profit(
|
||||
price=100.0,
|
||||
cost_per_unit=60.0,
|
||||
volume=0
|
||||
)
|
||||
|
||||
assert revenue == 0
|
||||
assert cost == 0
|
||||
assert profit == 0
|
||||
assert margin == 0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_profit_negative(self):
|
||||
"""测试亏损情况"""
|
||||
service = ProfitService(None)
|
||||
|
||||
revenue, cost, profit, margin = service.calculate_profit(
|
||||
price=50.0,
|
||||
cost_per_unit=60.0,
|
||||
volume=100
|
||||
)
|
||||
|
||||
assert revenue == 5000.0
|
||||
assert cost == 6000.0
|
||||
assert profit == -1000.0
|
||||
assert margin == -20.0
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_breakeven_basic(self):
|
||||
"""测试基础盈亏平衡计算"""
|
||||
service = ProfitService(None)
|
||||
|
||||
breakeven = service.calculate_breakeven(
|
||||
price=100.0,
|
||||
variable_cost=60.0,
|
||||
fixed_cost=0
|
||||
)
|
||||
|
||||
assert breakeven == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_breakeven_with_fixed_cost(self):
|
||||
"""测试有固定成本的盈亏平衡"""
|
||||
service = ProfitService(None)
|
||||
|
||||
breakeven = service.calculate_breakeven(
|
||||
price=100.0,
|
||||
variable_cost=60.0,
|
||||
fixed_cost=4000.0
|
||||
)
|
||||
|
||||
assert breakeven == 101
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_calculate_breakeven_no_margin(self):
|
||||
"""测试边际贡献为负时的处理"""
|
||||
service = ProfitService(None)
|
||||
|
||||
breakeven = service.calculate_breakeven(
|
||||
price=50.0,
|
||||
variable_cost=60.0,
|
||||
fixed_cost=1000.0
|
||||
)
|
||||
|
||||
assert breakeven == 999999
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pricing_plan(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试获取定价方案"""
|
||||
service = ProfitService(db_session)
|
||||
|
||||
plan = await service.get_pricing_plan(sample_pricing_plan.id)
|
||||
|
||||
assert plan.id == sample_pricing_plan.id
|
||||
assert plan.plan_name == "2026年Q1定价"
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pricing_plan_not_found(
|
||||
self,
|
||||
db_session: AsyncSession
|
||||
):
|
||||
"""测试获取不存在的方案"""
|
||||
service = ProfitService(db_session)
|
||||
|
||||
with pytest.raises(ValueError, match="定价方案不存在"):
|
||||
await service.get_pricing_plan(99999)
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_monthly_fixed_cost(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_fixed_cost: FixedCost
|
||||
):
|
||||
"""测试获取月度固定成本"""
|
||||
service = ProfitService(db_session)
|
||||
|
||||
total = await service.get_monthly_fixed_cost()
|
||||
|
||||
assert total == Decimal("30000.00")
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_simulate_profit(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试利润模拟"""
|
||||
service = ProfitService(db_session)
|
||||
|
||||
response = await service.simulate_profit(
|
||||
pricing_plan_id=sample_pricing_plan.id,
|
||||
price=580.0,
|
||||
estimated_volume=100,
|
||||
period_type=PeriodType.MONTHLY,
|
||||
)
|
||||
|
||||
assert response.pricing_plan_id == sample_pricing_plan.id
|
||||
assert response.input.price == 580.0
|
||||
assert response.input.estimated_volume == 100
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_sensitivity_analysis(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_pricing_plan: PricingPlan
|
||||
):
|
||||
"""测试敏感性分析"""
|
||||
service = ProfitService(db_session)
|
||||
|
||||
sim_response = await service.simulate_profit(
|
||||
pricing_plan_id=sample_pricing_plan.id,
|
||||
price=580.0,
|
||||
estimated_volume=100,
|
||||
period_type=PeriodType.MONTHLY,
|
||||
)
|
||||
|
||||
response = await service.sensitivity_analysis(
|
||||
simulation_id=sim_response.simulation_id,
|
||||
price_change_rates=[-20, -10, 0, 10, 20]
|
||||
)
|
||||
|
||||
assert response.simulation_id == sim_response.simulation_id
|
||||
assert len(response.sensitivity_results) == 5
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_breakeven_analysis(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
sample_pricing_plan: PricingPlan,
|
||||
sample_fixed_cost: FixedCost
|
||||
):
|
||||
"""测试盈亏平衡分析"""
|
||||
service = ProfitService(db_session)
|
||||
|
||||
response = await service.breakeven_analysis(
|
||||
pricing_plan_id=sample_pricing_plan.id
|
||||
)
|
||||
|
||||
assert response.pricing_plan_id == sample_pricing_plan.id
|
||||
assert response.price > 0
|
||||
assert response.breakeven_volume > 0
|
||||
Reference in New Issue
Block a user