Initial commit: 智能项目定价模型

This commit is contained in:
kuzma
2026-01-31 21:33:06 +08:00
commit ef0824303f
174 changed files with 31705 additions and 0 deletions

68
后端服务/Dockerfile Normal file
View File

@@ -0,0 +1,68 @@
# 智能项目定价模型 - 后端 Dockerfile
# 遵循瑞小美部署规范:使用具体版本号,配置阿里云镜像源
# 构建阶段
FROM python:3.11.9-slim AS builder
# 配置阿里云 APT 源
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources
# 安装构建依赖
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
libffi-dev \
&& rm -rf /var/lib/apt/lists/*
# 配置阿里云 pip 源
RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ \
&& pip config set global.trusted-host mirrors.aliyun.com
# 创建虚拟环境
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# 复制依赖文件并安装
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 运行阶段
FROM python:3.11.9-slim AS runner
# 配置阿里云 APT 源
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources
# 安装运行时依赖
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
# 创建非 root 用户
RUN groupadd -r appgroup && useradd -r -g appgroup appuser
# 设置工作目录
WORKDIR /app
# 从构建阶段复制虚拟环境
COPY --from=builder /opt/venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# 复制应用代码
COPY --chown=appuser:appgroup . .
# 设置环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
TZ=Asia/Shanghai
# 切换到非 root 用户
USER appuser
# 暴露端口
EXPOSE 8000
# 健康检查
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# 启动命令
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -0,0 +1,29 @@
# 开发环境 Dockerfile
FROM python:3.11.9-slim
# 配置阿里云源
RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debian.sources
# 安装依赖
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
&& rm -rf /var/lib/apt/lists/*
# 配置 pip
RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
WORKDIR /app
# 安装依赖
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 设置环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
TZ=Asia/Shanghai
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]

View File

@@ -0,0 +1,90 @@
# 安全检查清单
> 智能项目定价模型 - M4 测试优化阶段
> 遵循瑞小美系统技术栈标准与安全规范
---
## 1. API Key 管理
- [x] 未在代码中硬编码 API Key
- [x] 通过 `shared_backend.AIService` 调用 AI 服务
- [x] API Key 从门户系统统一获取
- [x] `.env` 文件在 `.gitignore` 中排除
## 2. 输入验证
- [x] 使用 Pydantic 进行请求参数验证
- [x] 添加 SQL 注入检测
- [x] 添加 XSS 检测
- [x] 限制请求数据嵌套深度
## 3. 身份认证
- [ ] 集成 OAuth 认证(待 M5 部署阶段完成)
- [x] 预留认证中间件接口
- [x] API 路由支持权限控制
## 4. 速率限制
- [x] 实现请求速率限制中间件
- [x] AI 接口有特殊限制10 次/分钟)
- [x] 默认限制 100 次/分钟
## 5. 安全响应头
- [x] X-XSS-Protection
- [x] X-Content-Type-Options
- [x] X-Frame-Options
- [x] Content-Security-Policy
## 6. 数据保护
- [x] 敏感数据使用 DECIMAL 类型存储
- [x] 数据库连接使用连接池
- [x] 支持敏感字段脱敏(待实现具体业务)
## 7. 日志审计
- [x] 实现审计日志记录器
- [x] 记录敏感操作
- [x] 日志格式为 JSON便于分析
## 8. 错误处理
- [x] 统一错误响应格式
- [x] 生产环境不暴露内部错误详情
- [x] 错误码规范化
## 9. 依赖安全
- [x] 使用固定版本的依赖
- [ ] 定期检查依赖漏洞(建议使用 `pip-audit`
## 10. 部署安全
- [x] Docker 容器使用非 root 用户(待验证)
- [x] 只暴露必要端口Nginx 80/443
- [x] 配置健康检查
- [x] 配置资源限制
---
## 安全测试命令
```bash
# 运行安全相关测试
pytest tests/ -m "security" -v
# 检查依赖漏洞
pip install pip-audit
pip-audit
# 检查代码安全问题
pip install bandit
bandit -r app/
```
---
*瑞小美技术团队 · 2026-01-20*

View File

@@ -0,0 +1,3 @@
"""智能项目定价模型 - 后端服务"""
__version__ = "1.0.0"

View File

@@ -0,0 +1,85 @@
"""配置管理模块
使用 Pydantic Settings 管理环境变量配置
遵循瑞小美系统技术栈标准
"""
import os
import secrets
import warnings
from functools import lru_cache
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import model_validator
class Settings(BaseSettings):
"""应用配置"""
# 应用配置
APP_NAME: str = "智能项目定价模型"
APP_VERSION: str = "1.0.0"
APP_ENV: str = "development"
DEBUG: bool = True
SECRET_KEY: str = "" # 必须通过环境变量设置
# 数据库配置 - MySQL 8.0, utf8mb4
# 开发环境使用默认值,生产环境必须通过环境变量设置
DATABASE_URL: str = "sqlite+aiosqlite:///:memory:" # 安全的开发默认值
# 数据库连接池配置
DB_POOL_SIZE: int = 5
DB_MAX_OVERFLOW: int = 10
DB_POOL_RECYCLE: int = 3600
# 门户系统配置 - AI Key 从门户获取
PORTAL_CONFIG_API: str = "http://portal-backend:8000/api/ai/internal/config"
# AI 服务配置
AI_MODULE_CODE: str = "pricing_model"
# 时区配置 - Asia/Shanghai
TIMEZONE: str = "Asia/Shanghai"
# CORS 配置 - 生产环境应限制为具体域名
CORS_ORIGINS: list[str] = ["http://localhost:5173", "http://127.0.0.1:5173"]
# API 配置
API_V1_PREFIX: str = "/api/v1"
@model_validator(mode='after')
def validate_production_config(self) -> 'Settings':
"""验证生产环境配置安全性"""
if self.APP_ENV == "production":
# 生产环境必须设置安全的 SECRET_KEY
if not self.SECRET_KEY or self.SECRET_KEY == "your-secret-key-change-in-production":
raise ValueError("生产环境必须设置 SECRET_KEY 环境变量")
# 生产环境 CORS 不能使用 "*"
if "*" in self.CORS_ORIGINS:
warnings.warn("生产环境 CORS_ORIGINS 不应使用 '*',请设置具体的允许域名")
# 生产环境不应开启 DEBUG
if self.DEBUG:
warnings.warn("生产环境建议关闭 DEBUG 模式")
# 如果没有设置 SECRET_KEY开发环境自动生成一个
if not self.SECRET_KEY:
self.SECRET_KEY = secrets.token_urlsafe(32)
return self
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = True
@lru_cache()
def get_settings() -> Settings:
"""获取配置单例"""
return Settings()
settings = get_settings()

View File

@@ -0,0 +1,84 @@
"""数据库连接模块
使用 SQLAlchemy 异步引擎
遵循瑞小美系统技术栈标准MySQL 8.0, utf8mb4, utf8mb4_unicode_ci
"""
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import MetaData
from app.config import settings
# 命名约定,便于数据库迁移
convention = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s"
}
metadata = MetaData(naming_convention=convention)
class Base(DeclarativeBase):
"""SQLAlchemy 模型基类"""
metadata = metadata
# 创建异步引擎
# SQLite 不支持连接池参数,需要区分处理
_engine_kwargs = {
"echo": settings.DEBUG,
}
# 仅在非 SQLite 环境下添加连接池参数
if not settings.DATABASE_URL.startswith("sqlite"):
_engine_kwargs.update({
"pool_size": settings.DB_POOL_SIZE,
"max_overflow": settings.DB_MAX_OVERFLOW,
"pool_recycle": settings.DB_POOL_RECYCLE,
"pool_pre_ping": True,
})
engine = create_async_engine(
settings.DATABASE_URL,
**_engine_kwargs,
)
# 创建异步会话工厂
async_session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话依赖"""
async with async_session_maker() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def init_db():
"""初始化数据库表"""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def close_db():
"""关闭数据库连接"""
await engine.dispose()

127
后端服务/app/main.py Normal file
View File

@@ -0,0 +1,127 @@
"""FastAPI 应用入口
智能项目定价模型后端服务
遵循瑞小美系统技术栈标准
"""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from app.config import settings
from app.database import init_db, close_db
from app.routers import health, categories, materials, equipments, staff_levels, fixed_costs, projects, market, pricing, profit, dashboard
from app.middleware import (
PerformanceMiddleware,
ResponseCacheMiddleware,
RateLimitMiddleware,
SecurityHeadersMiddleware,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时初始化数据库
await init_db()
yield
# 关闭时清理资源
await close_db()
# 创建 FastAPI 应用
app = FastAPI(
title=settings.APP_NAME,
version=settings.APP_VERSION,
description="智能项目定价模型 - 帮助机构精准核算成本、分析市场、智能定价",
docs_url="/docs" if settings.DEBUG else None,
redoc_url="/redoc" if settings.DEBUG else None,
lifespan=lifespan,
)
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 性能优化中间件
app.add_middleware(GZipMiddleware, minimum_size=1000) # 压缩大于 1KB 的响应
app.add_middleware(PerformanceMiddleware) # 性能监控
app.add_middleware(ResponseCacheMiddleware) # 响应缓存
# 安全中间件
app.add_middleware(RateLimitMiddleware, enabled=not settings.DEBUG) # 速率限制(生产环境)
app.add_middleware(SecurityHeadersMiddleware) # 安全响应头
# 注册路由
app.include_router(health.router, tags=["健康检查"])
app.include_router(
categories.router,
prefix=f"{settings.API_V1_PREFIX}/categories",
tags=["项目分类"]
)
app.include_router(
materials.router,
prefix=f"{settings.API_V1_PREFIX}/materials",
tags=["耗材管理"]
)
app.include_router(
equipments.router,
prefix=f"{settings.API_V1_PREFIX}/equipments",
tags=["设备管理"]
)
app.include_router(
staff_levels.router,
prefix=f"{settings.API_V1_PREFIX}/staff-levels",
tags=["人员级别"]
)
app.include_router(
fixed_costs.router,
prefix=f"{settings.API_V1_PREFIX}/fixed-costs",
tags=["固定成本"]
)
app.include_router(
projects.router,
prefix=f"{settings.API_V1_PREFIX}/projects",
tags=["服务项目"]
)
app.include_router(
market.router,
prefix=f"{settings.API_V1_PREFIX}",
tags=["市场行情"]
)
app.include_router(
pricing.router,
prefix=f"{settings.API_V1_PREFIX}",
tags=["智能定价"]
)
app.include_router(
profit.router,
prefix=f"{settings.API_V1_PREFIX}",
tags=["利润模拟"]
)
app.include_router(
dashboard.router,
prefix=f"{settings.API_V1_PREFIX}",
tags=["仪表盘"]
)
@app.get("/")
async def root():
"""根路径"""
return {
"code": 0,
"message": "success",
"data": {
"name": settings.APP_NAME,
"version": settings.APP_VERSION,
"docs": "/docs" if settings.DEBUG else None
}
}

View File

@@ -0,0 +1,20 @@
"""中间件模块"""
from .performance import PerformanceMiddleware
from .cache import ResponseCacheMiddleware
from .security import (
RateLimitMiddleware,
SecurityHeadersMiddleware,
InputSanitizer,
validate_request_body,
AuditLogger,
)
__all__ = [
"PerformanceMiddleware",
"ResponseCacheMiddleware",
"RateLimitMiddleware",
"SecurityHeadersMiddleware",
"InputSanitizer",
"validate_request_body",
"AuditLogger",
]

View File

@@ -0,0 +1,121 @@
"""响应缓存中间件
对 GET 请求的响应进行缓存,减少数据库查询
"""
import hashlib
import json
import logging
from typing import Callable, Optional, Set
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.services.cache_service import get_cache, CacheNamespace
logger = logging.getLogger(__name__)
class ResponseCacheMiddleware(BaseHTTPMiddleware):
"""响应缓存中间件
对符合条件的 GET 请求进行响应缓存
"""
# 需要缓存的路径前缀和 TTL 配置
CACHE_CONFIG = {
"/api/v1/categories": {"ttl": 300, "namespace": CacheNamespace.CATEGORIES},
"/api/v1/materials": {"ttl": 300, "namespace": CacheNamespace.MATERIALS},
"/api/v1/equipments": {"ttl": 300, "namespace": CacheNamespace.EQUIPMENTS},
"/api/v1/staff-levels": {"ttl": 300, "namespace": CacheNamespace.STAFF_LEVELS},
}
# 不缓存的路径(精确匹配)
NO_CACHE_PATHS: Set[str] = {
"/health",
"/docs",
"/redoc",
"/openapi.json",
}
def _should_cache(self, request: Request) -> Optional[dict]:
"""判断是否应该缓存"""
# 只缓存 GET 请求
if request.method != "GET":
return None
path = request.url.path
# 排除不缓存的路径
if path in self.NO_CACHE_PATHS:
return None
# 检查是否在缓存配置中
for prefix, config in self.CACHE_CONFIG.items():
if path.startswith(prefix):
return config
return None
def _generate_cache_key(self, request: Request) -> str:
"""生成缓存键"""
# 包含路径和查询参数
key_parts = [
request.method,
request.url.path,
str(sorted(request.query_params.items())),
]
key_str = "|".join(key_parts)
return hashlib.md5(key_str.encode()).hexdigest()
async def dispatch(self, request: Request, call_next: Callable) -> Response:
cache_config = self._should_cache(request)
if not cache_config:
return await call_next(request)
# 生成缓存键
cache_key = self._generate_cache_key(request)
cache = get_cache(cache_config["namespace"])
# 尝试从缓存获取
cached_data = cache.get(cache_key)
if cached_data is not None:
logger.debug(f"Cache hit: {request.url.path}")
response = Response(
content=cached_data["content"],
status_code=cached_data["status_code"],
headers=dict(cached_data["headers"]),
media_type="application/json",
)
response.headers["X-Cache"] = "HIT"
return response
# 执行请求
response = await call_next(request)
# 只缓存成功的响应
if response.status_code == 200:
# 读取响应体
body = b""
async for chunk in response.body_iterator:
body += chunk
# 保存到缓存
cache_data = {
"content": body,
"status_code": response.status_code,
"headers": dict(response.headers),
}
cache.set(cache_key, cache_data, cache_config["ttl"])
# 重新构建响应
response = Response(
content=body,
status_code=response.status_code,
headers=dict(response.headers),
media_type="application/json",
)
response.headers["X-Cache"] = "MISS"
return response

View File

@@ -0,0 +1,50 @@
"""性能监控中间件
记录请求响应时间,用于性能分析和优化
"""
import time
import logging
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class PerformanceMiddleware(BaseHTTPMiddleware):
"""性能监控中间件
记录每个请求的响应时间,并在响应头中添加 X-Response-Time
"""
# 慢请求阈值(毫秒)
SLOW_REQUEST_THRESHOLD = 1000
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
# 执行请求
response = await call_next(request)
# 计算响应时间
process_time = (time.time() - start_time) * 1000
# 添加响应头
response.headers["X-Response-Time"] = f"{process_time:.2f}ms"
# 记录慢请求
if process_time > self.SLOW_REQUEST_THRESHOLD:
logger.warning(
f"Slow request: {request.method} {request.url.path} "
f"took {process_time:.2f}ms"
)
# 记录请求日志(开发环境)
logger.debug(
f"{request.method} {request.url.path} - "
f"{response.status_code} - {process_time:.2f}ms"
)
return response

View File

@@ -0,0 +1,267 @@
"""安全中间件
实现安全相关功能:
- 请求验证
- 速率限制
- 安全头设置
- 敏感数据保护
遵循瑞小美系统技术栈标准
"""
import logging
import time
from collections import defaultdict
from typing import Callable, Dict, Optional
import re
from fastapi import Request, Response, HTTPException
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""速率限制中间件
防止 API 滥用,保护服务稳定性
"""
# 速率限制配置
RATE_LIMITS = {
"default": {"requests": 100, "window": 60}, # 默认100 次/分钟
"/api/v1/projects/*/generate-pricing": {"requests": 10, "window": 60}, # AI 接口10 次/分钟
"/api/v1/projects/*/market-analysis": {"requests": 20, "window": 60}, # 分析接口20 次/分钟
}
def __init__(self, app, enabled: bool = True):
super().__init__(app)
self.enabled = enabled
self._requests: Dict[str, list] = defaultdict(list)
def _get_client_id(self, request: Request) -> str:
"""获取客户端标识"""
# 优先使用 X-Forwarded-For反向代理场景
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
# 使用客户端 IP
return request.client.host if request.client else "unknown"
def _get_rate_limit(self, path: str) -> Dict:
"""获取路径的速率限制配置"""
for pattern, limit in self.RATE_LIMITS.items():
if pattern == "default":
continue
# 简单的路径匹配(* 匹配任意字符)
regex = pattern.replace("*", "[^/]+")
if re.match(regex, path):
return limit
return self.RATE_LIMITS["default"]
def _is_rate_limited(self, client_id: str, path: str) -> bool:
"""检查是否超过速率限制"""
limit_config = self._get_rate_limit(path)
requests = limit_config["requests"]
window = limit_config["window"]
now = time.time()
key = f"{client_id}:{path}"
# 清理过期记录
self._requests[key] = [t for t in self._requests[key] if now - t < window]
# 检查是否超限
if len(self._requests[key]) >= requests:
return True
# 记录请求
self._requests[key].append(now)
return False
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if not self.enabled:
return await call_next(request)
client_id = self._get_client_id(request)
path = request.url.path
if self._is_rate_limited(client_id, path):
logger.warning(f"Rate limit exceeded: {client_id} -> {path}")
return JSONResponse(
status_code=429,
content={
"code": 40001,
"message": "请求过于频繁,请稍后再试",
"data": None
}
)
return await call_next(request)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""安全响应头中间件
添加安全相关的 HTTP 响应头
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
response = await call_next(request)
# 防止 XSS 攻击
response.headers["X-XSS-Protection"] = "1; mode=block"
# 防止 MIME 类型嗅探
response.headers["X-Content-Type-Options"] = "nosniff"
# 点击劫持保护
response.headers["X-Frame-Options"] = "DENY"
# 内容安全策略(基础版)
response.headers["Content-Security-Policy"] = "default-src 'self'"
# 严格传输安全(仅在 HTTPS 环境)
# response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
class InputSanitizer:
"""输入清理工具
防止 SQL 注入、XSS 等攻击
"""
# SQL 注入关键字
SQL_KEYWORDS = [
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "UNION",
"OR", "AND", "--", "/*", "*/", "EXEC", "EXECUTE"
]
# XSS 危险模式
XSS_PATTERNS = [
r"<script.*?>",
r"javascript:",
r"on\w+\s*=",
r"eval\s*\(",
]
@classmethod
def check_sql_injection(cls, value: str) -> bool:
"""检查是否包含 SQL 注入特征"""
upper_value = value.upper()
for keyword in cls.SQL_KEYWORDS:
if keyword in upper_value:
return True
return False
@classmethod
def check_xss(cls, value: str) -> bool:
"""检查是否包含 XSS 特征"""
for pattern in cls.XSS_PATTERNS:
if re.search(pattern, value, re.IGNORECASE):
return True
return False
@classmethod
def sanitize(cls, value: str) -> str:
"""清理输入值
移除潜在危险字符
"""
# 移除 HTML 标签
value = re.sub(r'<[^>]+>', '', value)
# 转义特殊字符
value = value.replace("&", "&amp;")
value = value.replace("<", "&lt;")
value = value.replace(">", "&gt;")
value = value.replace('"', "&quot;")
value = value.replace("'", "&#x27;")
return value
def validate_request_body(data: dict, max_depth: int = 10) -> None:
"""验证请求体
检查数据深度和内容安全
Args:
data: 请求数据
max_depth: 最大嵌套深度
Raises:
HTTPException: 验证失败
"""
def check_depth(obj, depth=0):
if depth > max_depth:
raise HTTPException(
status_code=400,
detail={"code": 10001, "message": "请求数据嵌套层级过深"}
)
if isinstance(obj, dict):
for value in obj.values():
check_depth(value, depth + 1)
elif isinstance(obj, list):
for item in obj:
check_depth(item, depth + 1)
elif isinstance(obj, str):
if InputSanitizer.check_sql_injection(obj):
logger.warning(f"Potential SQL injection detected: {obj[:100]}")
if InputSanitizer.check_xss(obj):
logger.warning(f"Potential XSS detected: {obj[:100]}")
check_depth(data)
class AuditLogger:
"""审计日志记录器
记录敏感操作的审计日志
"""
# 需要审计的操作
AUDIT_OPERATIONS = {
("POST", "/api/v1/pricing-plans"): "创建定价方案",
("PUT", "/api/v1/pricing-plans/*"): "更新定价方案",
("DELETE", "/api/v1/pricing-plans/*"): "删除定价方案",
("POST", "/api/v1/projects/*/generate-pricing"): "生成 AI 定价建议",
}
@classmethod
def should_audit(cls, method: str, path: str) -> Optional[str]:
"""检查是否需要审计"""
for (m, p), desc in cls.AUDIT_OPERATIONS.items():
if m != method:
continue
regex = p.replace("*", "[^/]+")
if re.match(regex, path):
return desc
return None
@classmethod
def log(
cls,
operation: str,
user_id: Optional[int],
request: Request,
response_code: int,
details: Optional[Dict] = None
):
"""记录审计日志"""
log_data = {
"operation": operation,
"user_id": user_id,
"method": request.method,
"path": str(request.url.path),
"client_ip": request.client.host if request.client else "unknown",
"response_code": response_code,
"details": details or {},
}
logger.info(f"AUDIT: {log_data}")

View File

@@ -0,0 +1,44 @@
"""SQLAlchemy 数据模型"""
from app.models.base import BaseModel, TimestampMixin
from app.models.category import Category
from app.models.material import Material
from app.models.equipment import Equipment
from app.models.staff_level import StaffLevel
from app.models.fixed_cost import FixedCost
from app.models.project import Project
from app.models.project_cost_item import ProjectCostItem
from app.models.project_labor_cost import ProjectLaborCost
from app.models.project_cost_summary import ProjectCostSummary
from app.models.competitor import Competitor
from app.models.competitor_price import CompetitorPrice
from app.models.benchmark_price import BenchmarkPrice
from app.models.market_analysis_result import MarketAnalysisResult
from app.models.pricing_plan import PricingPlan
from app.models.profit_simulation import ProfitSimulation
from app.models.sensitivity_analysis import SensitivityAnalysis
from app.models.user import User
from app.models.operation_log import OperationLog
__all__ = [
"BaseModel",
"TimestampMixin",
"Category",
"Material",
"Equipment",
"StaffLevel",
"FixedCost",
"Project",
"ProjectCostItem",
"ProjectLaborCost",
"ProjectCostSummary",
"Competitor",
"CompetitorPrice",
"BenchmarkPrice",
"MarketAnalysisResult",
"PricingPlan",
"ProfitSimulation",
"SensitivityAnalysis",
"User",
"OperationLog",
]

View File

@@ -0,0 +1,51 @@
"""模型基类
包含时间戳 Mixin 和通用基类
遵循瑞小美数据库设计规范
"""
from datetime import datetime
from sqlalchemy import BigInteger, Integer, DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
from app.config import settings
# 主键类型SQLite 使用 Integer支持自增MySQL 使用 BigInteger
# SQLite 的 AUTOINCREMENT 只在 INTEGER PRIMARY KEY 时生效
_PrimaryKeyType = Integer if settings.DATABASE_URL.startswith("sqlite") else BigInteger
class TimestampMixin:
"""时间戳 Mixin"""
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=func.now(),
server_default=func.now(),
comment="创建时间"
)
updated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=func.now(),
server_default=func.now(),
onupdate=func.now(),
comment="更新时间"
)
class BaseModel(Base, TimestampMixin):
"""模型基类"""
__abstract__ = True
id: Mapped[int] = mapped_column(
_PrimaryKeyType,
primary_key=True,
autoincrement=True,
comment="主键ID"
)

View File

@@ -0,0 +1,72 @@
"""标杆价格模型
维护行业标杆机构的价格参考
"""
from typing import Optional, TYPE_CHECKING
from decimal import Decimal
from datetime import date
from sqlalchemy import BigInteger, String, Date, DECIMAL, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.category import Category
class BenchmarkPrice(BaseModel):
"""标杆价格表"""
__tablename__ = "benchmark_prices"
benchmark_name: Mapped[str] = mapped_column(
String(100),
nullable=False,
comment="标杆机构名称"
)
category_id: Mapped[Optional[int]] = mapped_column(
BigInteger,
ForeignKey("categories.id"),
nullable=True,
index=True,
comment="项目分类ID"
)
min_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="最低价"
)
max_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="最高价"
)
avg_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="均价"
)
price_tier: Mapped[str] = mapped_column(
String(20),
nullable=False,
default="medium",
comment="价格带low-低端, medium-中端, high-高端, premium-奢华"
)
effective_date: Mapped[date] = mapped_column(
Date,
nullable=False,
index=True,
comment="生效日期"
)
remark: Mapped[Optional[str]] = mapped_column(
String(200),
nullable=True,
comment="备注"
)
# 关系
category: Mapped[Optional["Category"]] = relationship(
"Category"
)

View File

@@ -0,0 +1,60 @@
"""项目分类模型
支持树形分类结构
"""
from typing import Optional, List
from sqlalchemy import BigInteger, String, Integer, Boolean, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
class Category(BaseModel):
"""项目分类表"""
__tablename__ = "categories"
category_name: Mapped[str] = mapped_column(
String(50),
nullable=False,
comment="分类名称"
)
parent_id: Mapped[Optional[int]] = mapped_column(
BigInteger,
ForeignKey("categories.id"),
nullable=True,
comment="父分类ID"
)
sort_order: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=0,
comment="排序"
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
comment="是否启用"
)
# 关系
parent: Mapped[Optional["Category"]] = relationship(
"Category",
remote_side="Category.id",
back_populates="children"
)
children: Mapped[List["Category"]] = relationship(
"Category",
back_populates="parent"
)
projects: Mapped[List["Project"]] = relationship(
"Project",
back_populates="category"
)
# 避免循环导入
from app.models.project import Project

View File

@@ -0,0 +1,69 @@
"""竞品机构模型
管理周边竞品医美机构信息
"""
from typing import Optional, List, TYPE_CHECKING
from decimal import Decimal
from sqlalchemy import String, Boolean, DECIMAL
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.competitor_price import CompetitorPrice
class Competitor(BaseModel):
"""竞品机构表"""
__tablename__ = "competitors"
competitor_name: Mapped[str] = mapped_column(
String(100),
nullable=False,
comment="机构名称"
)
address: Mapped[Optional[str]] = mapped_column(
String(200),
nullable=True,
comment="地址"
)
distance_km: Mapped[Optional[Decimal]] = mapped_column(
DECIMAL(5, 2),
nullable=True,
comment="距离(公里)"
)
positioning: Mapped[str] = mapped_column(
String(20),
nullable=False,
default="medium",
index=True,
comment="定位high-高端, medium-中端, budget-大众"
)
contact: Mapped[Optional[str]] = mapped_column(
String(50),
nullable=True,
comment="联系方式"
)
is_key_competitor: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
index=True,
comment="是否重点关注"
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
comment="是否启用"
)
# 关系
prices: Mapped[List["CompetitorPrice"]] = relationship(
"CompetitorPrice",
back_populates="competitor",
cascade="all, delete-orphan"
)

View File

@@ -0,0 +1,83 @@
"""竞品价格模型
记录竞品机构的项目价格信息
"""
from typing import Optional, TYPE_CHECKING
from decimal import Decimal
from datetime import date
from sqlalchemy import BigInteger, String, Date, DECIMAL, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.competitor import Competitor
from app.models.project import Project
class CompetitorPrice(BaseModel):
"""竞品价格表"""
__tablename__ = "competitor_prices"
competitor_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("competitors.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="竞品机构ID"
)
project_id: Mapped[Optional[int]] = mapped_column(
BigInteger,
ForeignKey("projects.id", ondelete="SET NULL"),
nullable=True,
index=True,
comment="关联本店项目ID"
)
project_name: Mapped[str] = mapped_column(
String(100),
nullable=False,
comment="竞品项目名称"
)
original_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="原价"
)
promo_price: Mapped[Optional[Decimal]] = mapped_column(
DECIMAL(12, 2),
nullable=True,
comment="促销价"
)
member_price: Mapped[Optional[Decimal]] = mapped_column(
DECIMAL(12, 2),
nullable=True,
comment="会员价"
)
price_source: Mapped[str] = mapped_column(
String(20),
nullable=False,
comment="来源official-官网, meituan-美团, dianping-大众点评, survey-实地调研"
)
collected_at: Mapped[date] = mapped_column(
Date,
nullable=False,
index=True,
comment="采集日期"
)
remark: Mapped[Optional[str]] = mapped_column(
String(200),
nullable=True,
comment="备注"
)
# 关系
competitor: Mapped["Competitor"] = relationship(
"Competitor",
back_populates="prices"
)
project: Mapped[Optional["Project"]] = relationship(
"Project"
)

View File

@@ -0,0 +1,75 @@
"""设备模型
管理设备基础信息和折旧计算
"""
from typing import Optional
from datetime import date
from sqlalchemy import String, Boolean, Integer, Date, DECIMAL
from sqlalchemy.orm import Mapped, mapped_column
from app.models.base import BaseModel
class Equipment(BaseModel):
"""设备表"""
__tablename__ = "equipments"
equipment_code: Mapped[str] = mapped_column(
String(50),
unique=True,
nullable=False,
index=True,
comment="设备编码"
)
equipment_name: Mapped[str] = mapped_column(
String(100),
nullable=False,
comment="设备名称"
)
original_value: Mapped[float] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="设备原值"
)
residual_rate: Mapped[float] = mapped_column(
DECIMAL(5, 2),
nullable=False,
default=5.00,
comment="残值率(%)"
)
service_years: Mapped[int] = mapped_column(
Integer,
nullable=False,
comment="预计使用年限"
)
estimated_uses: Mapped[int] = mapped_column(
Integer,
nullable=False,
comment="预计使用次数"
)
depreciation_per_use: Mapped[float] = mapped_column(
DECIMAL(12, 4),
nullable=False,
comment="单次折旧成本 = (原值 - 残值) / 总次数"
)
purchase_date: Mapped[Optional[date]] = mapped_column(
Date,
nullable=True,
comment="购入日期"
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
comment="是否启用"
)
def calculate_depreciation(self) -> float:
"""计算单次折旧成本"""
if self.estimated_uses <= 0:
return 0
residual_value = float(self.original_value) * float(self.residual_rate) / 100
return (float(self.original_value) - residual_value) / self.estimated_uses

View File

@@ -0,0 +1,49 @@
"""固定成本模型
管理月度固定成本(房租、水电等)及分摊方式
"""
from sqlalchemy import String, Boolean, DECIMAL
from sqlalchemy.orm import Mapped, mapped_column
from app.models.base import BaseModel
class FixedCost(BaseModel):
"""固定成本表"""
__tablename__ = "fixed_costs"
cost_name: Mapped[str] = mapped_column(
String(100),
nullable=False,
comment="成本名称"
)
cost_type: Mapped[str] = mapped_column(
String(20),
nullable=False,
comment="类型rent-房租, utilities-水电, property-物业, other-其他"
)
monthly_amount: Mapped[float] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="月度金额"
)
year_month: Mapped[str] = mapped_column(
String(7),
nullable=False,
index=True,
comment="年月2026-01"
)
allocation_method: Mapped[str] = mapped_column(
String(20),
nullable=False,
default="count",
comment="分摊方式count-按项目数, revenue-按营收, duration-按时长"
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
comment="是否启用"
)

View File

@@ -0,0 +1,76 @@
"""市场分析结果模型
存储项目的市场价格分析结果
"""
from typing import TYPE_CHECKING
from decimal import Decimal
from datetime import date
from sqlalchemy import BigInteger, Integer, Date, DECIMAL, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.project import Project
class MarketAnalysisResult(BaseModel):
"""市场分析结果表"""
__tablename__ = "market_analysis_results"
project_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="项目ID"
)
analysis_date: Mapped[date] = mapped_column(
Date,
nullable=False,
index=True,
comment="分析日期"
)
competitor_count: Mapped[int] = mapped_column(
Integer,
nullable=False,
comment="样本竞品数量"
)
market_min_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="市场最低价"
)
market_max_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="市场最高价"
)
market_avg_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="市场均价"
)
market_median_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="市场中位价"
)
suggested_range_min: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="建议区间下限"
)
suggested_range_max: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="建议区间上限"
)
# 关系
project: Mapped["Project"] = relationship(
"Project"
)

View File

@@ -0,0 +1,57 @@
"""耗材模型
管理耗材基础信息:名称、单位、单价、供应商等
"""
from typing import Optional
from sqlalchemy import String, Boolean, DECIMAL
from sqlalchemy.orm import Mapped, mapped_column
from app.models.base import BaseModel
class Material(BaseModel):
"""耗材表"""
__tablename__ = "materials"
material_code: Mapped[str] = mapped_column(
String(50),
unique=True,
nullable=False,
index=True,
comment="耗材编码"
)
material_name: Mapped[str] = mapped_column(
String(100),
nullable=False,
comment="耗材名称"
)
unit: Mapped[str] = mapped_column(
String(20),
nullable=False,
comment="单位(支/ml/个)"
)
unit_price: Mapped[float] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="单价"
)
supplier: Mapped[Optional[str]] = mapped_column(
String(100),
nullable=True,
comment="供应商"
)
material_type: Mapped[str] = mapped_column(
String(20),
nullable=False,
index=True,
comment="类型consumable-耗材, injectable-针剂, product-产品"
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
comment="是否启用"
)

View File

@@ -0,0 +1,81 @@
"""操作日志模型
记录用户操作审计日志
"""
from typing import Optional, Any
from datetime import datetime
from sqlalchemy import BigInteger, String, DateTime, JSON, ForeignKey, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
class OperationLog(Base):
"""操作日志表"""
__tablename__ = "operation_logs"
id: Mapped[int] = mapped_column(
BigInteger,
primary_key=True,
autoincrement=True,
comment="主键ID"
)
user_id: Mapped[Optional[int]] = mapped_column(
BigInteger,
ForeignKey("users.id"),
nullable=True,
index=True,
comment="用户ID"
)
module: Mapped[str] = mapped_column(
String(50),
nullable=False,
index=True,
comment="模块cost/market/pricing/profit"
)
action: Mapped[str] = mapped_column(
String(50),
nullable=False,
comment="操作create/update/delete/export"
)
target_type: Mapped[str] = mapped_column(
String(50),
nullable=False,
comment="对象类型"
)
target_id: Mapped[Optional[int]] = mapped_column(
BigInteger,
nullable=True,
comment="对象ID"
)
detail: Mapped[Optional[dict]] = mapped_column(
JSON,
nullable=True,
comment="详情"
)
ip_address: Mapped[Optional[str]] = mapped_column(
String(45),
nullable=True,
comment="IP地址"
)
created_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
default=func.now(),
server_default=func.now(),
index=True,
comment="操作时间"
)
# 关系
user: Mapped[Optional["User"]] = relationship(
"User",
back_populates="operation_logs"
)
# 避免循环导入
from app.models.user import User

View File

@@ -0,0 +1,94 @@
"""定价方案模型
管理项目定价方案,支持多种定价策略
"""
from typing import Optional, List, TYPE_CHECKING
from decimal import Decimal
from sqlalchemy import BigInteger, String, Boolean, Text, ForeignKey, DECIMAL
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.project import Project
from app.models.user import User
from app.models.profit_simulation import ProfitSimulation
class PricingPlan(BaseModel):
"""定价方案表"""
__tablename__ = "pricing_plans"
project_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("projects.id"),
nullable=False,
index=True,
comment="项目ID"
)
plan_name: Mapped[str] = mapped_column(
String(100),
nullable=False,
comment="方案名称"
)
strategy_type: Mapped[str] = mapped_column(
String(20),
nullable=False,
index=True,
comment="策略类型traffic-引流款, profit-利润款, premium-高端款"
)
base_cost: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="基础成本"
)
target_margin: Mapped[Decimal] = mapped_column(
DECIMAL(5, 2),
nullable=False,
comment="目标毛利率(%)"
)
suggested_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="建议价格"
)
final_price: Mapped[Optional[Decimal]] = mapped_column(
DECIMAL(12, 2),
nullable=True,
comment="最终定价"
)
ai_advice: Mapped[Optional[str]] = mapped_column(
Text,
nullable=True,
comment="AI建议内容"
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
comment="是否启用"
)
created_by: Mapped[Optional[int]] = mapped_column(
BigInteger,
ForeignKey("users.id"),
nullable=True,
comment="创建人ID"
)
# 关系
project: Mapped["Project"] = relationship(
"Project",
back_populates="pricing_plans"
)
creator: Mapped[Optional["User"]] = relationship(
"User",
back_populates="created_pricing_plans"
)
profit_simulations: Mapped[List["ProfitSimulation"]] = relationship(
"ProfitSimulation",
back_populates="pricing_plan",
cascade="all, delete-orphan"
)

View File

@@ -0,0 +1,97 @@
"""利润模拟模型
管理利润模拟测算记录
"""
from typing import Optional, List, TYPE_CHECKING
from decimal import Decimal
from sqlalchemy import BigInteger, String, Integer, ForeignKey, DECIMAL
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.pricing_plan import PricingPlan
from app.models.user import User
from app.models.sensitivity_analysis import SensitivityAnalysis
class ProfitSimulation(BaseModel):
"""利润模拟表"""
__tablename__ = "profit_simulations"
pricing_plan_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("pricing_plans.id"),
nullable=False,
index=True,
comment="定价方案ID"
)
simulation_name: Mapped[str] = mapped_column(
String(100),
nullable=False,
comment="模拟名称"
)
price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="模拟价格"
)
estimated_volume: Mapped[int] = mapped_column(
Integer,
nullable=False,
comment="预估客量"
)
period_type: Mapped[str] = mapped_column(
String(20),
nullable=False,
comment="周期类型daily-日, weekly-周, monthly-月"
)
estimated_revenue: Mapped[Decimal] = mapped_column(
DECIMAL(14, 2),
nullable=False,
comment="预估收入"
)
estimated_cost: Mapped[Decimal] = mapped_column(
DECIMAL(14, 2),
nullable=False,
comment="预估成本"
)
estimated_profit: Mapped[Decimal] = mapped_column(
DECIMAL(14, 2),
nullable=False,
comment="预估利润"
)
profit_margin: Mapped[Decimal] = mapped_column(
DECIMAL(5, 2),
nullable=False,
comment="利润率(%)"
)
breakeven_volume: Mapped[int] = mapped_column(
Integer,
nullable=False,
comment="盈亏平衡客量"
)
created_by: Mapped[Optional[int]] = mapped_column(
BigInteger,
ForeignKey("users.id"),
nullable=True,
comment="创建人ID"
)
# 关系
pricing_plan: Mapped["PricingPlan"] = relationship(
"PricingPlan",
back_populates="profit_simulations"
)
creator: Mapped[Optional["User"]] = relationship(
"User",
back_populates="created_profit_simulations"
)
sensitivity_analyses: Mapped[List["SensitivityAnalysis"]] = relationship(
"SensitivityAnalysis",
back_populates="simulation",
cascade="all, delete-orphan"
)

View File

@@ -0,0 +1,102 @@
"""服务项目模型
管理医美服务项目基础信息
"""
from typing import Optional, List, TYPE_CHECKING
from sqlalchemy import BigInteger, String, Boolean, Integer, Text, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.category import Category
from app.models.user import User
from app.models.project_cost_item import ProjectCostItem
from app.models.project_labor_cost import ProjectLaborCost
from app.models.project_cost_summary import ProjectCostSummary
from app.models.pricing_plan import PricingPlan
class Project(BaseModel):
"""服务项目表"""
__tablename__ = "projects"
project_code: Mapped[str] = mapped_column(
String(50),
unique=True,
nullable=False,
index=True,
comment="项目编码"
)
project_name: Mapped[str] = mapped_column(
String(100),
nullable=False,
comment="项目名称"
)
category_id: Mapped[Optional[int]] = mapped_column(
BigInteger,
ForeignKey("categories.id"),
nullable=True,
index=True,
comment="项目分类ID"
)
description: Mapped[Optional[str]] = mapped_column(
Text,
nullable=True,
comment="项目描述"
)
duration_minutes: Mapped[int] = mapped_column(
Integer,
nullable=False,
default=0,
comment="操作时长(分钟)"
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
index=True,
comment="是否启用"
)
created_by: Mapped[Optional[int]] = mapped_column(
BigInteger,
ForeignKey("users.id"),
nullable=True,
comment="创建人ID"
)
# 关系
category: Mapped[Optional["Category"]] = relationship(
"Category",
back_populates="projects"
)
creator: Mapped[Optional["User"]] = relationship(
"User",
back_populates="created_projects"
)
# 成本相关关系
cost_items: Mapped[List["ProjectCostItem"]] = relationship(
"ProjectCostItem",
back_populates="project",
cascade="all, delete-orphan"
)
labor_costs: Mapped[List["ProjectLaborCost"]] = relationship(
"ProjectLaborCost",
back_populates="project",
cascade="all, delete-orphan"
)
cost_summary: Mapped[Optional["ProjectCostSummary"]] = relationship(
"ProjectCostSummary",
back_populates="project",
uselist=False,
cascade="all, delete-orphan"
)
pricing_plans: Mapped[List["PricingPlan"]] = relationship(
"PricingPlan",
back_populates="project",
cascade="all, delete-orphan"
)

View File

@@ -0,0 +1,66 @@
"""项目成本明细模型
管理项目的耗材成本和设备折旧成本
"""
from typing import Optional, TYPE_CHECKING
from decimal import Decimal
from sqlalchemy import BigInteger, String, DECIMAL, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.project import Project
class ProjectCostItem(BaseModel):
"""项目成本明细表(耗材/设备)"""
__tablename__ = "project_cost_items"
project_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="项目ID"
)
item_type: Mapped[str] = mapped_column(
String(20),
nullable=False,
index=True,
comment="类型material-耗材, equipment-设备"
)
item_id: Mapped[int] = mapped_column(
BigInteger,
nullable=False,
comment="耗材/设备ID"
)
quantity: Mapped[Decimal] = mapped_column(
DECIMAL(10, 4),
nullable=False,
comment="用量"
)
unit_cost: Mapped[Decimal] = mapped_column(
DECIMAL(12, 4),
nullable=False,
comment="单位成本"
)
total_cost: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="总成本 = quantity * unit_cost"
)
remark: Mapped[Optional[str]] = mapped_column(
String(200),
nullable=True,
comment="备注"
)
# 关系
project: Mapped["Project"] = relationship(
"Project",
back_populates="cost_items"
)

View File

@@ -0,0 +1,72 @@
"""项目成本汇总模型
存储项目的成本计算结果
"""
from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING
from sqlalchemy import BigInteger, DateTime, DECIMAL, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.project import Project
class ProjectCostSummary(BaseModel):
"""项目成本汇总表"""
__tablename__ = "project_cost_summaries"
project_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
unique=True,
index=True,
comment="项目ID"
)
material_cost: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
default=0,
comment="耗材成本"
)
equipment_cost: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
default=0,
comment="设备折旧成本"
)
labor_cost: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
default=0,
comment="人工成本"
)
fixed_cost_allocation: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
default=0,
comment="固定成本分摊"
)
total_cost: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
default=0,
comment="总成本(最低成本线)"
)
calculated_at: Mapped[datetime] = mapped_column(
DateTime,
nullable=False,
comment="计算时间"
)
# 关系
project: Mapped["Project"] = relationship(
"Project",
back_populates="cost_summary"
)

View File

@@ -0,0 +1,67 @@
"""项目人工成本模型
管理项目的人工成本配置
"""
from typing import Optional, TYPE_CHECKING
from decimal import Decimal
from sqlalchemy import BigInteger, Integer, String, DECIMAL, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.project import Project
from app.models.staff_level import StaffLevel
class ProjectLaborCost(BaseModel):
"""项目人工成本表"""
__tablename__ = "project_labor_costs"
project_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="项目ID"
)
staff_level_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("staff_levels.id"),
nullable=False,
index=True,
comment="人员级别ID"
)
duration_minutes: Mapped[int] = mapped_column(
Integer,
nullable=False,
comment="操作时长(分钟)"
)
hourly_rate: Mapped[Decimal] = mapped_column(
DECIMAL(10, 2),
nullable=False,
comment="时薪(记录时的快照)"
)
labor_cost: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="人工成本 = duration/60 * hourly_rate"
)
remark: Mapped[Optional[str]] = mapped_column(
String(200),
nullable=True,
comment="备注"
)
# 关系
project: Mapped["Project"] = relationship(
"Project",
back_populates="labor_costs"
)
staff_level: Mapped["StaffLevel"] = relationship(
"StaffLevel",
back_populates="project_labor_costs"
)

View File

@@ -0,0 +1,55 @@
"""敏感性分析模型
记录价格变动对利润的敏感性分析结果
"""
from typing import TYPE_CHECKING
from decimal import Decimal
from sqlalchemy import BigInteger, ForeignKey, DECIMAL
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.profit_simulation import ProfitSimulation
class SensitivityAnalysis(BaseModel):
"""敏感性分析表"""
__tablename__ = "sensitivity_analyses"
simulation_id: Mapped[int] = mapped_column(
BigInteger,
ForeignKey("profit_simulations.id"),
nullable=False,
index=True,
comment="模拟ID"
)
price_change_rate: Mapped[Decimal] = mapped_column(
DECIMAL(5, 2),
nullable=False,
comment="价格变动率(%):如 -20, -10, 0, 10, 20"
)
adjusted_price: Mapped[Decimal] = mapped_column(
DECIMAL(12, 2),
nullable=False,
comment="调整后价格"
)
adjusted_profit: Mapped[Decimal] = mapped_column(
DECIMAL(14, 2),
nullable=False,
comment="调整后利润"
)
profit_change_rate: Mapped[Decimal] = mapped_column(
DECIMAL(5, 2),
nullable=False,
comment="利润变动率(%)"
)
# 关系
simulation: Mapped["ProfitSimulation"] = relationship(
"ProfitSimulation",
back_populates="sensitivity_analyses"
)

View File

@@ -0,0 +1,49 @@
"""人员级别模型
管理不同岗位/级别的时薪标准
"""
from typing import List, TYPE_CHECKING
from sqlalchemy import String, Boolean, DECIMAL
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.project_labor_cost import ProjectLaborCost
class StaffLevel(BaseModel):
"""人员级别表"""
__tablename__ = "staff_levels"
level_code: Mapped[str] = mapped_column(
String(20),
unique=True,
nullable=False,
comment="级别编码"
)
level_name: Mapped[str] = mapped_column(
String(50),
nullable=False,
comment="级别名称"
)
hourly_rate: Mapped[float] = mapped_column(
DECIMAL(10, 2),
nullable=False,
comment="时薪(元/小时)"
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
comment="是否启用"
)
# 关系
project_labor_costs: Mapped[List["ProjectLaborCost"]] = relationship(
"ProjectLaborCost",
back_populates="staff_level"
)

View File

@@ -0,0 +1,66 @@
"""用户模型
与门户系统关联的用户信息
"""
from typing import List
from sqlalchemy import BigInteger, String, Boolean
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.models.base import BaseModel
class User(BaseModel):
"""用户表"""
__tablename__ = "users"
portal_user_id: Mapped[int] = mapped_column(
BigInteger,
unique=True,
nullable=False,
index=True,
comment="门户用户ID"
)
username: Mapped[str] = mapped_column(
String(50),
nullable=False,
comment="用户名"
)
role: Mapped[str] = mapped_column(
String(20),
nullable=False,
comment="角色admin-管理员, manager-经理, operator-操作员"
)
is_active: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=True,
comment="是否启用"
)
# 关系
created_projects: Mapped[List["Project"]] = relationship(
"Project",
back_populates="creator"
)
operation_logs: Mapped[List["OperationLog"]] = relationship(
"OperationLog",
back_populates="user"
)
created_pricing_plans: Mapped[List["PricingPlan"]] = relationship(
"PricingPlan",
back_populates="creator"
)
created_profit_simulations: Mapped[List["ProfitSimulation"]] = relationship(
"ProfitSimulation",
back_populates="creator"
)
# 避免循环导入
from app.models.project import Project
from app.models.operation_log import OperationLog
from app.models.pricing_plan import PricingPlan
from app.models.profit_simulation import ProfitSimulation

View File

@@ -0,0 +1 @@
"""数据访问层"""

View File

@@ -0,0 +1,29 @@
"""API 路由模块"""
from app.routers import (
health,
categories,
materials,
equipments,
staff_levels,
fixed_costs,
projects,
market,
pricing,
profit,
dashboard,
)
__all__ = [
"health",
"categories",
"materials",
"equipments",
"staff_levels",
"fixed_costs",
"projects",
"market",
"pricing",
"profit",
"dashboard",
]

View File

@@ -0,0 +1,210 @@
"""项目分类路由
实现分类的 CRUD 操作,支持树形结构
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database import get_db
from app.models.category import Category
from app.schemas.common import ResponseModel, PaginatedResponse, PaginatedData, ErrorCode
from app.schemas.category import (
CategoryCreate,
CategoryUpdate,
CategoryResponse,
CategoryTreeResponse,
)
router = APIRouter()
@router.get("", response_model=PaginatedResponse[CategoryResponse])
async def get_categories(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
parent_id: Optional[int] = Query(None, description="父分类ID筛选"),
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
db: AsyncSession = Depends(get_db),
):
"""获取项目分类列表"""
# 构建查询
query = select(Category)
if parent_id is not None:
query = query.where(Category.parent_id == parent_id)
if is_active is not None:
query = query.where(Category.is_active == is_active)
query = query.order_by(Category.sort_order, Category.id)
# 分页
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
result = await db.execute(query)
categories = result.scalars().all()
# 统计总数
count_query = select(func.count(Category.id))
if parent_id is not None:
count_query = count_query.where(Category.parent_id == parent_id)
if is_active is not None:
count_query = count_query.where(Category.is_active == is_active)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
return PaginatedResponse(
data=PaginatedData(
items=[CategoryResponse.model_validate(c) for c in categories],
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
)
)
@router.get("/tree", response_model=ResponseModel[List[CategoryTreeResponse]])
async def get_category_tree(
is_active: Optional[bool] = Query(True, description="是否只返回启用的分类"),
db: AsyncSession = Depends(get_db),
):
"""获取分类树形结构"""
query = select(Category).options(selectinload(Category.children))
if is_active is not None:
query = query.where(Category.is_active == is_active)
# 只获取顶级分类
query = query.where(Category.parent_id.is_(None))
query = query.order_by(Category.sort_order, Category.id)
result = await db.execute(query)
categories = result.scalars().all()
return ResponseModel(data=[CategoryTreeResponse.model_validate(c) for c in categories])
@router.get("/{category_id}", response_model=ResponseModel[CategoryResponse])
async def get_category(
category_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取单个分类详情"""
result = await db.execute(select(Category).where(Category.id == category_id))
category = result.scalar_one_or_none()
if not category:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
)
return ResponseModel(data=CategoryResponse.model_validate(category))
@router.post("", response_model=ResponseModel[CategoryResponse])
async def create_category(
data: CategoryCreate,
db: AsyncSession = Depends(get_db),
):
"""创建项目分类"""
# 检查父分类是否存在
if data.parent_id:
parent_result = await db.execute(
select(Category).where(Category.id == data.parent_id)
)
if not parent_result.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.PARAM_ERROR, "message": "父分类不存在"}
)
# 创建分类
category = Category(**data.model_dump())
db.add(category)
await db.flush()
await db.refresh(category)
return ResponseModel(message="创建成功", data=CategoryResponse.model_validate(category))
@router.put("/{category_id}", response_model=ResponseModel[CategoryResponse])
async def update_category(
category_id: int,
data: CategoryUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新项目分类"""
result = await db.execute(select(Category).where(Category.id == category_id))
category = result.scalar_one_or_none()
if not category:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
)
# 检查父分类
if data.parent_id is not None and data.parent_id != category.parent_id:
if data.parent_id == category_id:
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.PARAM_ERROR, "message": "不能将自己设为父分类"}
)
parent_result = await db.execute(
select(Category).where(Category.id == data.parent_id)
)
if not parent_result.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.PARAM_ERROR, "message": "父分类不存在"}
)
# 更新字段
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(category, field, value)
await db.flush()
await db.refresh(category)
return ResponseModel(message="更新成功", data=CategoryResponse.model_validate(category))
@router.delete("/{category_id}", response_model=ResponseModel)
async def delete_category(
category_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除项目分类"""
result = await db.execute(select(Category).where(Category.id == category_id))
category = result.scalar_one_or_none()
if not category:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
)
# 检查是否有子分类
children_result = await db.execute(
select(func.count(Category.id)).where(Category.parent_id == category_id)
)
children_count = children_result.scalar() or 0
if children_count > 0:
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.NOT_ALLOWED, "message": "该分类下有子分类,无法删除"}
)
await db.delete(category)
return ResponseModel(message="删除成功")

View File

@@ -0,0 +1,305 @@
"""仪表盘路由
仪表盘数据相关的 API 接口
"""
from datetime import datetime, date, timedelta
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models import (
Project,
ProjectCostSummary,
Competitor,
CompetitorPrice,
PricingPlan,
ProfitSimulation,
OperationLog,
)
from app.schemas.common import ResponseModel
from app.schemas.dashboard import (
DashboardSummaryResponse,
ProjectOverview,
CostOverview,
CostProjectInfo,
MarketOverview,
PricingOverview,
StrategiesDistribution,
AIUsageOverview,
RecentActivity,
CostTrendResponse,
MarketTrendResponse,
TrendDataPoint,
)
router = APIRouter()
@router.get("/dashboard/summary", response_model=ResponseModel[DashboardSummaryResponse])
async def get_dashboard_summary(
db: AsyncSession = Depends(get_db),
):
"""获取仪表盘概览数据"""
# 项目概览
total_projects_result = await db.execute(
select(func.count(Project.id))
)
total_projects = total_projects_result.scalar() or 0
active_projects_result = await db.execute(
select(func.count(Project.id)).where(Project.is_active == True)
)
active_projects = active_projects_result.scalar() or 0
projects_with_pricing_result = await db.execute(
select(func.count(func.distinct(PricingPlan.project_id)))
)
projects_with_pricing = projects_with_pricing_result.scalar() or 0
project_overview = ProjectOverview(
total_projects=total_projects,
active_projects=active_projects,
projects_with_pricing=projects_with_pricing,
)
# 成本概览
avg_cost_result = await db.execute(
select(func.avg(ProjectCostSummary.total_cost))
)
avg_project_cost = float(avg_cost_result.scalar() or 0)
# 最高成本项目
highest_cost_result = await db.execute(
select(ProjectCostSummary).options(
).order_by(ProjectCostSummary.total_cost.desc()).limit(1)
)
highest_cost_summary = highest_cost_result.scalar_one_or_none()
highest_cost_project = None
if highest_cost_summary:
project_result = await db.execute(
select(Project).where(Project.id == highest_cost_summary.project_id)
)
project = project_result.scalar_one_or_none()
if project:
highest_cost_project = CostProjectInfo(
id=project.id,
name=project.project_name,
cost=float(highest_cost_summary.total_cost),
)
# 最低成本项目
lowest_cost_result = await db.execute(
select(ProjectCostSummary).where(
ProjectCostSummary.total_cost > 0
).order_by(ProjectCostSummary.total_cost.asc()).limit(1)
)
lowest_cost_summary = lowest_cost_result.scalar_one_or_none()
lowest_cost_project = None
if lowest_cost_summary:
project_result = await db.execute(
select(Project).where(Project.id == lowest_cost_summary.project_id)
)
project = project_result.scalar_one_or_none()
if project:
lowest_cost_project = CostProjectInfo(
id=project.id,
name=project.project_name,
cost=float(lowest_cost_summary.total_cost),
)
cost_overview = CostOverview(
avg_project_cost=round(avg_project_cost, 2),
highest_cost_project=highest_cost_project,
lowest_cost_project=lowest_cost_project,
)
# 市场概览
competitors_result = await db.execute(
select(func.count(Competitor.id)).where(Competitor.is_active == True)
)
competitors_tracked = competitors_result.scalar() or 0
# 本月价格记录数
this_month_start = date.today().replace(day=1)
price_records_result = await db.execute(
select(func.count(CompetitorPrice.id)).where(
CompetitorPrice.collected_at >= this_month_start
)
)
price_records_this_month = price_records_result.scalar() or 0
# 市场平均价
avg_market_price_result = await db.execute(
select(func.avg(CompetitorPrice.original_price))
)
avg_market_price = avg_market_price_result.scalar()
market_overview = MarketOverview(
competitors_tracked=competitors_tracked,
price_records_this_month=price_records_this_month,
avg_market_price=float(avg_market_price) if avg_market_price else None,
)
# 定价概览
pricing_plans_result = await db.execute(
select(func.count(PricingPlan.id))
)
pricing_plans_count = pricing_plans_result.scalar() or 0
avg_margin_result = await db.execute(
select(func.avg(PricingPlan.target_margin))
)
avg_target_margin = avg_margin_result.scalar()
# 策略分布
traffic_count_result = await db.execute(
select(func.count(PricingPlan.id)).where(PricingPlan.strategy_type == "traffic")
)
profit_count_result = await db.execute(
select(func.count(PricingPlan.id)).where(PricingPlan.strategy_type == "profit")
)
premium_count_result = await db.execute(
select(func.count(PricingPlan.id)).where(PricingPlan.strategy_type == "premium")
)
pricing_overview = PricingOverview(
pricing_plans_count=pricing_plans_count,
avg_target_margin=float(avg_target_margin) if avg_target_margin else None,
strategies_distribution=StrategiesDistribution(
traffic=traffic_count_result.scalar() or 0,
profit=profit_count_result.scalar() or 0,
premium=premium_count_result.scalar() or 0,
),
)
# 最近活动(从操作日志获取)
recent_logs_result = await db.execute(
select(OperationLog).order_by(
OperationLog.created_at.desc()
).limit(10)
)
recent_logs = recent_logs_result.scalars().all()
recent_activities = []
for log in recent_logs:
recent_activities.append(RecentActivity(
type=f"{log.module}_{log.action}",
project_name=log.target_type,
user=None, # 简化处理
time=log.created_at,
))
return ResponseModel(data=DashboardSummaryResponse(
project_overview=project_overview,
cost_overview=cost_overview,
market_overview=market_overview,
pricing_overview=pricing_overview,
ai_usage_this_month=None, # AI 使用统计需要从 ai_call_logs 表获取
recent_activities=recent_activities,
))
@router.get("/dashboard/cost-trend", response_model=ResponseModel[CostTrendResponse])
async def get_cost_trend(
period: str = Query("month", description="统计周期week/month/quarter"),
db: AsyncSession = Depends(get_db),
):
"""获取成本趋势数据"""
# 根据周期确定时间范围
today = date.today()
if period == "week":
start_date = today - timedelta(days=7)
elif period == "quarter":
start_date = today - timedelta(days=90)
else: # month
start_date = today - timedelta(days=30)
# 按日期分组统计平均成本
# 简化实现:返回最近的成本汇总数据
result = await db.execute(
select(ProjectCostSummary).order_by(
ProjectCostSummary.calculated_at.desc()
).limit(30)
)
summaries = result.scalars().all()
# 按日期聚合
date_costs = {}
for summary in summaries:
# 检查 calculated_at 是否为 None
if summary.calculated_at is None:
continue
day = summary.calculated_at.strftime("%Y-%m-%d")
if day not in date_costs:
date_costs[day] = []
date_costs[day].append(float(summary.total_cost))
data = []
total_cost = 0
for day in sorted(date_costs.keys()):
avg = sum(date_costs[day]) / len(date_costs[day])
data.append(TrendDataPoint(date=day, value=round(avg, 2)))
total_cost += avg
avg_cost = total_cost / len(data) if data else 0
return ResponseModel(data=CostTrendResponse(
period=period,
data=data,
avg_cost=round(avg_cost, 2),
))
@router.get("/dashboard/market-trend", response_model=ResponseModel[MarketTrendResponse])
async def get_market_trend(
period: str = Query("month", description="统计周期week/month/quarter"),
db: AsyncSession = Depends(get_db),
):
"""获取市场价格趋势数据"""
# 根据周期确定时间范围
today = date.today()
if period == "week":
start_date = today - timedelta(days=7)
elif period == "quarter":
start_date = today - timedelta(days=90)
else: # month
start_date = today - timedelta(days=30)
# 获取价格记录
result = await db.execute(
select(CompetitorPrice).where(
CompetitorPrice.collected_at >= start_date
).order_by(CompetitorPrice.collected_at.desc())
)
prices = result.scalars().all()
# 按日期聚合
date_prices = {}
for price in prices:
# 检查 collected_at 是否为 None
if price.collected_at is None:
continue
day = price.collected_at.strftime("%Y-%m-%d")
if day not in date_prices:
date_prices[day] = []
date_prices[day].append(float(price.original_price))
data = []
total_price = 0
for day in sorted(date_prices.keys()):
avg = sum(date_prices[day]) / len(date_prices[day])
data.append(TrendDataPoint(date=day, value=round(avg, 2)))
total_price += avg
avg_price = total_price / len(data) if data else 0
return ResponseModel(data=MarketTrendResponse(
period=period,
data=data,
avg_price=round(avg_price, 2),
))

View File

@@ -0,0 +1,188 @@
"""设备管理路由
实现设备的 CRUD 操作,包含折旧计算
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.equipment import Equipment
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
from app.schemas.equipment import (
EquipmentCreate,
EquipmentUpdate,
EquipmentResponse,
)
router = APIRouter()
@router.get("", response_model=ResponseModel[PaginatedData[EquipmentResponse]])
async def get_equipments(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
keyword: Optional[str] = Query(None, description="关键词搜索"),
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
db: AsyncSession = Depends(get_db),
):
"""获取设备列表"""
query = select(Equipment)
if keyword:
query = query.where(
or_(
Equipment.equipment_code.contains(keyword),
Equipment.equipment_name.contains(keyword),
)
)
if is_active is not None:
query = query.where(Equipment.is_active == is_active)
query = query.order_by(Equipment.id.desc())
# 分页
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
result = await db.execute(query)
equipments = result.scalars().all()
# 统计总数
count_query = select(func.count(Equipment.id))
if keyword:
count_query = count_query.where(
or_(
Equipment.equipment_code.contains(keyword),
Equipment.equipment_name.contains(keyword),
)
)
if is_active is not None:
count_query = count_query.where(Equipment.is_active == is_active)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
return ResponseModel(
data=PaginatedData(
items=[EquipmentResponse.model_validate(e) for e in equipments],
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
)
)
@router.get("/{equipment_id}", response_model=ResponseModel[EquipmentResponse])
async def get_equipment(
equipment_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取单个设备详情"""
result = await db.execute(select(Equipment).where(Equipment.id == equipment_id))
equipment = result.scalar_one_or_none()
if not equipment:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "设备不存在"}
)
return ResponseModel(data=EquipmentResponse.model_validate(equipment))
@router.post("", response_model=ResponseModel[EquipmentResponse])
async def create_equipment(
data: EquipmentCreate,
db: AsyncSession = Depends(get_db),
):
"""创建设备"""
# 检查编码是否已存在
existing = await db.execute(
select(Equipment).where(Equipment.equipment_code == data.equipment_code)
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "设备编码已存在"}
)
# 计算单次折旧成本
residual_value = data.original_value * data.residual_rate / 100
depreciation_per_use = (data.original_value - residual_value) / data.estimated_uses
equipment = Equipment(
**data.model_dump(),
depreciation_per_use=depreciation_per_use,
)
db.add(equipment)
await db.flush()
await db.refresh(equipment)
return ResponseModel(message="创建成功", data=EquipmentResponse.model_validate(equipment))
@router.put("/{equipment_id}", response_model=ResponseModel[EquipmentResponse])
async def update_equipment(
equipment_id: int,
data: EquipmentUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新设备"""
result = await db.execute(select(Equipment).where(Equipment.id == equipment_id))
equipment = result.scalar_one_or_none()
if not equipment:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "设备不存在"}
)
# 检查编码是否重复
if data.equipment_code and data.equipment_code != equipment.equipment_code:
existing = await db.execute(
select(Equipment).where(Equipment.equipment_code == data.equipment_code)
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "设备编码已存在"}
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(equipment, field, value)
# 重新计算折旧(如果相关字段更新了)
if any(f in update_data for f in ['original_value', 'residual_rate', 'estimated_uses']):
residual_value = float(equipment.original_value) * float(equipment.residual_rate) / 100
equipment.depreciation_per_use = (float(equipment.original_value) - residual_value) / equipment.estimated_uses
await db.flush()
await db.refresh(equipment)
return ResponseModel(message="更新成功", data=EquipmentResponse.model_validate(equipment))
@router.delete("/{equipment_id}", response_model=ResponseModel)
async def delete_equipment(
equipment_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除设备"""
result = await db.execute(select(Equipment).where(Equipment.id == equipment_id))
equipment = result.scalar_one_or_none()
if not equipment:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "设备不存在"}
)
await db.delete(equipment)
return ResponseModel(message="删除成功")

View File

@@ -0,0 +1,187 @@
"""固定成本路由
实现固定成本的 CRUD 操作
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.fixed_cost import FixedCost
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
from app.schemas.fixed_cost import (
FixedCostCreate,
FixedCostUpdate,
FixedCostResponse,
CostType,
AllocationMethod,
)
router = APIRouter()
@router.get("", response_model=ResponseModel[PaginatedData[FixedCostResponse]])
async def get_fixed_costs(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
year_month: Optional[str] = Query(None, description="年月筛选"),
cost_type: Optional[CostType] = Query(None, description="类型筛选"),
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
db: AsyncSession = Depends(get_db),
):
"""获取固定成本列表"""
query = select(FixedCost)
if year_month:
query = query.where(FixedCost.year_month == year_month)
if cost_type:
query = query.where(FixedCost.cost_type == cost_type.value)
if is_active is not None:
query = query.where(FixedCost.is_active == is_active)
query = query.order_by(FixedCost.year_month.desc(), FixedCost.id)
# 分页
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
result = await db.execute(query)
fixed_costs = result.scalars().all()
# 统计总数
count_query = select(func.count(FixedCost.id))
if year_month:
count_query = count_query.where(FixedCost.year_month == year_month)
if cost_type:
count_query = count_query.where(FixedCost.cost_type == cost_type.value)
if is_active is not None:
count_query = count_query.where(FixedCost.is_active == is_active)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
return ResponseModel(
data=PaginatedData(
items=[FixedCostResponse.model_validate(f) for f in fixed_costs],
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
)
)
@router.get("/summary", response_model=ResponseModel)
async def get_fixed_costs_summary(
year_month: str = Query(..., description="年月"),
db: AsyncSession = Depends(get_db),
):
"""获取指定月份的固定成本汇总"""
query = select(FixedCost).where(
FixedCost.year_month == year_month,
FixedCost.is_active == True,
)
result = await db.execute(query)
fixed_costs = result.scalars().all()
# 按类型汇总
summary_by_type = {}
total_amount = 0
for cost in fixed_costs:
cost_type = cost.cost_type
if cost_type not in summary_by_type:
summary_by_type[cost_type] = 0
summary_by_type[cost_type] += float(cost.monthly_amount)
total_amount += float(cost.monthly_amount)
return ResponseModel(
data={
"year_month": year_month,
"total_amount": total_amount,
"by_type": summary_by_type,
"count": len(fixed_costs),
}
)
@router.get("/{fixed_cost_id}", response_model=ResponseModel[FixedCostResponse])
async def get_fixed_cost(
fixed_cost_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取单个固定成本详情"""
result = await db.execute(select(FixedCost).where(FixedCost.id == fixed_cost_id))
fixed_cost = result.scalar_one_or_none()
if not fixed_cost:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "固定成本不存在"}
)
return ResponseModel(data=FixedCostResponse.model_validate(fixed_cost))
@router.post("", response_model=ResponseModel[FixedCostResponse])
async def create_fixed_cost(
data: FixedCostCreate,
db: AsyncSession = Depends(get_db),
):
"""创建固定成本"""
fixed_cost = FixedCost(**data.model_dump())
db.add(fixed_cost)
await db.flush()
await db.refresh(fixed_cost)
return ResponseModel(message="创建成功", data=FixedCostResponse.model_validate(fixed_cost))
@router.put("/{fixed_cost_id}", response_model=ResponseModel[FixedCostResponse])
async def update_fixed_cost(
fixed_cost_id: int,
data: FixedCostUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新固定成本"""
result = await db.execute(select(FixedCost).where(FixedCost.id == fixed_cost_id))
fixed_cost = result.scalar_one_or_none()
if not fixed_cost:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "固定成本不存在"}
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(fixed_cost, field, value)
await db.flush()
await db.refresh(fixed_cost)
return ResponseModel(message="更新成功", data=FixedCostResponse.model_validate(fixed_cost))
@router.delete("/{fixed_cost_id}", response_model=ResponseModel)
async def delete_fixed_cost(
fixed_cost_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除固定成本"""
result = await db.execute(select(FixedCost).where(FixedCost.id == fixed_cost_id))
fixed_cost = result.scalar_one_or_none()
if not fixed_cost:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "固定成本不存在"}
)
await db.delete(fixed_cost)
return ResponseModel(message="删除成功")

View File

@@ -0,0 +1,44 @@
"""健康检查路由
提供 /health 端点用于 Docker 健康检查
遵循瑞小美部署规范30s interval, 10s timeout, 3 retries
"""
from datetime import datetime
from fastapi import APIRouter, Depends
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import get_db
router = APIRouter()
@router.get("/health")
async def health_check(db: AsyncSession = Depends(get_db)):
"""健康检查端点
检查内容:
- 应用运行状态
- 数据库连接状态
- 当前时间戳
"""
# 检查数据库连接
db_status = "connected"
try:
await db.execute(text("SELECT 1"))
except Exception as e:
db_status = f"error: {str(e)}"
return {
"code": 0,
"message": "success",
"data": {
"status": "healthy" if db_status == "connected" else "unhealthy",
"version": settings.APP_VERSION,
"database": db_status,
"timestamp": datetime.now().isoformat()
}
}

View File

@@ -0,0 +1,733 @@
"""市场行情管理路由
实现竞品机构、竞品价格、标杆价格和市场分析
"""
from typing import Optional
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database import get_db
from app.models import (
Project,
Competitor,
CompetitorPrice,
BenchmarkPrice,
MarketAnalysisResult,
Category,
)
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
from app.schemas.competitor import (
CompetitorCreate,
CompetitorUpdate,
CompetitorResponse,
CompetitorPriceCreate,
CompetitorPriceUpdate,
CompetitorPriceResponse,
Positioning,
)
from app.schemas.market import (
BenchmarkPriceCreate,
BenchmarkPriceUpdate,
BenchmarkPriceResponse,
MarketAnalysisRequest,
MarketAnalysisResult as MarketAnalysisResultSchema,
MarketAnalysisResponse,
)
from app.services.market_service import MarketService
router = APIRouter()
# ============ 竞品机构 CRUD ============
@router.get("/competitors", response_model=ResponseModel[PaginatedData[CompetitorResponse]])
async def get_competitors(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
positioning: Optional[Positioning] = Query(None, description="定位筛选"),
is_key_competitor: Optional[bool] = Query(None, description="是否重点关注"),
keyword: Optional[str] = Query(None, description="关键词搜索"),
db: AsyncSession = Depends(get_db),
):
"""获取竞品机构列表"""
query = select(Competitor).options(selectinload(Competitor.prices))
if positioning:
query = query.where(Competitor.positioning == positioning.value)
if is_key_competitor is not None:
query = query.where(Competitor.is_key_competitor == is_key_competitor)
if keyword:
query = query.where(
or_(
Competitor.competitor_name.contains(keyword),
Competitor.address.contains(keyword),
)
)
query = query.order_by(Competitor.is_key_competitor.desc(), Competitor.id.desc())
# 分页
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
result = await db.execute(query)
competitors = result.scalars().all()
# 统计总数
count_query = select(func.count(Competitor.id))
if positioning:
count_query = count_query.where(Competitor.positioning == positioning.value)
if is_key_competitor is not None:
count_query = count_query.where(Competitor.is_key_competitor == is_key_competitor)
if keyword:
count_query = count_query.where(
or_(
Competitor.competitor_name.contains(keyword),
Competitor.address.contains(keyword),
)
)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# 构建响应
items = []
for c in competitors:
# 获取最新价格更新日期
last_price_update = None
if c.prices:
last_price_update = max(p.collected_at for p in c.prices)
items.append(CompetitorResponse(
id=c.id,
competitor_name=c.competitor_name,
address=c.address,
distance_km=float(c.distance_km) if c.distance_km else None,
positioning=Positioning(c.positioning),
contact=c.contact,
is_key_competitor=c.is_key_competitor,
is_active=c.is_active,
price_count=len(c.prices),
last_price_update=last_price_update,
created_at=c.created_at,
updated_at=c.updated_at,
))
return ResponseModel(
data=PaginatedData(
items=items,
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
)
)
@router.get("/competitors/{competitor_id}", response_model=ResponseModel[CompetitorResponse])
async def get_competitor(
competitor_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取竞品机构详情"""
result = await db.execute(
select(Competitor).options(
selectinload(Competitor.prices)
).where(Competitor.id == competitor_id)
)
competitor = result.scalar_one_or_none()
if not competitor:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
)
last_price_update = None
if competitor.prices:
last_price_update = max(p.collected_at for p in competitor.prices)
return ResponseModel(
data=CompetitorResponse(
id=competitor.id,
competitor_name=competitor.competitor_name,
address=competitor.address,
distance_km=float(competitor.distance_km) if competitor.distance_km else None,
positioning=Positioning(competitor.positioning),
contact=competitor.contact,
is_key_competitor=competitor.is_key_competitor,
is_active=competitor.is_active,
price_count=len(competitor.prices),
last_price_update=last_price_update,
created_at=competitor.created_at,
updated_at=competitor.updated_at,
)
)
@router.post("/competitors", response_model=ResponseModel[CompetitorResponse])
async def create_competitor(
data: CompetitorCreate,
db: AsyncSession = Depends(get_db),
):
"""创建竞品机构"""
competitor = Competitor(
competitor_name=data.competitor_name,
address=data.address,
distance_km=data.distance_km,
positioning=data.positioning.value,
contact=data.contact,
is_key_competitor=data.is_key_competitor,
is_active=data.is_active,
)
db.add(competitor)
await db.flush()
await db.refresh(competitor)
return ResponseModel(
message="创建成功",
data=CompetitorResponse(
id=competitor.id,
competitor_name=competitor.competitor_name,
address=competitor.address,
distance_km=float(competitor.distance_km) if competitor.distance_km else None,
positioning=Positioning(competitor.positioning),
contact=competitor.contact,
is_key_competitor=competitor.is_key_competitor,
is_active=competitor.is_active,
price_count=0,
last_price_update=None,
created_at=competitor.created_at,
updated_at=competitor.updated_at,
)
)
@router.put("/competitors/{competitor_id}", response_model=ResponseModel[CompetitorResponse])
async def update_competitor(
competitor_id: int,
data: CompetitorUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新竞品机构"""
result = await db.execute(
select(Competitor).options(
selectinload(Competitor.prices)
).where(Competitor.id == competitor_id)
)
competitor = result.scalar_one_or_none()
if not competitor:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "positioning" and value:
value = value.value
setattr(competitor, field, value)
await db.flush()
await db.refresh(competitor)
last_price_update = None
if competitor.prices:
last_price_update = max(p.collected_at for p in competitor.prices)
return ResponseModel(
message="更新成功",
data=CompetitorResponse(
id=competitor.id,
competitor_name=competitor.competitor_name,
address=competitor.address,
distance_km=float(competitor.distance_km) if competitor.distance_km else None,
positioning=Positioning(competitor.positioning),
contact=competitor.contact,
is_key_competitor=competitor.is_key_competitor,
is_active=competitor.is_active,
price_count=len(competitor.prices),
last_price_update=last_price_update,
created_at=competitor.created_at,
updated_at=competitor.updated_at,
)
)
@router.delete("/competitors/{competitor_id}", response_model=ResponseModel)
async def delete_competitor(
competitor_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除竞品机构"""
result = await db.execute(
select(Competitor).where(Competitor.id == competitor_id)
)
competitor = result.scalar_one_or_none()
if not competitor:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
)
await db.delete(competitor)
return ResponseModel(message="删除成功")
# ============ 竞品价格管理 ============
@router.get("/competitors/{competitor_id}/prices", response_model=ResponseModel[list[CompetitorPriceResponse]])
async def get_competitor_prices(
competitor_id: int,
project_id: Optional[int] = Query(None, description="项目筛选"),
db: AsyncSession = Depends(get_db),
):
"""获取竞品价格列表"""
# 检查竞品机构是否存在
competitor_result = await db.execute(
select(Competitor).where(Competitor.id == competitor_id)
)
competitor = competitor_result.scalar_one_or_none()
if not competitor:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
)
query = select(CompetitorPrice).where(
CompetitorPrice.competitor_id == competitor_id
)
if project_id:
query = query.where(CompetitorPrice.project_id == project_id)
query = query.order_by(CompetitorPrice.collected_at.desc())
result = await db.execute(query)
prices = result.scalars().all()
response_items = []
for p in prices:
response_items.append(CompetitorPriceResponse(
id=p.id,
competitor_id=p.competitor_id,
competitor_name=competitor.competitor_name,
project_id=p.project_id,
project_name=p.project_name,
original_price=float(p.original_price),
promo_price=float(p.promo_price) if p.promo_price else None,
member_price=float(p.member_price) if p.member_price else None,
price_source=p.price_source,
collected_at=p.collected_at,
remark=p.remark,
created_at=p.created_at,
updated_at=p.updated_at,
))
return ResponseModel(data=response_items)
@router.post("/competitors/{competitor_id}/prices", response_model=ResponseModel[CompetitorPriceResponse])
async def create_competitor_price(
competitor_id: int,
data: CompetitorPriceCreate,
db: AsyncSession = Depends(get_db),
):
"""添加竞品价格"""
# 检查竞品机构是否存在
competitor_result = await db.execute(
select(Competitor).where(Competitor.id == competitor_id)
)
competitor = competitor_result.scalar_one_or_none()
if not competitor:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品机构不存在"}
)
# 检查关联项目是否存在
if data.project_id:
project_result = await db.execute(
select(Project).where(Project.id == data.project_id)
)
if not project_result.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.NOT_FOUND, "message": "关联项目不存在"}
)
price = CompetitorPrice(
competitor_id=competitor_id,
project_id=data.project_id,
project_name=data.project_name,
original_price=data.original_price,
promo_price=data.promo_price,
member_price=data.member_price,
price_source=data.price_source.value,
collected_at=data.collected_at,
remark=data.remark,
)
db.add(price)
await db.flush()
await db.refresh(price)
return ResponseModel(
message="添加成功",
data=CompetitorPriceResponse(
id=price.id,
competitor_id=price.competitor_id,
competitor_name=competitor.competitor_name,
project_id=price.project_id,
project_name=price.project_name,
original_price=float(price.original_price),
promo_price=float(price.promo_price) if price.promo_price else None,
member_price=float(price.member_price) if price.member_price else None,
price_source=price.price_source,
collected_at=price.collected_at,
remark=price.remark,
created_at=price.created_at,
updated_at=price.updated_at,
)
)
@router.put("/competitor-prices/{price_id}", response_model=ResponseModel[CompetitorPriceResponse])
async def update_competitor_price(
price_id: int,
data: CompetitorPriceUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新竞品价格"""
result = await db.execute(
select(CompetitorPrice).options(
selectinload(CompetitorPrice.competitor)
).where(CompetitorPrice.id == price_id)
)
price = result.scalar_one_or_none()
if not price:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品价格不存在"}
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "price_source" and value:
value = value.value
setattr(price, field, value)
await db.flush()
await db.refresh(price)
return ResponseModel(
message="更新成功",
data=CompetitorPriceResponse(
id=price.id,
competitor_id=price.competitor_id,
competitor_name=price.competitor.competitor_name if price.competitor else None,
project_id=price.project_id,
project_name=price.project_name,
original_price=float(price.original_price),
promo_price=float(price.promo_price) if price.promo_price else None,
member_price=float(price.member_price) if price.member_price else None,
price_source=price.price_source,
collected_at=price.collected_at,
remark=price.remark,
created_at=price.created_at,
updated_at=price.updated_at,
)
)
@router.delete("/competitor-prices/{price_id}", response_model=ResponseModel)
async def delete_competitor_price(
price_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除竞品价格"""
result = await db.execute(
select(CompetitorPrice).where(CompetitorPrice.id == price_id)
)
price = result.scalar_one_or_none()
if not price:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "竞品价格不存在"}
)
await db.delete(price)
return ResponseModel(message="删除成功")
# ============ 标杆价格管理 ============
@router.get("/benchmark-prices", response_model=ResponseModel[PaginatedData[BenchmarkPriceResponse]])
async def get_benchmark_prices(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
category_id: Optional[int] = Query(None, description="分类筛选"),
db: AsyncSession = Depends(get_db),
):
"""获取标杆价格列表"""
query = select(BenchmarkPrice).options(selectinload(BenchmarkPrice.category))
if category_id:
query = query.where(BenchmarkPrice.category_id == category_id)
query = query.order_by(BenchmarkPrice.effective_date.desc())
# 分页
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
result = await db.execute(query)
benchmarks = result.scalars().all()
# 统计总数
count_query = select(func.count(BenchmarkPrice.id))
if category_id:
count_query = count_query.where(BenchmarkPrice.category_id == category_id)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
items = []
for b in benchmarks:
items.append(BenchmarkPriceResponse(
id=b.id,
benchmark_name=b.benchmark_name,
category_id=b.category_id,
category_name=b.category.category_name if b.category else None,
min_price=float(b.min_price),
max_price=float(b.max_price),
avg_price=float(b.avg_price),
price_tier=b.price_tier,
effective_date=b.effective_date,
remark=b.remark,
created_at=b.created_at,
updated_at=b.updated_at,
))
return ResponseModel(
data=PaginatedData(
items=items,
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
)
)
@router.post("/benchmark-prices", response_model=ResponseModel[BenchmarkPriceResponse])
async def create_benchmark_price(
data: BenchmarkPriceCreate,
db: AsyncSession = Depends(get_db),
):
"""创建标杆价格"""
# 检查分类是否存在
category_name = None
if data.category_id:
category_result = await db.execute(
select(Category).where(Category.id == data.category_id)
)
category = category_result.scalar_one_or_none()
if not category:
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
)
category_name = category.category_name
benchmark = BenchmarkPrice(
benchmark_name=data.benchmark_name,
category_id=data.category_id,
min_price=data.min_price,
max_price=data.max_price,
avg_price=data.avg_price,
price_tier=data.price_tier.value,
effective_date=data.effective_date,
remark=data.remark,
)
db.add(benchmark)
await db.flush()
await db.refresh(benchmark)
return ResponseModel(
message="创建成功",
data=BenchmarkPriceResponse(
id=benchmark.id,
benchmark_name=benchmark.benchmark_name,
category_id=benchmark.category_id,
category_name=category_name,
min_price=float(benchmark.min_price),
max_price=float(benchmark.max_price),
avg_price=float(benchmark.avg_price),
price_tier=benchmark.price_tier,
effective_date=benchmark.effective_date,
remark=benchmark.remark,
created_at=benchmark.created_at,
updated_at=benchmark.updated_at,
)
)
@router.put("/benchmark-prices/{benchmark_id}", response_model=ResponseModel[BenchmarkPriceResponse])
async def update_benchmark_price(
benchmark_id: int,
data: BenchmarkPriceUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新标杆价格"""
result = await db.execute(
select(BenchmarkPrice).options(
selectinload(BenchmarkPrice.category)
).where(BenchmarkPrice.id == benchmark_id)
)
benchmark = result.scalar_one_or_none()
if not benchmark:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "标杆价格不存在"}
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "price_tier" and value:
value = value.value
setattr(benchmark, field, value)
await db.flush()
await db.refresh(benchmark)
# 获取分类名称
category_name = None
if benchmark.category_id:
cat_result = await db.execute(
select(Category).where(Category.id == benchmark.category_id)
)
category = cat_result.scalar_one_or_none()
if category:
category_name = category.category_name
return ResponseModel(
message="更新成功",
data=BenchmarkPriceResponse(
id=benchmark.id,
benchmark_name=benchmark.benchmark_name,
category_id=benchmark.category_id,
category_name=category_name,
min_price=float(benchmark.min_price),
max_price=float(benchmark.max_price),
avg_price=float(benchmark.avg_price),
price_tier=benchmark.price_tier,
effective_date=benchmark.effective_date,
remark=benchmark.remark,
created_at=benchmark.created_at,
updated_at=benchmark.updated_at,
)
)
@router.delete("/benchmark-prices/{benchmark_id}", response_model=ResponseModel)
async def delete_benchmark_price(
benchmark_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除标杆价格"""
result = await db.execute(
select(BenchmarkPrice).where(BenchmarkPrice.id == benchmark_id)
)
benchmark = result.scalar_one_or_none()
if not benchmark:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "标杆价格不存在"}
)
await db.delete(benchmark)
return ResponseModel(message="删除成功")
# ============ 市场分析 ============
@router.post("/projects/{project_id}/market-analysis", response_model=ResponseModel[MarketAnalysisResultSchema])
async def analyze_market(
project_id: int,
data: MarketAnalysisRequest = MarketAnalysisRequest(),
db: AsyncSession = Depends(get_db),
):
"""执行市场价格分析"""
market_service = MarketService(db)
try:
result = await market_service.analyze_market(
project_id=project_id,
competitor_ids=data.competitor_ids,
include_benchmark=data.include_benchmark,
)
except ValueError as e:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
)
return ResponseModel(message="分析完成", data=result)
@router.get("/projects/{project_id}/market-analysis", response_model=ResponseModel[MarketAnalysisResponse])
async def get_market_analysis(
project_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取最新市场分析结果"""
# 检查项目是否存在
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
if not project_result.scalar_one_or_none():
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
market_service = MarketService(db)
result = await market_service.get_latest_analysis(project_id)
if not result:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "暂无分析结果,请先执行市场分析"}
)
return ResponseModel(
data=MarketAnalysisResponse(
id=result.id,
project_id=result.project_id,
analysis_date=result.analysis_date,
competitor_count=result.competitor_count,
market_min_price=float(result.market_min_price),
market_max_price=float(result.market_max_price),
market_avg_price=float(result.market_avg_price),
market_median_price=float(result.market_median_price),
suggested_range_min=float(result.suggested_range_min),
suggested_range_max=float(result.suggested_range_max),
created_at=result.created_at,
)
)

View File

@@ -0,0 +1,272 @@
"""耗材管理路由
实现耗材的 CRUD 操作和批量导入
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile, File
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.material import Material
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
from app.schemas.material import (
MaterialCreate,
MaterialUpdate,
MaterialResponse,
MaterialImportResult,
MaterialType,
)
router = APIRouter()
@router.get("", response_model=ResponseModel[PaginatedData[MaterialResponse]])
async def get_materials(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
keyword: Optional[str] = Query(None, description="关键词搜索"),
material_type: Optional[MaterialType] = Query(None, description="类型筛选"),
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
db: AsyncSession = Depends(get_db),
):
"""获取耗材列表"""
query = select(Material)
if keyword:
query = query.where(
or_(
Material.material_code.contains(keyword),
Material.material_name.contains(keyword),
)
)
if material_type:
query = query.where(Material.material_type == material_type.value)
if is_active is not None:
query = query.where(Material.is_active == is_active)
query = query.order_by(Material.id.desc())
# 分页
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
result = await db.execute(query)
materials = result.scalars().all()
# 统计总数
count_query = select(func.count(Material.id))
if keyword:
count_query = count_query.where(
or_(
Material.material_code.contains(keyword),
Material.material_name.contains(keyword),
)
)
if material_type:
count_query = count_query.where(Material.material_type == material_type.value)
if is_active is not None:
count_query = count_query.where(Material.is_active == is_active)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
return ResponseModel(
data=PaginatedData(
items=[MaterialResponse.model_validate(m) for m in materials],
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
)
)
@router.get("/{material_id}", response_model=ResponseModel[MaterialResponse])
async def get_material(
material_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取单个耗材详情"""
result = await db.execute(select(Material).where(Material.id == material_id))
material = result.scalar_one_or_none()
if not material:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "耗材不存在"}
)
return ResponseModel(data=MaterialResponse.model_validate(material))
@router.post("", response_model=ResponseModel[MaterialResponse])
async def create_material(
data: MaterialCreate,
db: AsyncSession = Depends(get_db),
):
"""创建耗材"""
# 检查编码是否已存在
existing = await db.execute(
select(Material).where(Material.material_code == data.material_code)
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "耗材编码已存在"}
)
material = Material(**data.model_dump())
db.add(material)
await db.flush()
await db.refresh(material)
return ResponseModel(message="创建成功", data=MaterialResponse.model_validate(material))
@router.put("/{material_id}", response_model=ResponseModel[MaterialResponse])
async def update_material(
material_id: int,
data: MaterialUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新耗材"""
result = await db.execute(select(Material).where(Material.id == material_id))
material = result.scalar_one_or_none()
if not material:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "耗材不存在"}
)
# 检查编码是否重复
if data.material_code and data.material_code != material.material_code:
existing = await db.execute(
select(Material).where(Material.material_code == data.material_code)
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "耗材编码已存在"}
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(material, field, value)
await db.flush()
await db.refresh(material)
return ResponseModel(message="更新成功", data=MaterialResponse.model_validate(material))
@router.delete("/{material_id}", response_model=ResponseModel)
async def delete_material(
material_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除耗材"""
result = await db.execute(select(Material).where(Material.id == material_id))
material = result.scalar_one_or_none()
if not material:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "耗材不存在"}
)
await db.delete(material)
return ResponseModel(message="删除成功")
@router.post("/import", response_model=ResponseModel[MaterialImportResult])
async def import_materials(
file: UploadFile = File(..., description="Excel 文件"),
update_existing: bool = Query(False, description="是否更新已存在的数据"),
db: AsyncSession = Depends(get_db),
):
"""批量导入耗材
Excel 格式:耗材编码 | 耗材名称 | 单位 | 单价 | 供应商 | 类型
"""
import openpyxl
from io import BytesIO
if not file.filename.endswith(('.xlsx', '.xls')):
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.PARAM_ERROR, "message": "请上传 Excel 文件"}
)
content = await file.read()
wb = openpyxl.load_workbook(BytesIO(content))
ws = wb.active
total = 0
success = 0
errors = []
for row_num, row in enumerate(ws.iter_rows(min_row=2, values_only=True), start=2):
if not row[0]: # 跳过空行
continue
total += 1
try:
material_code = str(row[0]).strip()
material_name = str(row[1]).strip() if row[1] else ""
unit = str(row[2]).strip() if row[2] else ""
unit_price = float(row[3]) if row[3] else 0
supplier = str(row[4]).strip() if row[4] else None
material_type = str(row[5]).strip() if row[5] else "consumable"
if not material_name or not unit:
errors.append({"row": row_num, "error": "名称或单位不能为空"})
continue
# 检查是否已存在
existing = await db.execute(
select(Material).where(Material.material_code == material_code)
)
existing_material = existing.scalar_one_or_none()
if existing_material:
if update_existing:
existing_material.material_name = material_name
existing_material.unit = unit
existing_material.unit_price = unit_price
existing_material.supplier = supplier
existing_material.material_type = material_type
success += 1
else:
errors.append({"row": row_num, "error": f"耗材编码 {material_code} 已存在"})
else:
material = Material(
material_code=material_code,
material_name=material_name,
unit=unit,
unit_price=unit_price,
supplier=supplier,
material_type=material_type,
)
db.add(material)
success += 1
except Exception as e:
errors.append({"row": row_num, "error": str(e)})
await db.flush()
return ResponseModel(
message="导入完成",
data=MaterialImportResult(
total=total,
success=success,
failed=len(errors),
errors=errors,
)
)

View File

@@ -0,0 +1,327 @@
"""智能定价路由
智能定价建议相关的 API 接口
"""
from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database import get_db
from app.models import PricingPlan, Project
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
from app.schemas.pricing import (
StrategyType,
PricingPlanCreate,
PricingPlanUpdate,
PricingPlanResponse,
PricingPlanListResponse,
PricingPlanQuery,
GeneratePricingRequest,
GeneratePricingResponse,
SimulateStrategyRequest,
SimulateStrategyResponse,
)
from app.services.pricing_service import PricingService
router = APIRouter()
# 定价方案 CRUD
# 定价方案允许的排序字段白名单
PRICING_PLAN_SORT_FIELDS = {"created_at", "updated_at", "plan_name", "base_cost", "target_margin", "suggested_price"}
@router.get("/pricing-plans", response_model=ResponseModel[PaginatedData[PricingPlanListResponse]])
async def list_pricing_plans(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
project_id: Optional[int] = None,
strategy_type: Optional[StrategyType] = None,
is_active: Optional[bool] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
db: AsyncSession = Depends(get_db),
):
"""获取定价方案列表"""
query = select(PricingPlan).options(
selectinload(PricingPlan.project),
selectinload(PricingPlan.creator),
)
if project_id:
query = query.where(PricingPlan.project_id == project_id)
if strategy_type:
query = query.where(PricingPlan.strategy_type == strategy_type.value)
if is_active is not None:
query = query.where(PricingPlan.is_active == is_active)
# 计算总数
count_query = select(func.count()).select_from(query.subquery())
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# 排序 - 使用白名单验证防止注入
if sort_by not in PRICING_PLAN_SORT_FIELDS:
sort_by = "created_at"
sort_column = getattr(PricingPlan, sort_by, PricingPlan.created_at)
if sort_order == "desc":
query = query.order_by(sort_column.desc())
else:
query = query.order_by(sort_column.asc())
# 分页
query = query.offset((page - 1) * page_size).limit(page_size)
result = await db.execute(query)
plans = result.scalars().all()
items = []
for plan in plans:
items.append(PricingPlanListResponse(
id=plan.id,
project_id=plan.project_id,
project_name=plan.project.project_name if plan.project else None,
plan_name=plan.plan_name,
strategy_type=plan.strategy_type,
base_cost=float(plan.base_cost),
target_margin=float(plan.target_margin),
suggested_price=float(plan.suggested_price),
final_price=float(plan.final_price) if plan.final_price else None,
is_active=plan.is_active,
created_at=plan.created_at,
created_by_name=plan.creator.username if plan.creator else None,
))
return ResponseModel(data=PaginatedData(
items=items,
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
))
@router.post("/pricing-plans", response_model=ResponseModel[PricingPlanResponse])
async def create_pricing_plan(
data: PricingPlanCreate,
db: AsyncSession = Depends(get_db),
):
"""创建定价方案"""
service = PricingService(db)
try:
plan = await service.create_pricing_plan(
project_id=data.project_id,
plan_name=data.plan_name,
strategy_type=data.strategy_type,
target_margin=data.target_margin,
)
await db.commit()
# 重新加载关系
await db.refresh(plan, ["project", "creator"])
return ResponseModel(
message="创建成功",
data=PricingPlanResponse(
id=plan.id,
project_id=plan.project_id,
project_name=plan.project.project_name if plan.project else None,
plan_name=plan.plan_name,
strategy_type=plan.strategy_type,
base_cost=float(plan.base_cost),
target_margin=float(plan.target_margin),
suggested_price=float(plan.suggested_price),
final_price=float(plan.final_price) if plan.final_price else None,
ai_advice=plan.ai_advice,
is_active=plan.is_active,
created_at=plan.created_at,
updated_at=plan.updated_at,
)
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/pricing-plans/{plan_id}", response_model=ResponseModel[PricingPlanResponse])
async def get_pricing_plan(
plan_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取定价方案详情"""
result = await db.execute(
select(PricingPlan).options(
selectinload(PricingPlan.project),
selectinload(PricingPlan.creator),
).where(PricingPlan.id == plan_id)
)
plan = result.scalar_one_or_none()
if not plan:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "定价方案不存在"}
)
return ResponseModel(data=PricingPlanResponse(
id=plan.id,
project_id=plan.project_id,
project_name=plan.project.project_name if plan.project else None,
plan_name=plan.plan_name,
strategy_type=plan.strategy_type,
base_cost=float(plan.base_cost),
target_margin=float(plan.target_margin),
suggested_price=float(plan.suggested_price),
final_price=float(plan.final_price) if plan.final_price else None,
ai_advice=plan.ai_advice,
is_active=plan.is_active,
created_at=plan.created_at,
updated_at=plan.updated_at,
created_by_name=plan.creator.username if plan.creator else None,
))
@router.put("/pricing-plans/{plan_id}", response_model=ResponseModel[PricingPlanResponse])
async def update_pricing_plan(
plan_id: int,
data: PricingPlanUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新定价方案"""
service = PricingService(db)
try:
plan = await service.update_pricing_plan(
plan_id=plan_id,
plan_name=data.plan_name,
strategy_type=data.strategy_type.value if data.strategy_type else None,
target_margin=data.target_margin,
final_price=data.final_price,
is_active=data.is_active,
)
await db.commit()
await db.refresh(plan, ["project", "creator"])
return ResponseModel(
message="更新成功",
data=PricingPlanResponse(
id=plan.id,
project_id=plan.project_id,
project_name=plan.project.project_name if plan.project else None,
plan_name=plan.plan_name,
strategy_type=plan.strategy_type,
base_cost=float(plan.base_cost),
target_margin=float(plan.target_margin),
suggested_price=float(plan.suggested_price),
final_price=float(plan.final_price) if plan.final_price else None,
ai_advice=plan.ai_advice,
is_active=plan.is_active,
created_at=plan.created_at,
updated_at=plan.updated_at,
)
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/pricing-plans/{plan_id}", response_model=ResponseModel)
async def delete_pricing_plan(
plan_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除定价方案"""
result = await db.execute(
select(PricingPlan).where(PricingPlan.id == plan_id)
)
plan = result.scalar_one_or_none()
if not plan:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "定价方案不存在"}
)
await db.delete(plan)
await db.commit()
return ResponseModel(message="删除成功")
# AI 定价建议
@router.post("/projects/{project_id}/generate-pricing", response_model=ResponseModel[GeneratePricingResponse])
async def generate_pricing(
project_id: int,
request: GeneratePricingRequest,
db: AsyncSession = Depends(get_db),
):
"""AI 生成定价建议
支持流式和非流式两种模式
"""
service = PricingService(db)
# 检查项目是否存在
result = await db.execute(
select(Project).where(Project.id == project_id)
)
if not result.scalar_one_or_none():
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
if request.stream:
# 流式返回
return StreamingResponse(
service.generate_pricing_advice_stream(
project_id=project_id,
target_margin=request.target_margin,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)
else:
# 非流式返回
try:
response = await service.generate_pricing_advice(
project_id=project_id,
target_margin=request.target_margin,
strategies=request.strategies,
)
return ResponseModel(data=response)
except Exception as e:
raise HTTPException(
status_code=500,
detail={"code": ErrorCode.AI_SERVICE_ERROR, "message": str(e)}
)
@router.post("/projects/{project_id}/simulate-strategy", response_model=ResponseModel[SimulateStrategyResponse])
async def simulate_strategy(
project_id: int,
request: SimulateStrategyRequest,
db: AsyncSession = Depends(get_db),
):
"""模拟定价策略"""
service = PricingService(db)
try:
response = await service.simulate_strategies(
project_id=project_id,
strategies=request.strategies,
target_margin=request.target_margin,
)
return ResponseModel(data=response)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -0,0 +1,273 @@
"""利润模拟路由
利润模拟测算相关的 API 接口
"""
from typing import Optional, List
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database import get_db
from app.models import ProfitSimulation, PricingPlan
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
from app.schemas.profit import (
PeriodType,
ProfitSimulationResponse,
ProfitSimulationListResponse,
SimulateProfitRequest,
SimulateProfitResponse,
SensitivityAnalysisRequest,
SensitivityAnalysisResponse,
BreakevenRequest,
BreakevenResponse,
)
from app.services.profit_service import ProfitService
router = APIRouter()
# 利润模拟 CRUD
@router.get("/profit-simulations", response_model=ResponseModel[PaginatedData[ProfitSimulationListResponse]])
async def list_profit_simulations(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
pricing_plan_id: Optional[int] = None,
period_type: Optional[PeriodType] = None,
sort_by: str = "created_at",
sort_order: str = "desc",
db: AsyncSession = Depends(get_db),
):
"""获取利润模拟列表"""
service = ProfitService(db)
simulations, total = await service.get_simulation_list(
pricing_plan_id=pricing_plan_id,
period_type=period_type,
page=page,
page_size=page_size,
)
items = []
for sim in simulations:
items.append(ProfitSimulationListResponse(
id=sim.id,
pricing_plan_id=sim.pricing_plan_id,
plan_name=sim.pricing_plan.plan_name if sim.pricing_plan else None,
project_name=sim.pricing_plan.project.project_name if sim.pricing_plan and sim.pricing_plan.project else None,
simulation_name=sim.simulation_name,
price=float(sim.price),
estimated_volume=sim.estimated_volume,
period_type=sim.period_type,
estimated_profit=float(sim.estimated_profit),
profit_margin=float(sim.profit_margin),
created_at=sim.created_at,
))
return ResponseModel(data=PaginatedData(
items=items,
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
))
@router.get("/profit-simulations/{simulation_id}", response_model=ResponseModel[ProfitSimulationResponse])
async def get_profit_simulation(
simulation_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取模拟详情"""
result = await db.execute(
select(ProfitSimulation).options(
selectinload(ProfitSimulation.pricing_plan).selectinload(PricingPlan.project),
selectinload(ProfitSimulation.creator),
).where(ProfitSimulation.id == simulation_id)
)
sim = result.scalar_one_or_none()
if not sim:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "模拟记录不存在"}
)
return ResponseModel(data=ProfitSimulationResponse(
id=sim.id,
pricing_plan_id=sim.pricing_plan_id,
plan_name=sim.pricing_plan.plan_name if sim.pricing_plan else None,
project_name=sim.pricing_plan.project.project_name if sim.pricing_plan and sim.pricing_plan.project else None,
simulation_name=sim.simulation_name,
price=float(sim.price),
estimated_volume=sim.estimated_volume,
period_type=sim.period_type,
estimated_revenue=float(sim.estimated_revenue),
estimated_cost=float(sim.estimated_cost),
estimated_profit=float(sim.estimated_profit),
profit_margin=float(sim.profit_margin),
breakeven_volume=sim.breakeven_volume,
created_at=sim.created_at,
created_by_name=sim.creator.username if sim.creator else None,
))
@router.delete("/profit-simulations/{simulation_id}", response_model=ResponseModel)
async def delete_profit_simulation(
simulation_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除模拟记录"""
service = ProfitService(db)
deleted = await service.delete_simulation(simulation_id)
if not deleted:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "模拟记录不存在"}
)
await db.commit()
return ResponseModel(message="删除成功")
# 执行模拟
@router.post("/pricing-plans/{plan_id}/simulate-profit", response_model=ResponseModel[SimulateProfitResponse])
async def simulate_profit(
plan_id: int,
request: SimulateProfitRequest,
db: AsyncSession = Depends(get_db),
):
"""执行利润模拟"""
service = ProfitService(db)
try:
response = await service.simulate_profit(
pricing_plan_id=plan_id,
price=request.price,
estimated_volume=request.estimated_volume,
period_type=request.period_type,
)
await db.commit()
return ResponseModel(data=response)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# 敏感性分析
@router.post("/profit-simulations/{simulation_id}/sensitivity", response_model=ResponseModel[SensitivityAnalysisResponse])
async def create_sensitivity_analysis(
simulation_id: int,
request: SensitivityAnalysisRequest,
db: AsyncSession = Depends(get_db),
):
"""执行敏感性分析"""
service = ProfitService(db)
try:
response = await service.sensitivity_analysis(
simulation_id=simulation_id,
price_change_rates=request.price_change_rates,
)
await db.commit()
return ResponseModel(data=response)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/profit-simulations/{simulation_id}/sensitivity", response_model=ResponseModel[SensitivityAnalysisResponse])
async def get_sensitivity_analysis(
simulation_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取敏感性分析结果"""
result = await db.execute(
select(ProfitSimulation).options(
selectinload(ProfitSimulation.sensitivity_analyses),
selectinload(ProfitSimulation.pricing_plan),
).where(ProfitSimulation.id == simulation_id)
)
sim = result.scalar_one_or_none()
if not sim:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "模拟记录不存在"}
)
if not sim.sensitivity_analyses:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "尚未执行敏感性分析"}
)
from app.schemas.profit import SensitivityResultItem
results = [
SensitivityResultItem(
price_change_rate=float(sa.price_change_rate),
adjusted_price=float(sa.adjusted_price),
adjusted_profit=float(sa.adjusted_profit),
profit_change_rate=float(sa.profit_change_rate),
)
for sa in sorted(sim.sensitivity_analyses, key=lambda x: x.price_change_rate)
]
return ResponseModel(data=SensitivityAnalysisResponse(
simulation_id=simulation_id,
base_price=float(sim.price),
base_profit=float(sim.estimated_profit),
sensitivity_results=results,
))
# 盈亏平衡分析
@router.get("/pricing-plans/{plan_id}/breakeven", response_model=ResponseModel[BreakevenResponse])
async def get_breakeven_analysis(
plan_id: int,
target_profit: Optional[float] = Query(None, description="目标利润"),
db: AsyncSession = Depends(get_db),
):
"""获取盈亏平衡分析"""
service = ProfitService(db)
try:
response = await service.breakeven_analysis(
pricing_plan_id=plan_id,
target_profit=target_profit,
)
return ResponseModel(data=response)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# AI 利润预测
@router.post("/profit-simulations/{simulation_id}/forecast", response_model=ResponseModel[dict])
async def generate_profit_forecast(
simulation_id: int,
db: AsyncSession = Depends(get_db),
):
"""AI 生成利润预测分析"""
service = ProfitService(db)
try:
content = await service.generate_profit_forecast(simulation_id)
return ResponseModel(data={"content": content})
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(
status_code=500,
detail={"code": ErrorCode.AI_SERVICE_ERROR, "message": str(e)}
)

View File

@@ -0,0 +1,904 @@
"""服务项目管理路由
实现服务项目的 CRUD 操作和成本管理
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.database import get_db
from app.models import (
Project,
ProjectCostItem,
ProjectLaborCost,
ProjectCostSummary,
Category,
Material,
Equipment,
StaffLevel,
)
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
from app.schemas.project import (
ProjectCreate,
ProjectUpdate,
ProjectResponse,
ProjectListResponse,
CostSummaryBrief,
)
from app.schemas.project_cost import (
CostItemCreate,
CostItemUpdate,
CostItemResponse,
LaborCostCreate,
LaborCostUpdate,
LaborCostResponse,
CalculateCostRequest,
CostCalculationResult,
CostSummaryResponse,
ProjectDetailResponse,
CostItemType,
AllocationMethod,
)
from app.services.cost_service import CostService
router = APIRouter()
# 项目允许的排序字段白名单
PROJECT_SORT_FIELDS = {"created_at", "updated_at", "project_code", "project_name", "duration_minutes"}
# ============ 项目 CRUD ============
@router.get("", response_model=ResponseModel[PaginatedData[ProjectListResponse]])
async def get_projects(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
category_id: Optional[int] = Query(None, description="分类筛选"),
keyword: Optional[str] = Query(None, description="关键词搜索"),
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
sort_by: str = Query("created_at", description="排序字段"),
sort_order: str = Query("desc", description="排序方向"),
db: AsyncSession = Depends(get_db),
):
"""获取服务项目列表"""
query = select(Project).options(
selectinload(Project.category),
selectinload(Project.cost_summary),
)
if keyword:
query = query.where(
or_(
Project.project_code.contains(keyword),
Project.project_name.contains(keyword),
)
)
if category_id:
query = query.where(Project.category_id == category_id)
if is_active is not None:
query = query.where(Project.is_active == is_active)
# 排序 - 使用白名单验证防止注入
if sort_by not in PROJECT_SORT_FIELDS:
sort_by = "created_at"
sort_column = getattr(Project, sort_by, Project.created_at)
if sort_order == "asc":
query = query.order_by(sort_column.asc())
else:
query = query.order_by(sort_column.desc())
# 分页
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
result = await db.execute(query)
projects = result.scalars().all()
# 统计总数
count_query = select(func.count(Project.id))
if keyword:
count_query = count_query.where(
or_(
Project.project_code.contains(keyword),
Project.project_name.contains(keyword),
)
)
if category_id:
count_query = count_query.where(Project.category_id == category_id)
if is_active is not None:
count_query = count_query.where(Project.is_active == is_active)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
# 构建响应
items = []
for p in projects:
item = {
"id": p.id,
"project_code": p.project_code,
"project_name": p.project_name,
"category_id": p.category_id,
"category_name": p.category.category_name if p.category else None,
"description": p.description,
"duration_minutes": p.duration_minutes,
"is_active": p.is_active,
"cost_summary": None,
"created_at": p.created_at,
"updated_at": p.updated_at,
}
if p.cost_summary:
item["cost_summary"] = CostSummaryBrief(
total_cost=float(p.cost_summary.total_cost),
material_cost=float(p.cost_summary.material_cost),
equipment_cost=float(p.cost_summary.equipment_cost),
labor_cost=float(p.cost_summary.labor_cost),
fixed_cost_allocation=float(p.cost_summary.fixed_cost_allocation),
)
items.append(ProjectListResponse(**item))
return ResponseModel(
data=PaginatedData(
items=items,
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
)
)
@router.get("/{project_id}", response_model=ResponseModel[ProjectDetailResponse])
async def get_project(
project_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取项目详情(含成本明细)"""
result = await db.execute(
select(Project).options(
selectinload(Project.category),
selectinload(Project.cost_items),
selectinload(Project.labor_costs).selectinload(ProjectLaborCost.staff_level),
selectinload(Project.cost_summary),
).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
# 构建成本明细响应
cost_items = []
for item in project.cost_items:
item_name = None
unit = None
if item.item_type == CostItemType.MATERIAL.value:
material_result = await db.execute(
select(Material).where(Material.id == item.item_id)
)
material = material_result.scalar_one_or_none()
if material:
item_name = material.material_name
unit = material.unit
else:
equipment_result = await db.execute(
select(Equipment).where(Equipment.id == item.item_id)
)
equipment = equipment_result.scalar_one_or_none()
if equipment:
item_name = equipment.equipment_name
unit = ""
cost_items.append(CostItemResponse(
id=item.id,
item_type=CostItemType(item.item_type),
item_id=item.item_id,
item_name=item_name,
quantity=float(item.quantity),
unit=unit,
unit_cost=float(item.unit_cost),
total_cost=float(item.total_cost),
remark=item.remark,
created_at=item.created_at,
updated_at=item.updated_at,
))
# 构建人工成本响应
labor_costs = []
for item in project.labor_costs:
labor_costs.append(LaborCostResponse(
id=item.id,
staff_level_id=item.staff_level_id,
level_name=item.staff_level.level_name if item.staff_level else None,
duration_minutes=item.duration_minutes,
hourly_rate=float(item.hourly_rate),
labor_cost=float(item.labor_cost),
remark=item.remark,
created_at=item.created_at,
updated_at=item.updated_at,
))
# 构建成本汇总
cost_summary = None
if project.cost_summary:
cost_summary = CostSummaryResponse(
project_id=project.id,
material_cost=float(project.cost_summary.material_cost),
equipment_cost=float(project.cost_summary.equipment_cost),
labor_cost=float(project.cost_summary.labor_cost),
fixed_cost_allocation=float(project.cost_summary.fixed_cost_allocation),
total_cost=float(project.cost_summary.total_cost),
calculated_at=project.cost_summary.calculated_at,
)
return ResponseModel(
data=ProjectDetailResponse(
id=project.id,
project_code=project.project_code,
project_name=project.project_name,
category_id=project.category_id,
category_name=project.category.category_name if project.category else None,
description=project.description,
duration_minutes=project.duration_minutes,
is_active=project.is_active,
cost_items=cost_items,
labor_costs=labor_costs,
cost_summary=cost_summary,
created_at=project.created_at,
updated_at=project.updated_at,
)
)
@router.post("", response_model=ResponseModel[ProjectResponse])
async def create_project(
data: ProjectCreate,
db: AsyncSession = Depends(get_db),
):
"""创建服务项目"""
# 检查编码是否已存在
existing = await db.execute(
select(Project).where(Project.project_code == data.project_code)
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "项目编码已存在"}
)
# 检查分类是否存在
if data.category_id:
category_result = await db.execute(
select(Category).where(Category.id == data.category_id)
)
if not category_result.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
)
project = Project(**data.model_dump())
db.add(project)
await db.flush()
await db.refresh(project)
# 获取分类名称
category_name = None
if project.category_id:
result = await db.execute(
select(Category).where(Category.id == project.category_id)
)
category = result.scalar_one_or_none()
if category:
category_name = category.category_name
return ResponseModel(
message="创建成功",
data=ProjectResponse(
id=project.id,
project_code=project.project_code,
project_name=project.project_name,
category_id=project.category_id,
category_name=category_name,
description=project.description,
duration_minutes=project.duration_minutes,
is_active=project.is_active,
cost_summary=None,
created_at=project.created_at,
updated_at=project.updated_at,
)
)
@router.put("/{project_id}", response_model=ResponseModel[ProjectResponse])
async def update_project(
project_id: int,
data: ProjectUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新服务项目"""
result = await db.execute(
select(Project).options(
selectinload(Project.category),
selectinload(Project.cost_summary),
).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
# 检查编码是否重复
if data.project_code and data.project_code != project.project_code:
existing = await db.execute(
select(Project).where(Project.project_code == data.project_code)
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "项目编码已存在"}
)
# 检查分类是否存在
if data.category_id:
category_result = await db.execute(
select(Category).where(Category.id == data.category_id)
)
if not category_result.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.NOT_FOUND, "message": "分类不存在"}
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(project, field, value)
await db.flush()
await db.refresh(project)
# 获取分类名称
category_name = None
if project.category_id:
cat_result = await db.execute(
select(Category).where(Category.id == project.category_id)
)
category = cat_result.scalar_one_or_none()
if category:
category_name = category.category_name
# 构建成本汇总
cost_summary = None
if project.cost_summary:
cost_summary = CostSummaryBrief(
total_cost=float(project.cost_summary.total_cost),
material_cost=float(project.cost_summary.material_cost),
equipment_cost=float(project.cost_summary.equipment_cost),
labor_cost=float(project.cost_summary.labor_cost),
fixed_cost_allocation=float(project.cost_summary.fixed_cost_allocation),
)
return ResponseModel(
message="更新成功",
data=ProjectResponse(
id=project.id,
project_code=project.project_code,
project_name=project.project_name,
category_id=project.category_id,
category_name=category_name,
description=project.description,
duration_minutes=project.duration_minutes,
is_active=project.is_active,
cost_summary=cost_summary,
created_at=project.created_at,
updated_at=project.updated_at,
)
)
@router.delete("/{project_id}", response_model=ResponseModel)
async def delete_project(
project_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除服务项目"""
result = await db.execute(select(Project).where(Project.id == project_id))
project = result.scalar_one_or_none()
if not project:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
await db.delete(project)
return ResponseModel(message="删除成功")
# ============ 成本明细(耗材/设备)管理 ============
@router.get("/{project_id}/cost-items", response_model=ResponseModel[list[CostItemResponse]])
async def get_cost_items(
project_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取项目成本明细"""
# 检查项目是否存在
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
if not project_result.scalar_one_or_none():
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
result = await db.execute(
select(ProjectCostItem).where(
ProjectCostItem.project_id == project_id
).order_by(ProjectCostItem.id)
)
items = result.scalars().all()
response_items = []
for item in items:
item_name = None
unit = None
if item.item_type == CostItemType.MATERIAL.value:
material_result = await db.execute(
select(Material).where(Material.id == item.item_id)
)
material = material_result.scalar_one_or_none()
if material:
item_name = material.material_name
unit = material.unit
else:
equipment_result = await db.execute(
select(Equipment).where(Equipment.id == item.item_id)
)
equipment = equipment_result.scalar_one_or_none()
if equipment:
item_name = equipment.equipment_name
unit = ""
response_items.append(CostItemResponse(
id=item.id,
item_type=CostItemType(item.item_type),
item_id=item.item_id,
item_name=item_name,
quantity=float(item.quantity),
unit=unit,
unit_cost=float(item.unit_cost),
total_cost=float(item.total_cost),
remark=item.remark,
created_at=item.created_at,
updated_at=item.updated_at,
))
return ResponseModel(data=response_items)
@router.post("/{project_id}/cost-items", response_model=ResponseModel[CostItemResponse])
async def create_cost_item(
project_id: int,
data: CostItemCreate,
db: AsyncSession = Depends(get_db),
):
"""添加成本明细"""
# 检查项目是否存在
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
if not project_result.scalar_one_or_none():
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
cost_service = CostService(db)
try:
cost_item = await cost_service.add_cost_item(
project_id=project_id,
item_type=data.item_type,
item_id=data.item_id,
quantity=data.quantity,
remark=data.remark,
)
except ValueError as e:
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
)
# 获取物品名称
item_name = None
unit = None
if data.item_type == CostItemType.MATERIAL:
material_result = await db.execute(
select(Material).where(Material.id == data.item_id)
)
material = material_result.scalar_one_or_none()
if material:
item_name = material.material_name
unit = material.unit
else:
equipment_result = await db.execute(
select(Equipment).where(Equipment.id == data.item_id)
)
equipment = equipment_result.scalar_one_or_none()
if equipment:
item_name = equipment.equipment_name
unit = ""
return ResponseModel(
message="添加成功",
data=CostItemResponse(
id=cost_item.id,
item_type=CostItemType(cost_item.item_type),
item_id=cost_item.item_id,
item_name=item_name,
quantity=float(cost_item.quantity),
unit=unit,
unit_cost=float(cost_item.unit_cost),
total_cost=float(cost_item.total_cost),
remark=cost_item.remark,
created_at=cost_item.created_at,
updated_at=cost_item.updated_at,
)
)
@router.put("/{project_id}/cost-items/{item_id}", response_model=ResponseModel[CostItemResponse])
async def update_cost_item(
project_id: int,
item_id: int,
data: CostItemUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新成本明细"""
result = await db.execute(
select(ProjectCostItem).where(
ProjectCostItem.id == item_id,
ProjectCostItem.project_id == project_id,
)
)
cost_item = result.scalar_one_or_none()
if not cost_item:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "成本明细不存在"}
)
cost_service = CostService(db)
cost_item = await cost_service.update_cost_item(
cost_item=cost_item,
quantity=data.quantity,
remark=data.remark,
)
# 获取物品名称
item_name = None
unit = None
if cost_item.item_type == CostItemType.MATERIAL.value:
material_result = await db.execute(
select(Material).where(Material.id == cost_item.item_id)
)
material = material_result.scalar_one_or_none()
if material:
item_name = material.material_name
unit = material.unit
else:
equipment_result = await db.execute(
select(Equipment).where(Equipment.id == cost_item.item_id)
)
equipment = equipment_result.scalar_one_or_none()
if equipment:
item_name = equipment.equipment_name
unit = ""
return ResponseModel(
message="更新成功",
data=CostItemResponse(
id=cost_item.id,
item_type=CostItemType(cost_item.item_type),
item_id=cost_item.item_id,
item_name=item_name,
quantity=float(cost_item.quantity),
unit=unit,
unit_cost=float(cost_item.unit_cost),
total_cost=float(cost_item.total_cost),
remark=cost_item.remark,
created_at=cost_item.created_at,
updated_at=cost_item.updated_at,
)
)
@router.delete("/{project_id}/cost-items/{item_id}", response_model=ResponseModel)
async def delete_cost_item(
project_id: int,
item_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除成本明细"""
result = await db.execute(
select(ProjectCostItem).where(
ProjectCostItem.id == item_id,
ProjectCostItem.project_id == project_id,
)
)
cost_item = result.scalar_one_or_none()
if not cost_item:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "成本明细不存在"}
)
await db.delete(cost_item)
return ResponseModel(message="删除成功")
# ============ 人工成本管理 ============
@router.get("/{project_id}/labor-costs", response_model=ResponseModel[list[LaborCostResponse]])
async def get_labor_costs(
project_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取项目人工成本"""
# 检查项目是否存在
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
if not project_result.scalar_one_or_none():
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
result = await db.execute(
select(ProjectLaborCost).options(
selectinload(ProjectLaborCost.staff_level)
).where(
ProjectLaborCost.project_id == project_id
).order_by(ProjectLaborCost.id)
)
items = result.scalars().all()
response_items = []
for item in items:
response_items.append(LaborCostResponse(
id=item.id,
staff_level_id=item.staff_level_id,
level_name=item.staff_level.level_name if item.staff_level else None,
duration_minutes=item.duration_minutes,
hourly_rate=float(item.hourly_rate),
labor_cost=float(item.labor_cost),
remark=item.remark,
created_at=item.created_at,
updated_at=item.updated_at,
))
return ResponseModel(data=response_items)
@router.post("/{project_id}/labor-costs", response_model=ResponseModel[LaborCostResponse])
async def create_labor_cost(
project_id: int,
data: LaborCostCreate,
db: AsyncSession = Depends(get_db),
):
"""添加人工成本"""
# 检查项目是否存在
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
if not project_result.scalar_one_or_none():
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
cost_service = CostService(db)
try:
labor_cost = await cost_service.add_labor_cost(
project_id=project_id,
staff_level_id=data.staff_level_id,
duration_minutes=data.duration_minutes,
remark=data.remark,
)
except ValueError as e:
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
)
# 获取级别名称
level_result = await db.execute(
select(StaffLevel).where(StaffLevel.id == data.staff_level_id)
)
staff_level = level_result.scalar_one_or_none()
return ResponseModel(
message="添加成功",
data=LaborCostResponse(
id=labor_cost.id,
staff_level_id=labor_cost.staff_level_id,
level_name=staff_level.level_name if staff_level else None,
duration_minutes=labor_cost.duration_minutes,
hourly_rate=float(labor_cost.hourly_rate),
labor_cost=float(labor_cost.labor_cost),
remark=labor_cost.remark,
created_at=labor_cost.created_at,
updated_at=labor_cost.updated_at,
)
)
@router.put("/{project_id}/labor-costs/{item_id}", response_model=ResponseModel[LaborCostResponse])
async def update_labor_cost(
project_id: int,
item_id: int,
data: LaborCostUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新人工成本"""
result = await db.execute(
select(ProjectLaborCost).where(
ProjectLaborCost.id == item_id,
ProjectLaborCost.project_id == project_id,
)
)
labor_item = result.scalar_one_or_none()
if not labor_item:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "人工成本记录不存在"}
)
cost_service = CostService(db)
try:
labor_item = await cost_service.update_labor_cost(
labor_item=labor_item,
staff_level_id=data.staff_level_id,
duration_minutes=data.duration_minutes,
remark=data.remark,
)
except ValueError as e:
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
)
# 获取级别名称
level_result = await db.execute(
select(StaffLevel).where(StaffLevel.id == labor_item.staff_level_id)
)
staff_level = level_result.scalar_one_or_none()
return ResponseModel(
message="更新成功",
data=LaborCostResponse(
id=labor_item.id,
staff_level_id=labor_item.staff_level_id,
level_name=staff_level.level_name if staff_level else None,
duration_minutes=labor_item.duration_minutes,
hourly_rate=float(labor_item.hourly_rate),
labor_cost=float(labor_item.labor_cost),
remark=labor_item.remark,
created_at=labor_item.created_at,
updated_at=labor_item.updated_at,
)
)
@router.delete("/{project_id}/labor-costs/{item_id}", response_model=ResponseModel)
async def delete_labor_cost(
project_id: int,
item_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除人工成本"""
result = await db.execute(
select(ProjectLaborCost).where(
ProjectLaborCost.id == item_id,
ProjectLaborCost.project_id == project_id,
)
)
labor_item = result.scalar_one_or_none()
if not labor_item:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "人工成本记录不存在"}
)
await db.delete(labor_item)
return ResponseModel(message="删除成功")
# ============ 成本计算 ============
@router.post("/{project_id}/calculate-cost", response_model=ResponseModel[CostCalculationResult])
async def calculate_cost(
project_id: int,
data: CalculateCostRequest = CalculateCostRequest(),
db: AsyncSession = Depends(get_db),
):
"""计算项目总成本"""
cost_service = CostService(db)
try:
result = await cost_service.calculate_project_cost(
project_id=project_id,
allocation_method=data.fixed_cost_allocation_method,
)
except ValueError as e:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": str(e)}
)
return ResponseModel(message="计算完成", data=result)
@router.get("/{project_id}/cost-summary", response_model=ResponseModel[CostSummaryResponse])
async def get_cost_summary(
project_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取成本汇总"""
# 检查项目是否存在
project_result = await db.execute(
select(Project).where(Project.id == project_id)
)
if not project_result.scalar_one_or_none():
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "项目不存在"}
)
result = await db.execute(
select(ProjectCostSummary).where(
ProjectCostSummary.project_id == project_id
)
)
summary = result.scalar_one_or_none()
if not summary:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "成本汇总不存在,请先计算成本"}
)
return ResponseModel(
data=CostSummaryResponse(
project_id=summary.project_id,
material_cost=float(summary.material_cost),
equipment_cost=float(summary.equipment_cost),
labor_cost=float(summary.labor_cost),
fixed_cost_allocation=float(summary.fixed_cost_allocation),
total_cost=float(summary.total_cost),
calculated_at=summary.calculated_at,
)
)

View File

@@ -0,0 +1,172 @@
"""人员级别路由
实现人员级别的 CRUD 操作
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.staff_level import StaffLevel
from app.schemas.common import ResponseModel, PaginatedData, ErrorCode
from app.schemas.staff_level import (
StaffLevelCreate,
StaffLevelUpdate,
StaffLevelResponse,
)
router = APIRouter()
@router.get("", response_model=ResponseModel[PaginatedData[StaffLevelResponse]])
async def get_staff_levels(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
keyword: Optional[str] = Query(None, description="关键词搜索"),
is_active: Optional[bool] = Query(None, description="是否启用筛选"),
db: AsyncSession = Depends(get_db),
):
"""获取人员级别列表"""
query = select(StaffLevel)
if keyword:
query = query.where(
StaffLevel.level_name.contains(keyword) |
StaffLevel.level_code.contains(keyword)
)
if is_active is not None:
query = query.where(StaffLevel.is_active == is_active)
query = query.order_by(StaffLevel.hourly_rate)
# 分页
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
result = await db.execute(query)
staff_levels = result.scalars().all()
# 统计总数
count_query = select(func.count(StaffLevel.id))
if keyword:
count_query = count_query.where(
StaffLevel.level_name.contains(keyword) |
StaffLevel.level_code.contains(keyword)
)
if is_active is not None:
count_query = count_query.where(StaffLevel.is_active == is_active)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
return ResponseModel(
data=PaginatedData(
items=[StaffLevelResponse.model_validate(s) for s in staff_levels],
total=total,
page=page,
page_size=page_size,
total_pages=(total + page_size - 1) // page_size,
)
)
@router.get("/{staff_level_id}", response_model=ResponseModel[StaffLevelResponse])
async def get_staff_level(
staff_level_id: int,
db: AsyncSession = Depends(get_db),
):
"""获取单个人员级别详情"""
result = await db.execute(select(StaffLevel).where(StaffLevel.id == staff_level_id))
staff_level = result.scalar_one_or_none()
if not staff_level:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "人员级别不存在"}
)
return ResponseModel(data=StaffLevelResponse.model_validate(staff_level))
@router.post("", response_model=ResponseModel[StaffLevelResponse])
async def create_staff_level(
data: StaffLevelCreate,
db: AsyncSession = Depends(get_db),
):
"""创建人员级别"""
# 检查编码是否已存在
existing = await db.execute(
select(StaffLevel).where(StaffLevel.level_code == data.level_code)
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "级别编码已存在"}
)
staff_level = StaffLevel(**data.model_dump())
db.add(staff_level)
await db.flush()
await db.refresh(staff_level)
return ResponseModel(message="创建成功", data=StaffLevelResponse.model_validate(staff_level))
@router.put("/{staff_level_id}", response_model=ResponseModel[StaffLevelResponse])
async def update_staff_level(
staff_level_id: int,
data: StaffLevelUpdate,
db: AsyncSession = Depends(get_db),
):
"""更新人员级别"""
result = await db.execute(select(StaffLevel).where(StaffLevel.id == staff_level_id))
staff_level = result.scalar_one_or_none()
if not staff_level:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "人员级别不存在"}
)
# 检查编码是否重复
if data.level_code and data.level_code != staff_level.level_code:
existing = await db.execute(
select(StaffLevel).where(StaffLevel.level_code == data.level_code)
)
if existing.scalar_one_or_none():
raise HTTPException(
status_code=400,
detail={"code": ErrorCode.ALREADY_EXISTS, "message": "级别编码已存在"}
)
update_data = data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(staff_level, field, value)
await db.flush()
await db.refresh(staff_level)
return ResponseModel(message="更新成功", data=StaffLevelResponse.model_validate(staff_level))
@router.delete("/{staff_level_id}", response_model=ResponseModel)
async def delete_staff_level(
staff_level_id: int,
db: AsyncSession = Depends(get_db),
):
"""删除人员级别"""
result = await db.execute(select(StaffLevel).where(StaffLevel.id == staff_level_id))
staff_level = result.scalar_one_or_none()
if not staff_level:
raise HTTPException(
status_code=404,
detail={"code": ErrorCode.NOT_FOUND, "message": "人员级别不存在"}
)
await db.delete(staff_level)
return ResponseModel(message="删除成功")

View File

@@ -0,0 +1,131 @@
"""Pydantic 数据模型"""
from app.schemas.common import ResponseModel, PaginatedResponse
from app.schemas.category import CategoryCreate, CategoryUpdate, CategoryResponse
from app.schemas.material import MaterialCreate, MaterialUpdate, MaterialResponse
from app.schemas.equipment import EquipmentCreate, EquipmentUpdate, EquipmentResponse
from app.schemas.staff_level import StaffLevelCreate, StaffLevelUpdate, StaffLevelResponse
from app.schemas.fixed_cost import FixedCostCreate, FixedCostUpdate, FixedCostResponse
from app.schemas.project import (
ProjectCreate, ProjectUpdate, ProjectResponse,
ProjectListResponse, ProjectQuery
)
from app.schemas.project_cost import (
CostItemCreate, CostItemUpdate, CostItemResponse,
LaborCostCreate, LaborCostUpdate, LaborCostResponse,
CalculateCostRequest, CostCalculationResult, CostSummaryResponse,
ProjectDetailResponse, CostItemType, AllocationMethod
)
from app.schemas.competitor import (
CompetitorCreate, CompetitorUpdate, CompetitorResponse,
CompetitorPriceCreate, CompetitorPriceUpdate, CompetitorPriceResponse,
Positioning, PriceSource
)
from app.schemas.market import (
BenchmarkPriceCreate, BenchmarkPriceUpdate, BenchmarkPriceResponse,
MarketAnalysisRequest, MarketAnalysisResult, MarketAnalysisResponse,
PriceTier
)
from app.schemas.pricing import (
PricingPlanCreate, PricingPlanUpdate, PricingPlanResponse,
PricingPlanListResponse, PricingPlanQuery,
GeneratePricingRequest, GeneratePricingResponse,
SimulateStrategyRequest, SimulateStrategyResponse,
StrategyType
)
from app.schemas.profit import (
ProfitSimulationCreate, ProfitSimulationResponse,
ProfitSimulationListResponse, ProfitSimulationQuery,
SimulateProfitRequest, SimulateProfitResponse,
SensitivityAnalysisRequest, SensitivityAnalysisResponse,
BreakevenRequest, BreakevenResponse,
PeriodType
)
from app.schemas.dashboard import (
DashboardSummaryResponse, CostTrendResponse, MarketTrendResponse,
AIUsageStatsResponse
)
__all__ = [
"ResponseModel",
"PaginatedResponse",
"CategoryCreate",
"CategoryUpdate",
"CategoryResponse",
"MaterialCreate",
"MaterialUpdate",
"MaterialResponse",
"EquipmentCreate",
"EquipmentUpdate",
"EquipmentResponse",
"StaffLevelCreate",
"StaffLevelUpdate",
"StaffLevelResponse",
"FixedCostCreate",
"FixedCostUpdate",
"FixedCostResponse",
# Project schemas
"ProjectCreate",
"ProjectUpdate",
"ProjectResponse",
"ProjectListResponse",
"ProjectQuery",
# Project cost schemas
"CostItemCreate",
"CostItemUpdate",
"CostItemResponse",
"LaborCostCreate",
"LaborCostUpdate",
"LaborCostResponse",
"CalculateCostRequest",
"CostCalculationResult",
"CostSummaryResponse",
"ProjectDetailResponse",
"CostItemType",
"AllocationMethod",
# Competitor schemas
"CompetitorCreate",
"CompetitorUpdate",
"CompetitorResponse",
"CompetitorPriceCreate",
"CompetitorPriceUpdate",
"CompetitorPriceResponse",
"Positioning",
"PriceSource",
# Market schemas
"BenchmarkPriceCreate",
"BenchmarkPriceUpdate",
"BenchmarkPriceResponse",
"MarketAnalysisRequest",
"MarketAnalysisResult",
"MarketAnalysisResponse",
"PriceTier",
# Pricing schemas
"PricingPlanCreate",
"PricingPlanUpdate",
"PricingPlanResponse",
"PricingPlanListResponse",
"PricingPlanQuery",
"GeneratePricingRequest",
"GeneratePricingResponse",
"SimulateStrategyRequest",
"SimulateStrategyResponse",
"StrategyType",
# Profit simulation schemas
"ProfitSimulationCreate",
"ProfitSimulationResponse",
"ProfitSimulationListResponse",
"ProfitSimulationQuery",
"SimulateProfitRequest",
"SimulateProfitResponse",
"SensitivityAnalysisRequest",
"SensitivityAnalysisResponse",
"BreakevenRequest",
"BreakevenResponse",
"PeriodType",
# Dashboard schemas
"DashboardSummaryResponse",
"CostTrendResponse",
"MarketTrendResponse",
"AIUsageStatsResponse",
]

View File

@@ -0,0 +1,53 @@
"""项目分类 Schema"""
from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel, Field
class CategoryBase(BaseModel):
"""分类基础字段"""
category_name: str = Field(..., min_length=1, max_length=50, description="分类名称")
parent_id: Optional[int] = Field(None, description="父分类ID")
sort_order: int = Field(0, ge=0, description="排序")
is_active: bool = Field(True, description="是否启用")
class CategoryCreate(CategoryBase):
"""创建分类请求"""
pass
class CategoryUpdate(BaseModel):
"""更新分类请求"""
category_name: Optional[str] = Field(None, min_length=1, max_length=50, description="分类名称")
parent_id: Optional[int] = Field(None, description="父分类ID")
sort_order: Optional[int] = Field(None, ge=0, description="排序")
is_active: Optional[bool] = Field(None, description="是否启用")
class CategoryResponse(CategoryBase):
"""分类响应"""
id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class CategoryTreeResponse(CategoryResponse):
"""分类树形响应"""
children: List["CategoryTreeResponse"] = []
class Config:
from_attributes = True
# 解决循环引用
CategoryTreeResponse.model_rebuild()

View File

@@ -0,0 +1,60 @@
"""通用响应模型
遵循瑞小美 API 响应格式规范
"""
from typing import Generic, TypeVar, Optional, List
from pydantic import BaseModel
T = TypeVar("T")
class ResponseModel(BaseModel, Generic[T]):
"""统一响应格式"""
code: int = 0
message: str = "success"
data: Optional[T] = None
class PaginatedData(BaseModel, Generic[T]):
"""分页数据"""
items: List[T]
total: int
page: int
page_size: int
total_pages: int
class PaginatedResponse(BaseModel, Generic[T]):
"""分页响应"""
code: int = 0
message: str = "success"
data: Optional[PaginatedData[T]] = None
class ErrorResponse(BaseModel):
"""错误响应"""
code: int
message: str
data: None = None
# 错误码定义
class ErrorCode:
SUCCESS = 0
PARAM_ERROR = 10001
NOT_FOUND = 10002
ALREADY_EXISTS = 10003
NOT_ALLOWED = 10004
AUTH_FAILED = 20001
PERMISSION_DENIED = 20002
TOKEN_EXPIRED = 20003
INTERNAL_ERROR = 30001
SERVICE_UNAVAILABLE = 30002
AI_SERVICE_ERROR = 40001
AI_SERVICE_TIMEOUT = 40002

View File

@@ -0,0 +1,114 @@
"""竞品机构 Schema"""
from typing import Optional
from datetime import datetime, date
from enum import Enum
from pydantic import BaseModel, Field
class Positioning(str, Enum):
"""机构定位枚举"""
HIGH = "high" # 高端
MEDIUM = "medium" # 中端
BUDGET = "budget" # 大众
class PriceSource(str, Enum):
"""价格来源枚举"""
OFFICIAL = "official" # 官网
MEITUAN = "meituan" # 美团
DIANPING = "dianping" # 大众点评
SURVEY = "survey" # 实地调研
# ============ 竞品机构 Schema ============
class CompetitorBase(BaseModel):
"""竞品机构基础字段"""
competitor_name: str = Field(..., min_length=1, max_length=100, description="机构名称")
address: Optional[str] = Field(None, max_length=200, description="地址")
distance_km: Optional[float] = Field(None, ge=0, description="距离(公里)")
positioning: Positioning = Field(Positioning.MEDIUM, description="定位")
contact: Optional[str] = Field(None, max_length=50, description="联系方式")
is_key_competitor: bool = Field(False, description="是否重点关注")
is_active: bool = Field(True, description="是否启用")
class CompetitorCreate(CompetitorBase):
"""创建竞品机构请求"""
pass
class CompetitorUpdate(BaseModel):
"""更新竞品机构请求"""
competitor_name: Optional[str] = Field(None, min_length=1, max_length=100, description="机构名称")
address: Optional[str] = Field(None, max_length=200, description="地址")
distance_km: Optional[float] = Field(None, ge=0, description="距离(公里)")
positioning: Optional[Positioning] = Field(None, description="定位")
contact: Optional[str] = Field(None, max_length=50, description="联系方式")
is_key_competitor: Optional[bool] = Field(None, description="是否重点关注")
is_active: Optional[bool] = Field(None, description="是否启用")
class CompetitorResponse(CompetitorBase):
"""竞品机构响应"""
id: int
price_count: int = Field(0, description="价格记录数")
last_price_update: Optional[date] = Field(None, description="最后价格更新日期")
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ============ 竞品价格 Schema ============
class CompetitorPriceBase(BaseModel):
"""竞品价格基础字段"""
project_id: Optional[int] = Field(None, description="关联本店项目ID")
project_name: str = Field(..., min_length=1, max_length=100, description="竞品项目名称")
original_price: float = Field(..., gt=0, description="原价")
promo_price: Optional[float] = Field(None, gt=0, description="促销价")
member_price: Optional[float] = Field(None, gt=0, description="会员价")
price_source: PriceSource = Field(..., description="价格来源")
collected_at: date = Field(..., description="采集日期")
remark: Optional[str] = Field(None, max_length=200, description="备注")
class CompetitorPriceCreate(CompetitorPriceBase):
"""创建竞品价格请求"""
pass
class CompetitorPriceUpdate(BaseModel):
"""更新竞品价格请求"""
project_id: Optional[int] = Field(None, description="关联本店项目ID")
project_name: Optional[str] = Field(None, min_length=1, max_length=100, description="竞品项目名称")
original_price: Optional[float] = Field(None, gt=0, description="原价")
promo_price: Optional[float] = Field(None, gt=0, description="促销价")
member_price: Optional[float] = Field(None, gt=0, description="会员价")
price_source: Optional[PriceSource] = Field(None, description="价格来源")
collected_at: Optional[date] = Field(None, description="采集日期")
remark: Optional[str] = Field(None, max_length=200, description="备注")
class CompetitorPriceResponse(CompetitorPriceBase):
"""竞品价格响应"""
id: int
competitor_id: int
competitor_name: Optional[str] = Field(None, description="竞品机构名称")
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,142 @@
"""仪表盘 Schema
仪表盘数据相关的响应模型
"""
from typing import Optional, List, Dict
from datetime import datetime
from pydantic import BaseModel, Field
class ProjectOverview(BaseModel):
"""项目概览"""
total_projects: int = Field(..., description="总项目数")
active_projects: int = Field(..., description="启用项目数")
projects_with_pricing: int = Field(..., description="已定价项目数")
class CostProjectInfo(BaseModel):
"""成本项目信息"""
id: int
name: str
cost: float
class CostOverview(BaseModel):
"""成本概览"""
avg_project_cost: float = Field(..., description="平均项目成本")
highest_cost_project: Optional[CostProjectInfo] = Field(None, description="最高成本项目")
lowest_cost_project: Optional[CostProjectInfo] = Field(None, description="最低成本项目")
class MarketOverview(BaseModel):
"""市场概览"""
competitors_tracked: int = Field(..., description="跟踪竞品数")
price_records_this_month: int = Field(..., description="本月价格记录数")
avg_market_price: Optional[float] = Field(None, description="市场平均价")
class StrategiesDistribution(BaseModel):
"""策略分布"""
traffic: int = Field(0, description="引流款数量")
profit: int = Field(0, description="利润款数量")
premium: int = Field(0, description="高端款数量")
class PricingOverview(BaseModel):
"""定价概览"""
pricing_plans_count: int = Field(..., description="定价方案总数")
avg_target_margin: Optional[float] = Field(None, description="平均目标毛利率")
strategies_distribution: StrategiesDistribution = Field(..., description="策略分布")
class ProviderDistribution(BaseModel):
"""服务商分布"""
primary: int = Field(0, alias="4sapi", description="4sapi 调用次数")
fallback: int = Field(0, alias="openrouter", description="OpenRouter 调用次数")
class Config:
populate_by_name = True
class AIUsageOverview(BaseModel):
"""AI 使用概览"""
total_calls: int = Field(..., description="总调用次数")
total_tokens: int = Field(..., description="总 Token 消耗")
total_cost_usd: float = Field(..., description="总费用(美元)")
provider_distribution: Dict[str, int] = Field(..., description="服务商分布")
class RecentActivity(BaseModel):
"""最近活动"""
type: str = Field(..., description="活动类型")
project_name: str = Field(..., description="项目名称")
user: Optional[str] = Field(None, description="用户")
time: datetime = Field(..., description="时间")
class DashboardSummaryResponse(BaseModel):
"""仪表盘概览响应"""
project_overview: ProjectOverview
cost_overview: CostOverview
market_overview: MarketOverview
pricing_overview: PricingOverview
ai_usage_this_month: Optional[AIUsageOverview] = None
recent_activities: List[RecentActivity] = Field(default_factory=list)
# 趋势数据
class TrendDataPoint(BaseModel):
"""趋势数据点"""
date: str = Field(..., description="日期")
value: float = Field(..., description="")
class CostTrendResponse(BaseModel):
"""成本趋势响应"""
period: str = Field(..., description="统计周期")
data: List[TrendDataPoint]
avg_cost: float = Field(..., description="平均成本")
class MarketTrendResponse(BaseModel):
"""市场趋势响应"""
period: str = Field(..., description="统计周期")
data: List[TrendDataPoint]
avg_price: float = Field(..., description="平均价格")
# AI 使用统计
class AIUsageStatItem(BaseModel):
"""AI 使用统计项"""
date: str
calls: int
tokens: int
cost: float
class AIUsageStatsResponse(BaseModel):
"""AI 使用统计响应"""
period: str = Field(..., description="统计周期")
total_calls: int
total_tokens: int
total_cost: float
daily_stats: List[AIUsageStatItem]

View File

@@ -0,0 +1,54 @@
"""设备 Schema"""
from typing import Optional
from datetime import datetime, date
from pydantic import BaseModel, Field, field_validator
class EquipmentBase(BaseModel):
"""设备基础字段"""
equipment_code: str = Field(..., min_length=1, max_length=50, description="设备编码")
equipment_name: str = Field(..., min_length=1, max_length=100, description="设备名称")
original_value: float = Field(..., gt=0, description="设备原值")
residual_rate: float = Field(5.00, ge=0, le=100, description="残值率(%)")
service_years: int = Field(..., gt=0, description="预计使用年限")
estimated_uses: int = Field(..., gt=0, description="预计使用次数")
purchase_date: Optional[date] = Field(None, description="购入日期")
is_active: bool = Field(True, description="是否启用")
class EquipmentCreate(EquipmentBase):
"""创建设备请求"""
@property
def depreciation_per_use(self) -> float:
"""计算单次折旧成本"""
residual_value = self.original_value * self.residual_rate / 100
return (self.original_value - residual_value) / self.estimated_uses
class EquipmentUpdate(BaseModel):
"""更新设备请求"""
equipment_code: Optional[str] = Field(None, min_length=1, max_length=50, description="设备编码")
equipment_name: Optional[str] = Field(None, min_length=1, max_length=100, description="设备名称")
original_value: Optional[float] = Field(None, gt=0, description="设备原值")
residual_rate: Optional[float] = Field(None, ge=0, le=100, description="残值率(%)")
service_years: Optional[int] = Field(None, gt=0, description="预计使用年限")
estimated_uses: Optional[int] = Field(None, gt=0, description="预计使用次数")
purchase_date: Optional[date] = Field(None, description="购入日期")
is_active: Optional[bool] = Field(None, description="是否启用")
class EquipmentResponse(EquipmentBase):
"""设备响应"""
id: int
depreciation_per_use: float = Field(..., description="单次折旧成本")
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,79 @@
"""固定成本 Schema"""
from typing import Optional
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field, field_validator
import re
class CostType(str, Enum):
"""成本类型枚举"""
RENT = "rent" # 房租
UTILITIES = "utilities" # 水电
PROPERTY = "property" # 物业
OTHER = "other" # 其他
class AllocationMethod(str, Enum):
"""分摊方式枚举"""
COUNT = "count" # 按项目数量
REVENUE = "revenue" # 按营收占比
DURATION = "duration" # 按时长占比
class FixedCostBase(BaseModel):
"""固定成本基础字段"""
cost_name: str = Field(..., min_length=1, max_length=100, description="成本名称")
cost_type: CostType = Field(..., description="类型")
monthly_amount: float = Field(..., gt=0, description="月度金额")
year_month: str = Field(..., description="年月2026-01")
allocation_method: AllocationMethod = Field(AllocationMethod.COUNT, description="分摊方式")
is_active: bool = Field(True, description="是否启用")
@field_validator("year_month")
@classmethod
def validate_year_month(cls, v: str) -> str:
"""验证年月格式"""
if not re.match(r"^\d{4}-\d{2}$", v):
raise ValueError("年月格式必须为 YYYY-MM")
return v
class FixedCostCreate(FixedCostBase):
"""创建固定成本请求"""
pass
class FixedCostUpdate(BaseModel):
"""更新固定成本请求"""
cost_name: Optional[str] = Field(None, min_length=1, max_length=100, description="成本名称")
cost_type: Optional[CostType] = Field(None, description="类型")
monthly_amount: Optional[float] = Field(None, gt=0, description="月度金额")
year_month: Optional[str] = Field(None, description="年月2026-01")
allocation_method: Optional[AllocationMethod] = Field(None, description="分摊方式")
is_active: Optional[bool] = Field(None, description="是否启用")
@field_validator("year_month")
@classmethod
def validate_year_month(cls, v: Optional[str]) -> Optional[str]:
"""验证年月格式"""
if v is not None and not re.match(r"^\d{4}-\d{2}$", v):
raise ValueError("年月格式必须为 YYYY-MM")
return v
class FixedCostResponse(FixedCostBase):
"""固定成本响应"""
id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,156 @@
"""市场分析 Schema"""
from typing import Optional, List
from datetime import datetime, date
from enum import Enum
from pydantic import BaseModel, Field
class PriceTier(str, Enum):
"""价格带枚举"""
LOW = "low" # 低端
MEDIUM = "medium" # 中端
HIGH = "high" # 高端
PREMIUM = "premium" # 奢华
# ============ 标杆价格 Schema ============
class BenchmarkPriceBase(BaseModel):
"""标杆价格基础字段"""
benchmark_name: str = Field(..., min_length=1, max_length=100, description="标杆机构名称")
category_id: Optional[int] = Field(None, description="项目分类ID")
min_price: float = Field(..., gt=0, description="最低价")
max_price: float = Field(..., gt=0, description="最高价")
avg_price: float = Field(..., gt=0, description="均价")
price_tier: PriceTier = Field(PriceTier.MEDIUM, description="价格带")
effective_date: date = Field(..., description="生效日期")
remark: Optional[str] = Field(None, max_length=200, description="备注")
class BenchmarkPriceCreate(BenchmarkPriceBase):
"""创建标杆价格请求"""
pass
class BenchmarkPriceUpdate(BaseModel):
"""更新标杆价格请求"""
benchmark_name: Optional[str] = Field(None, min_length=1, max_length=100, description="标杆机构名称")
category_id: Optional[int] = Field(None, description="项目分类ID")
min_price: Optional[float] = Field(None, gt=0, description="最低价")
max_price: Optional[float] = Field(None, gt=0, description="最高价")
avg_price: Optional[float] = Field(None, gt=0, description="均价")
price_tier: Optional[PriceTier] = Field(None, description="价格带")
effective_date: Optional[date] = Field(None, description="生效日期")
remark: Optional[str] = Field(None, max_length=200, description="备注")
class BenchmarkPriceResponse(BenchmarkPriceBase):
"""标杆价格响应"""
id: int
category_name: Optional[str] = Field(None, description="分类名称")
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ============ 市场分析 Schema ============
class MarketAnalysisRequest(BaseModel):
"""市场分析请求"""
competitor_ids: Optional[List[int]] = Field(None, description="指定竞品机构ID列表")
include_benchmark: bool = Field(True, description="是否包含标杆价格参考")
class PriceStatistics(BaseModel):
"""价格统计"""
min_price: float
max_price: float
avg_price: float
median_price: float
std_deviation: Optional[float] = None
class PriceDistributionItem(BaseModel):
"""价格分布项"""
range: str
count: int
percentage: float
class PriceDistribution(BaseModel):
"""价格分布"""
low: PriceDistributionItem
medium: PriceDistributionItem
high: PriceDistributionItem
class CompetitorPriceSummary(BaseModel):
"""竞品价格摘要"""
competitor_name: str
positioning: str
original_price: float
promo_price: Optional[float] = None
collected_at: date
class BenchmarkReference(BaseModel):
"""标杆参考"""
tier: str
min_price: float
max_price: float
avg_price: float
class SuggestedRange(BaseModel):
"""建议定价区间"""
min: float
max: float
recommended: float
class MarketAnalysisResult(BaseModel):
"""市场分析结果"""
project_id: int
project_name: str
analysis_date: date
competitor_count: int
price_statistics: PriceStatistics
price_distribution: Optional[PriceDistribution] = None
competitor_prices: List[CompetitorPriceSummary] = []
benchmark_reference: Optional[BenchmarkReference] = None
suggested_range: SuggestedRange
class MarketAnalysisResponse(BaseModel):
"""市场分析响应(数据库记录)"""
id: int
project_id: int
analysis_date: date
competitor_count: int
market_min_price: float
market_max_price: float
market_avg_price: float
market_median_price: float
suggested_range_min: float
suggested_range_max: float
created_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,64 @@
"""耗材 Schema"""
from typing import Optional
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
class MaterialType(str, Enum):
"""耗材类型枚举"""
CONSUMABLE = "consumable" # 一般耗材
INJECTABLE = "injectable" # 针剂
PRODUCT = "product" # 产品
class MaterialBase(BaseModel):
"""耗材基础字段"""
material_code: str = Field(..., min_length=1, max_length=50, description="耗材编码")
material_name: str = Field(..., min_length=1, max_length=100, description="耗材名称")
unit: str = Field(..., min_length=1, max_length=20, description="单位")
unit_price: float = Field(..., ge=0, description="单价")
supplier: Optional[str] = Field(None, max_length=100, description="供应商")
material_type: MaterialType = Field(..., description="类型")
is_active: bool = Field(True, description="是否启用")
class MaterialCreate(MaterialBase):
"""创建耗材请求"""
pass
class MaterialUpdate(BaseModel):
"""更新耗材请求"""
material_code: Optional[str] = Field(None, min_length=1, max_length=50, description="耗材编码")
material_name: Optional[str] = Field(None, min_length=1, max_length=100, description="耗材名称")
unit: Optional[str] = Field(None, min_length=1, max_length=20, description="单位")
unit_price: Optional[float] = Field(None, ge=0, description="单价")
supplier: Optional[str] = Field(None, max_length=100, description="供应商")
material_type: Optional[MaterialType] = Field(None, description="类型")
is_active: Optional[bool] = Field(None, description="是否启用")
class MaterialResponse(MaterialBase):
"""耗材响应"""
id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class MaterialImportResult(BaseModel):
"""批量导入结果"""
total: int = Field(..., description="总数")
success: int = Field(..., description="成功数")
failed: int = Field(..., description="失败数")
errors: list[dict] = Field(default_factory=list, description="错误详情")

View File

@@ -0,0 +1,194 @@
"""定价方案 Schema
智能定价建议相关的请求和响应模型
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
class StrategyType(str, Enum):
"""定价策略类型"""
TRAFFIC = "traffic" # 引流款
PROFIT = "profit" # 利润款
PREMIUM = "premium" # 高端款
class PricingPlanBase(BaseModel):
"""定价方案基础字段"""
project_id: int = Field(..., description="项目ID")
plan_name: str = Field(..., min_length=1, max_length=100, description="方案名称")
strategy_type: StrategyType = Field(..., description="策略类型")
target_margin: float = Field(..., ge=0, le=100, description="目标毛利率(%)")
class PricingPlanCreate(PricingPlanBase):
"""创建定价方案请求"""
pass
class PricingPlanUpdate(BaseModel):
"""更新定价方案请求"""
plan_name: Optional[str] = Field(None, min_length=1, max_length=100, description="方案名称")
strategy_type: Optional[StrategyType] = Field(None, description="策略类型")
target_margin: Optional[float] = Field(None, ge=0, le=100, description="目标毛利率(%)")
final_price: Optional[float] = Field(None, ge=0, description="最终定价")
is_active: Optional[bool] = Field(None, description="是否启用")
class PricingPlanResponse(BaseModel):
"""定价方案响应"""
id: int
project_id: int
project_name: Optional[str] = Field(None, description="项目名称")
plan_name: str
strategy_type: str
base_cost: float
target_margin: float
suggested_price: float
final_price: Optional[float] = None
ai_advice: Optional[str] = None
is_active: bool
created_at: datetime
updated_at: datetime
created_by_name: Optional[str] = Field(None, description="创建人姓名")
class Config:
from_attributes = True
class PricingPlanListResponse(BaseModel):
"""定价方案列表响应"""
id: int
project_id: int
project_name: Optional[str] = None
plan_name: str
strategy_type: str
base_cost: float
target_margin: float
suggested_price: float
final_price: Optional[float] = None
is_active: bool
created_at: datetime
created_by_name: Optional[str] = None
class Config:
from_attributes = True
class PricingPlanQuery(BaseModel):
"""定价方案查询参数"""
page: int = Field(1, ge=1, description="页码")
page_size: int = Field(20, ge=1, le=100, description="每页数量")
project_id: Optional[int] = Field(None, description="项目筛选")
strategy_type: Optional[StrategyType] = Field(None, description="策略类型筛选")
is_active: Optional[bool] = Field(None, description="是否启用")
sort_by: str = Field("created_at", description="排序字段")
sort_order: str = Field("desc", description="排序方向")
# AI 定价建议相关
class GeneratePricingRequest(BaseModel):
"""生成定价建议请求"""
target_margin: float = Field(50, ge=0, le=100, description="目标毛利率(%)默认50")
strategies: Optional[List[StrategyType]] = Field(None, description="策略类型列表,默认全部")
stream: bool = Field(False, description="是否流式返回")
class StrategySuggestion(BaseModel):
"""单个策略的定价建议"""
strategy: str = Field(..., description="策略名称")
suggested_price: float = Field(..., description="建议价格")
margin: float = Field(..., description="毛利率(%)")
description: str = Field(..., description="策略说明")
class AIAdvice(BaseModel):
"""AI 建议内容"""
summary: str = Field(..., description="综合建议摘要")
cost_analysis: str = Field(..., description="成本分析")
market_analysis: str = Field(..., description="市场分析")
risk_notes: str = Field(..., description="风险提示")
recommendations: List[str] = Field(..., description="具体建议列表")
class AIUsage(BaseModel):
"""AI 调用使用情况"""
provider: str = Field(..., description="服务商")
model: str = Field(..., description="模型")
tokens: int = Field(..., description="Token 消耗")
latency_ms: int = Field(..., description="延迟(毫秒)")
class MarketReference(BaseModel):
"""市场参考数据"""
min: float
max: float
avg: float
class PricingSuggestions(BaseModel):
"""定价建议集合"""
traffic: Optional[StrategySuggestion] = None
profit: Optional[StrategySuggestion] = None
premium: Optional[StrategySuggestion] = None
class GeneratePricingResponse(BaseModel):
"""生成定价建议响应"""
project_id: int
project_name: str
cost_base: float = Field(..., description="基础成本")
market_reference: Optional[MarketReference] = Field(None, description="市场参考")
pricing_suggestions: PricingSuggestions = Field(..., description="各策略定价建议")
ai_advice: Optional[AIAdvice] = Field(None, description="AI 建议详情")
ai_usage: Optional[AIUsage] = Field(None, description="AI 使用统计")
class SimulateStrategyRequest(BaseModel):
"""模拟定价策略请求"""
strategies: List[StrategyType] = Field(..., description="要模拟的策略类型")
target_margin: float = Field(50, ge=0, le=100, description="目标毛利率(%)")
class StrategySimulationResult(BaseModel):
"""策略模拟结果"""
strategy_type: str
strategy_name: str
suggested_price: float
margin: float
profit_per_unit: float
market_position: str = Field(..., description="市场位置描述")
class SimulateStrategyResponse(BaseModel):
"""模拟定价策略响应"""
project_id: int
project_name: str
base_cost: float
results: List[StrategySimulationResult]
class ExportReportRequest(BaseModel):
"""导出报告请求"""
format: str = Field("pdf", description="导出格式pdf/excel")

View File

@@ -0,0 +1,194 @@
"""利润模拟 Schema
利润模拟测算相关的请求和响应模型
"""
from typing import Optional, List
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
class PeriodType(str, Enum):
"""周期类型"""
DAILY = "daily" # 日
WEEKLY = "weekly" # 周
MONTHLY = "monthly" # 月
# 利润模拟相关
class ProfitSimulationBase(BaseModel):
"""利润模拟基础字段"""
pricing_plan_id: int = Field(..., description="定价方案ID")
simulation_name: str = Field(..., min_length=1, max_length=100, description="模拟名称")
price: float = Field(..., gt=0, description="模拟价格")
estimated_volume: int = Field(..., gt=0, description="预估客量")
period_type: PeriodType = Field(..., description="周期类型")
class ProfitSimulationCreate(ProfitSimulationBase):
"""创建利润模拟请求"""
pass
class SimulateProfitRequest(BaseModel):
"""执行利润模拟请求"""
price: float = Field(..., gt=0, description="模拟价格")
estimated_volume: int = Field(..., gt=0, description="预估客量")
period_type: PeriodType = Field(PeriodType.MONTHLY, description="周期类型")
class SimulationInput(BaseModel):
"""模拟输入参数"""
price: float
cost_per_unit: float
estimated_volume: int
period_type: str
class SimulationResult(BaseModel):
"""模拟计算结果"""
estimated_revenue: float = Field(..., description="预估收入")
estimated_cost: float = Field(..., description="预估成本")
estimated_profit: float = Field(..., description="预估利润")
profit_margin: float = Field(..., description="利润率(%)")
profit_per_unit: float = Field(..., description="单位利润")
class BreakevenAnalysis(BaseModel):
"""盈亏平衡分析"""
breakeven_volume: int = Field(..., description="盈亏平衡客量")
current_volume: int = Field(..., description="当前预估客量")
safety_margin: int = Field(..., description="安全边际(客量)")
safety_margin_percentage: float = Field(..., description="安全边际率(%)")
class SimulateProfitResponse(BaseModel):
"""执行利润模拟响应"""
simulation_id: int
pricing_plan_id: int
project_name: str
input: SimulationInput
result: SimulationResult
breakeven_analysis: BreakevenAnalysis
created_at: datetime
class ProfitSimulationResponse(BaseModel):
"""利润模拟响应"""
id: int
pricing_plan_id: int
plan_name: Optional[str] = None
project_name: Optional[str] = None
simulation_name: str
price: float
estimated_volume: int
period_type: str
estimated_revenue: float
estimated_cost: float
estimated_profit: float
profit_margin: float
breakeven_volume: int
created_at: datetime
created_by_name: Optional[str] = None
class Config:
from_attributes = True
class ProfitSimulationListResponse(BaseModel):
"""利润模拟列表响应"""
id: int
pricing_plan_id: int
plan_name: Optional[str] = None
project_name: Optional[str] = None
simulation_name: str
price: float
estimated_volume: int
period_type: str
estimated_profit: float
profit_margin: float
created_at: datetime
class Config:
from_attributes = True
class ProfitSimulationQuery(BaseModel):
"""利润模拟查询参数"""
page: int = Field(1, ge=1, description="页码")
page_size: int = Field(20, ge=1, le=100, description="每页数量")
pricing_plan_id: Optional[int] = Field(None, description="定价方案筛选")
period_type: Optional[PeriodType] = Field(None, description="周期类型筛选")
sort_by: str = Field("created_at", description="排序字段")
sort_order: str = Field("desc", description="排序方向")
# 敏感性分析相关
class SensitivityAnalysisRequest(BaseModel):
"""敏感性分析请求"""
price_change_rates: List[float] = Field(
default=[-20, -15, -10, -5, 0, 5, 10, 15, 20],
description="价格变动率列表(%)"
)
class SensitivityResultItem(BaseModel):
"""敏感性分析单项结果"""
price_change_rate: float = Field(..., description="价格变动率(%)")
adjusted_price: float = Field(..., description="调整后价格")
adjusted_profit: float = Field(..., description="调整后利润")
profit_change_rate: float = Field(..., description="利润变动率(%)")
class SensitivityInsights(BaseModel):
"""敏感性分析洞察"""
price_elasticity: str = Field(..., description="价格弹性描述")
risk_level: str = Field(..., description="风险等级")
recommendation: str = Field(..., description="建议")
class SensitivityAnalysisResponse(BaseModel):
"""敏感性分析响应"""
simulation_id: int
base_price: float
base_profit: float
sensitivity_results: List[SensitivityResultItem]
insights: Optional[SensitivityInsights] = None
# 盈亏平衡分析
class BreakevenRequest(BaseModel):
"""盈亏平衡分析请求"""
target_profit: Optional[float] = Field(None, description="目标利润(可选)")
class BreakevenResponse(BaseModel):
"""盈亏平衡分析响应"""
pricing_plan_id: int
project_name: str
price: float
unit_cost: float
fixed_cost_monthly: float
breakeven_volume: int = Field(..., description="盈亏平衡客量")
current_margin: float = Field(..., description="当前边际贡献")
target_profit_volume: Optional[int] = Field(None, description="达到目标利润所需客量")

View File

@@ -0,0 +1,84 @@
"""服务项目 Schema"""
from typing import Optional, List
from datetime import datetime
from pydantic import BaseModel, Field
class ProjectBase(BaseModel):
"""项目基础字段"""
project_code: str = Field(..., min_length=1, max_length=50, description="项目编码")
project_name: str = Field(..., min_length=1, max_length=100, description="项目名称")
category_id: Optional[int] = Field(None, description="项目分类ID")
description: Optional[str] = Field(None, description="项目描述")
duration_minutes: int = Field(0, ge=0, description="操作时长(分钟)")
is_active: bool = Field(True, description="是否启用")
class ProjectCreate(ProjectBase):
"""创建项目请求"""
pass
class ProjectUpdate(BaseModel):
"""更新项目请求"""
project_code: Optional[str] = Field(None, min_length=1, max_length=50, description="项目编码")
project_name: Optional[str] = Field(None, min_length=1, max_length=100, description="项目名称")
category_id: Optional[int] = Field(None, description="项目分类ID")
description: Optional[str] = Field(None, description="项目描述")
duration_minutes: Optional[int] = Field(None, ge=0, description="操作时长(分钟)")
is_active: Optional[bool] = Field(None, description="是否启用")
class CostSummaryBrief(BaseModel):
"""成本汇总简要信息"""
total_cost: float = Field(..., description="总成本(最低成本线)")
material_cost: float = Field(..., description="耗材成本")
equipment_cost: float = Field(..., description="设备折旧成本")
labor_cost: float = Field(..., description="人工成本")
fixed_cost_allocation: float = Field(..., description="固定成本分摊")
class Config:
from_attributes = True
class ProjectResponse(ProjectBase):
"""项目响应"""
id: int
category_name: Optional[str] = Field(None, description="分类名称")
cost_summary: Optional[CostSummaryBrief] = Field(None, description="成本汇总")
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ProjectListResponse(ProjectBase):
"""项目列表响应"""
id: int
category_name: Optional[str] = Field(None, description="分类名称")
cost_summary: Optional[CostSummaryBrief] = Field(None, description="成本汇总")
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ProjectQuery(BaseModel):
"""项目查询参数"""
page: int = Field(1, ge=1, description="页码")
page_size: int = Field(20, ge=1, le=100, description="每页数量")
category_id: Optional[int] = Field(None, description="分类筛选")
keyword: Optional[str] = Field(None, description="关键词搜索")
is_active: Optional[bool] = Field(None, description="是否启用")
sort_by: str = Field("created_at", description="排序字段")
sort_order: str = Field("desc", description="排序方向")

View File

@@ -0,0 +1,195 @@
"""项目成本相关 Schema"""
from typing import Optional, List
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
class CostItemType(str, Enum):
"""成本明细类型枚举"""
MATERIAL = "material" # 耗材
EQUIPMENT = "equipment" # 设备
class AllocationMethod(str, Enum):
"""固定成本分摊方式枚举"""
COUNT = "count" # 按项目数量平均分摊
REVENUE = "revenue" # 按项目营收占比分摊
DURATION = "duration" # 按项目时长占比分摊
# ============ 项目成本明细(耗材/设备Schema ============
class CostItemBase(BaseModel):
"""成本明细基础字段"""
item_type: CostItemType = Field(..., description="类型material/equipment")
item_id: int = Field(..., description="耗材/设备ID")
quantity: float = Field(..., gt=0, description="用量")
remark: Optional[str] = Field(None, max_length=200, description="备注")
class CostItemCreate(CostItemBase):
"""创建成本明细请求"""
pass
class CostItemUpdate(BaseModel):
"""更新成本明细请求"""
quantity: Optional[float] = Field(None, gt=0, description="用量")
remark: Optional[str] = Field(None, max_length=200, description="备注")
class CostItemResponse(BaseModel):
"""成本明细响应"""
id: int
item_type: CostItemType
item_id: int
item_name: Optional[str] = Field(None, description="耗材/设备名称")
quantity: float
unit: Optional[str] = Field(None, description="单位")
unit_cost: float
total_cost: float
remark: Optional[str] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ============ 项目人工成本 Schema ============
class LaborCostBase(BaseModel):
"""人工成本基础字段"""
staff_level_id: int = Field(..., description="人员级别ID")
duration_minutes: int = Field(..., gt=0, description="操作时长(分钟)")
remark: Optional[str] = Field(None, max_length=200, description="备注")
class LaborCostCreate(LaborCostBase):
"""创建人工成本请求"""
pass
class LaborCostUpdate(BaseModel):
"""更新人工成本请求"""
staff_level_id: Optional[int] = Field(None, description="人员级别ID")
duration_minutes: Optional[int] = Field(None, gt=0, description="操作时长(分钟)")
remark: Optional[str] = Field(None, max_length=200, description="备注")
class LaborCostResponse(BaseModel):
"""人工成本响应"""
id: int
staff_level_id: int
level_name: Optional[str] = Field(None, description="级别名称")
duration_minutes: int
hourly_rate: float
labor_cost: float
remark: Optional[str] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
# ============ 成本计算相关 Schema ============
class CalculateCostRequest(BaseModel):
"""计算成本请求"""
fixed_cost_allocation_method: AllocationMethod = Field(
AllocationMethod.COUNT,
description="固定成本分摊方式"
)
class CostBreakdownItem(BaseModel):
"""成本明细项"""
name: str
quantity: Optional[float] = None
unit: Optional[str] = None
unit_cost: Optional[float] = None
depreciation_per_use: Optional[float] = None
duration_minutes: Optional[int] = None
hourly_rate: Optional[float] = None
total: float
class CostBreakdown(BaseModel):
"""成本分项"""
items: List[CostBreakdownItem]
subtotal: float
class FixedCostAllocationDetail(BaseModel):
"""固定成本分摊详情"""
method: str
total_fixed_cost: float
project_count: Optional[int] = None
total_revenue: Optional[float] = None
total_duration: Optional[int] = None
allocation: float
class CostCalculationResult(BaseModel):
"""成本计算结果"""
project_id: int
project_name: str
cost_breakdown: dict = Field(..., description="成本分项明细")
total_cost: float = Field(..., description="总成本")
min_price_suggestion: float = Field(..., description="建议最低售价(等于总成本)")
calculated_at: datetime
class CostSummaryResponse(BaseModel):
"""成本汇总响应"""
project_id: int
material_cost: float
equipment_cost: float
labor_cost: float
fixed_cost_allocation: float
total_cost: float
calculated_at: datetime
class Config:
from_attributes = True
# ============ 项目详情含成本Schema ============
class ProjectDetailResponse(BaseModel):
"""项目详情响应(含成本明细)"""
id: int
project_code: str
project_name: str
category_id: Optional[int]
category_name: Optional[str]
description: Optional[str]
duration_minutes: int
is_active: bool
cost_items: List[CostItemResponse] = []
labor_costs: List[LaborCostResponse] = []
cost_summary: Optional[CostSummaryResponse] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,40 @@
"""人员级别 Schema"""
from typing import Optional
from datetime import datetime
from pydantic import BaseModel, Field
class StaffLevelBase(BaseModel):
"""人员级别基础字段"""
level_code: str = Field(..., min_length=1, max_length=20, description="级别编码")
level_name: str = Field(..., min_length=1, max_length=50, description="级别名称")
hourly_rate: float = Field(..., gt=0, description="时薪(元/小时)")
is_active: bool = Field(True, description="是否启用")
class StaffLevelCreate(StaffLevelBase):
"""创建人员级别请求"""
pass
class StaffLevelUpdate(BaseModel):
"""更新人员级别请求"""
level_code: Optional[str] = Field(None, min_length=1, max_length=20, description="级别编码")
level_name: Optional[str] = Field(None, min_length=1, max_length=50, description="级别名称")
hourly_rate: Optional[float] = Field(None, gt=0, description="时薪(元/小时)")
is_active: Optional[bool] = Field(None, description="是否启用")
class StaffLevelResponse(StaffLevelBase):
"""人员级别响应"""
id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,13 @@
"""业务逻辑服务"""
from app.services.cost_service import CostService
from app.services.market_service import MarketService
from app.services.pricing_service import PricingService
from app.services.profit_service import ProfitService
__all__ = [
"CostService",
"MarketService",
"PricingService",
"ProfitService",
]

View File

@@ -0,0 +1,257 @@
"""AI 服务封装
遵循瑞小美 AI 接入规范:
- 通过 shared_backend.AIService 调用
- 初始化时传入 db_session用于日志记录
- 调用时传入 prompt_name用于统计
性能优化:
- 相同输入的响应缓存(减少 API 调用成本)
- 缓存键基于消息内容哈希
"""
import hashlib
import json
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.services.cache_service import get_cache, CacheNamespace
class AIServiceWrapper:
"""AI 服务封装类
封装 shared_backend.AIService 的调用
提供统一的接口供业务层使用
支持响应缓存以减少 API 调用
"""
# 默认缓存 TTL1 小时)- AI 响应通常不会频繁变化
DEFAULT_CACHE_TTL = 3600
def __init__(self, db_session: AsyncSession, enable_cache: bool = True):
"""初始化 AI 服务
Args:
db_session: 数据库会话,用于记录 AI 调用日志
enable_cache: 是否启用响应缓存
"""
self.db_session = db_session
self.module_code = settings.AI_MODULE_CODE
self.enable_cache = enable_cache
self._ai_service = None
self._cache = get_cache(CacheNamespace.AI_RESPONSES, maxsize=100, ttl=self.DEFAULT_CACHE_TTL)
def _generate_cache_key(
self,
messages: List[Dict[str, str]],
prompt_name: str,
model: Optional[str] = None,
) -> str:
"""生成缓存键
基于消息内容和参数生成唯一的缓存键
"""
key_data = {
"messages": messages,
"prompt_name": prompt_name,
"model": model or "default",
}
key_str = json.dumps(key_data, sort_keys=True, ensure_ascii=False)
return f"ai:{hashlib.sha256(key_str.encode()).hexdigest()[:16]}"
async def _get_service(self):
"""获取 AIService 实例(延迟加载)"""
if self._ai_service is None:
try:
from shared_backend.services.ai_service import AIService
self._ai_service = AIService(
module_code=self.module_code,
db_session=self.db_session,
)
except ImportError:
# 开发环境可能没有 shared_backend
# 使用 Mock 实现
self._ai_service = MockAIService(self.module_code)
return self._ai_service
async def chat(
self,
messages: List[Dict[str, str]],
prompt_name: str,
model: Optional[str] = None,
use_cache: bool = True,
cache_ttl: Optional[int] = None,
**kwargs,
):
"""调用 AI 聊天接口(带缓存支持)
Args:
messages: 消息列表
prompt_name: 提示词名称(必填,用于统计)
model: 模型名称,默认使用配置的模型
use_cache: 是否使用缓存
cache_ttl: 缓存 TTL默认 3600
**kwargs: 其他参数
Returns:
AIResponse 对象
"""
# 检查缓存
if self.enable_cache and use_cache:
cache_key = self._generate_cache_key(messages, prompt_name, model)
cached_response = self._cache.get(cache_key)
if cached_response is not None:
# 返回缓存的响应(添加标记)
cached_response.from_cache = True
return cached_response
# 调用 AI 服务
service = await self._get_service()
response = await service.chat(
messages=messages,
prompt_name=prompt_name,
model=model,
**kwargs,
)
# 存入缓存
if self.enable_cache and use_cache and response is not None:
cache_key = self._generate_cache_key(messages, prompt_name, model)
response.from_cache = False
self._cache.set(cache_key, response, cache_ttl or self.DEFAULT_CACHE_TTL)
return response
async def chat_stream(
self,
messages: List[Dict[str, str]],
prompt_name: str,
model: Optional[str] = None,
**kwargs,
):
"""调用 AI 聊天流式接口
注意:流式接口不使用缓存
Args:
messages: 消息列表
prompt_name: 提示词名称(必填,用于统计)
model: 模型名称
**kwargs: 其他参数
Yields:
响应片段
"""
service = await self._get_service()
async for chunk in service.chat_stream(
messages=messages,
prompt_name=prompt_name,
model=model,
**kwargs,
):
yield chunk
def clear_cache(self, prompt_name: Optional[str] = None):
"""清除 AI 响应缓存
Args:
prompt_name: 指定提示词名称清除None 则清除全部
"""
# 简单实现:清除整个缓存
# 更精细的实现可以按 prompt_name 过滤
self._cache.clear()
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
return self._cache.stats()
@dataclass
class MockAIResponse:
"""Mock AI 响应"""
content: str = "这是一个 Mock 响应,用于开发测试。实际部署时会使用真实的 AI 服务。"
model: str = "mock-model"
provider: str = "mock"
input_tokens: int = 100
output_tokens: int = 50
total_tokens: int = 150
cost: float = 0.0
latency_ms: int = 100
raw_response: dict = None
images: list = None
annotations: dict = None
from_cache: bool = False
class MockAIService:
"""Mock AI 服务(开发环境使用)"""
def __init__(self, module_code: str):
self.module_code = module_code
async def chat(self, messages, prompt_name, **kwargs):
"""Mock 聊天接口"""
# 生成一个基于输入的简单响应
user_message = ""
for msg in messages:
if msg.get("role") == "user":
user_message = msg.get("content", "")[:100]
break
response = MockAIResponse(
content=f"""## Mock AI 分析报告
根据您提供的数据,以下是分析结果:
### 定价建议
- **推荐价格**: 根据成本和市场分析,建议定价在合理区间内
- **引流款策略**: 适合新客引流,建议价格较低
- **利润款策略**: 适合日常经营,建议价格适中
- **高端款策略**: 适合高端客群,可考虑较高定价
### 风险提示
- 请密切关注市场动态
- 建议定期复核定价策略
*注意:这是开发测试环境的 Mock 响应,实际部署时会使用真实的 AI 服务。*
""",
model="mock-model",
provider="mock",
input_tokens=len(str(messages)),
output_tokens=200,
total_tokens=len(str(messages)) + 200,
cost=0.0,
latency_ms=50,
)
return response
async def chat_stream(self, messages, prompt_name, **kwargs):
"""Mock 流式接口"""
chunks = [
"## Mock AI 分析报告\n\n",
"根据您提供的数据,以下是分析结果:\n\n",
"### 定价建议\n",
"- 推荐价格:合理区间内\n",
"- 引流款策略:适合新客引流\n",
"- 利润款策略:适合日常经营\n\n",
"*这是 Mock 响应,实际部署时会使用真实的 AI 服务。*",
]
for chunk in chunks:
yield chunk
async def get_ai_service(db_session: AsyncSession, enable_cache: bool = True) -> AIServiceWrapper:
"""获取 AI 服务实例(依赖注入)
Args:
db_session: 数据库会话
enable_cache: 是否启用缓存
Returns:
AIServiceWrapper 实例
"""
return AIServiceWrapper(db_session, enable_cache=enable_cache)

View File

@@ -0,0 +1,233 @@
"""缓存服务
实现简单的内存缓存和 LRU 缓存策略
用于优化频繁查询的数据
遵循瑞小美系统技术栈标准
"""
import asyncio
import hashlib
import json
from datetime import datetime, timedelta
from functools import lru_cache, wraps
from typing import Any, Callable, Optional, Dict
from collections import OrderedDict
import threading
class TTLCache:
"""带过期时间的缓存
线程安全的 TTL 缓存实现
"""
def __init__(self, maxsize: int = 1000, ttl: int = 300):
"""
Args:
maxsize: 最大缓存条目数
ttl: 默认过期时间(秒)
"""
self.maxsize = maxsize
self.ttl = ttl
self._cache: OrderedDict = OrderedDict()
self._lock = threading.Lock()
def _is_expired(self, expire_at: datetime) -> bool:
"""检查是否过期"""
return datetime.now() > expire_at
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
with self._lock:
if key not in self._cache:
return None
value, expire_at = self._cache[key]
if self._is_expired(expire_at):
del self._cache[key]
return None
# 移动到末尾LRU
self._cache.move_to_end(key)
return value
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""设置缓存值"""
with self._lock:
expire_at = datetime.now() + timedelta(seconds=ttl or self.ttl)
if key in self._cache:
self._cache.move_to_end(key)
self._cache[key] = (value, expire_at)
# 超过最大容量时删除最旧的
while len(self._cache) > self.maxsize:
self._cache.popitem(last=False)
def delete(self, key: str) -> bool:
"""删除缓存"""
with self._lock:
if key in self._cache:
del self._cache[key]
return True
return False
def clear(self) -> None:
"""清空缓存"""
with self._lock:
self._cache.clear()
def cleanup(self) -> int:
"""清理过期缓存,返回清理数量"""
with self._lock:
expired_keys = [
key for key, (_, expire_at) in self._cache.items()
if self._is_expired(expire_at)
]
for key in expired_keys:
del self._cache[key]
return len(expired_keys)
def stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
with self._lock:
return {
"size": len(self._cache),
"maxsize": self.maxsize,
"ttl": self.ttl,
}
# 全局缓存实例
_cache_instances: Dict[str, TTLCache] = {}
def get_cache(namespace: str = "default", maxsize: int = 1000, ttl: int = 300) -> TTLCache:
"""获取缓存实例
Args:
namespace: 缓存命名空间
maxsize: 最大缓存条目数
ttl: 默认过期时间(秒)
Returns:
缓存实例
"""
if namespace not in _cache_instances:
_cache_instances[namespace] = TTLCache(maxsize=maxsize, ttl=ttl)
return _cache_instances[namespace]
def cache_key(*args, **kwargs) -> str:
"""生成缓存键"""
key_parts = [str(arg) for arg in args]
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
key_str = "|".join(key_parts)
return hashlib.md5(key_str.encode()).hexdigest()
def cached(
namespace: str = "default",
ttl: int = 300,
key_prefix: str = "",
):
"""缓存装饰器
Args:
namespace: 缓存命名空间
ttl: 过期时间(秒)
key_prefix: 键前缀
Example:
@cached(namespace="projects", ttl=60, key_prefix="project_detail")
async def get_project(project_id: int):
...
"""
def decorator(func: Callable):
@wraps(func)
async def async_wrapper(*args, **kwargs):
cache = get_cache(namespace, ttl=ttl)
# 生成缓存键
key = f"{key_prefix}:{cache_key(*args, **kwargs)}"
# 尝试获取缓存
cached_value = cache.get(key)
if cached_value is not None:
return cached_value
# 执行函数
result = await func(*args, **kwargs)
# 设置缓存
if result is not None:
cache.set(key, result, ttl)
return result
@wraps(func)
def sync_wrapper(*args, **kwargs):
cache = get_cache(namespace, ttl=ttl)
key = f"{key_prefix}:{cache_key(*args, **kwargs)}"
cached_value = cache.get(key)
if cached_value is not None:
return cached_value
result = func(*args, **kwargs)
if result is not None:
cache.set(key, result, ttl)
return result
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
return decorator
def invalidate_cache(namespace: str, key_prefix: str = "") -> None:
"""使缓存失效
注意:这会清除整个命名空间的缓存
"""
cache = get_cache(namespace)
cache.clear()
# 预定义缓存命名空间
class CacheNamespace:
"""缓存命名空间常量"""
CATEGORIES = "categories"
MATERIALS = "materials"
EQUIPMENTS = "equipments"
STAFF_LEVELS = "staff_levels"
PROJECTS = "projects"
MARKET_ANALYSIS = "market_analysis"
AI_RESPONSES = "ai_responses"
# 预初始化常用缓存
def init_caches():
"""初始化缓存实例"""
# 基础数据缓存(较长 TTL
get_cache(CacheNamespace.CATEGORIES, maxsize=100, ttl=600)
get_cache(CacheNamespace.MATERIALS, maxsize=500, ttl=600)
get_cache(CacheNamespace.EQUIPMENTS, maxsize=200, ttl=600)
get_cache(CacheNamespace.STAFF_LEVELS, maxsize=50, ttl=600)
# 业务数据缓存(较短 TTL
get_cache(CacheNamespace.PROJECTS, maxsize=500, ttl=300)
get_cache(CacheNamespace.MARKET_ANALYSIS, maxsize=200, ttl=180)
# AI 响应缓存(较长 TTL减少 API 调用)
get_cache(CacheNamespace.AI_RESPONSES, maxsize=100, ttl=3600)
# 应用启动时初始化
init_caches()

View File

@@ -0,0 +1,478 @@
"""成本计算服务
实现项目成本计算核心业务逻辑,包含:
- 耗材成本计算
- 设备折旧成本计算
- 人工成本计算
- 固定成本分摊(三种分摊方式)
- 成本汇总
"""
from datetime import datetime
from decimal import Decimal
from typing import Optional, List, Dict, Any
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models import (
Project,
ProjectCostItem,
ProjectLaborCost,
ProjectCostSummary,
Material,
Equipment,
StaffLevel,
FixedCost,
)
from app.schemas.project_cost import (
AllocationMethod,
CostItemType,
CostBreakdownItem,
CostCalculationResult,
)
class CostService:
"""成本计算服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_material_info(self, material_id: int) -> Optional[Material]:
"""获取耗材信息"""
result = await self.db.execute(
select(Material).where(Material.id == material_id)
)
return result.scalar_one_or_none()
async def get_equipment_info(self, equipment_id: int) -> Optional[Equipment]:
"""获取设备信息"""
result = await self.db.execute(
select(Equipment).where(Equipment.id == equipment_id)
)
return result.scalar_one_or_none()
async def get_staff_level_info(self, staff_level_id: int) -> Optional[StaffLevel]:
"""获取人员级别信息"""
result = await self.db.execute(
select(StaffLevel).where(StaffLevel.id == staff_level_id)
)
return result.scalar_one_or_none()
async def calculate_material_cost(self, project_id: int) -> tuple[Decimal, List[Dict[str, Any]]]:
"""计算耗材成本
Returns:
(总耗材成本, 耗材明细列表)
"""
result = await self.db.execute(
select(ProjectCostItem).where(
ProjectCostItem.project_id == project_id,
ProjectCostItem.item_type == CostItemType.MATERIAL.value
)
)
items = result.scalars().all()
total = Decimal("0")
breakdown = []
for item in items:
material = await self.get_material_info(item.item_id)
item_detail = {
"name": material.material_name if material else f"耗材#{item.item_id}",
"quantity": float(item.quantity),
"unit": material.unit if material else "",
"unit_cost": float(item.unit_cost),
"total": float(item.total_cost),
}
breakdown.append(item_detail)
total += item.total_cost
return total, breakdown
async def calculate_equipment_cost(self, project_id: int) -> tuple[Decimal, List[Dict[str, Any]]]:
"""计算设备折旧成本
Returns:
(总设备折旧成本, 设备明细列表)
"""
result = await self.db.execute(
select(ProjectCostItem).where(
ProjectCostItem.project_id == project_id,
ProjectCostItem.item_type == CostItemType.EQUIPMENT.value
)
)
items = result.scalars().all()
total = Decimal("0")
breakdown = []
for item in items:
equipment = await self.get_equipment_info(item.item_id)
item_detail = {
"name": equipment.equipment_name if equipment else f"设备#{item.item_id}",
"depreciation_per_use": float(item.unit_cost),
"total": float(item.total_cost),
}
breakdown.append(item_detail)
total += item.total_cost
return total, breakdown
async def calculate_labor_cost(self, project_id: int) -> tuple[Decimal, List[Dict[str, Any]]]:
"""计算人工成本
Returns:
(总人工成本, 人工明细列表)
"""
result = await self.db.execute(
select(ProjectLaborCost).where(
ProjectLaborCost.project_id == project_id
)
)
items = result.scalars().all()
total = Decimal("0")
breakdown = []
for item in items:
staff_level = await self.get_staff_level_info(item.staff_level_id)
item_detail = {
"name": staff_level.level_name if staff_level else f"级别#{item.staff_level_id}",
"duration_minutes": item.duration_minutes,
"hourly_rate": float(item.hourly_rate),
"total": float(item.labor_cost),
}
breakdown.append(item_detail)
total += item.labor_cost
return total, breakdown
async def calculate_fixed_cost_allocation(
self,
project_id: int,
method: AllocationMethod = AllocationMethod.COUNT,
year_month: Optional[str] = None
) -> tuple[Decimal, Dict[str, Any]]:
"""计算固定成本分摊
三种分摊方式:
- COUNT: 按项目数量平均分摊
- REVENUE: 按项目营收占比分摊(当前简化为平均分摊)
- DURATION: 按项目时长占比分摊
Args:
project_id: 项目ID
method: 分摊方式
year_month: 年月,默认当前月份
Returns:
(分摊金额, 分摊详情)
"""
if not year_month:
year_month = datetime.now().strftime("%Y-%m")
# 获取当月固定成本总额
result = await self.db.execute(
select(func.sum(FixedCost.monthly_amount)).where(
FixedCost.year_month == year_month,
FixedCost.is_active == True
)
)
total_fixed_cost = result.scalar() or Decimal("0")
if total_fixed_cost == 0:
return Decimal("0"), {
"method": method.value,
"total_fixed_cost": 0,
"allocation": 0,
}
allocation = Decimal("0")
detail = {
"method": method.value,
"total_fixed_cost": float(total_fixed_cost),
}
if method == AllocationMethod.COUNT:
# 按项目数量平均分摊
count_result = await self.db.execute(
select(func.count(Project.id)).where(Project.is_active == True)
)
project_count = count_result.scalar() or 1
allocation = total_fixed_cost / Decimal(str(project_count))
detail["project_count"] = project_count
elif method == AllocationMethod.DURATION:
# 按项目时长占比分摊
# 获取当前项目时长
project_result = await self.db.execute(
select(Project.duration_minutes).where(Project.id == project_id)
)
project_duration = project_result.scalar() or 0
# 获取所有活跃项目总时长
total_duration_result = await self.db.execute(
select(func.sum(Project.duration_minutes)).where(Project.is_active == True)
)
total_duration = total_duration_result.scalar() or 1
if total_duration > 0:
ratio = Decimal(str(project_duration)) / Decimal(str(total_duration))
allocation = total_fixed_cost * ratio
detail["project_duration"] = project_duration
detail["total_duration"] = total_duration
elif method == AllocationMethod.REVENUE:
# 按营收占比分摊(暂简化为平均分摊,后续可接入实际营收数据)
count_result = await self.db.execute(
select(func.count(Project.id)).where(Project.is_active == True)
)
project_count = count_result.scalar() or 1
allocation = total_fixed_cost / Decimal(str(project_count))
detail["project_count"] = project_count
detail["note"] = "暂按项目数量平均分摊,后续可接入实际营收数据"
detail["allocation"] = float(allocation)
return allocation, detail
async def calculate_project_cost(
self,
project_id: int,
allocation_method: AllocationMethod = AllocationMethod.COUNT
) -> CostCalculationResult:
"""计算项目总成本
总成本 = 耗材成本 + 设备折旧成本 + 人工成本 + 固定成本分摊
Args:
project_id: 项目ID
allocation_method: 固定成本分摊方式
Returns:
成本计算结果
"""
# 获取项目信息
project_result = await self.db.execute(
select(Project).where(Project.id == project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise ValueError(f"项目不存在: {project_id}")
# 计算各项成本
material_cost, material_breakdown = await self.calculate_material_cost(project_id)
equipment_cost, equipment_breakdown = await self.calculate_equipment_cost(project_id)
labor_cost, labor_breakdown = await self.calculate_labor_cost(project_id)
fixed_allocation, fixed_detail = await self.calculate_fixed_cost_allocation(
project_id, allocation_method
)
# 计算总成本
total_cost = material_cost + equipment_cost + labor_cost + fixed_allocation
# 构建成本分项
cost_breakdown = {
"material_cost": {
"items": material_breakdown,
"subtotal": float(material_cost),
},
"equipment_cost": {
"items": equipment_breakdown,
"subtotal": float(equipment_cost),
},
"labor_cost": {
"items": labor_breakdown,
"subtotal": float(labor_cost),
},
"fixed_cost_allocation": fixed_detail,
}
calculated_at = datetime.now()
# 更新或创建成本汇总记录
await self._save_cost_summary(
project_id=project_id,
material_cost=material_cost,
equipment_cost=equipment_cost,
labor_cost=labor_cost,
fixed_cost_allocation=fixed_allocation,
total_cost=total_cost,
calculated_at=calculated_at,
)
return CostCalculationResult(
project_id=project_id,
project_name=project.project_name,
cost_breakdown=cost_breakdown,
total_cost=float(total_cost),
min_price_suggestion=float(total_cost),
calculated_at=calculated_at,
)
async def _save_cost_summary(
self,
project_id: int,
material_cost: Decimal,
equipment_cost: Decimal,
labor_cost: Decimal,
fixed_cost_allocation: Decimal,
total_cost: Decimal,
calculated_at: datetime,
):
"""保存或更新成本汇总"""
result = await self.db.execute(
select(ProjectCostSummary).where(
ProjectCostSummary.project_id == project_id
)
)
summary = result.scalar_one_or_none()
if summary:
summary.material_cost = material_cost
summary.equipment_cost = equipment_cost
summary.labor_cost = labor_cost
summary.fixed_cost_allocation = fixed_cost_allocation
summary.total_cost = total_cost
summary.calculated_at = calculated_at
else:
summary = ProjectCostSummary(
project_id=project_id,
material_cost=material_cost,
equipment_cost=equipment_cost,
labor_cost=labor_cost,
fixed_cost_allocation=fixed_cost_allocation,
total_cost=total_cost,
calculated_at=calculated_at,
)
self.db.add(summary)
await self.db.flush()
async def add_cost_item(
self,
project_id: int,
item_type: CostItemType,
item_id: int,
quantity: float,
remark: Optional[str] = None,
) -> ProjectCostItem:
"""添加成本明细项
自动计算 unit_cost 和 total_cost
"""
# 根据类型获取单位成本
if item_type == CostItemType.MATERIAL:
material = await self.get_material_info(item_id)
if not material:
raise ValueError(f"耗材不存在: {item_id}")
unit_cost = Decimal(str(material.unit_price))
else:
equipment = await self.get_equipment_info(item_id)
if not equipment:
raise ValueError(f"设备不存在: {item_id}")
unit_cost = equipment.depreciation_per_use
total_cost = unit_cost * Decimal(str(quantity))
cost_item = ProjectCostItem(
project_id=project_id,
item_type=item_type.value,
item_id=item_id,
quantity=Decimal(str(quantity)),
unit_cost=unit_cost,
total_cost=total_cost,
remark=remark,
)
self.db.add(cost_item)
await self.db.flush()
await self.db.refresh(cost_item)
return cost_item
async def update_cost_item(
self,
cost_item: ProjectCostItem,
quantity: Optional[float] = None,
remark: Optional[str] = None,
) -> ProjectCostItem:
"""更新成本明细项"""
if quantity is not None:
cost_item.quantity = Decimal(str(quantity))
cost_item.total_cost = cost_item.unit_cost * cost_item.quantity
if remark is not None:
cost_item.remark = remark
await self.db.flush()
await self.db.refresh(cost_item)
return cost_item
async def add_labor_cost(
self,
project_id: int,
staff_level_id: int,
duration_minutes: int,
remark: Optional[str] = None,
) -> ProjectLaborCost:
"""添加人工成本
自动获取时薪并计算人工成本
"""
staff_level = await self.get_staff_level_info(staff_level_id)
if not staff_level:
raise ValueError(f"人员级别不存在: {staff_level_id}")
hourly_rate = Decimal(str(staff_level.hourly_rate))
labor_cost = hourly_rate * Decimal(str(duration_minutes)) / Decimal("60")
labor_item = ProjectLaborCost(
project_id=project_id,
staff_level_id=staff_level_id,
duration_minutes=duration_minutes,
hourly_rate=hourly_rate,
labor_cost=labor_cost,
remark=remark,
)
self.db.add(labor_item)
await self.db.flush()
await self.db.refresh(labor_item)
return labor_item
async def update_labor_cost(
self,
labor_item: ProjectLaborCost,
staff_level_id: Optional[int] = None,
duration_minutes: Optional[int] = None,
remark: Optional[str] = None,
) -> ProjectLaborCost:
"""更新人工成本"""
if staff_level_id is not None:
staff_level = await self.get_staff_level_info(staff_level_id)
if not staff_level:
raise ValueError(f"人员级别不存在: {staff_level_id}")
labor_item.staff_level_id = staff_level_id
labor_item.hourly_rate = Decimal(str(staff_level.hourly_rate))
if duration_minutes is not None:
labor_item.duration_minutes = duration_minutes
# 重新计算人工成本
labor_item.labor_cost = (
labor_item.hourly_rate * Decimal(str(labor_item.duration_minutes)) / Decimal("60")
)
if remark is not None:
labor_item.remark = remark
await self.db.flush()
await self.db.refresh(labor_item)
return labor_item

View File

@@ -0,0 +1,367 @@
"""市场分析服务
实现市场价格分析核心业务逻辑,包含:
- 竞品价格统计分析
- 价格分布计算
- 建议定价区间生成
"""
from datetime import date
from decimal import Decimal
from typing import Optional, List
from statistics import mean, median, stdev
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models import (
Project,
Competitor,
CompetitorPrice,
BenchmarkPrice,
MarketAnalysisResult,
Category,
)
from app.schemas.market import (
MarketAnalysisResult as MarketAnalysisResultSchema,
PriceStatistics,
PriceDistribution,
PriceDistributionItem,
CompetitorPriceSummary,
BenchmarkReference,
SuggestedRange,
)
class MarketService:
"""市场分析服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_competitor_prices_for_project(
self,
project_id: int,
competitor_ids: Optional[List[int]] = None
) -> List[CompetitorPrice]:
"""获取项目的竞品价格数据
Args:
project_id: 项目ID
competitor_ids: 指定竞品机构ID列表
Returns:
竞品价格列表
"""
query = select(CompetitorPrice).options(
selectinload(CompetitorPrice.competitor)
).where(
CompetitorPrice.project_id == project_id
)
if competitor_ids:
query = query.where(CompetitorPrice.competitor_id.in_(competitor_ids))
result = await self.db.execute(query.order_by(CompetitorPrice.collected_at.desc()))
return result.scalars().all()
async def get_benchmark_prices_for_category(
self,
category_id: Optional[int]
) -> List[BenchmarkPrice]:
"""获取分类的标杆价格
Args:
category_id: 分类ID
Returns:
标杆价格列表
"""
if not category_id:
return []
query = select(BenchmarkPrice).where(
BenchmarkPrice.category_id == category_id
).order_by(BenchmarkPrice.effective_date.desc())
result = await self.db.execute(query)
return result.scalars().all()
def calculate_price_statistics(self, prices: List[float]) -> PriceStatistics:
"""计算价格统计数据
Args:
prices: 价格列表
Returns:
价格统计
"""
if not prices:
return PriceStatistics(
min_price=0,
max_price=0,
avg_price=0,
median_price=0,
std_deviation=0
)
std_dev = None
if len(prices) > 1:
try:
std_dev = round(stdev(prices), 2)
except Exception:
pass
return PriceStatistics(
min_price=round(min(prices), 2),
max_price=round(max(prices), 2),
avg_price=round(mean(prices), 2),
median_price=round(median(prices), 2),
std_deviation=std_dev
)
def calculate_price_distribution(
self,
prices: List[float],
min_price: float,
max_price: float
) -> PriceDistribution:
"""计算价格分布
将价格分为低/中/高三个区间
Args:
prices: 价格列表
min_price: 最低价
max_price: 最高价
Returns:
价格分布
"""
if not prices or min_price >= max_price:
return PriceDistribution(
low=PriceDistributionItem(range="N/A", count=0, percentage=0),
medium=PriceDistributionItem(range="N/A", count=0, percentage=0),
high=PriceDistributionItem(range="N/A", count=0, percentage=0),
)
# 计算三个区间的边界
range_size = (max_price - min_price) / 3
low_upper = min_price + range_size
mid_upper = min_price + range_size * 2
# 统计各区间数量
low_count = sum(1 for p in prices if p < low_upper)
mid_count = sum(1 for p in prices if low_upper <= p < mid_upper)
high_count = sum(1 for p in prices if p >= mid_upper)
total = len(prices)
return PriceDistribution(
low=PriceDistributionItem(
range=f"{int(min_price)}-{int(low_upper)}",
count=low_count,
percentage=round(low_count / total * 100, 1) if total > 0 else 0
),
medium=PriceDistributionItem(
range=f"{int(low_upper)}-{int(mid_upper)}",
count=mid_count,
percentage=round(mid_count / total * 100, 1) if total > 0 else 0
),
high=PriceDistributionItem(
range=f"{int(mid_upper)}-{int(max_price)}",
count=high_count,
percentage=round(high_count / total * 100, 1) if total > 0 else 0
),
)
def calculate_suggested_range(
self,
avg_price: float,
min_price: float,
max_price: float,
benchmark_avg: Optional[float] = None
) -> SuggestedRange:
"""计算建议定价区间
Args:
avg_price: 市场均价
min_price: 市场最低价
max_price: 市场最高价
benchmark_avg: 标杆均价(可选)
Returns:
建议定价区间
"""
if avg_price == 0:
return SuggestedRange(min=0, max=0, recommended=0)
# 建议区间以均价为中心±20%
range_factor = 0.2
suggested_min = round(avg_price * (1 - range_factor), 2)
suggested_max = round(avg_price * (1 + range_factor), 2)
# 确保不低于市场最低价的80%不高于市场最高价的120%
suggested_min = max(suggested_min, round(min_price * 0.8, 2))
suggested_max = min(suggested_max, round(max_price * 1.2, 2))
# 推荐价格:如果有标杆价格,取市场均价和标杆均价的加权平均
if benchmark_avg:
recommended = round((avg_price * 0.6 + benchmark_avg * 0.4), 2)
else:
recommended = avg_price
return SuggestedRange(
min=suggested_min,
max=suggested_max,
recommended=recommended
)
async def analyze_market(
self,
project_id: int,
competitor_ids: Optional[List[int]] = None,
include_benchmark: bool = True
) -> MarketAnalysisResultSchema:
"""执行市场价格分析
Args:
project_id: 项目ID
competitor_ids: 指定竞品机构ID列表
include_benchmark: 是否包含标杆参考
Returns:
市场分析结果
"""
# 获取项目信息
project_result = await self.db.execute(
select(Project).where(Project.id == project_id)
)
project = project_result.scalar_one_or_none()
if not project:
raise ValueError(f"项目不存在: {project_id}")
# 获取竞品价格
competitor_prices = await self.get_competitor_prices_for_project(
project_id, competitor_ids
)
# 提取价格列表(使用原价)
prices = [float(cp.original_price) for cp in competitor_prices]
# 计算统计数据
price_stats = self.calculate_price_statistics(prices)
# 计算价格分布
price_distribution = None
if len(prices) >= 3:
price_distribution = self.calculate_price_distribution(
prices, price_stats.min_price, price_stats.max_price
)
# 构建竞品价格摘要
competitor_summaries = []
for cp in competitor_prices[:10]: # 最多返回10条
competitor_summaries.append(CompetitorPriceSummary(
competitor_name=cp.competitor.competitor_name if cp.competitor else "未知",
positioning=cp.competitor.positioning if cp.competitor else "medium",
original_price=float(cp.original_price),
promo_price=float(cp.promo_price) if cp.promo_price else None,
collected_at=cp.collected_at,
))
# 获取标杆参考
benchmark_ref = None
benchmark_avg = None
if include_benchmark and project.category_id:
benchmarks = await self.get_benchmark_prices_for_category(project.category_id)
if benchmarks:
latest_benchmark = benchmarks[0]
benchmark_avg = float(latest_benchmark.avg_price)
benchmark_ref = BenchmarkReference(
tier=latest_benchmark.price_tier,
min_price=float(latest_benchmark.min_price),
max_price=float(latest_benchmark.max_price),
avg_price=benchmark_avg,
)
# 计算建议区间
suggested_range = self.calculate_suggested_range(
price_stats.avg_price,
price_stats.min_price,
price_stats.max_price,
benchmark_avg
)
analysis_date = date.today()
# 保存分析结果到数据库
await self._save_analysis_result(
project_id=project_id,
analysis_date=analysis_date,
competitor_count=len(competitor_prices),
min_price=price_stats.min_price,
max_price=price_stats.max_price,
avg_price=price_stats.avg_price,
median_price=price_stats.median_price,
suggested_min=suggested_range.min,
suggested_max=suggested_range.max,
)
return MarketAnalysisResultSchema(
project_id=project_id,
project_name=project.project_name,
analysis_date=analysis_date,
competitor_count=len(competitor_prices),
price_statistics=price_stats,
price_distribution=price_distribution,
competitor_prices=competitor_summaries,
benchmark_reference=benchmark_ref,
suggested_range=suggested_range,
)
async def _save_analysis_result(
self,
project_id: int,
analysis_date: date,
competitor_count: int,
min_price: float,
max_price: float,
avg_price: float,
median_price: float,
suggested_min: float,
suggested_max: float,
):
"""保存分析结果"""
# 创建新记录(保留历史)
result = MarketAnalysisResult(
project_id=project_id,
analysis_date=analysis_date,
competitor_count=competitor_count,
market_min_price=Decimal(str(min_price)),
market_max_price=Decimal(str(max_price)),
market_avg_price=Decimal(str(avg_price)),
market_median_price=Decimal(str(median_price)),
suggested_range_min=Decimal(str(suggested_min)),
suggested_range_max=Decimal(str(suggested_max)),
)
self.db.add(result)
await self.db.flush()
async def get_latest_analysis(self, project_id: int) -> Optional[MarketAnalysisResult]:
"""获取最新的市场分析结果
Args:
project_id: 项目ID
Returns:
最新分析结果
"""
result = await self.db.execute(
select(MarketAnalysisResult).where(
MarketAnalysisResult.project_id == project_id
).order_by(MarketAnalysisResult.analysis_date.desc()).limit(1)
)
return result.scalar_one_or_none()

View File

@@ -0,0 +1,574 @@
"""智能定价服务
实现智能定价核心业务逻辑,包含:
- 综合定价计算
- AI 定价建议生成(遵循瑞小美 AI 接入规范)
- 定价策略模拟
- 定价报告导出
"""
import json
from datetime import datetime
from decimal import Decimal
from typing import Optional, List, Dict, Any, AsyncIterator
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models import (
Project,
ProjectCostSummary,
MarketAnalysisResult,
PricingPlan,
)
from app.schemas.pricing import (
StrategyType,
PricingSuggestions,
StrategySuggestion,
MarketReference,
AIAdvice,
AIUsage,
GeneratePricingResponse,
StrategySimulationResult,
SimulateStrategyResponse,
)
from app.services.ai_service_wrapper import AIServiceWrapper
from app.services.cost_service import CostService
from app.services.market_service import MarketService
# 导入提示词
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../prompts'))
from prompts.pricing_advice_prompts import SYSTEM_PROMPT, USER_PROMPT, PROMPT_META
class PricingService:
"""智能定价服务"""
# 各策略的利润率范围
STRATEGY_MARGINS = {
StrategyType.TRAFFIC: (0.10, 0.20), # 引流款10%-20%
StrategyType.PROFIT: (0.40, 0.60), # 利润款40%-60%
StrategyType.PREMIUM: (0.60, 0.80), # 高端款60%-80%
}
STRATEGY_NAMES = {
StrategyType.TRAFFIC: "引流款",
StrategyType.PROFIT: "利润款",
StrategyType.PREMIUM: "高端款",
}
def __init__(self, db: AsyncSession):
self.db = db
self.cost_service = CostService(db)
self.market_service = MarketService(db)
async def get_project_with_cost(self, project_id: int) -> tuple[Project, Optional[ProjectCostSummary]]:
"""获取项目及其成本汇总"""
result = await self.db.execute(
select(Project).options(
selectinload(Project.cost_summary)
).where(Project.id == project_id)
)
project = result.scalar_one_or_none()
if not project:
raise ValueError(f"项目不存在: {project_id}")
return project, project.cost_summary
async def get_market_reference(self, project_id: int) -> Optional[MarketReference]:
"""获取市场参考数据"""
result = await self.db.execute(
select(MarketAnalysisResult).where(
MarketAnalysisResult.project_id == project_id
).order_by(MarketAnalysisResult.analysis_date.desc()).limit(1)
)
market_result = result.scalar_one_or_none()
if market_result:
return MarketReference(
min=float(market_result.market_min_price),
max=float(market_result.market_max_price),
avg=float(market_result.market_avg_price),
)
return None
def calculate_strategy_price(
self,
base_cost: float,
strategy: StrategyType,
target_margin: Optional[float] = None,
market_ref: Optional[MarketReference] = None
) -> StrategySuggestion:
"""计算单个策略的定价建议
Args:
base_cost: 基础成本
strategy: 策略类型
target_margin: 自定义目标毛利率(可选)
market_ref: 市场参考(可选)
Returns:
策略定价建议
"""
min_margin, max_margin = self.STRATEGY_MARGINS[strategy]
# 使用策略默认的中间利润率,或自定义目标
if target_margin is not None:
margin = target_margin / 100
else:
margin = (min_margin + max_margin) / 2
# 成本加成定价法:价格 = 成本 / (1 - 毛利率)
suggested_price = base_cost / (1 - margin)
# 如果有市场参考,调整价格
if market_ref:
if strategy == StrategyType.TRAFFIC:
# 引流款:取市场最低价和成本定价的较低者
market_low = market_ref.min * 0.9
suggested_price = min(suggested_price, market_low)
elif strategy == StrategyType.PREMIUM:
# 高端款:取市场高位
market_high = market_ref.max * 1.1
suggested_price = max(suggested_price, market_high * 0.9)
# 确保价格不低于成本(边界处理:零成本时设置最低价格)
if base_cost == 0:
# 零成本时,使用市场参考或设置一个最低保护价格
if market_ref:
suggested_price = market_ref.avg * 0.8
else:
suggested_price = 1.0 # 最低保护价格
else:
suggested_price = max(suggested_price, base_cost * 1.05)
# 计算实际毛利率(防止除零)
if suggested_price > 0:
actual_margin = (suggested_price - base_cost) / suggested_price * 100
else:
actual_margin = 0
descriptions = {
StrategyType.TRAFFIC: "低于市场均价,适合引流获客、新店开业、淡季促销",
StrategyType.PROFIT: "接近市场均价,平衡利润与竞争力,适合日常经营",
StrategyType.PREMIUM: "定位高端,高利润空间,需配套优质服务和品牌溢价",
}
return StrategySuggestion(
strategy=self.STRATEGY_NAMES[strategy],
suggested_price=round(suggested_price, 2),
margin=round(actual_margin, 1),
description=descriptions[strategy],
)
def calculate_all_strategies(
self,
base_cost: float,
target_margin: float,
market_ref: Optional[MarketReference] = None,
strategies: Optional[List[StrategyType]] = None
) -> PricingSuggestions:
"""计算所有策略的定价建议"""
if strategies is None:
strategies = list(StrategyType)
suggestions = PricingSuggestions()
for strategy in strategies:
suggestion = self.calculate_strategy_price(
base_cost=base_cost,
strategy=strategy,
target_margin=target_margin if strategy == StrategyType.PROFIT else None,
market_ref=market_ref,
)
if strategy == StrategyType.TRAFFIC:
suggestions.traffic = suggestion
elif strategy == StrategyType.PROFIT:
suggestions.profit = suggestion
elif strategy == StrategyType.PREMIUM:
suggestions.premium = suggestion
return suggestions
async def generate_pricing_advice(
self,
project_id: int,
target_margin: float = 50,
strategies: Optional[List[StrategyType]] = None,
stream: bool = False,
) -> GeneratePricingResponse:
"""生成智能定价建议
遵循瑞小美 AI 接入规范:
- 通过 AIServiceWrapper 调用
- 必须传入 prompt_name用于统计
Args:
project_id: 项目ID
target_margin: 目标毛利率
strategies: 要计算的策略列表
stream: 是否流式输出(此方法返回完整结果,流式由路由处理)
Returns:
定价建议响应
"""
# 获取项目和成本数据
project, cost_summary = await self.get_project_with_cost(project_id)
if not cost_summary:
# 如果没有成本汇总,先计算
cost_result = await self.cost_service.calculate_project_cost(project_id)
base_cost = cost_result.total_cost
else:
base_cost = float(cost_summary.total_cost)
# 获取市场参考
market_ref = await self.get_market_reference(project_id)
# 计算各策略价格
pricing_suggestions = self.calculate_all_strategies(
base_cost=base_cost,
target_margin=target_margin,
market_ref=market_ref,
strategies=strategies,
)
# 构建 AI 输入数据
cost_data = self._format_cost_data(cost_summary)
market_data = self._format_market_data(market_ref)
# 调用 AI 生成建议(遵循规范)
ai_service = AIServiceWrapper(db_session=self.db)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT.format(
project_name=project.project_name,
cost_data=cost_data,
market_data=market_data,
target_margin=target_margin,
)},
]
ai_advice = None
ai_usage = None
try:
response = await ai_service.chat(
messages=messages,
prompt_name=PROMPT_META["name"], # 必填!用于调用统计
)
ai_advice = AIAdvice(
summary=self._extract_section(response.content, "推荐方案") or response.content[:200],
cost_analysis=self._extract_section(response.content, "成本") or "",
market_analysis=self._extract_section(response.content, "市场") or "",
risk_notes=self._extract_section(response.content, "风险") or "",
recommendations=self._extract_recommendations(response.content),
)
ai_usage = AIUsage(
provider=response.provider,
model=response.model,
tokens=response.total_tokens,
latency_ms=response.latency_ms,
)
except Exception as e:
# AI 调用失败不影响基本定价计算
print(f"AI 调用失败: {e}")
return GeneratePricingResponse(
project_id=project_id,
project_name=project.project_name,
cost_base=base_cost,
market_reference=market_ref,
pricing_suggestions=pricing_suggestions,
ai_advice=ai_advice,
ai_usage=ai_usage,
)
async def generate_pricing_advice_stream(
self,
project_id: int,
target_margin: float = 50,
) -> AsyncIterator[str]:
"""流式生成定价建议
Args:
project_id: 项目ID
target_margin: 目标毛利率
Yields:
SSE 格式的响应片段
"""
# 获取基础数据
project, cost_summary = await self.get_project_with_cost(project_id)
if not cost_summary:
cost_result = await self.cost_service.calculate_project_cost(project_id)
base_cost = cost_result.total_cost
else:
base_cost = float(cost_summary.total_cost)
market_ref = await self.get_market_reference(project_id)
# 先返回基础定价计算结果
pricing_suggestions = self.calculate_all_strategies(
base_cost=base_cost,
target_margin=target_margin,
market_ref=market_ref,
)
# 发送初始数据
initial_data = {
"type": "init",
"project_name": project.project_name,
"cost_base": base_cost,
"market_reference": market_ref.model_dump() if market_ref else None,
"pricing_suggestions": pricing_suggestions.model_dump(),
}
yield f"data: {json.dumps(initial_data, ensure_ascii=False)}\n\n"
# 构建 AI 输入
cost_data = self._format_cost_data(cost_summary)
market_data = self._format_market_data(market_ref)
ai_service = AIServiceWrapper(db_session=self.db)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT.format(
project_name=project.project_name,
cost_data=cost_data,
market_data=market_data,
target_margin=target_margin,
)},
]
# 流式返回 AI 建议
try:
async for chunk in ai_service.chat_stream(
messages=messages,
prompt_name=PROMPT_META["name"],
):
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk}, ensure_ascii=False)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n"
# 发送完成信号
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
async def simulate_strategies(
self,
project_id: int,
strategies: List[StrategyType],
target_margin: float = 50,
) -> SimulateStrategyResponse:
"""模拟定价策略
Args:
project_id: 项目ID
strategies: 要模拟的策略列表
target_margin: 目标毛利率
Returns:
策略模拟结果
"""
project, cost_summary = await self.get_project_with_cost(project_id)
if not cost_summary:
cost_result = await self.cost_service.calculate_project_cost(project_id)
base_cost = cost_result.total_cost
else:
base_cost = float(cost_summary.total_cost)
market_ref = await self.get_market_reference(project_id)
results = []
for strategy in strategies:
suggestion = self.calculate_strategy_price(
base_cost=base_cost,
strategy=strategy,
target_margin=target_margin if strategy == StrategyType.PROFIT else None,
market_ref=market_ref,
)
# 确定市场位置
market_position = "中等"
if market_ref:
if suggestion.suggested_price < market_ref.avg * 0.8:
market_position = "低于市场均价"
elif suggestion.suggested_price > market_ref.avg * 1.2:
market_position = "高于市场均价"
else:
market_position = "接近市场均价"
results.append(StrategySimulationResult(
strategy_type=strategy.value,
strategy_name=self.STRATEGY_NAMES[strategy],
suggested_price=suggestion.suggested_price,
margin=suggestion.margin,
profit_per_unit=round(suggestion.suggested_price - base_cost, 2),
market_position=market_position,
))
return SimulateStrategyResponse(
project_id=project_id,
project_name=project.project_name,
base_cost=base_cost,
results=results,
)
async def create_pricing_plan(
self,
project_id: int,
plan_name: str,
strategy_type: StrategyType,
target_margin: float,
created_by: Optional[int] = None,
) -> PricingPlan:
"""创建定价方案
Args:
project_id: 项目ID
plan_name: 方案名称
strategy_type: 策略类型
target_margin: 目标毛利率
created_by: 创建人ID
Returns:
创建的定价方案
"""
# 获取成本数据
project, cost_summary = await self.get_project_with_cost(project_id)
if not cost_summary:
cost_result = await self.cost_service.calculate_project_cost(project_id)
base_cost = Decimal(str(cost_result.total_cost))
else:
base_cost = cost_summary.total_cost
# 获取市场参考
market_ref = await self.get_market_reference(project_id)
# 计算建议价格
suggestion = self.calculate_strategy_price(
base_cost=float(base_cost),
strategy=strategy_type,
target_margin=target_margin if strategy_type == StrategyType.PROFIT else None,
market_ref=market_ref,
)
# 创建定价方案
pricing_plan = PricingPlan(
project_id=project_id,
plan_name=plan_name,
strategy_type=strategy_type.value,
base_cost=base_cost,
target_margin=Decimal(str(target_margin)),
suggested_price=Decimal(str(suggestion.suggested_price)),
is_active=True,
created_by=created_by,
)
self.db.add(pricing_plan)
await self.db.flush()
await self.db.refresh(pricing_plan)
return pricing_plan
async def update_pricing_plan(
self,
plan_id: int,
**kwargs
) -> PricingPlan:
"""更新定价方案"""
result = await self.db.execute(
select(PricingPlan).where(PricingPlan.id == plan_id)
)
plan = result.scalar_one_or_none()
if not plan:
raise ValueError(f"定价方案不存在: {plan_id}")
for key, value in kwargs.items():
if value is not None and hasattr(plan, key):
if key in ['target_margin', 'final_price', 'base_cost', 'suggested_price']:
setattr(plan, key, Decimal(str(value)))
else:
setattr(plan, key, value)
await self.db.flush()
await self.db.refresh(plan)
return plan
async def save_ai_advice(self, plan_id: int, advice: str) -> None:
"""保存 AI 建议到定价方案"""
result = await self.db.execute(
select(PricingPlan).where(PricingPlan.id == plan_id)
)
plan = result.scalar_one_or_none()
if plan:
plan.ai_advice = advice
await self.db.flush()
def _format_cost_data(self, cost_summary: Optional[ProjectCostSummary]) -> str:
"""格式化成本数据用于 AI 输入"""
if not cost_summary:
return "暂无成本数据"
return f"""- 耗材成本:{float(cost_summary.material_cost):.2f}
- 设备折旧:{float(cost_summary.equipment_cost):.2f}
- 人工成本:{float(cost_summary.labor_cost):.2f}
- 固定成本分摊:{float(cost_summary.fixed_cost_allocation):.2f}
- **总成本(最低成本线):{float(cost_summary.total_cost):.2f} 元**"""
def _format_market_data(self, market_ref: Optional[MarketReference]) -> str:
"""格式化市场数据用于 AI 输入"""
if not market_ref:
return "暂无市场行情数据"
return f"""- 市场最低价:{market_ref.min:.2f}
- 市场最高价:{market_ref.max:.2f}
- 市场均价:{market_ref.avg:.2f}"""
def _extract_section(self, content: str, keyword: str) -> Optional[str]:
"""从 AI 响应中提取特定部分"""
lines = content.split('\n')
in_section = False
section_lines = []
for line in lines:
if keyword in line and ('#' in line or '**' in line):
in_section = True
continue
elif in_section:
if line.startswith('#') or (line.startswith('**') and line.endswith('**')):
break
section_lines.append(line)
return '\n'.join(section_lines).strip() if section_lines else None
def _extract_recommendations(self, content: str) -> List[str]:
"""从 AI 响应中提取建议列表"""
recommendations = []
lines = content.split('\n')
for line in lines:
line = line.strip()
if line.startswith('- ') or line.startswith('* ') or line.startswith(''):
recommendations.append(line[2:].strip())
elif line and line[0].isdigit() and '.' in line:
# 处理 "1. xxx" 格式
parts = line.split('.', 1)
if len(parts) > 1:
recommendations.append(parts[1].strip())
return recommendations[:5] # 最多返回5条

View File

@@ -0,0 +1,513 @@
"""利润模拟服务
实现利润模拟测算核心业务逻辑,包含:
- 利润模拟计算
- 敏感性分析
- 盈亏平衡分析
- AI 利润预测分析
"""
from datetime import datetime
from decimal import Decimal
from typing import Optional, List, Dict, Any
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models import (
Project,
PricingPlan,
ProfitSimulation,
SensitivityAnalysis,
FixedCost,
)
from app.schemas.profit import (
PeriodType,
SimulationInput,
SimulationResult,
BreakevenAnalysis,
SimulateProfitResponse,
SensitivityResultItem,
SensitivityInsights,
SensitivityAnalysisResponse,
BreakevenResponse,
)
from app.services.ai_service_wrapper import AIServiceWrapper
# 导入提示词
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../prompts'))
from prompts.profit_forecast_prompts import SYSTEM_PROMPT, USER_PROMPT, PROMPT_META
class ProfitService:
"""利润模拟服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_pricing_plan(self, plan_id: int) -> PricingPlan:
"""获取定价方案"""
result = await self.db.execute(
select(PricingPlan).options(
selectinload(PricingPlan.project)
).where(PricingPlan.id == plan_id)
)
plan = result.scalar_one_or_none()
if not plan:
raise ValueError(f"定价方案不存在: {plan_id}")
return plan
async def get_monthly_fixed_cost(self, year_month: Optional[str] = None) -> Decimal:
"""获取月度固定成本总额"""
if not year_month:
year_month = datetime.now().strftime("%Y-%m")
result = await self.db.execute(
select(func.sum(FixedCost.monthly_amount)).where(
FixedCost.year_month == year_month,
FixedCost.is_active == True
)
)
return result.scalar() or Decimal("0")
def calculate_profit(
self,
price: float,
cost_per_unit: float,
volume: int,
) -> tuple[float, float, float, float]:
"""计算利润
Returns:
(收入, 成本, 利润, 利润率)
"""
revenue = price * volume
total_cost = cost_per_unit * volume
profit = revenue - total_cost
margin = (profit / revenue * 100) if revenue > 0 else 0
return revenue, total_cost, profit, margin
def calculate_breakeven(
self,
price: float,
variable_cost: float,
fixed_cost: float = 0,
) -> int:
"""计算盈亏平衡点
盈亏平衡客量 = 固定成本 / (单价 - 单位变动成本)
Args:
price: 单价
variable_cost: 单位变动成本
fixed_cost: 固定成本(可选)
Returns:
盈亏平衡客量
"""
contribution_margin = price - variable_cost
if contribution_margin <= 0:
# 边际贡献为负,无法盈利
return 999999
if fixed_cost > 0:
breakeven = int(fixed_cost / contribution_margin) + 1
else:
# 无固定成本时,只要有销量就盈利
breakeven = 1
return breakeven
async def simulate_profit(
self,
pricing_plan_id: int,
price: float,
estimated_volume: int,
period_type: PeriodType,
created_by: Optional[int] = None,
) -> SimulateProfitResponse:
"""执行利润模拟
Args:
pricing_plan_id: 定价方案ID
price: 模拟价格
estimated_volume: 预估客量
period_type: 周期类型
created_by: 创建人ID
Returns:
利润模拟结果
"""
# 获取定价方案
plan = await self.get_pricing_plan(pricing_plan_id)
cost_per_unit = float(plan.base_cost)
# 计算利润
revenue, total_cost, profit, margin = self.calculate_profit(
price=price,
cost_per_unit=cost_per_unit,
volume=estimated_volume,
)
# 计算盈亏平衡点
breakeven_volume = self.calculate_breakeven(
price=price,
variable_cost=cost_per_unit,
)
# 计算安全边际
safety_margin = estimated_volume - breakeven_volume
safety_margin_pct = (safety_margin / estimated_volume * 100) if estimated_volume > 0 else 0
# 限制利润率范围以避免数据库溢出 (DECIMAL(5,2) 范围 -999.99 ~ 999.99)
clamped_margin = max(-999.99, min(999.99, margin))
# 创建模拟记录
simulation = ProfitSimulation(
pricing_plan_id=pricing_plan_id,
simulation_name=f"{plan.project.project_name}-{period_type.value}模拟",
price=Decimal(str(price)),
estimated_volume=estimated_volume,
period_type=period_type.value,
estimated_revenue=Decimal(str(revenue)),
estimated_cost=Decimal(str(total_cost)),
estimated_profit=Decimal(str(profit)),
profit_margin=Decimal(str(round(clamped_margin, 2))),
breakeven_volume=breakeven_volume,
created_by=created_by,
)
self.db.add(simulation)
await self.db.flush()
await self.db.refresh(simulation)
return SimulateProfitResponse(
simulation_id=simulation.id,
pricing_plan_id=pricing_plan_id,
project_name=plan.project.project_name,
input=SimulationInput(
price=price,
cost_per_unit=cost_per_unit,
estimated_volume=estimated_volume,
period_type=period_type.value,
),
result=SimulationResult(
estimated_revenue=round(revenue, 2),
estimated_cost=round(total_cost, 2),
estimated_profit=round(profit, 2),
profit_margin=round(margin, 2),
profit_per_unit=round(price - cost_per_unit, 2),
),
breakeven_analysis=BreakevenAnalysis(
breakeven_volume=breakeven_volume,
current_volume=estimated_volume,
safety_margin=safety_margin,
safety_margin_percentage=round(safety_margin_pct, 1),
),
created_at=simulation.created_at,
)
async def sensitivity_analysis(
self,
simulation_id: int,
price_change_rates: List[float],
) -> SensitivityAnalysisResponse:
"""执行敏感性分析
分析价格变动对利润的影响
Args:
simulation_id: 模拟ID
price_change_rates: 价格变动率列表
Returns:
敏感性分析结果
"""
# 获取模拟记录
result = await self.db.execute(
select(ProfitSimulation).options(
selectinload(ProfitSimulation.pricing_plan)
).where(ProfitSimulation.id == simulation_id)
)
simulation = result.scalar_one_or_none()
if not simulation:
raise ValueError(f"模拟记录不存在: {simulation_id}")
base_price = float(simulation.price)
base_profit = float(simulation.estimated_profit)
cost_per_unit = float(simulation.pricing_plan.base_cost)
volume = simulation.estimated_volume
sensitivity_results = []
for rate in sorted(price_change_rates):
# 计算调整后价格
adjusted_price = base_price * (1 + rate / 100)
# 计算调整后利润
_, _, adjusted_profit, _ = self.calculate_profit(
price=adjusted_price,
cost_per_unit=cost_per_unit,
volume=volume,
)
# 计算利润变动率
profit_change_rate = 0
if base_profit != 0:
profit_change_rate = (adjusted_profit - base_profit) / abs(base_profit) * 100
# 限制变动率范围以避免数据库溢出
clamped_profit_change_rate = max(-999.99, min(999.99, profit_change_rate))
item = SensitivityResultItem(
price_change_rate=rate,
adjusted_price=round(adjusted_price, 2),
adjusted_profit=round(adjusted_profit, 2),
profit_change_rate=round(clamped_profit_change_rate, 2),
)
sensitivity_results.append(item)
# 保存到数据库
analysis = SensitivityAnalysis(
simulation_id=simulation_id,
price_change_rate=Decimal(str(rate)),
adjusted_price=Decimal(str(adjusted_price)),
adjusted_profit=Decimal(str(adjusted_profit)),
profit_change_rate=Decimal(str(round(clamped_profit_change_rate, 2))),
)
self.db.add(analysis)
await self.db.flush()
# 生成洞察
insights = self._generate_sensitivity_insights(
base_price=base_price,
base_profit=base_profit,
results=sensitivity_results,
)
return SensitivityAnalysisResponse(
simulation_id=simulation_id,
base_price=base_price,
base_profit=base_profit,
sensitivity_results=sensitivity_results,
insights=insights,
)
def _generate_sensitivity_insights(
self,
base_price: float,
base_profit: float,
results: List[SensitivityResultItem],
) -> SensitivityInsights:
"""生成敏感性分析洞察"""
# 计算价格弹性(使用 ±10% 的数据点)
elasticity = 0
for r in results:
if r.price_change_rate == 10:
elasticity = r.profit_change_rate / 10
break
# 判断风险等级
risk_level = ""
min_profit = min(r.adjusted_profit for r in results)
if min_profit < 0:
risk_level = ""
elif min_profit < base_profit * 0.5:
risk_level = "中等"
# 生成建议
recommendation = "价格调整空间较大,经营风险可控。"
if risk_level == "":
recommendation = f"价格下降超过某阈值会导致亏损,建议密切关注市场动态。"
elif risk_level == "中等":
recommendation = f"价格变动对利润影响较大,建议谨慎调价。"
return SensitivityInsights(
price_elasticity=f"价格每变动1%,利润变动约{abs(elasticity):.2f}%",
risk_level=risk_level,
recommendation=recommendation,
)
async def breakeven_analysis(
self,
pricing_plan_id: int,
target_profit: Optional[float] = None,
) -> BreakevenResponse:
"""盈亏平衡分析
Args:
pricing_plan_id: 定价方案ID
target_profit: 目标利润(可选)
Returns:
盈亏平衡分析结果
"""
plan = await self.get_pricing_plan(pricing_plan_id)
price = float(plan.final_price or plan.suggested_price)
unit_cost = float(plan.base_cost)
# 获取月度固定成本
monthly_fixed_cost = float(await self.get_monthly_fixed_cost())
# 计算盈亏平衡点
contribution_margin = price - unit_cost
breakeven_volume = self.calculate_breakeven(
price=price,
variable_cost=unit_cost,
fixed_cost=monthly_fixed_cost,
)
# 计算达到目标利润的客量
target_volume = None
if target_profit is not None and contribution_margin > 0:
target_volume = int((monthly_fixed_cost + target_profit) / contribution_margin) + 1
return BreakevenResponse(
pricing_plan_id=pricing_plan_id,
project_name=plan.project.project_name,
price=price,
unit_cost=unit_cost,
fixed_cost_monthly=monthly_fixed_cost,
breakeven_volume=breakeven_volume,
current_margin=round(contribution_margin, 2),
target_profit_volume=target_volume,
)
async def generate_profit_forecast(
self,
simulation_id: int,
) -> str:
"""AI 生成利润预测分析
遵循瑞小美 AI 接入规范
Args:
simulation_id: 模拟ID
Returns:
AI 分析内容
"""
# 获取模拟数据
result = await self.db.execute(
select(ProfitSimulation).options(
selectinload(ProfitSimulation.pricing_plan).selectinload(PricingPlan.project),
selectinload(ProfitSimulation.sensitivity_analyses),
).where(ProfitSimulation.id == simulation_id)
)
simulation = result.scalar_one_or_none()
if not simulation:
raise ValueError(f"模拟记录不存在: {simulation_id}")
# 格式化数据
pricing_data = f"""- 定价方案:{simulation.pricing_plan.plan_name}
- 策略类型:{simulation.pricing_plan.strategy_type}
- 基础成本:{float(simulation.pricing_plan.base_cost):.2f}
- 建议价格:{float(simulation.pricing_plan.suggested_price):.2f}
- 目标毛利率:{float(simulation.pricing_plan.target_margin):.1f}%"""
simulation_data = f"""- 模拟价格:{float(simulation.price):.2f}
- 预估客量:{simulation.estimated_volume} ({simulation.period_type})
- 预估收入:{float(simulation.estimated_revenue):.2f}
- 预估成本:{float(simulation.estimated_cost):.2f}
- 预估利润:{float(simulation.estimated_profit):.2f}
- 利润率:{float(simulation.profit_margin):.1f}%
- 盈亏平衡客量:{simulation.breakeven_volume}"""
# 格式化敏感性数据
sensitivity_data = "暂无敏感性分析数据"
if simulation.sensitivity_analyses:
lines = []
for sa in simulation.sensitivity_analyses:
lines.append(
f" - 价格{float(sa.price_change_rate):+.0f}%: "
f"价格{float(sa.adjusted_price):.0f}元, "
f"利润{float(sa.adjusted_profit):.0f}"
f"({float(sa.profit_change_rate):+.1f}%)"
)
sensitivity_data = "\n".join(lines)
# 调用 AI
ai_service = AIServiceWrapper(db_session=self.db)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT.format(
project_name=simulation.pricing_plan.project.project_name,
pricing_data=pricing_data,
simulation_data=simulation_data,
sensitivity_data=sensitivity_data,
)},
]
try:
response = await ai_service.chat(
messages=messages,
prompt_name=PROMPT_META["name"], # 必填!
)
return response.content
except Exception as e:
return f"AI 分析暂不可用: {str(e)}"
async def get_simulation_list(
self,
pricing_plan_id: Optional[int] = None,
period_type: Optional[PeriodType] = None,
page: int = 1,
page_size: int = 20,
) -> tuple[List[ProfitSimulation], int]:
"""获取模拟列表
Returns:
(模拟列表, 总数)
"""
query = select(ProfitSimulation).options(
selectinload(ProfitSimulation.pricing_plan).selectinload(PricingPlan.project)
)
if pricing_plan_id:
query = query.where(ProfitSimulation.pricing_plan_id == pricing_plan_id)
if period_type:
query = query.where(ProfitSimulation.period_type == period_type.value)
# 计算总数
count_query = select(func.count()).select_from(
query.subquery()
)
total_result = await self.db.execute(count_query)
total = total_result.scalar() or 0
# 分页
query = query.order_by(ProfitSimulation.created_at.desc())
query = query.offset((page - 1) * page_size).limit(page_size)
result = await self.db.execute(query)
simulations = result.scalars().all()
return simulations, total
async def delete_simulation(self, simulation_id: int) -> bool:
"""删除模拟记录"""
result = await self.db.execute(
select(ProfitSimulation).where(ProfitSimulation.id == simulation_id)
)
simulation = result.scalar_one_or_none()
if simulation:
await self.db.delete(simulation)
await self.db.flush()
return True
return False

View File

@@ -0,0 +1,29 @@
"""AI 提示词模块
遵循瑞小美 AI 接入规范:
- 文件位置:{模块}/后端服务/prompts/{功能名}_prompts.py
- 每个文件必须包含PROMPT_META, SYSTEM_PROMPT, USER_PROMPT
"""
from prompts.pricing_advice_prompts import (
PROMPT_META as PRICING_ADVICE_META,
SYSTEM_PROMPT as PRICING_ADVICE_SYSTEM,
USER_PROMPT as PRICING_ADVICE_USER,
)
from prompts.market_analysis_prompts import (
PROMPT_META as MARKET_ANALYSIS_META,
SYSTEM_PROMPT as MARKET_ANALYSIS_SYSTEM,
USER_PROMPT as MARKET_ANALYSIS_USER,
)
from prompts.profit_forecast_prompts import (
PROMPT_META as PROFIT_FORECAST_META,
SYSTEM_PROMPT as PROFIT_FORECAST_SYSTEM,
USER_PROMPT as PROFIT_FORECAST_USER,
)
# 导出所有提示词元数据
ALL_PROMPT_METAS = [
PRICING_ADVICE_META,
MARKET_ANALYSIS_META,
PROFIT_FORECAST_META,
]

View File

@@ -0,0 +1,63 @@
"""市场分析报告提示词
分析市场数据,生成市场分析报告
"""
# 提示词元数据(必须包含)
PROMPT_META = {
"name": "market_analysis",
"display_name": "市场分析报告",
"description": "分析竞品价格和市场行情,生成市场分析报告",
"module": "pricing_model",
"variables": ["project_name", "competitor_data", "benchmark_data"],
}
# 系统提示词(必须包含)
SYSTEM_PROMPT = """你是一位专业的医美行业市场分析师,擅长分析竞争格局和价格趋势。
你需要根据提供的竞品价格数据和标杆机构数据,分析市场定价情况,给出市场洞察。
分析维度:
1. **价格分布**:分析市场价格的分布特征(高、中、低端)
2. **竞争格局**:识别主要竞争对手的定价策略
3. **标杆对比**:与行业标杆进行对比分析
4. **趋势判断**:分析价格变化趋势
5. **机会识别**:发现市场空白和定价机会
输出要求:
- 使用数据支撑分析结论
- 提供可视化友好的结构
- 给出具体的市场建议"""
# 用户提示词模板(必须包含)
USER_PROMPT = """请分析以下医美项目的市场行情:
## 项目名称
{project_name}
## 竞品价格数据
{competitor_data}
## 标杆机构参考
{benchmark_data}
---
请给出以下分析内容:
### 1. 市场价格概览
- 价格区间
- 均价和中位价
- 价格分布特征
### 2. 竞争格局分析
- 主要竞争对手定位
- 定价策略特点
### 3. 市场定位建议
- 建议的定价区间
- 定价策略建议
### 4. 市场机会与风险
- 潜在机会
- 需要关注的风险"""

View File

@@ -0,0 +1,66 @@
"""定价建议生成提示词
综合成本、市场、目标利润率,生成项目定价建议
"""
# 提示词元数据(必须包含)
PROMPT_META = {
"name": "pricing_advice",
"display_name": "智能定价建议",
"description": "综合成本、市场、目标利润率,生成项目定价建议",
"module": "pricing_model",
"variables": ["project_name", "cost_data", "market_data", "target_margin"],
}
# 系统提示词(必须包含)
SYSTEM_PROMPT = """你是一位专业的医美行业定价分析师,拥有丰富的市场分析和定价策略经验。
你需要根据提供的成本数据、市场行情数据,结合目标利润率,给出专业的定价建议。
分析时请考虑以下维度:
1. **成本结构分析**:评估成本构成的合理性,识别成本优化空间
2. **市场竞争态势**:分析竞品定价分布,确定市场位置
3. **目标客群定位**:根据定价策略匹配目标客群
4. **风险评估**:识别定价可能面临的风险和挑战
输出要求:
- 使用清晰的结构化格式
- 提供具体的数字建议
- 包含不同策略的对比分析
- 给出风险提示和注意事项
请用专业但易懂的语言回复,避免过于学术化的表达。"""
# 用户提示词模板(必须包含)
USER_PROMPT = """请为以下医美项目生成定价建议:
## 项目信息
**项目名称**{project_name}
## 成本数据
{cost_data}
## 市场行情
{market_data}
## 目标毛利率
{target_margin}%
---
请给出以下内容:
### 1. 定价建议区间
分析成本和市场数据,给出合理的定价区间。
### 2. 策略定价建议
针对三种定价策略给出具体价格:
- **引流款**:低价引流,适合获客
- **利润款**:平衡利润与竞争力
- **高端款**:高端定位,高利润
### 3. 推荐方案
给出最推荐的定价方案及理由。
### 4. 风险提示
指出定价时需要注意的风险和问题。"""

View File

@@ -0,0 +1,69 @@
"""利润预测分析提示词
分析利润趋势与风险,提供经营建议
"""
# 提示词元数据(必须包含)
PROMPT_META = {
"name": "profit_forecast",
"display_name": "利润预测分析",
"description": "分析利润趋势与风险,提供经营建议",
"module": "pricing_model",
"variables": ["project_name", "pricing_data", "simulation_data", "sensitivity_data"],
}
# 系统提示词(必须包含)
SYSTEM_PROMPT = """你是一位专业的医美行业财务分析师,擅长利润分析和经营预测。
你需要根据提供的定价数据、利润模拟数据和敏感性分析数据,评估项目的盈利能力和风险。
分析维度:
1. **盈利能力评估**:评估项目的利润水平和利润率
2. **盈亏平衡分析**:分析达到盈亏平衡所需的客量
3. **敏感性分析**:分析价格变动对利润的影响
4. **风险评估**:识别可能影响利润的风险因素
5. **优化建议**:提供提升利润的具体建议
输出要求:
- 使用清晰的数据对比
- 提供具体的经营建议
- 给出风险预警信息"""
# 用户提示词模板(必须包含)
USER_PROMPT = """请分析以下医美项目的利润预测:
## 项目名称
{project_name}
## 定价数据
{pricing_data}
## 利润模拟结果
{simulation_data}
## 敏感性分析
{sensitivity_data}
---
请给出以下分析内容:
### 1. 盈利能力评估
- 预估利润水平
- 利润率分析
- 与行业对比
### 2. 盈亏平衡分析
- 盈亏平衡点客量
- 安全边际评估
- 达标难度分析
### 3. 敏感性分析解读
- 价格弹性分析
- 关键影响因素
- 临界点提醒
### 4. 经营建议
- 定价优化建议
- 成本控制建议
- 风险防范建议"""

36
后端服务/pytest.ini Normal file
View File

@@ -0,0 +1,36 @@
[pytest]
# pytest 配置
# 遵循瑞小美系统技术栈标准
# 测试路径
testpaths = tests
# 异步测试配置pytest-asyncio 0.23+
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function
# 测试输出配置
addopts =
-v
--strict-markers
--tb=short
-p no:warnings
# 测试标记
markers =
unit: 单元测试
integration: 集成测试
slow: 慢速测试
api: API 测试
# 日志配置
log_cli = true
log_cli_level = INFO
log_cli_format = %(asctime)s [%(levelname)s] %(message)s
log_cli_date_format = %Y-%m-%d %H:%M:%S
# 最小版本
minversion = 7.4
# Python 路径
pythonpath = .

View File

@@ -0,0 +1,41 @@
# FastAPI 框架
fastapi==0.109.2
uvicorn[standard]==0.27.1
python-multipart==0.0.9
# 数据库
sqlalchemy[asyncio]==2.0.25
aiomysql==0.2.0
pymysql==1.1.0
# 配置管理
pydantic==2.6.1
pydantic-settings==2.1.0
# 数据处理
python-dateutil==2.8.2
openpyxl==3.1.2
# HTTP 客户端
httpx==0.26.0
aiohttp==3.9.3
# 工具库
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
# 日志
structlog==24.1.0
# 测试
pytest==7.4.4
pytest-asyncio==0.23.4
pytest-cov==4.1.0
httpx==0.26.0
aiosqlite==0.19.0
faker==22.5.1
# 代码质量
black==24.1.1
isort==5.13.2
flake8==7.0.0

40
后端服务/run_tests.sh Executable file
View File

@@ -0,0 +1,40 @@
#!/bin/bash
# 测试运行脚本
# 遵循瑞小美系统技术栈标准
set -e
echo "=========================================="
echo "智能项目定价模型 - 测试套件"
echo "=========================================="
# 检查虚拟环境
if [ -z "$VIRTUAL_ENV" ]; then
echo "建议在虚拟环境中运行测试"
fi
# 安装测试依赖
echo ""
echo "[1/4] 安装测试依赖..."
pip install -q pytest pytest-asyncio pytest-cov aiosqlite faker httpx
# 运行单元测试
echo ""
echo "[2/4] 运行单元测试..."
pytest tests/test_services/ -v --tb=short -m "unit" || true
# 运行 API 集成测试
echo ""
echo "[3/4] 运行 API 集成测试..."
pytest tests/test_api/ -v --tb=short -m "api" || true
# 生成覆盖率报告
echo ""
echo "[4/4] 生成覆盖率报告..."
pytest tests/ --cov=app --cov-report=term-missing --cov-report=html:coverage_report || true
echo ""
echo "=========================================="
echo "测试完成!"
echo "覆盖率报告: coverage_report/index.html"
echo "=========================================="

View File

@@ -0,0 +1,5 @@
"""智能项目定价模型 - 测试模块
遵循瑞小美系统技术栈标准
测试框架: pytest + pytest-asyncio
"""

View File

@@ -0,0 +1,349 @@
"""pytest 测试配置
提供测试所需的 fixtures
- 测试数据库会话
- 测试客户端
- 测试数据工厂
遵循瑞小美系统技术栈标准
"""
import os
from decimal import Decimal
from datetime import datetime
from typing import AsyncGenerator
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.pool import StaticPool
# 设置测试环境变量
os.environ["APP_ENV"] = "test"
os.environ["DEBUG"] = "true"
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
from app.main import app
from app.database import Base, get_db
from app.models import (
Category, Material, Equipment, StaffLevel, FixedCost,
Project, ProjectCostItem, ProjectLaborCost, ProjectCostSummary,
Competitor, CompetitorPrice, BenchmarkPrice, MarketAnalysisResult,
PricingPlan, ProfitSimulation, SensitivityAnalysis
)
# 测试数据库引擎(使用 SQLite 内存数据库)
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
test_engine = create_async_engine(
TEST_DATABASE_URL,
echo=False,
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
test_session_maker = async_sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
@pytest_asyncio.fixture(scope="function")
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""获取测试数据库会话
每个测试函数使用独立的数据库会话,测试后自动回滚
"""
# 创建表
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async with test_session_maker() as session:
try:
yield session
finally:
await session.rollback()
await session.close()
# 清理表
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture(scope="function")
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
"""获取测试客户端
使用测试数据库会话替换应用的数据库依赖
"""
async def override_get_db():
yield db_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
# ============ 测试数据工厂 ============
@pytest_asyncio.fixture
async def sample_category(db_session: AsyncSession) -> Category:
"""创建示例项目分类"""
category = Category(
category_name="光电类",
parent_id=None,
sort_order=1,
is_active=True,
)
db_session.add(category)
await db_session.commit()
await db_session.refresh(category)
return category
@pytest_asyncio.fixture
async def sample_material(db_session: AsyncSession) -> Material:
"""创建示例耗材"""
material = Material(
material_code="MAT001",
material_name="冷凝胶",
unit="ml",
unit_price=Decimal("2.00"),
supplier="供应商A",
material_type="consumable",
is_active=True,
)
db_session.add(material)
await db_session.commit()
await db_session.refresh(material)
return material
@pytest_asyncio.fixture
async def sample_equipment(db_session: AsyncSession) -> Equipment:
"""创建示例设备"""
equipment = Equipment(
equipment_code="EQP001",
equipment_name="光子仪",
original_value=Decimal("100000.00"),
residual_rate=Decimal("5.00"),
service_years=5,
estimated_uses=2000,
depreciation_per_use=Decimal("47.50"), # (100000 - 5000) / 2000
purchase_date=datetime(2025, 1, 1).date(),
is_active=True,
)
db_session.add(equipment)
await db_session.commit()
await db_session.refresh(equipment)
return equipment
@pytest_asyncio.fixture
async def sample_staff_level(db_session: AsyncSession) -> StaffLevel:
"""创建示例人员级别"""
staff_level = StaffLevel(
level_code="L2",
level_name="中级美容师",
hourly_rate=Decimal("50.00"),
is_active=True,
)
db_session.add(staff_level)
await db_session.commit()
await db_session.refresh(staff_level)
return staff_level
@pytest_asyncio.fixture
async def sample_fixed_cost(db_session: AsyncSession) -> FixedCost:
"""创建示例固定成本"""
fixed_cost = FixedCost(
cost_name="房租",
cost_type="rent",
monthly_amount=Decimal("30000.00"),
year_month=datetime.now().strftime("%Y-%m"),
allocation_method="count",
is_active=True,
)
db_session.add(fixed_cost)
await db_session.commit()
await db_session.refresh(fixed_cost)
return fixed_cost
@pytest_asyncio.fixture
async def sample_project(
db_session: AsyncSession,
sample_category: Category
) -> Project:
"""创建示例服务项目"""
project = Project(
project_code="PRJ001",
project_name="光子嫩肤",
category_id=sample_category.id,
description="IPL光子嫩肤项目",
duration_minutes=60,
is_active=True,
)
db_session.add(project)
await db_session.commit()
await db_session.refresh(project)
return project
@pytest_asyncio.fixture
async def sample_project_with_costs(
db_session: AsyncSession,
sample_project: Project,
sample_material: Material,
sample_equipment: Equipment,
sample_staff_level: StaffLevel,
sample_fixed_cost: FixedCost,
) -> Project:
"""创建带成本明细的示例项目"""
# 添加耗材成本
cost_item = ProjectCostItem(
project_id=sample_project.id,
item_type="material",
item_id=sample_material.id,
quantity=Decimal("20"),
unit_cost=Decimal("2.00"),
total_cost=Decimal("40.00"),
)
db_session.add(cost_item)
# 添加设备折旧成本
equip_cost = ProjectCostItem(
project_id=sample_project.id,
item_type="equipment",
item_id=sample_equipment.id,
quantity=Decimal("1"),
unit_cost=Decimal("47.50"),
total_cost=Decimal("47.50"),
)
db_session.add(equip_cost)
# 添加人工成本
labor_cost = ProjectLaborCost(
project_id=sample_project.id,
staff_level_id=sample_staff_level.id,
duration_minutes=60,
hourly_rate=Decimal("50.00"),
labor_cost=Decimal("50.00"), # 60分钟 / 60 * 50
)
db_session.add(labor_cost)
await db_session.commit()
await db_session.refresh(sample_project)
return sample_project
@pytest_asyncio.fixture
async def sample_competitor(db_session: AsyncSession) -> Competitor:
"""创建示例竞品机构"""
competitor = Competitor(
competitor_name="美丽人生医美",
address="XX市XX路100号",
distance_km=Decimal("2.5"),
positioning="medium",
contact="13800138000",
is_key_competitor=True,
is_active=True,
)
db_session.add(competitor)
await db_session.commit()
await db_session.refresh(competitor)
return competitor
@pytest_asyncio.fixture
async def sample_competitor_price(
db_session: AsyncSession,
sample_competitor: Competitor,
sample_project: Project,
) -> CompetitorPrice:
"""创建示例竞品价格"""
price = CompetitorPrice(
competitor_id=sample_competitor.id,
project_id=sample_project.id,
project_name="光子嫩肤",
original_price=Decimal("680.00"),
promo_price=Decimal("480.00"),
member_price=Decimal("580.00"),
price_source="meituan",
collected_at=datetime.now().date(),
)
db_session.add(price)
await db_session.commit()
await db_session.refresh(price)
return price
@pytest_asyncio.fixture
async def sample_pricing_plan(
db_session: AsyncSession,
sample_project: Project,
) -> PricingPlan:
"""创建示例定价方案"""
plan = PricingPlan(
project_id=sample_project.id,
plan_name="2026年Q1定价",
strategy_type="profit",
base_cost=Decimal("280.50"),
target_margin=Decimal("50.00"),
suggested_price=Decimal("561.00"),
final_price=Decimal("580.00"),
is_active=True,
)
db_session.add(plan)
await db_session.commit()
await db_session.refresh(plan)
return plan
@pytest_asyncio.fixture
async def sample_cost_summary(
db_session: AsyncSession,
sample_project: Project,
) -> ProjectCostSummary:
"""创建示例成本汇总"""
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
material_cost=Decimal("100.00"),
equipment_cost=Decimal("50.00"),
labor_cost=Decimal("50.00"),
fixed_cost_allocation=Decimal("30.00"),
total_cost=Decimal("230.00"),
calculated_at=datetime.now(), # 确保设置计算时间
)
db_session.add(cost_summary)
await db_session.commit()
await db_session.refresh(cost_summary)
return cost_summary
# ============ 辅助函数 ============
def assert_response_success(response, expected_code=0):
"""断言响应成功"""
assert response.status_code == 200
data = response.json()
assert data["code"] == expected_code
assert "data" in data
return data["data"]
def assert_response_error(response, expected_code):
"""断言响应错误"""
data = response.json()
assert data["code"] == expected_code
return data

View File

@@ -0,0 +1 @@
"""API 层集成测试"""

View File

@@ -0,0 +1,107 @@
"""项目分类接口测试"""
import pytest
from httpx import AsyncClient
from app.models import Category
from tests.conftest import assert_response_success
class TestCategoriesAPI:
"""项目分类 API 测试"""
@pytest.mark.api
@pytest.mark.asyncio
async def test_create_category(self, client: AsyncClient):
"""测试创建分类"""
response = await client.post(
"/api/v1/categories",
json={
"category_name": "测试分类",
"sort_order": 1,
"is_active": True
}
)
data = assert_response_success(response)
assert data["category_name"] == "测试分类"
assert data["sort_order"] == 1
assert data["is_active"] is True
assert "id" in data
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_categories_list(
self,
client: AsyncClient,
sample_category: Category
):
"""测试获取分类列表"""
response = await client.get("/api/v1/categories")
data = assert_response_success(response)
assert "items" in data
assert data["total"] >= 1
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_category_by_id(
self,
client: AsyncClient,
sample_category: Category
):
"""测试获取单个分类"""
response = await client.get(f"/api/v1/categories/{sample_category.id}")
data = assert_response_success(response)
assert data["id"] == sample_category.id
assert data["category_name"] == sample_category.category_name
@pytest.mark.api
@pytest.mark.asyncio
async def test_update_category(
self,
client: AsyncClient,
sample_category: Category
):
"""测试更新分类"""
response = await client.put(
f"/api/v1/categories/{sample_category.id}",
json={
"category_name": "更新后分类",
"sort_order": 99
}
)
data = assert_response_success(response)
assert data["category_name"] == "更新后分类"
assert data["sort_order"] == 99
@pytest.mark.api
@pytest.mark.asyncio
async def test_delete_category(self, client: AsyncClient):
"""测试删除分类"""
# 先创建
create_response = await client.post(
"/api/v1/categories",
json={"category_name": "待删除分类"}
)
created = assert_response_success(create_response)
# 删除
delete_response = await client.delete(
f"/api/v1/categories/{created['id']}"
)
assert delete_response.status_code == 200
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_nonexistent_category(self, client: AsyncClient):
"""测试获取不存在的分类"""
response = await client.get("/api/v1/categories/99999")
# API 返回 HTTP 404 + 错误详情
assert response.status_code == 404
data = response.json()
assert data["detail"]["code"] == 10002 # 数据不存在

View File

@@ -0,0 +1,31 @@
"""健康检查接口测试"""
import pytest
from httpx import AsyncClient
class TestHealthAPI:
"""健康检查 API 测试"""
@pytest.mark.api
@pytest.mark.asyncio
async def test_health_check(self, client: AsyncClient):
"""测试健康检查端点"""
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["code"] == 0
assert data["data"]["status"] == "healthy"
assert "version" in data["data"]
@pytest.mark.api
@pytest.mark.asyncio
async def test_root_endpoint(self, client: AsyncClient):
"""测试根路径端点"""
response = await client.get("/")
assert response.status_code == 200
data = response.json()
assert data["code"] == 0
assert "智能项目定价模型" in data["data"]["name"]

View File

@@ -0,0 +1,113 @@
"""市场行情接口测试"""
import pytest
from httpx import AsyncClient
from app.models import Project, Competitor, CompetitorPrice
from tests.conftest import assert_response_success
class TestMarketAPI:
"""市场行情 API 测试"""
@pytest.mark.api
@pytest.mark.asyncio
async def test_create_competitor(self, client: AsyncClient):
"""测试创建竞品机构"""
response = await client.post(
"/api/v1/competitors",
json={
"competitor_name": "测试竞品",
"address": "测试地址",
"distance_km": 3.5,
"positioning": "medium",
"is_key_competitor": True
}
)
data = assert_response_success(response)
assert data["competitor_name"] == "测试竞品"
assert float(data["distance_km"]) == 3.5
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_competitors_list(
self,
client: AsyncClient,
sample_competitor: Competitor
):
"""测试获取竞品列表"""
response = await client.get("/api/v1/competitors")
data = assert_response_success(response)
assert "items" in data
assert data["total"] >= 1
@pytest.mark.api
@pytest.mark.asyncio
async def test_add_competitor_price(
self,
client: AsyncClient,
sample_competitor: Competitor,
sample_project: Project
):
"""测试添加竞品价格"""
response = await client.post(
f"/api/v1/competitors/{sample_competitor.id}/prices",
json={
"project_id": sample_project.id,
"project_name": "光子嫩肤",
"original_price": 800.00,
"promo_price": 600.00,
"price_source": "meituan",
"collected_at": "2026-01-20"
}
)
data = assert_response_success(response)
assert float(data["original_price"]) == 800.00
assert float(data["promo_price"]) == 600.00
@pytest.mark.api
@pytest.mark.asyncio
async def test_market_analysis(
self,
client: AsyncClient,
sample_project: Project,
sample_competitor_price: CompetitorPrice
):
"""测试市场分析"""
response = await client.post(
f"/api/v1/projects/{sample_project.id}/market-analysis",
json={
"include_benchmark": False
}
)
data = assert_response_success(response)
assert data["project_id"] == sample_project.id
assert "price_statistics" in data
assert "suggested_range" in data
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_market_analysis(
self,
client: AsyncClient,
sample_project: Project,
sample_competitor_price: CompetitorPrice
):
"""测试获取市场分析结果"""
# 先执行分析
await client.post(
f"/api/v1/projects/{sample_project.id}/market-analysis",
json={"include_benchmark": False}
)
# 获取结果
response = await client.get(
f"/api/v1/projects/{sample_project.id}/market-analysis"
)
data = assert_response_success(response)
assert data["project_id"] == sample_project.id

View File

@@ -0,0 +1,106 @@
"""耗材管理接口测试"""
import pytest
from httpx import AsyncClient
from app.models import Material
from tests.conftest import assert_response_success
class TestMaterialsAPI:
"""耗材管理 API 测试"""
@pytest.mark.api
@pytest.mark.asyncio
async def test_create_material(self, client: AsyncClient):
"""测试创建耗材"""
response = await client.post(
"/api/v1/materials",
json={
"material_code": "MAT_TEST001",
"material_name": "测试耗材",
"unit": "",
"unit_price": 10.50,
"supplier": "测试供应商",
"material_type": "consumable"
}
)
data = assert_response_success(response)
assert data["material_code"] == "MAT_TEST001"
assert data["material_name"] == "测试耗材"
assert float(data["unit_price"]) == 10.50
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_materials_list(
self,
client: AsyncClient,
sample_material: Material
):
"""测试获取耗材列表"""
response = await client.get("/api/v1/materials")
data = assert_response_success(response)
assert "items" in data
assert data["total"] >= 1
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_materials_with_filter(
self,
client: AsyncClient,
sample_material: Material
):
"""测试带筛选的耗材列表"""
response = await client.get(
"/api/v1/materials",
params={"material_type": "consumable"}
)
data = assert_response_success(response)
assert data["total"] >= 1
@pytest.mark.api
@pytest.mark.asyncio
async def test_update_material(
self,
client: AsyncClient,
sample_material: Material
):
"""测试更新耗材"""
response = await client.put(
f"/api/v1/materials/{sample_material.id}",
json={
"unit_price": 3.00,
"supplier": "新供应商"
}
)
data = assert_response_success(response)
assert float(data["unit_price"]) == 3.00
assert data["supplier"] == "新供应商"
@pytest.mark.api
@pytest.mark.asyncio
async def test_create_duplicate_material_code(
self,
client: AsyncClient,
sample_material: Material
):
"""测试创建重复编码的耗材"""
response = await client.post(
"/api/v1/materials",
json={
"material_code": sample_material.material_code,
"material_name": "重复编码耗材",
"unit": "",
"unit_price": 10.00,
"material_type": "consumable"
}
)
# API 返回 HTTP 400 + 错误详情
assert response.status_code == 400
data = response.json()
assert data["detail"]["code"] == 10003 # 数据已存在

View File

@@ -0,0 +1,134 @@
"""智能定价接口测试"""
import pytest
from httpx import AsyncClient
from decimal import Decimal
from datetime import datetime
from app.models import Project, ProjectCostSummary, PricingPlan
from tests.conftest import assert_response_success
class TestPricingAPI:
"""智能定价 API 测试"""
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_pricing_plans_list(
self,
client: AsyncClient,
sample_pricing_plan: PricingPlan
):
"""测试获取定价方案列表"""
response = await client.get("/api/v1/pricing-plans")
data = assert_response_success(response)
assert "items" in data
assert data["total"] >= 1
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_pricing_plan_detail(
self,
client: AsyncClient,
sample_pricing_plan: PricingPlan
):
"""测试获取定价方案详情"""
response = await client.get(
f"/api/v1/pricing-plans/{sample_pricing_plan.id}"
)
data = assert_response_success(response)
assert data["id"] == sample_pricing_plan.id
assert data["plan_name"] == sample_pricing_plan.plan_name
@pytest.mark.api
@pytest.mark.asyncio
async def test_create_pricing_plan(
self,
client: AsyncClient,
db_session,
sample_project: Project
):
"""测试创建定价方案"""
# 先添加成本汇总
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
material_cost=Decimal("100"),
equipment_cost=Decimal("50"),
labor_cost=Decimal("50"),
fixed_cost_allocation=Decimal("0"),
total_cost=Decimal("200"),
calculated_at=datetime.now()
)
db_session.add(cost_summary)
await db_session.commit()
response = await client.post(
"/api/v1/pricing-plans",
json={
"project_id": sample_project.id,
"plan_name": "测试定价方案",
"strategy_type": "profit",
"target_margin": 50
}
)
data = assert_response_success(response)
assert data["project_id"] == sample_project.id
assert data["plan_name"] == "测试定价方案"
assert data["strategy_type"] == "profit"
@pytest.mark.api
@pytest.mark.asyncio
async def test_update_pricing_plan(
self,
client: AsyncClient,
sample_pricing_plan: PricingPlan
):
"""测试更新定价方案"""
response = await client.put(
f"/api/v1/pricing-plans/{sample_pricing_plan.id}",
json={
"final_price": 599.00,
"plan_name": "更新后方案"
}
)
data = assert_response_success(response)
assert float(data["final_price"]) == 599.00
assert data["plan_name"] == "更新后方案"
@pytest.mark.api
@pytest.mark.asyncio
async def test_simulate_strategy(
self,
client: AsyncClient,
db_session,
sample_project: Project
):
"""测试策略模拟"""
# 添加成本汇总
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
total_cost=Decimal("200"),
material_cost=Decimal("100"),
equipment_cost=Decimal("50"),
labor_cost=Decimal("50"),
fixed_cost_allocation=Decimal("0"),
calculated_at=datetime.now()
)
db_session.add(cost_summary)
await db_session.commit()
response = await client.post(
f"/api/v1/projects/{sample_project.id}/simulate-strategy",
json={
"strategies": ["traffic", "profit", "premium"],
"target_margin": 50
}
)
data = assert_response_success(response)
assert data["project_id"] == sample_project.id
assert len(data["results"]) == 3

View File

@@ -0,0 +1,133 @@
"""利润模拟接口测试"""
import pytest
from httpx import AsyncClient
from app.models import PricingPlan
from tests.conftest import assert_response_success
class TestProfitAPI:
"""利润模拟 API 测试"""
@pytest.mark.api
@pytest.mark.asyncio
async def test_simulate_profit(
self,
client: AsyncClient,
sample_pricing_plan: PricingPlan
):
"""测试利润模拟"""
response = await client.post(
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/simulate-profit",
json={
"price": 580.00,
"estimated_volume": 100,
"period_type": "monthly"
}
)
data = assert_response_success(response)
assert data["pricing_plan_id"] == sample_pricing_plan.id
assert "input" in data
assert "result" in data
assert "breakeven_analysis" in data
assert data["input"]["price"] == 580.00
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_profit_simulations_list(
self,
client: AsyncClient,
sample_pricing_plan: PricingPlan
):
"""测试获取模拟列表"""
# 先创建一个模拟
await client.post(
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/simulate-profit",
json={
"price": 580.00,
"estimated_volume": 100,
"period_type": "monthly"
}
)
response = await client.get("/api/v1/profit-simulations")
data = assert_response_success(response)
assert "items" in data
assert data["total"] >= 1
@pytest.mark.api
@pytest.mark.asyncio
async def test_sensitivity_analysis(
self,
client: AsyncClient,
sample_pricing_plan: PricingPlan
):
"""测试敏感性分析"""
# 先创建模拟
sim_response = await client.post(
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/simulate-profit",
json={
"price": 580.00,
"estimated_volume": 100,
"period_type": "monthly"
}
)
sim_data = assert_response_success(sim_response)
# 执行敏感性分析
response = await client.post(
f"/api/v1/profit-simulations/{sim_data['simulation_id']}/sensitivity",
json={
"price_change_rates": [-20, -10, 0, 10, 20]
}
)
data = assert_response_success(response)
assert len(data["sensitivity_results"]) == 5
assert "insights" in data
@pytest.mark.api
@pytest.mark.asyncio
async def test_breakeven_analysis(
self,
client: AsyncClient,
sample_pricing_plan: PricingPlan
):
"""测试盈亏平衡分析"""
response = await client.get(
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/breakeven"
)
data = assert_response_success(response)
assert data["pricing_plan_id"] == sample_pricing_plan.id
assert "breakeven_volume" in data
assert "current_margin" in data
@pytest.mark.api
@pytest.mark.asyncio
async def test_delete_simulation(
self,
client: AsyncClient,
sample_pricing_plan: PricingPlan
):
"""测试删除模拟"""
# 创建模拟
sim_response = await client.post(
f"/api/v1/pricing-plans/{sample_pricing_plan.id}/simulate-profit",
json={
"price": 580.00,
"estimated_volume": 100,
"period_type": "monthly"
}
)
sim_data = assert_response_success(sim_response)
# 删除
delete_response = await client.delete(
f"/api/v1/profit-simulations/{sim_data['simulation_id']}"
)
assert delete_response.status_code == 200

View File

@@ -0,0 +1,152 @@
"""服务项目接口测试"""
import pytest
from httpx import AsyncClient
from app.models import Project, Category, Material, Equipment, StaffLevel
from tests.conftest import assert_response_success
class TestProjectsAPI:
"""服务项目 API 测试"""
@pytest.mark.api
@pytest.mark.asyncio
async def test_create_project(
self,
client: AsyncClient,
sample_category: Category
):
"""测试创建项目"""
response = await client.post(
"/api/v1/projects",
json={
"project_code": "PRJ_TEST001",
"project_name": "测试项目",
"category_id": sample_category.id,
"description": "测试项目描述",
"duration_minutes": 45,
"is_active": True
}
)
data = assert_response_success(response)
assert data["project_code"] == "PRJ_TEST001"
assert data["project_name"] == "测试项目"
assert data["duration_minutes"] == 45
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_projects_list(
self,
client: AsyncClient,
sample_project: Project
):
"""测试获取项目列表"""
response = await client.get("/api/v1/projects")
data = assert_response_success(response)
assert "items" in data
assert data["total"] >= 1
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_project_detail(
self,
client: AsyncClient,
sample_project: Project
):
"""测试获取项目详情"""
response = await client.get(f"/api/v1/projects/{sample_project.id}")
data = assert_response_success(response)
assert data["id"] == sample_project.id
assert data["project_name"] == sample_project.project_name
@pytest.mark.api
@pytest.mark.asyncio
async def test_add_cost_item(
self,
client: AsyncClient,
sample_project: Project,
sample_material: Material
):
"""测试添加成本明细"""
response = await client.post(
f"/api/v1/projects/{sample_project.id}/cost-items",
json={
"item_type": "material",
"item_id": sample_material.id,
"quantity": 5,
"remark": "测试备注"
}
)
data = assert_response_success(response)
assert data["item_type"] == "material"
assert float(data["quantity"]) == 5.0
assert float(data["total_cost"]) == 10.0 # 5 * 2.00
@pytest.mark.api
@pytest.mark.asyncio
async def test_add_labor_cost(
self,
client: AsyncClient,
sample_project: Project,
sample_staff_level: StaffLevel
):
"""测试添加人工成本"""
response = await client.post(
f"/api/v1/projects/{sample_project.id}/labor-costs",
json={
"staff_level_id": sample_staff_level.id,
"duration_minutes": 30
}
)
data = assert_response_success(response)
assert data["duration_minutes"] == 30
assert float(data["labor_cost"]) == 25.0 # 30/60 * 50
@pytest.mark.api
@pytest.mark.asyncio
async def test_calculate_project_cost(
self,
client: AsyncClient,
sample_project_with_costs: Project
):
"""测试计算项目成本"""
response = await client.post(
f"/api/v1/projects/{sample_project_with_costs.id}/calculate-cost",
json={"allocation_method": "count"}
)
data = assert_response_success(response)
assert data["project_id"] == sample_project_with_costs.id
assert "cost_breakdown" in data
assert "total_cost" in data
assert data["total_cost"] > 0
@pytest.mark.api
@pytest.mark.asyncio
async def test_get_cost_summary(
self,
client: AsyncClient,
sample_project_with_costs: Project
):
"""测试获取成本汇总"""
# 先计算成本
await client.post(
f"/api/v1/projects/{sample_project_with_costs.id}/calculate-cost",
json={"allocation_method": "count"}
)
# 获取汇总
response = await client.get(
f"/api/v1/projects/{sample_project_with_costs.id}/cost-summary"
)
data = assert_response_success(response)
assert data["project_id"] == sample_project_with_costs.id
assert "material_cost" in data
assert "total_cost" in data

View File

@@ -0,0 +1 @@
"""服务层单元测试"""

View File

@@ -0,0 +1,415 @@
"""成本计算服务单元测试
测试 CostService 的核心业务逻辑
"""
import pytest
from decimal import Decimal
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.cost_service import CostService
from app.schemas.project_cost import AllocationMethod, CostItemType
from app.models import (
Material, Equipment, StaffLevel, Project, FixedCost,
ProjectCostItem, ProjectLaborCost, ProjectCostSummary
)
class TestCostService:
"""成本服务测试类"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_material_info(
self,
db_session: AsyncSession,
sample_material: Material
):
"""测试获取耗材信息"""
service = CostService(db_session)
# 获取存在的耗材
material = await service.get_material_info(sample_material.id)
assert material is not None
assert material.material_name == "冷凝胶"
assert material.unit_price == Decimal("2.00")
# 获取不存在的耗材
material = await service.get_material_info(99999)
assert material is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_equipment_info(
self,
db_session: AsyncSession,
sample_equipment: Equipment
):
"""测试获取设备信息"""
service = CostService(db_session)
# 获取存在的设备
equipment = await service.get_equipment_info(sample_equipment.id)
assert equipment is not None
assert equipment.equipment_name == "光子仪"
assert equipment.depreciation_per_use == Decimal("47.50")
# 获取不存在的设备
equipment = await service.get_equipment_info(99999)
assert equipment is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_staff_level_info(
self,
db_session: AsyncSession,
sample_staff_level: StaffLevel
):
"""测试获取人员级别信息"""
service = CostService(db_session)
# 获取存在的级别
level = await service.get_staff_level_info(sample_staff_level.id)
assert level is not None
assert level.level_name == "中级美容师"
assert level.hourly_rate == Decimal("50.00")
# 获取不存在的级别
level = await service.get_staff_level_info(99999)
assert level is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_material_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project
):
"""测试耗材成本计算"""
service = CostService(db_session)
total, breakdown = await service.calculate_material_cost(
sample_project_with_costs.id
)
assert total == Decimal("40.00") # 20 * 2.00
assert len(breakdown) == 1
assert breakdown[0]["name"] == "冷凝胶"
assert breakdown[0]["quantity"] == 20.0
assert breakdown[0]["total"] == 40.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_equipment_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project
):
"""测试设备折旧成本计算"""
service = CostService(db_session)
total, breakdown = await service.calculate_equipment_cost(
sample_project_with_costs.id
)
assert total == Decimal("47.50") # 1 * 47.50
assert len(breakdown) == 1
assert breakdown[0]["name"] == "光子仪"
assert breakdown[0]["depreciation_per_use"] == 47.5
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_labor_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project
):
"""测试人工成本计算"""
service = CostService(db_session)
total, breakdown = await service.calculate_labor_cost(
sample_project_with_costs.id
)
assert total == Decimal("50.00") # 60分钟 / 60 * 50
assert len(breakdown) == 1
assert breakdown[0]["name"] == "中级美容师"
assert breakdown[0]["duration_minutes"] == 60
assert breakdown[0]["hourly_rate"] == 50.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_fixed_cost_allocation_by_count(
self,
db_session: AsyncSession,
sample_project: Project,
sample_fixed_cost: FixedCost
):
"""测试固定成本按项目数量分摊"""
service = CostService(db_session)
allocation, detail = await service.calculate_fixed_cost_allocation(
sample_project.id,
method=AllocationMethod.COUNT
)
# 只有一个项目,分摊全部固定成本
assert allocation == Decimal("30000.00")
assert detail["method"] == "count"
assert detail["project_count"] == 1
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_fixed_cost_allocation_by_duration(
self,
db_session: AsyncSession,
sample_project: Project,
sample_fixed_cost: FixedCost
):
"""测试固定成本按时长分摊"""
service = CostService(db_session)
allocation, detail = await service.calculate_fixed_cost_allocation(
sample_project.id,
method=AllocationMethod.DURATION
)
# 只有一个项目,占比 100%
assert allocation == Decimal("30000.00")
assert detail["method"] == "duration"
assert detail["project_duration"] == 60
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_project_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project,
sample_fixed_cost: FixedCost
):
"""测试项目总成本计算"""
service = CostService(db_session)
result = await service.calculate_project_cost(
sample_project_with_costs.id,
allocation_method=AllocationMethod.COUNT
)
assert result.project_id == sample_project_with_costs.id
assert result.project_name == "光子嫩肤"
# 验证成本构成
breakdown = result.cost_breakdown
assert breakdown["material_cost"]["subtotal"] == 40.0
assert breakdown["equipment_cost"]["subtotal"] == 47.5
assert breakdown["labor_cost"]["subtotal"] == 50.0
# 总成本 = 耗材40 + 设备47.5 + 人工50 + 固定30000
expected_total = 40 + 47.5 + 50 + 30000
assert result.total_cost == expected_total
assert result.min_price_suggestion == expected_total
@pytest.mark.unit
@pytest.mark.asyncio
async def test_calculate_project_cost_not_found(
self,
db_session: AsyncSession
):
"""测试项目不存在时的错误处理"""
service = CostService(db_session)
with pytest.raises(ValueError, match="项目不存在"):
await service.calculate_project_cost(99999)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_cost_item_material(
self,
db_session: AsyncSession,
sample_project: Project,
sample_material: Material
):
"""测试添加耗材成本明细"""
service = CostService(db_session)
cost_item = await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.MATERIAL,
item_id=sample_material.id,
quantity=10,
remark="测试备注"
)
assert cost_item.project_id == sample_project.id
assert cost_item.item_type == "material"
assert cost_item.quantity == Decimal("10")
assert cost_item.unit_cost == Decimal("2.00")
assert cost_item.total_cost == Decimal("20.00")
assert cost_item.remark == "测试备注"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_cost_item_equipment(
self,
db_session: AsyncSession,
sample_project: Project,
sample_equipment: Equipment
):
"""测试添加设备折旧成本明细"""
service = CostService(db_session)
cost_item = await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.EQUIPMENT,
item_id=sample_equipment.id,
quantity=1,
)
assert cost_item.item_type == "equipment"
assert cost_item.unit_cost == Decimal("47.50")
assert cost_item.total_cost == Decimal("47.50")
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_cost_item_not_found(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试添加不存在的耗材/设备时的错误处理"""
service = CostService(db_session)
with pytest.raises(ValueError, match="耗材不存在"):
await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.MATERIAL,
item_id=99999,
quantity=1,
)
with pytest.raises(ValueError, match="设备不存在"):
await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.EQUIPMENT,
item_id=99999,
quantity=1,
)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_labor_cost(
self,
db_session: AsyncSession,
sample_project: Project,
sample_staff_level: StaffLevel
):
"""测试添加人工成本"""
service = CostService(db_session)
labor_cost = await service.add_labor_cost(
project_id=sample_project.id,
staff_level_id=sample_staff_level.id,
duration_minutes=30,
remark="测试人工"
)
assert labor_cost.project_id == sample_project.id
assert labor_cost.duration_minutes == 30
assert labor_cost.hourly_rate == Decimal("50.00")
# 30分钟 / 60 * 50 = 25
assert labor_cost.labor_cost == Decimal("25.00")
@pytest.mark.unit
@pytest.mark.asyncio
async def test_add_labor_cost_not_found(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试添加不存在的人员级别时的错误处理"""
service = CostService(db_session)
with pytest.raises(ValueError, match="人员级别不存在"):
await service.add_labor_cost(
project_id=sample_project.id,
staff_level_id=99999,
duration_minutes=30,
)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_update_cost_item(
self,
db_session: AsyncSession,
sample_project: Project,
sample_material: Material
):
"""测试更新成本明细"""
service = CostService(db_session)
# 先添加
cost_item = await service.add_cost_item(
project_id=sample_project.id,
item_type=CostItemType.MATERIAL,
item_id=sample_material.id,
quantity=10,
)
# 更新数量
updated = await service.update_cost_item(
cost_item=cost_item,
quantity=20,
remark="更新后备注"
)
assert updated.quantity == Decimal("20")
assert updated.total_cost == Decimal("40.00") # 20 * 2
assert updated.remark == "更新后备注"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_update_labor_cost(
self,
db_session: AsyncSession,
sample_project: Project,
sample_staff_level: StaffLevel
):
"""测试更新人工成本"""
service = CostService(db_session)
# 先添加
labor = await service.add_labor_cost(
project_id=sample_project.id,
staff_level_id=sample_staff_level.id,
duration_minutes=30,
)
# 更新时长
updated = await service.update_labor_cost(
labor_item=labor,
duration_minutes=60,
)
assert updated.duration_minutes == 60
assert updated.labor_cost == Decimal("50.00") # 60/60 * 50
@pytest.mark.unit
@pytest.mark.asyncio
async def test_empty_project_cost(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试没有成本明细的项目计算"""
service = CostService(db_session)
# 计算空项目成本(无固定成本)
total_material, _ = await service.calculate_material_cost(sample_project.id)
total_equipment, _ = await service.calculate_equipment_cost(sample_project.id)
total_labor, _ = await service.calculate_labor_cost(sample_project.id)
assert total_material == Decimal("0")
assert total_equipment == Decimal("0")
assert total_labor == Decimal("0")

View File

@@ -0,0 +1,305 @@
"""市场分析服务单元测试
测试 MarketService 的核心业务逻辑
"""
import pytest
from decimal import Decimal
from datetime import date
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.market_service import MarketService
from app.models import (
Project, Competitor, CompetitorPrice, BenchmarkPrice, Category
)
class TestMarketService:
"""市场分析服务测试类"""
@pytest.mark.unit
def test_calculate_price_statistics_empty(self):
"""测试空价格列表的统计"""
service = MarketService(None) # 不需要 db
stats = service.calculate_price_statistics([])
assert stats.min_price == 0
assert stats.max_price == 0
assert stats.avg_price == 0
assert stats.median_price == 0
assert stats.std_deviation is None or stats.std_deviation == 0
@pytest.mark.unit
def test_calculate_price_statistics_single(self):
"""测试单个价格的统计"""
service = MarketService(None)
stats = service.calculate_price_statistics([500.0])
assert stats.min_price == 500.0
assert stats.max_price == 500.0
assert stats.avg_price == 500.0
assert stats.median_price == 500.0
@pytest.mark.unit
def test_calculate_price_statistics_multiple(self):
"""测试多个价格的统计"""
service = MarketService(None)
prices = [300.0, 400.0, 500.0, 600.0, 700.0]
stats = service.calculate_price_statistics(prices)
assert stats.min_price == 300.0
assert stats.max_price == 700.0
assert stats.avg_price == 500.0 # (300+400+500+600+700)/5
assert stats.median_price == 500.0 # 中位数
assert stats.std_deviation is not None
assert stats.std_deviation > 0
@pytest.mark.unit
def test_calculate_price_distribution(self):
"""测试价格分布计算"""
service = MarketService(None)
# 价格范围 300-900分为三个区间
# 低: 300-500, 中: 500-700, 高: 700-900
prices = [350.0, 450.0, 550.0, 650.0, 750.0, 850.0]
distribution = service.calculate_price_distribution(
prices=prices,
min_price=300.0,
max_price=900.0
)
# 验证分布
assert distribution.low.count == 2 # 350, 450
assert distribution.medium.count == 2 # 550, 650
assert distribution.high.count == 2 # 750, 850
# 验证百分比
assert distribution.low.percentage == pytest.approx(33.3, rel=0.1)
assert distribution.medium.percentage == pytest.approx(33.3, rel=0.1)
assert distribution.high.percentage == pytest.approx(33.3, rel=0.1)
@pytest.mark.unit
def test_calculate_price_distribution_empty(self):
"""测试空价格列表的分布"""
service = MarketService(None)
distribution = service.calculate_price_distribution(
prices=[],
min_price=0,
max_price=0
)
assert distribution.low.count == 0
assert distribution.medium.count == 0
assert distribution.high.count == 0
@pytest.mark.unit
def test_calculate_suggested_range(self):
"""测试建议定价区间计算"""
service = MarketService(None)
suggested = service.calculate_suggested_range(
avg_price=500.0,
min_price=300.0,
max_price=700.0,
benchmark_avg=None
)
# 以均价为中心 ±20%
assert suggested.min == pytest.approx(400.0, rel=0.01) # 500 * 0.8
assert suggested.max == pytest.approx(600.0, rel=0.01) # 500 * 1.2
assert suggested.recommended == 500.0
@pytest.mark.unit
def test_calculate_suggested_range_with_benchmark(self):
"""测试带标杆参考的建议定价区间"""
service = MarketService(None)
suggested = service.calculate_suggested_range(
avg_price=500.0,
min_price=300.0,
max_price=700.0,
benchmark_avg=600.0
)
# 推荐价格 = 市场均价 * 0.6 + 标杆均价 * 0.4
expected_recommended = 500 * 0.6 + 600 * 0.4 # 540
assert suggested.recommended == pytest.approx(expected_recommended, rel=0.01)
@pytest.mark.unit
def test_calculate_suggested_range_zero_avg(self):
"""测试均价为0时的处理"""
service = MarketService(None)
suggested = service.calculate_suggested_range(
avg_price=0,
min_price=0,
max_price=0,
benchmark_avg=None
)
assert suggested.min == 0
assert suggested.max == 0
assert suggested.recommended == 0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_competitor_prices_for_project(
self,
db_session: AsyncSession,
sample_competitor_price: CompetitorPrice,
sample_project: Project
):
"""测试获取项目的竞品价格"""
service = MarketService(db_session)
prices = await service.get_competitor_prices_for_project(
sample_project.id
)
assert len(prices) == 1
assert float(prices[0].original_price) == 680.0
assert float(prices[0].promo_price) == 480.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_competitor_prices_filter_by_competitor(
self,
db_session: AsyncSession,
sample_competitor_price: CompetitorPrice,
sample_project: Project,
sample_competitor: Competitor
):
"""测试按竞品机构筛选价格"""
service = MarketService(db_session)
# 指定竞品ID
prices = await service.get_competitor_prices_for_project(
sample_project.id,
competitor_ids=[sample_competitor.id]
)
assert len(prices) == 1
# 指定不存在的竞品ID
prices = await service.get_competitor_prices_for_project(
sample_project.id,
competitor_ids=[99999]
)
assert len(prices) == 0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_analyze_market(
self,
db_session: AsyncSession,
sample_project: Project,
sample_competitor_price: CompetitorPrice
):
"""测试市场分析"""
service = MarketService(db_session)
result = await service.analyze_market(
project_id=sample_project.id,
include_benchmark=False
)
assert result.project_id == sample_project.id
assert result.project_name == "光子嫩肤"
assert result.competitor_count == 1
assert result.price_statistics.min_price == 680.0
assert result.price_statistics.max_price == 680.0
assert result.price_statistics.avg_price == 680.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_analyze_market_not_found(
self,
db_session: AsyncSession
):
"""测试项目不存在时的错误处理"""
service = MarketService(db_session)
with pytest.raises(ValueError, match="项目不存在"):
await service.analyze_market(project_id=99999)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_latest_analysis(
self,
db_session: AsyncSession,
sample_project: Project,
sample_competitor_price: CompetitorPrice
):
"""测试获取最新分析结果"""
service = MarketService(db_session)
# 先执行分析
await service.analyze_market(
project_id=sample_project.id,
include_benchmark=False
)
# 获取最新结果
latest = await service.get_latest_analysis(sample_project.id)
assert latest is not None
assert latest.project_id == sample_project.id
assert float(latest.market_avg_price) == 680.0
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_benchmark_prices_empty(
self,
db_session: AsyncSession
):
"""测试没有标杆价格时的处理"""
service = MarketService(db_session)
benchmarks = await service.get_benchmark_prices_for_category(None)
assert benchmarks == []
benchmarks = await service.get_benchmark_prices_for_category(99999)
assert benchmarks == []
class TestMarketServiceEdgeCases:
"""市场分析服务边界情况测试"""
@pytest.mark.unit
def test_price_distribution_same_min_max(self):
"""测试最小最大价相同时的分布"""
service = MarketService(None)
distribution = service.calculate_price_distribution(
prices=[500.0, 500.0],
min_price=500.0,
max_price=500.0
)
# 应返回 N/A
assert distribution.low.range == "N/A"
@pytest.mark.unit
def test_statistics_with_outliers(self):
"""测试包含极端值的统计"""
service = MarketService(None)
# 包含一个极端高价
prices = [300.0, 400.0, 500.0, 600.0, 5000.0]
stats = service.calculate_price_statistics(prices)
assert stats.min_price == 300.0
assert stats.max_price == 5000.0
# 均值会被拉高
assert stats.avg_price == 1360.0 # (300+400+500+600+5000)/5
# 中位数不受极端值影响
assert stats.median_price == 500.0
# 标准差会很大
assert stats.std_deviation > 1000

View File

@@ -0,0 +1,369 @@
"""智能定价服务单元测试
测试 PricingService 的核心业务逻辑
"""
import pytest
from decimal import Decimal
from datetime import datetime
from unittest.mock import AsyncMock, patch, MagicMock
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.pricing_service import PricingService
from app.schemas.pricing import (
StrategyType, MarketReference, StrategySuggestion, PricingSuggestions
)
from app.models import Project, ProjectCostSummary, PricingPlan
class TestPricingService:
"""智能定价服务测试类"""
@pytest.mark.unit
def test_calculate_strategy_price_traffic(self):
"""测试引流款定价策略"""
service = PricingService(None)
suggestion = service.calculate_strategy_price(
base_cost=100.0,
strategy=StrategyType.TRAFFIC,
)
# 引流款利润率 10%-20%,使用中间值 15%
# 价格 = 100 / (1 - 0.15) ≈ 117.65
assert suggestion.strategy == "引流款"
assert suggestion.suggested_price > 100 # 大于成本
assert suggestion.suggested_price < 130 # 利润率适中
assert suggestion.margin > 0
assert "引流" in suggestion.description
@pytest.mark.unit
def test_calculate_strategy_price_profit(self):
"""测试利润款定价策略"""
service = PricingService(None)
suggestion = service.calculate_strategy_price(
base_cost=100.0,
strategy=StrategyType.PROFIT,
target_margin=50, # 50% 目标毛利率
)
# 价格 = 100 / (1 - 0.5) = 200
assert suggestion.strategy == "利润款"
assert suggestion.suggested_price >= 200
assert suggestion.margin >= 45 # 接近目标
assert "日常" in suggestion.description
@pytest.mark.unit
def test_calculate_strategy_price_premium(self):
"""测试高端款定价策略"""
service = PricingService(None)
suggestion = service.calculate_strategy_price(
base_cost=100.0,
strategy=StrategyType.PREMIUM,
)
# 高端款利润率 60%-80%,使用中间值 70%
# 价格 = 100 / (1 - 0.7) ≈ 333
assert suggestion.strategy == "高端款"
assert suggestion.suggested_price > 300
assert suggestion.margin > 60
assert "高端" in suggestion.description
@pytest.mark.unit
def test_calculate_strategy_price_with_market_reference(self):
"""测试带市场参考的定价"""
service = PricingService(None)
market_ref = MarketReference(min=80.0, max=150.0, avg=100.0)
# 引流款应该参考市场最低价
suggestion = service.calculate_strategy_price(
base_cost=50.0,
strategy=StrategyType.TRAFFIC,
market_ref=market_ref,
)
# 应该取市场最低价的 90% 和成本定价的较低者
assert suggestion.suggested_price <= 100 # 不会太高
assert suggestion.suggested_price >= 50 * 1.05 # 不低于成本
@pytest.mark.unit
def test_calculate_strategy_price_ensures_profit(self):
"""测试确保价格不低于成本"""
service = PricingService(None)
market_ref = MarketReference(min=30.0, max=50.0, avg=40.0)
# 即使市场价很低,也不能低于成本
suggestion = service.calculate_strategy_price(
base_cost=100.0, # 成本高于市场价
strategy=StrategyType.TRAFFIC,
market_ref=market_ref,
)
# 价格至少是成本的 1.05 倍
assert suggestion.suggested_price >= 100 * 1.05
@pytest.mark.unit
def test_calculate_all_strategies(self):
"""测试计算所有策略"""
service = PricingService(None)
suggestions = service.calculate_all_strategies(
base_cost=100.0,
target_margin=50.0,
)
assert suggestions.traffic is not None
assert suggestions.profit is not None
assert suggestions.premium is not None
# 价格应该递增:引流款 < 利润款 < 高端款
assert suggestions.traffic.suggested_price < suggestions.profit.suggested_price
assert suggestions.profit.suggested_price < suggestions.premium.suggested_price
@pytest.mark.unit
def test_calculate_all_strategies_selected(self):
"""测试只计算选定的策略"""
service = PricingService(None)
suggestions = service.calculate_all_strategies(
base_cost=100.0,
target_margin=50.0,
strategies=[StrategyType.TRAFFIC, StrategyType.PROFIT],
)
assert suggestions.traffic is not None
assert suggestions.profit is not None
assert suggestions.premium is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_project_with_cost(
self,
db_session: AsyncSession,
sample_project_with_costs: Project
):
"""测试获取项目及成本"""
service = PricingService(db_session)
project, cost_summary = await service.get_project_with_cost(
sample_project_with_costs.id
)
assert project.id == sample_project_with_costs.id
assert project.project_name == "光子嫩肤"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_project_with_cost_not_found(
self,
db_session: AsyncSession
):
"""测试项目不存在时的错误处理"""
service = PricingService(db_session)
with pytest.raises(ValueError, match="项目不存在"):
await service.get_project_with_cost(99999)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_create_pricing_plan(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试创建定价方案"""
# 先添加成本汇总
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
material_cost=Decimal("40.00"),
equipment_cost=Decimal("50.00"),
labor_cost=Decimal("60.00"),
fixed_cost_allocation=Decimal("30.00"),
total_cost=Decimal("180.00"),
calculated_at=datetime.now(),
)
db_session.add(cost_summary)
await db_session.commit()
service = PricingService(db_session)
plan = await service.create_pricing_plan(
project_id=sample_project.id,
plan_name="测试定价方案",
strategy_type=StrategyType.PROFIT,
target_margin=50.0,
)
assert plan.project_id == sample_project.id
assert plan.plan_name == "测试定价方案"
assert plan.strategy_type == "profit"
assert float(plan.target_margin) == 50.0
assert float(plan.base_cost) == 180.0
assert plan.suggested_price > plan.base_cost
@pytest.mark.unit
@pytest.mark.asyncio
async def test_update_pricing_plan(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan
):
"""测试更新定价方案"""
service = PricingService(db_session)
updated = await service.update_pricing_plan(
plan_id=sample_pricing_plan.id,
final_price=599.00,
plan_name="更新后方案名",
)
assert float(updated.final_price) == 599.00
assert updated.plan_name == "更新后方案名"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_update_pricing_plan_not_found(
self,
db_session: AsyncSession
):
"""测试更新不存在的方案"""
service = PricingService(db_session)
with pytest.raises(ValueError, match="定价方案不存在"):
await service.update_pricing_plan(
plan_id=99999,
final_price=599.00,
)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_simulate_strategies(
self,
db_session: AsyncSession,
sample_project: Project
):
"""测试策略模拟"""
# 添加成本汇总
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
total_cost=Decimal("200.00"),
material_cost=Decimal("100.00"),
equipment_cost=Decimal("50.00"),
labor_cost=Decimal("50.00"),
fixed_cost_allocation=Decimal("0.00"),
calculated_at=datetime.now(),
)
db_session.add(cost_summary)
await db_session.commit()
service = PricingService(db_session)
response = await service.simulate_strategies(
project_id=sample_project.id,
strategies=[StrategyType.TRAFFIC, StrategyType.PROFIT, StrategyType.PREMIUM],
target_margin=50.0,
)
assert response.project_id == sample_project.id
assert response.base_cost == 200.0
assert len(response.results) == 3
# 验证结果排序
prices = [r.suggested_price for r in response.results]
assert prices == sorted(prices) # 应该是升序
@pytest.mark.unit
def test_format_cost_data(self):
"""测试成本数据格式化"""
service = PricingService(None)
# 测试空数据
result = service._format_cost_data(None)
assert "暂无成本数据" in result
@pytest.mark.unit
def test_format_market_data(self):
"""测试市场数据格式化"""
service = PricingService(None)
# 测试空数据
result = service._format_market_data(None)
assert "暂无市场行情数据" in result
# 测试有数据
market_ref = MarketReference(min=100.0, max=500.0, avg=300.0)
result = service._format_market_data(market_ref)
assert "100.00" in result
assert "500.00" in result
assert "300.00" in result
@pytest.mark.unit
def test_extract_recommendations(self):
"""测试提取 AI 建议列表"""
service = PricingService(None)
content = """
根据分析,建议如下:
- 建议一:常规定价 580 元
- 建议二:新客首单 388 元
* 建议三VIP 会员 520 元
1. 定期促销活动
2. 会员体系建设
"""
recommendations = service._extract_recommendations(content)
assert len(recommendations) == 5
assert "常规定价" in recommendations[0]
class TestPricingServiceWithAI:
"""需要 AI 服务的定价测试"""
@pytest.mark.unit
@pytest.mark.asyncio
@patch('app.services.pricing_service.AIServiceWrapper')
async def test_generate_pricing_advice_ai_failure(
self,
mock_ai_wrapper,
db_session: AsyncSession,
sample_project: Project
):
"""测试 AI 调用失败时的降级处理"""
# 添加成本汇总
cost_summary = ProjectCostSummary(
project_id=sample_project.id,
total_cost=Decimal("200.00"),
material_cost=Decimal("100.00"),
equipment_cost=Decimal("50.00"),
labor_cost=Decimal("50.00"),
fixed_cost_allocation=Decimal("0.00"),
calculated_at=datetime.now(),
)
db_session.add(cost_summary)
await db_session.commit()
# 模拟 AI 调用失败
mock_instance = MagicMock()
mock_instance.chat = AsyncMock(side_effect=Exception("AI 服务不可用"))
mock_ai_wrapper.return_value = mock_instance
service = PricingService(db_session)
# 即使 AI 失败,基本定价计算应该仍然返回
response = await service.generate_pricing_advice(
project_id=sample_project.id,
target_margin=50.0,
)
# 验证基本定价仍然可用
assert response.project_id == sample_project.id
assert response.cost_base == 200.0
assert response.pricing_suggestions is not None
# AI 建议可能为空
assert response.ai_advice is None or response.ai_usage is None

View File

@@ -0,0 +1,211 @@
"""利润模拟服务单元测试
测试 ProfitService 的核心业务逻辑
"""
import pytest
from decimal import Decimal
from unittest.mock import AsyncMock, patch, MagicMock
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.profit_service import ProfitService
from app.schemas.profit import PeriodType
from app.models import PricingPlan, FixedCost
class TestProfitService:
"""利润模拟服务测试类"""
@pytest.mark.unit
def test_calculate_profit_basic(self):
"""测试基础利润计算"""
service = ProfitService(None)
revenue, cost, profit, margin = service.calculate_profit(
price=100.0,
cost_per_unit=60.0,
volume=100
)
assert revenue == 10000.0
assert cost == 6000.0
assert profit == 4000.0
assert margin == 40.0
@pytest.mark.unit
def test_calculate_profit_zero_revenue(self):
"""测试零收入时的处理"""
service = ProfitService(None)
revenue, cost, profit, margin = service.calculate_profit(
price=100.0,
cost_per_unit=60.0,
volume=0
)
assert revenue == 0
assert cost == 0
assert profit == 0
assert margin == 0
@pytest.mark.unit
def test_calculate_profit_negative(self):
"""测试亏损情况"""
service = ProfitService(None)
revenue, cost, profit, margin = service.calculate_profit(
price=50.0,
cost_per_unit=60.0,
volume=100
)
assert revenue == 5000.0
assert cost == 6000.0
assert profit == -1000.0
assert margin == -20.0
@pytest.mark.unit
def test_calculate_breakeven_basic(self):
"""测试基础盈亏平衡计算"""
service = ProfitService(None)
breakeven = service.calculate_breakeven(
price=100.0,
variable_cost=60.0,
fixed_cost=0
)
assert breakeven == 1
@pytest.mark.unit
def test_calculate_breakeven_with_fixed_cost(self):
"""测试有固定成本的盈亏平衡"""
service = ProfitService(None)
breakeven = service.calculate_breakeven(
price=100.0,
variable_cost=60.0,
fixed_cost=4000.0
)
assert breakeven == 101
@pytest.mark.unit
def test_calculate_breakeven_no_margin(self):
"""测试边际贡献为负时的处理"""
service = ProfitService(None)
breakeven = service.calculate_breakeven(
price=50.0,
variable_cost=60.0,
fixed_cost=1000.0
)
assert breakeven == 999999
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_pricing_plan(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan
):
"""测试获取定价方案"""
service = ProfitService(db_session)
plan = await service.get_pricing_plan(sample_pricing_plan.id)
assert plan.id == sample_pricing_plan.id
assert plan.plan_name == "2026年Q1定价"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_pricing_plan_not_found(
self,
db_session: AsyncSession
):
"""测试获取不存在的方案"""
service = ProfitService(db_session)
with pytest.raises(ValueError, match="定价方案不存在"):
await service.get_pricing_plan(99999)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_get_monthly_fixed_cost(
self,
db_session: AsyncSession,
sample_fixed_cost: FixedCost
):
"""测试获取月度固定成本"""
service = ProfitService(db_session)
total = await service.get_monthly_fixed_cost()
assert total == Decimal("30000.00")
@pytest.mark.unit
@pytest.mark.asyncio
async def test_simulate_profit(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan
):
"""测试利润模拟"""
service = ProfitService(db_session)
response = await service.simulate_profit(
pricing_plan_id=sample_pricing_plan.id,
price=580.0,
estimated_volume=100,
period_type=PeriodType.MONTHLY,
)
assert response.pricing_plan_id == sample_pricing_plan.id
assert response.input.price == 580.0
assert response.input.estimated_volume == 100
@pytest.mark.unit
@pytest.mark.asyncio
async def test_sensitivity_analysis(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan
):
"""测试敏感性分析"""
service = ProfitService(db_session)
sim_response = await service.simulate_profit(
pricing_plan_id=sample_pricing_plan.id,
price=580.0,
estimated_volume=100,
period_type=PeriodType.MONTHLY,
)
response = await service.sensitivity_analysis(
simulation_id=sim_response.simulation_id,
price_change_rates=[-20, -10, 0, 10, 20]
)
assert response.simulation_id == sim_response.simulation_id
assert len(response.sensitivity_results) == 5
@pytest.mark.unit
@pytest.mark.asyncio
async def test_breakeven_analysis(
self,
db_session: AsyncSession,
sample_pricing_plan: PricingPlan,
sample_fixed_cost: FixedCost
):
"""测试盈亏平衡分析"""
service = ProfitService(db_session)
response = await service.breakeven_analysis(
pricing_plan_id=sample_pricing_plan.id
)
assert response.pricing_plan_id == sample_pricing_plan.id
assert response.price > 0
assert response.breakeven_volume > 0