- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
1283 lines
34 KiB
Markdown
1283 lines
34 KiB
Markdown
# 考培练系统后端统一基础代码
|
||
|
||
## 1. 基础配置文件
|
||
|
||
### 1.1 主应用入口 (app/main.py)
|
||
```python
|
||
"""
|
||
考培练系统后端主应用入口
|
||
"""
|
||
import time
|
||
from contextlib import asynccontextmanager
|
||
from typing import Dict, Any
|
||
|
||
from fastapi import FastAPI, Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
|
||
from app.api.v1 import api_router
|
||
from app.config.settings import settings
|
||
from app.core.exceptions import BaseAPIException
|
||
from app.core.logger import logger, setup_logging
|
||
from app.core.middleware import (
|
||
RequestIDMiddleware,
|
||
LoggingMiddleware,
|
||
RateLimitMiddleware
|
||
)
|
||
from app.core.events import startup_handler, shutdown_handler
|
||
|
||
# 配置日志
|
||
setup_logging(log_level=settings.LOG_LEVEL)
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理"""
|
||
# 启动时执行
|
||
await startup_handler()
|
||
logger.info("应用启动完成", version=settings.VERSION)
|
||
yield
|
||
# 关闭时执行
|
||
await shutdown_handler()
|
||
logger.info("应用关闭完成")
|
||
|
||
# 创建FastAPI应用
|
||
app = FastAPI(
|
||
title=settings.PROJECT_NAME,
|
||
description="考培练系统后端API",
|
||
version=settings.VERSION,
|
||
openapi_url=f"{settings.API_V1_STR}/openapi.json" if not settings.PRODUCTION else None,
|
||
docs_url="/docs" if not settings.PRODUCTION else None,
|
||
redoc_url="/redoc" if not settings.PRODUCTION else None,
|
||
lifespan=lifespan
|
||
)
|
||
|
||
# 配置CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=settings.CORS_ORIGINS,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# 添加可信主机中间件
|
||
app.add_middleware(
|
||
TrustedHostMiddleware,
|
||
allowed_hosts=settings.ALLOWED_HOSTS
|
||
)
|
||
|
||
# 添加自定义中间件
|
||
app.add_middleware(RequestIDMiddleware)
|
||
app.add_middleware(LoggingMiddleware)
|
||
app.add_middleware(RateLimitMiddleware)
|
||
|
||
# 全局异常处理
|
||
@app.exception_handler(BaseAPIException)
|
||
async def api_exception_handler(request: Request, exc: BaseAPIException):
|
||
"""处理自定义API异常"""
|
||
return JSONResponse(
|
||
status_code=exc.code,
|
||
content={
|
||
"code": exc.code,
|
||
"message": exc.message,
|
||
"details": exc.details,
|
||
"request_id": request.state.request_id
|
||
}
|
||
)
|
||
|
||
@app.exception_handler(Exception)
|
||
async def global_exception_handler(request: Request, exc: Exception):
|
||
"""处理所有未捕获的异常"""
|
||
logger.error(
|
||
"未处理的异常",
|
||
exc_info=exc,
|
||
request_id=request.state.request_id,
|
||
path=request.url.path
|
||
)
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={
|
||
"code": 500,
|
||
"message": "服务器内部错误",
|
||
"request_id": request.state.request_id
|
||
}
|
||
)
|
||
|
||
# 注册路由
|
||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||
|
||
# 根路径
|
||
@app.get("/", tags=["Root"])
|
||
async def root() -> Dict[str, Any]:
|
||
"""根路径,返回系统基本信息"""
|
||
return {
|
||
"name": settings.PROJECT_NAME,
|
||
"version": settings.VERSION,
|
||
"status": "running",
|
||
"timestamp": time.time()
|
||
}
|
||
|
||
# 健康检查
|
||
@app.get("/health", tags=["Health"])
|
||
async def health() -> Dict[str, str]:
|
||
"""健康检查端点"""
|
||
return {"status": "healthy"}
|
||
```
|
||
|
||
### 1.2 系统配置 (app/config/settings.py)
|
||
```python
|
||
"""
|
||
系统配置管理
|
||
使用Pydantic进行配置验证和管理
|
||
"""
|
||
from typing import List, Optional, Union
|
||
from pathlib import Path
|
||
|
||
from pydantic import BaseSettings, AnyHttpUrl, validator
|
||
|
||
|
||
class Settings(BaseSettings):
|
||
"""系统配置类"""
|
||
|
||
# 基础配置
|
||
PROJECT_NAME: str = "考培练系统"
|
||
VERSION: str = "1.0.0"
|
||
API_V1_STR: str = "/api/v1"
|
||
|
||
# 环境配置
|
||
PRODUCTION: bool = False
|
||
DEBUG: bool = True
|
||
LOG_LEVEL: str = "INFO"
|
||
|
||
# 安全配置
|
||
SECRET_KEY: str # 必须设置
|
||
ALGORITHM: str = "HS256"
|
||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||
|
||
# 跨域配置
|
||
CORS_ORIGINS: List[AnyHttpUrl] = []
|
||
ALLOWED_HOSTS: List[str] = ["*"]
|
||
|
||
@validator("CORS_ORIGINS", pre=True)
|
||
def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
|
||
if isinstance(v, str) and not v.startswith("["):
|
||
return [i.strip() for i in v.split(",")]
|
||
elif isinstance(v, (list, str)):
|
||
return v
|
||
raise ValueError(v)
|
||
|
||
# 数据库配置
|
||
DATABASE_URL: str
|
||
DATABASE_POOL_SIZE: int = 10
|
||
DATABASE_MAX_OVERFLOW: int = 20
|
||
|
||
# Redis配置
|
||
REDIS_URL: str = "redis://localhost:6379/0"
|
||
REDIS_PASSWORD: Optional[str] = None
|
||
|
||
# AI平台配置
|
||
# Coze配置
|
||
COZE_API_BASE: str = "https://api.coze.cn"
|
||
COZE_WORKSPACE_ID: str
|
||
COZE_API_TOKEN: str
|
||
COZE_TRAINING_BOT_ID: str
|
||
COZE_CHAT_BOT_ID: str
|
||
|
||
# Dify配置
|
||
DIFY_API_BASE: str = "https://api.dify.ai/v1"
|
||
DIFY_API_KEY: str
|
||
DIFY_EXAM_WORKFLOW_ID: str
|
||
DIFY_ASSESSMENT_WORKFLOW_ID: str
|
||
DIFY_RECOMMEND_WORKFLOW_ID: str
|
||
|
||
# 文件上传配置
|
||
UPLOAD_DIR: Path = Path("uploads")
|
||
MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024 # 10MB
|
||
ALLOWED_UPLOAD_EXTENSIONS: List[str] = [
|
||
".pdf", ".doc", ".docx", ".xls", ".xlsx",
|
||
".png", ".jpg", ".jpeg", ".mp4", ".mp3"
|
||
]
|
||
|
||
# 限流配置
|
||
RATE_LIMIT_ENABLED: bool = True
|
||
RATE_LIMIT_PER_MINUTE: int = 60
|
||
|
||
# 日志配置
|
||
LOG_DIR: Optional[Path] = Path("logs")
|
||
LOG_ROTATION: str = "1 day"
|
||
LOG_RETENTION: str = "30 days"
|
||
|
||
class Config:
|
||
env_file = ".env"
|
||
case_sensitive = True
|
||
|
||
# 创建全局配置实例
|
||
settings = Settings()
|
||
|
||
# 确保必要的目录存在
|
||
settings.UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||
if settings.LOG_DIR:
|
||
settings.LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||
```
|
||
|
||
### 1.3 数据库配置 (app/config/database.py)
|
||
```python
|
||
"""
|
||
数据库配置和会话管理
|
||
"""
|
||
from typing import AsyncGenerator
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||
from sqlalchemy.orm import declarative_base
|
||
from sqlalchemy.pool import NullPool
|
||
|
||
from app.config.settings import settings
|
||
|
||
# 创建异步引擎
|
||
engine = create_async_engine(
|
||
settings.DATABASE_URL,
|
||
echo=settings.DEBUG,
|
||
pool_size=settings.DATABASE_POOL_SIZE,
|
||
max_overflow=settings.DATABASE_MAX_OVERFLOW,
|
||
pool_pre_ping=True, # 启用连接健康检查
|
||
poolclass=NullPool if settings.DEBUG else None, # 开发环境使用NullPool
|
||
)
|
||
|
||
# 创建异步会话工厂
|
||
async_session = async_sessionmaker(
|
||
engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autocommit=False,
|
||
autoflush=False,
|
||
)
|
||
|
||
# 声明基类
|
||
Base = declarative_base()
|
||
|
||
# 数据库会话依赖
|
||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||
"""获取数据库会话"""
|
||
async with async_session() as session:
|
||
try:
|
||
yield session
|
||
await session.commit()
|
||
except Exception:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
await session.close()
|
||
```
|
||
|
||
## 2. 核心模块代码
|
||
|
||
### 2.1 基础模型类 (app/models/base.py)
|
||
```python
|
||
"""
|
||
数据库模型基类
|
||
"""
|
||
from datetime import datetime
|
||
from typing import Optional
|
||
|
||
from sqlalchemy import Column, BigInteger, DateTime, Boolean, String
|
||
from sqlalchemy.ext.declarative import declared_attr
|
||
from sqlalchemy.sql import func
|
||
|
||
from app.config.database import Base
|
||
|
||
|
||
class BaseModel(Base):
|
||
"""模型基类,包含通用字段"""
|
||
__abstract__ = True
|
||
|
||
id = Column(BigInteger, primary_key=True, index=True, comment="主键ID")
|
||
created_at = Column(
|
||
DateTime(timezone=True),
|
||
server_default=func.now(),
|
||
nullable=False,
|
||
comment="创建时间"
|
||
)
|
||
updated_at = Column(
|
||
DateTime(timezone=True),
|
||
server_default=func.now(),
|
||
onupdate=func.now(),
|
||
nullable=False,
|
||
comment="更新时间"
|
||
)
|
||
|
||
@declared_attr
|
||
def __tablename__(cls) -> str:
|
||
"""自动生成表名(类名转换为蛇形命名法的复数形式)"""
|
||
import re
|
||
# CamelCase to snake_case
|
||
name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', cls.__name__)
|
||
name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
|
||
# 简单的复数形式(实际项目中可能需要更复杂的规则)
|
||
if name.endswith('y'):
|
||
return name[:-1] + 'ies'
|
||
elif name.endswith('s'):
|
||
return name + 'es'
|
||
else:
|
||
return name + 's'
|
||
|
||
|
||
class SoftDeleteMixin:
|
||
"""软删除混入类"""
|
||
is_deleted = Column(Boolean, default=False, nullable=False, comment="是否删除")
|
||
deleted_at = Column(DateTime(timezone=True), comment="删除时间")
|
||
deleted_by = Column(BigInteger, comment="删除人ID")
|
||
|
||
|
||
class AuditMixin:
|
||
"""审计混入类"""
|
||
created_by = Column(BigInteger, comment="创建人ID")
|
||
updated_by = Column(BigInteger, comment="更新人ID")
|
||
```
|
||
|
||
### 2.2 基础Schema (app/schemas/base.py)
|
||
```python
|
||
"""
|
||
Pydantic模型基类
|
||
"""
|
||
from datetime import datetime
|
||
from typing import Optional, Generic, TypeVar, List, Any
|
||
|
||
from pydantic import BaseModel, Field
|
||
|
||
|
||
# 配置所有模型的通用行为
|
||
class BaseSchema(BaseModel):
|
||
"""Schema基类"""
|
||
|
||
class Config:
|
||
# 允许从ORM模型创建
|
||
orm_mode = True
|
||
# 使用枚举的值而不是名称
|
||
use_enum_values = True
|
||
# 验证赋值
|
||
validate_assignment = True
|
||
# 允许population by field name
|
||
allow_population_by_field_name = True
|
||
# JSON编码器
|
||
json_encoders = {
|
||
datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
|
||
}
|
||
|
||
|
||
class TimestampMixin(BaseModel):
|
||
"""时间戳混入类"""
|
||
created_at: datetime = Field(..., description="创建时间")
|
||
updated_at: datetime = Field(..., description="更新时间")
|
||
|
||
|
||
class IDMixin(BaseModel):
|
||
"""ID混入类"""
|
||
id: int = Field(..., description="ID")
|
||
|
||
|
||
# 通用响应模型
|
||
T = TypeVar('T')
|
||
|
||
class ResponseModel(BaseModel, Generic[T]):
|
||
"""统一响应模型"""
|
||
code: int = Field(200, description="响应码")
|
||
message: str = Field("success", description="响应消息")
|
||
data: Optional[T] = Field(None, description="响应数据")
|
||
request_id: Optional[str] = Field(None, description="请求ID")
|
||
|
||
|
||
class PageModel(BaseModel, Generic[T]):
|
||
"""分页响应模型"""
|
||
items: List[T] = Field(..., description="数据列表")
|
||
total: int = Field(..., description="总数")
|
||
page: int = Field(..., description="当前页")
|
||
size: int = Field(..., description="每页大小")
|
||
pages: int = Field(..., description="总页数")
|
||
|
||
|
||
class ErrorResponse(BaseModel):
|
||
"""错误响应模型"""
|
||
code: int = Field(..., description="错误码")
|
||
message: str = Field(..., description="错误消息")
|
||
details: Optional[str] = Field(None, description="错误详情")
|
||
request_id: Optional[str] = Field(None, description="请求ID")
|
||
|
||
|
||
# 通用查询参数
|
||
class QueryParams(BaseModel):
|
||
"""通用查询参数"""
|
||
page: int = Field(1, ge=1, description="页码")
|
||
size: int = Field(10, ge=1, le=100, description="每页大小")
|
||
order_by: Optional[str] = Field(None, description="排序字段")
|
||
order: Optional[str] = Field("desc", regex="^(asc|desc)$", description="排序方向")
|
||
|
||
@property
|
||
def skip(self) -> int:
|
||
"""计算跳过的记录数"""
|
||
return (self.page - 1) * self.size
|
||
```
|
||
|
||
### 2.3 认证依赖 (app/api/deps.py)
|
||
```python
|
||
"""
|
||
API依赖注入
|
||
"""
|
||
from typing import Optional, Generator
|
||
|
||
from fastapi import Depends, HTTPException, status
|
||
from fastapi.security import OAuth2PasswordBearer
|
||
from jose import JWTError, jwt
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.config.database import get_db
|
||
from app.config.settings import settings
|
||
from app.core.security import verify_token
|
||
from app.models.user import User
|
||
from app.services.user_service import UserService
|
||
|
||
# OAuth2 scheme
|
||
oauth2_scheme = OAuth2PasswordBearer(
|
||
tokenUrl=f"{settings.API_V1_STR}/auth/login"
|
||
)
|
||
|
||
# 可选的OAuth2 scheme
|
||
optional_oauth2_scheme = OAuth2PasswordBearer(
|
||
tokenUrl=f"{settings.API_V1_STR}/auth/login",
|
||
auto_error=False
|
||
)
|
||
|
||
|
||
async def get_current_user(
|
||
db: AsyncSession = Depends(get_db),
|
||
token: str = Depends(oauth2_scheme)
|
||
) -> User:
|
||
"""获取当前登录用户"""
|
||
credentials_exception = HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的认证凭据",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
try:
|
||
payload = verify_token(token)
|
||
user_id: int = payload.get("sub")
|
||
if user_id is None:
|
||
raise credentials_exception
|
||
except JWTError:
|
||
raise credentials_exception
|
||
|
||
user_service = UserService(db)
|
||
user = await user_service.get_by_id(user_id)
|
||
if user is None:
|
||
raise credentials_exception
|
||
|
||
if not user.is_active:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户已被禁用"
|
||
)
|
||
|
||
return user
|
||
|
||
|
||
async def get_current_active_user(
|
||
current_user: User = Depends(get_current_user)
|
||
) -> User:
|
||
"""获取当前活跃用户"""
|
||
if not current_user.is_active:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="用户已被禁用"
|
||
)
|
||
return current_user
|
||
|
||
|
||
async def get_optional_current_user(
|
||
db: AsyncSession = Depends(get_db),
|
||
token: Optional[str] = Depends(optional_oauth2_scheme)
|
||
) -> Optional[User]:
|
||
"""获取可选的当前用户(用于公开接口的用户识别)"""
|
||
if not token:
|
||
return None
|
||
|
||
try:
|
||
payload = verify_token(token)
|
||
user_id: int = payload.get("sub")
|
||
if user_id is None:
|
||
return None
|
||
|
||
user_service = UserService(db)
|
||
user = await user_service.get_by_id(user_id)
|
||
return user if user and user.is_active else None
|
||
except JWTError:
|
||
return None
|
||
|
||
|
||
class RoleChecker:
|
||
"""角色检查依赖类"""
|
||
|
||
def __init__(self, allowed_roles: List[str]):
|
||
self.allowed_roles = allowed_roles
|
||
|
||
def __call__(self, user: User = Depends(get_current_active_user)) -> User:
|
||
if user.role not in self.allowed_roles:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="权限不足"
|
||
)
|
||
return user
|
||
|
||
|
||
# 常用的角色检查器
|
||
require_admin = RoleChecker(["admin"])
|
||
require_manager = RoleChecker(["admin", "manager"])
|
||
require_trainer = RoleChecker(["admin", "manager", "trainer"])
|
||
```
|
||
|
||
### 2.4 统一异常处理 (app/core/exceptions.py)
|
||
```python
|
||
"""
|
||
统一异常定义
|
||
"""
|
||
from typing import Optional, Dict, Any
|
||
|
||
|
||
class BaseAPIException(Exception):
|
||
"""API异常基类"""
|
||
|
||
def __init__(
|
||
self,
|
||
code: int,
|
||
message: str,
|
||
details: Optional[str] = None
|
||
):
|
||
self.code = code
|
||
self.message = message
|
||
self.details = details
|
||
super().__init__(message)
|
||
|
||
|
||
class BadRequestError(BaseAPIException):
|
||
"""请求错误 400"""
|
||
|
||
def __init__(self, message: str = "请求参数错误", details: str = None):
|
||
super().__init__(400, message, details)
|
||
|
||
|
||
class UnauthorizedError(BaseAPIException):
|
||
"""未授权错误 401"""
|
||
|
||
def __init__(self, message: str = "未授权访问", details: str = None):
|
||
super().__init__(401, message, details)
|
||
|
||
|
||
class ForbiddenError(BaseAPIException):
|
||
"""禁止访问错误 403"""
|
||
|
||
def __init__(self, message: str = "禁止访问", details: str = None):
|
||
super().__init__(403, message, details)
|
||
|
||
|
||
class NotFoundError(BaseAPIException):
|
||
"""资源不存在错误 404"""
|
||
|
||
def __init__(self, resource: str = "资源", details: str = None):
|
||
super().__init__(404, f"{resource}不存在", details)
|
||
|
||
|
||
class ConflictError(BaseAPIException):
|
||
"""冲突错误 409"""
|
||
|
||
def __init__(self, message: str = "资源冲突", details: str = None):
|
||
super().__init__(409, message, details)
|
||
|
||
|
||
class ValidationError(BaseAPIException):
|
||
"""验证错误 422"""
|
||
|
||
def __init__(self, message: str = "数据验证失败", details: str = None):
|
||
super().__init__(422, message, details)
|
||
|
||
|
||
class InternalServerError(BaseAPIException):
|
||
"""服务器内部错误 500"""
|
||
|
||
def __init__(self, message: str = "服务器内部错误", details: str = None):
|
||
super().__init__(500, message, details)
|
||
|
||
|
||
class ExternalServiceError(BaseAPIException):
|
||
"""外部服务错误 502"""
|
||
|
||
def __init__(self, service: str, details: str = None):
|
||
super().__init__(502, f"{service}服务异常", details)
|
||
|
||
|
||
class ServiceUnavailableError(BaseAPIException):
|
||
"""服务不可用错误 503"""
|
||
|
||
def __init__(self, message: str = "服务暂时不可用", details: str = None):
|
||
super().__init__(503, message, details)
|
||
```
|
||
|
||
### 2.5 日志配置 (app/core/logger.py)
|
||
```python
|
||
"""
|
||
统一日志配置
|
||
"""
|
||
import logging
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
import structlog
|
||
from structlog.stdlib import LoggerFactory
|
||
|
||
from app.config.settings import settings
|
||
|
||
|
||
def setup_logging(
|
||
log_level: str = "INFO",
|
||
log_dir: Optional[Path] = None
|
||
) -> structlog.BoundLogger:
|
||
"""
|
||
配置结构化日志
|
||
|
||
Args:
|
||
log_level: 日志级别
|
||
log_dir: 日志目录
|
||
|
||
Returns:
|
||
配置好的日志器
|
||
"""
|
||
# 设置时间戳格式
|
||
timestamper = structlog.processors.TimeStamper(fmt="iso")
|
||
|
||
# 配置处理器链
|
||
shared_processors = [
|
||
structlog.stdlib.add_logger_name,
|
||
structlog.stdlib.add_log_level,
|
||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||
timestamper,
|
||
structlog.processors.StackInfoRenderer(),
|
||
structlog.processors.format_exc_info,
|
||
structlog.processors.UnicodeDecoder(),
|
||
structlog.contextvars.merge_contextvars,
|
||
structlog.processors.CallsiteParameterAdder(
|
||
parameters=[
|
||
structlog.processors.CallsiteParameter.FILENAME,
|
||
structlog.processors.CallsiteParameter.LINENO,
|
||
structlog.processors.CallsiteParameter.FUNC_NAME,
|
||
]
|
||
),
|
||
]
|
||
|
||
# 配置structlog
|
||
structlog.configure(
|
||
processors=shared_processors + [
|
||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||
],
|
||
logger_factory=LoggerFactory(),
|
||
cache_logger_on_first_use=True,
|
||
)
|
||
|
||
# 配置标准库日志
|
||
formatter = structlog.stdlib.ProcessorFormatter(
|
||
processors=[
|
||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||
structlog.dev.ConsoleRenderer() if settings.DEBUG else structlog.processors.JSONRenderer(),
|
||
],
|
||
)
|
||
|
||
# 控制台处理器
|
||
console_handler = logging.StreamHandler(sys.stdout)
|
||
console_handler.setFormatter(formatter)
|
||
console_handler.setLevel(getattr(logging, log_level.upper()))
|
||
|
||
# 配置根日志器
|
||
root_logger = logging.getLogger()
|
||
root_logger.handlers.clear()
|
||
root_logger.addHandler(console_handler)
|
||
root_logger.setLevel(getattr(logging, log_level.upper()))
|
||
|
||
# 文件处理器(如果指定了日志目录)
|
||
if log_dir and settings.LOG_DIR:
|
||
from logging.handlers import RotatingFileHandler
|
||
|
||
log_file = settings.LOG_DIR / "app.log"
|
||
file_handler = RotatingFileHandler(
|
||
log_file,
|
||
maxBytes=10 * 1024 * 1024, # 10MB
|
||
backupCount=10
|
||
)
|
||
file_handler.setFormatter(
|
||
structlog.stdlib.ProcessorFormatter(
|
||
processors=[
|
||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||
structlog.processors.JSONRenderer(),
|
||
],
|
||
)
|
||
)
|
||
file_handler.setLevel(logging.INFO)
|
||
root_logger.addHandler(file_handler)
|
||
|
||
# 设置第三方库的日志级别
|
||
logging.getLogger("uvicorn").setLevel(logging.INFO)
|
||
logging.getLogger("sqlalchemy").setLevel(logging.WARNING)
|
||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||
|
||
return structlog.get_logger()
|
||
|
||
|
||
# 创建全局日志器
|
||
logger = setup_logging(log_level=settings.LOG_LEVEL, log_dir=settings.LOG_DIR)
|
||
```
|
||
|
||
## 3. 业务服务基类
|
||
|
||
### 3.1 基础服务类 (app/services/base_service.py)
|
||
```python
|
||
"""
|
||
业务服务基类
|
||
"""
|
||
from typing import TypeVar, Generic, Type, Optional, List, Dict, Any
|
||
from datetime import datetime
|
||
|
||
from sqlalchemy import select, func, and_, or_
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import selectinload
|
||
|
||
from app.models.base import BaseModel
|
||
from app.schemas.base import BaseSchema, QueryParams, PageModel
|
||
from app.core.exceptions import NotFoundError, BadRequestError
|
||
from app.core.logger import logger
|
||
|
||
ModelType = TypeVar("ModelType", bound=BaseModel)
|
||
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseSchema)
|
||
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseSchema)
|
||
|
||
|
||
class BaseService(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||
"""
|
||
基础CRUD服务类
|
||
提供通用的增删改查功能
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model: Type[ModelType],
|
||
db: AsyncSession
|
||
):
|
||
self.model = model
|
||
self.db = db
|
||
self.logger = logger.bind(service=self.__class__.__name__)
|
||
|
||
async def get(self, id: int) -> Optional[ModelType]:
|
||
"""根据ID获取单个对象"""
|
||
query = select(self.model).where(self.model.id == id)
|
||
result = await self.db.execute(query)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_or_404(self, id: int) -> ModelType:
|
||
"""根据ID获取单个对象,不存在则抛出404错误"""
|
||
obj = await self.get(id)
|
||
if not obj:
|
||
raise NotFoundError(f"{self.model.__name__}")
|
||
return obj
|
||
|
||
async def get_multi(
|
||
self,
|
||
*,
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
order_by: Optional[str] = None,
|
||
order: str = "desc"
|
||
) -> List[ModelType]:
|
||
"""获取多个对象"""
|
||
query = select(self.model)
|
||
|
||
# 排序
|
||
if order_by and hasattr(self.model, order_by):
|
||
order_column = getattr(self.model, order_by)
|
||
if order == "asc":
|
||
query = query.order_by(order_column.asc())
|
||
else:
|
||
query = query.order_by(order_column.desc())
|
||
else:
|
||
query = query.order_by(self.model.id.desc())
|
||
|
||
# 分页
|
||
query = query.offset(skip).limit(limit)
|
||
|
||
result = await self.db.execute(query)
|
||
return result.scalars().all()
|
||
|
||
async def get_page(
|
||
self,
|
||
params: QueryParams,
|
||
filters: Optional[List] = None
|
||
) -> PageModel[ModelType]:
|
||
"""分页查询"""
|
||
# 构建查询
|
||
query = select(self.model)
|
||
count_query = select(func.count()).select_from(self.model)
|
||
|
||
# 添加过滤条件
|
||
if filters:
|
||
for filter_condition in filters:
|
||
query = query.where(filter_condition)
|
||
count_query = count_query.where(filter_condition)
|
||
|
||
# 获取总数
|
||
total_result = await self.db.execute(count_query)
|
||
total = total_result.scalar()
|
||
|
||
# 排序
|
||
if params.order_by and hasattr(self.model, params.order_by):
|
||
order_column = getattr(self.model, params.order_by)
|
||
if params.order == "asc":
|
||
query = query.order_by(order_column.asc())
|
||
else:
|
||
query = query.order_by(order_column.desc())
|
||
else:
|
||
query = query.order_by(self.model.id.desc())
|
||
|
||
# 分页
|
||
query = query.offset(params.skip).limit(params.size)
|
||
|
||
# 执行查询
|
||
result = await self.db.execute(query)
|
||
items = result.scalars().all()
|
||
|
||
# 计算总页数
|
||
pages = (total + params.size - 1) // params.size
|
||
|
||
return PageModel(
|
||
items=items,
|
||
total=total,
|
||
page=params.page,
|
||
size=params.size,
|
||
pages=pages
|
||
)
|
||
|
||
async def create(
|
||
self,
|
||
*,
|
||
obj_in: CreateSchemaType,
|
||
**extra_fields
|
||
) -> ModelType:
|
||
"""创建对象"""
|
||
# 转换为字典
|
||
obj_data = obj_in.dict()
|
||
obj_data.update(extra_fields)
|
||
|
||
# 创建模型实例
|
||
db_obj = self.model(**obj_data)
|
||
|
||
# 保存到数据库
|
||
self.db.add(db_obj)
|
||
await self.db.commit()
|
||
await self.db.refresh(db_obj)
|
||
|
||
self.logger.info(
|
||
"创建对象",
|
||
model=self.model.__name__,
|
||
id=db_obj.id
|
||
)
|
||
|
||
return db_obj
|
||
|
||
async def update(
|
||
self,
|
||
*,
|
||
db_obj: ModelType,
|
||
obj_in: UpdateSchemaType,
|
||
**extra_fields
|
||
) -> ModelType:
|
||
"""更新对象"""
|
||
# 转换为字典,排除未设置的值
|
||
update_data = obj_in.dict(exclude_unset=True)
|
||
update_data.update(extra_fields)
|
||
|
||
# 更新字段
|
||
for field, value in update_data.items():
|
||
if hasattr(db_obj, field):
|
||
setattr(db_obj, field, value)
|
||
|
||
# 更新时间
|
||
if hasattr(db_obj, "updated_at"):
|
||
db_obj.updated_at = datetime.utcnow()
|
||
|
||
# 保存到数据库
|
||
self.db.add(db_obj)
|
||
await self.db.commit()
|
||
await self.db.refresh(db_obj)
|
||
|
||
self.logger.info(
|
||
"更新对象",
|
||
model=self.model.__name__,
|
||
id=db_obj.id
|
||
)
|
||
|
||
return db_obj
|
||
|
||
async def delete(self, *, id: int) -> bool:
|
||
"""删除对象(物理删除)"""
|
||
db_obj = await self.get_or_404(id)
|
||
|
||
await self.db.delete(db_obj)
|
||
await self.db.commit()
|
||
|
||
self.logger.info(
|
||
"删除对象",
|
||
model=self.model.__name__,
|
||
id=id
|
||
)
|
||
|
||
return True
|
||
|
||
async def soft_delete(
|
||
self,
|
||
*,
|
||
id: int,
|
||
deleted_by: Optional[int] = None
|
||
) -> ModelType:
|
||
"""软删除对象"""
|
||
db_obj = await self.get_or_404(id)
|
||
|
||
if hasattr(db_obj, "is_deleted"):
|
||
db_obj.is_deleted = True
|
||
if hasattr(db_obj, "deleted_at"):
|
||
db_obj.deleted_at = datetime.utcnow()
|
||
if hasattr(db_obj, "deleted_by") and deleted_by:
|
||
db_obj.deleted_by = deleted_by
|
||
|
||
await self.db.commit()
|
||
await self.db.refresh(db_obj)
|
||
|
||
self.logger.info(
|
||
"软删除对象",
|
||
model=self.model.__name__,
|
||
id=id,
|
||
deleted_by=deleted_by
|
||
)
|
||
|
||
return db_obj
|
||
else:
|
||
raise BadRequestError("该对象不支持软删除")
|
||
|
||
async def count(self, filters: Optional[List] = None) -> int:
|
||
"""统计数量"""
|
||
query = select(func.count()).select_from(self.model)
|
||
|
||
if filters:
|
||
for filter_condition in filters:
|
||
query = query.where(filter_condition)
|
||
|
||
result = await self.db.execute(query)
|
||
return result.scalar()
|
||
|
||
async def exists(self, **kwargs) -> bool:
|
||
"""检查是否存在"""
|
||
query = select(self.model)
|
||
|
||
for key, value in kwargs.items():
|
||
if hasattr(self.model, key):
|
||
query = query.where(getattr(self.model, key) == value)
|
||
|
||
query = query.limit(1)
|
||
result = await self.db.execute(query)
|
||
return result.scalar_one_or_none() is not None
|
||
```
|
||
|
||
## 4. 常用工具函数
|
||
|
||
### 4.1 安全工具 (app/core/security.py)
|
||
```python
|
||
"""
|
||
安全相关工具函数
|
||
"""
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, Dict, Any
|
||
|
||
from jose import JWTError, jwt
|
||
from passlib.context import CryptContext
|
||
|
||
from app.config.settings import settings
|
||
|
||
# 密码加密上下文
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
|
||
|
||
def create_access_token(
|
||
subject: int,
|
||
expires_delta: Optional[timedelta] = None,
|
||
**extra_claims
|
||
) -> str:
|
||
"""
|
||
创建访问令牌
|
||
|
||
Args:
|
||
subject: 用户ID
|
||
expires_delta: 过期时间
|
||
extra_claims: 额外的声明
|
||
|
||
Returns:
|
||
JWT令牌
|
||
"""
|
||
if expires_delta:
|
||
expire = datetime.utcnow() + expires_delta
|
||
else:
|
||
expire = datetime.utcnow() + timedelta(
|
||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||
)
|
||
|
||
to_encode = {
|
||
"exp": expire,
|
||
"sub": str(subject),
|
||
"type": "access",
|
||
**extra_claims
|
||
}
|
||
|
||
encoded_jwt = jwt.encode(
|
||
to_encode,
|
||
settings.SECRET_KEY,
|
||
algorithm=settings.ALGORITHM
|
||
)
|
||
return encoded_jwt
|
||
|
||
|
||
def create_refresh_token(
|
||
subject: int,
|
||
expires_delta: Optional[timedelta] = None
|
||
) -> str:
|
||
"""创建刷新令牌"""
|
||
if expires_delta:
|
||
expire = datetime.utcnow() + expires_delta
|
||
else:
|
||
expire = datetime.utcnow() + timedelta(
|
||
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
|
||
)
|
||
|
||
to_encode = {
|
||
"exp": expire,
|
||
"sub": str(subject),
|
||
"type": "refresh"
|
||
}
|
||
|
||
encoded_jwt = jwt.encode(
|
||
to_encode,
|
||
settings.SECRET_KEY,
|
||
algorithm=settings.ALGORITHM
|
||
)
|
||
return encoded_jwt
|
||
|
||
|
||
def verify_token(token: str) -> Dict[str, Any]:
|
||
"""
|
||
验证令牌
|
||
|
||
Args:
|
||
token: JWT令牌
|
||
|
||
Returns:
|
||
令牌payload
|
||
|
||
Raises:
|
||
JWTError: 令牌无效
|
||
"""
|
||
payload = jwt.decode(
|
||
token,
|
||
settings.SECRET_KEY,
|
||
algorithms=[settings.ALGORITHM]
|
||
)
|
||
return payload
|
||
|
||
|
||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
"""验证密码"""
|
||
return pwd_context.verify(plain_password, hashed_password)
|
||
|
||
|
||
def get_password_hash(password: str) -> str:
|
||
"""获取密码哈希值"""
|
||
return pwd_context.hash(password)
|
||
|
||
|
||
def generate_random_password(length: int = 12) -> str:
|
||
"""生成随机密码"""
|
||
import secrets
|
||
import string
|
||
|
||
alphabet = string.ascii_letters + string.digits + string.punctuation
|
||
password = ''.join(secrets.choice(alphabet) for i in range(length))
|
||
return password
|
||
```
|
||
|
||
### 4.2 验证工具 (app/utils/validators.py)
|
||
```python
|
||
"""
|
||
通用验证工具
|
||
"""
|
||
import re
|
||
from typing import Optional
|
||
|
||
from app.core.exceptions import ValidationError
|
||
|
||
|
||
def validate_email(email: str) -> str:
|
||
"""验证邮箱格式"""
|
||
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||
if not re.match(pattern, email):
|
||
raise ValidationError("邮箱格式不正确")
|
||
return email.lower()
|
||
|
||
|
||
def validate_phone(phone: str) -> str:
|
||
"""验证手机号格式(中国大陆)"""
|
||
pattern = r'^1[3-9]\d{9}$'
|
||
if not re.match(pattern, phone):
|
||
raise ValidationError("手机号格式不正确")
|
||
return phone
|
||
|
||
|
||
def validate_username(username: str) -> str:
|
||
"""验证用户名格式"""
|
||
if len(username) < 3 or len(username) > 20:
|
||
raise ValidationError("用户名长度必须在3-20个字符之间")
|
||
|
||
pattern = r'^[a-zA-Z0-9_-]+$'
|
||
if not re.match(pattern, username):
|
||
raise ValidationError("用户名只能包含字母、数字、下划线和连字符")
|
||
|
||
return username
|
||
|
||
|
||
def validate_password(password: str) -> str:
|
||
"""验证密码强度"""
|
||
if len(password) < 8:
|
||
raise ValidationError("密码长度至少8个字符")
|
||
|
||
if not re.search(r'[A-Z]', password):
|
||
raise ValidationError("密码必须包含至少一个大写字母")
|
||
|
||
if not re.search(r'[a-z]', password):
|
||
raise ValidationError("密码必须包含至少一个小写字母")
|
||
|
||
if not re.search(r'\d', password):
|
||
raise ValidationError("密码必须包含至少一个数字")
|
||
|
||
return password
|
||
|
||
|
||
def validate_id_card(id_card: str) -> str:
|
||
"""验证身份证号(中国大陆)"""
|
||
if len(id_card) != 18:
|
||
raise ValidationError("身份证号长度必须为18位")
|
||
|
||
# 验证前17位是否都是数字
|
||
if not id_card[:17].isdigit():
|
||
raise ValidationError("身份证号前17位必须是数字")
|
||
|
||
# 验证最后一位
|
||
if not (id_card[-1].isdigit() or id_card[-1].upper() == 'X'):
|
||
raise ValidationError("身份证号最后一位必须是数字或X")
|
||
|
||
# 验证校验码
|
||
factors = [7, 9, 10, 5, 8, 4, 2, 1, 6, 3, 7, 9, 10, 5, 8, 4, 2]
|
||
check_codes = ['1', '0', 'X', '9', '8', '7', '6', '5', '4', '3', '2']
|
||
|
||
sum_value = sum(int(id_card[i]) * factors[i] for i in range(17))
|
||
check_code = check_codes[sum_value % 11]
|
||
|
||
if id_card[-1].upper() != check_code:
|
||
raise ValidationError("身份证号校验码错误")
|
||
|
||
return id_card.upper()
|
||
|
||
|
||
def sanitize_filename(filename: str) -> str:
|
||
"""清理文件名,移除不安全字符"""
|
||
# 移除路径分隔符和其他特殊字符
|
||
filename = re.sub(r'[/\\:<>"|?*]', '_', filename)
|
||
# 移除前后空格
|
||
filename = filename.strip()
|
||
# 限制长度
|
||
if len(filename) > 255:
|
||
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
|
||
max_name_length = 255 - len(ext) - 1
|
||
filename = f"{name[:max_name_length]}.{ext}" if ext else name[:255]
|
||
|
||
return filename
|
||
```
|
||
|
||
## 5. 快速开始命令
|
||
|
||
### 5.1 项目初始化脚本
|
||
```bash
|
||
#!/bin/bash
|
||
# scripts/init_project.sh
|
||
|
||
echo "初始化考培练系统后端项目..."
|
||
|
||
# 创建虚拟环境
|
||
python3 -m venv venv
|
||
source venv/bin/activate
|
||
|
||
# 安装依赖
|
||
pip install -r requirements/dev.txt
|
||
|
||
# 复制环境变量文件
|
||
cp .env.example .env
|
||
echo "请编辑 .env 文件配置必要的环境变量"
|
||
|
||
# 初始化数据库
|
||
alembic init migrations
|
||
echo "数据库迁移已初始化"
|
||
|
||
# 创建必要的目录
|
||
mkdir -p logs uploads
|
||
|
||
# 运行代码质量检查
|
||
make lint
|
||
make type-check
|
||
|
||
echo "项目初始化完成!"
|
||
echo "下一步:"
|
||
echo "1. 编辑 .env 文件"
|
||
echo "2. 运行 'alembic revision --autogenerate -m \"initial\"' 创建初始迁移"
|
||
echo "3. 运行 'alembic upgrade head' 应用迁移"
|
||
echo "4. 运行 'make run-dev' 启动开发服务器"
|
||
```
|
||
|
||
### 5.2 开发环境启动脚本
|
||
```bash
|
||
#!/bin/bash
|
||
# scripts/start_dev.sh
|
||
|
||
# 激活虚拟环境
|
||
source venv/bin/activate
|
||
|
||
# 检查环境变量
|
||
if [ ! -f .env ]; then
|
||
echo "错误:.env 文件不存在"
|
||
exit 1
|
||
fi
|
||
|
||
# 启动依赖服务
|
||
docker-compose up -d mysql redis
|
||
|
||
# 等待服务启动
|
||
echo "等待数据库启动..."
|
||
sleep 10
|
||
|
||
# 运行数据库迁移
|
||
alembic upgrade head
|
||
|
||
# 启动开发服务器
|
||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||
```
|
||
|