Files
111 998211c483 feat: 初始化考培练系统项目
- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
2026-01-24 19:33:28 +08:00

34 KiB
Raw Permalink Blame History

考培练系统后端统一基础代码

1. 基础配置文件

1.1 主应用入口 (app/main.py)

"""
考培练系统后端主应用入口
"""
import time
from contextlib import asynccontextmanager
from typing import Dict, Any

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import JSONResponse

from app.api.v1 import api_router
from app.config.settings import settings
from app.core.exceptions import BaseAPIException
from app.core.logger import logger, setup_logging
from app.core.middleware import (
    RequestIDMiddleware,
    LoggingMiddleware,
    RateLimitMiddleware
)
from app.core.events import startup_handler, shutdown_handler

# 配置日志
setup_logging(log_level=settings.LOG_LEVEL)

@asynccontextmanager
async def lifespan(app: FastAPI):
    """应用生命周期管理"""
    # 启动时执行
    await startup_handler()
    logger.info("应用启动完成", version=settings.VERSION)
    yield
    # 关闭时执行
    await shutdown_handler()
    logger.info("应用关闭完成")

# 创建FastAPI应用
app = FastAPI(
    title=settings.PROJECT_NAME,
    description="考培练系统后端API",
    version=settings.VERSION,
    openapi_url=f"{settings.API_V1_STR}/openapi.json" if not settings.PRODUCTION else None,
    docs_url="/docs" if not settings.PRODUCTION else None,
    redoc_url="/redoc" if not settings.PRODUCTION else None,
    lifespan=lifespan
)

# 配置CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=settings.CORS_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 添加可信主机中间件
app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=settings.ALLOWED_HOSTS
)

# 添加自定义中间件
app.add_middleware(RequestIDMiddleware)
app.add_middleware(LoggingMiddleware)
app.add_middleware(RateLimitMiddleware)

# 全局异常处理
@app.exception_handler(BaseAPIException)
async def api_exception_handler(request: Request, exc: BaseAPIException):
    """处理自定义API异常"""
    return JSONResponse(
        status_code=exc.code,
        content={
            "code": exc.code,
            "message": exc.message,
            "details": exc.details,
            "request_id": request.state.request_id
        }
    )

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
    """处理所有未捕获的异常"""
    logger.error(
        "未处理的异常",
        exc_info=exc,
        request_id=request.state.request_id,
        path=request.url.path
    )
    return JSONResponse(
        status_code=500,
        content={
            "code": 500,
            "message": "服务器内部错误",
            "request_id": request.state.request_id
        }
    )

# 注册路由
app.include_router(api_router, prefix=settings.API_V1_STR)

# 根路径
@app.get("/", tags=["Root"])
async def root() -> Dict[str, Any]:
    """根路径,返回系统基本信息"""
    return {
        "name": settings.PROJECT_NAME,
        "version": settings.VERSION,
        "status": "running",
        "timestamp": time.time()
    }

# 健康检查
@app.get("/health", tags=["Health"])
async def health() -> Dict[str, str]:
    """健康检查端点"""
    return {"status": "healthy"}

1.2 系统配置 (app/config/settings.py)

"""
系统配置管理
使用Pydantic进行配置验证和管理
"""
from typing import List, Optional, Union
from pathlib import Path

from pydantic import BaseSettings, AnyHttpUrl, validator


class Settings(BaseSettings):
    """系统配置类"""
    
    # 基础配置
    PROJECT_NAME: str = "考培练系统"
    VERSION: str = "1.0.0"
    API_V1_STR: str = "/api/v1"
    
    # 环境配置
    PRODUCTION: bool = False
    DEBUG: bool = True
    LOG_LEVEL: str = "INFO"
    
    # 安全配置
    SECRET_KEY: str  # 必须设置
    ALGORITHM: str = "HS256"
    ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
    REFRESH_TOKEN_EXPIRE_DAYS: int = 7
    
    # 跨域配置
    CORS_ORIGINS: List[AnyHttpUrl] = []
    ALLOWED_HOSTS: List[str] = ["*"]
    
    @validator("CORS_ORIGINS", pre=True)
    def assemble_cors_origins(cls, v: Union[str, List[str]]) -> Union[List[str], str]:
        if isinstance(v, str) and not v.startswith("["):
            return [i.strip() for i in v.split(",")]
        elif isinstance(v, (list, str)):
            return v
        raise ValueError(v)
    
    # 数据库配置
    DATABASE_URL: str
    DATABASE_POOL_SIZE: int = 10
    DATABASE_MAX_OVERFLOW: int = 20
    
    # Redis配置
    REDIS_URL: str = "redis://localhost:6379/0"
    REDIS_PASSWORD: Optional[str] = None
    
    # AI平台配置
    # Coze配置
    COZE_API_BASE: str = "https://api.coze.cn"
    COZE_WORKSPACE_ID: str
    COZE_API_TOKEN: str
    COZE_TRAINING_BOT_ID: str
    COZE_CHAT_BOT_ID: str
    
    # Dify配置
    DIFY_API_BASE: str = "https://api.dify.ai/v1"
    DIFY_API_KEY: str
    DIFY_EXAM_WORKFLOW_ID: str
    DIFY_ASSESSMENT_WORKFLOW_ID: str
    DIFY_RECOMMEND_WORKFLOW_ID: str
    
    # 文件上传配置
    UPLOAD_DIR: Path = Path("uploads")
    MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024  # 10MB
    ALLOWED_UPLOAD_EXTENSIONS: List[str] = [
        ".pdf", ".doc", ".docx", ".xls", ".xlsx",
        ".png", ".jpg", ".jpeg", ".mp4", ".mp3"
    ]
    
    # 限流配置
    RATE_LIMIT_ENABLED: bool = True
    RATE_LIMIT_PER_MINUTE: int = 60
    
    # 日志配置
    LOG_DIR: Optional[Path] = Path("logs")
    LOG_ROTATION: str = "1 day"
    LOG_RETENTION: str = "30 days"
    
    class Config:
        env_file = ".env"
        case_sensitive = True

# 创建全局配置实例
settings = Settings()

# 确保必要的目录存在
settings.UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
if settings.LOG_DIR:
    settings.LOG_DIR.mkdir(parents=True, exist_ok=True)

1.3 数据库配置 (app/config/database.py)

"""
数据库配置和会话管理
"""
from typing import AsyncGenerator

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import declarative_base
from sqlalchemy.pool import NullPool

from app.config.settings import settings

# 创建异步引擎
engine = create_async_engine(
    settings.DATABASE_URL,
    echo=settings.DEBUG,
    pool_size=settings.DATABASE_POOL_SIZE,
    max_overflow=settings.DATABASE_MAX_OVERFLOW,
    pool_pre_ping=True,  # 启用连接健康检查
    poolclass=NullPool if settings.DEBUG else None,  # 开发环境使用NullPool
)

# 创建异步会话工厂
async_session = async_sessionmaker(
    engine,
    class_=AsyncSession,
    expire_on_commit=False,
    autocommit=False,
    autoflush=False,
)

# 声明基类
Base = declarative_base()

# 数据库会话依赖
async def get_db() -> AsyncGenerator[AsyncSession, None]:
    """获取数据库会话"""
    async with async_session() as session:
        try:
            yield session
            await session.commit()
        except Exception:
            await session.rollback()
            raise
        finally:
            await session.close()

2. 核心模块代码

2.1 基础模型类 (app/models/base.py)

"""
数据库模型基类
"""
from datetime import datetime
from typing import Optional

from sqlalchemy import Column, BigInteger, DateTime, Boolean, String
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql import func

from app.config.database import Base


class BaseModel(Base):
    """模型基类,包含通用字段"""
    __abstract__ = True
    
    id = Column(BigInteger, primary_key=True, index=True, comment="主键ID")
    created_at = Column(
        DateTime(timezone=True),
        server_default=func.now(),
        nullable=False,
        comment="创建时间"
    )
    updated_at = Column(
        DateTime(timezone=True),
        server_default=func.now(),
        onupdate=func.now(),
        nullable=False,
        comment="更新时间"
    )
    
    @declared_attr
    def __tablename__(cls) -> str:
        """自动生成表名(类名转换为蛇形命名法的复数形式)"""
        import re
        # CamelCase to snake_case
        name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', cls.__name__)
        name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name).lower()
        # 简单的复数形式(实际项目中可能需要更复杂的规则)
        if name.endswith('y'):
            return name[:-1] + 'ies'
        elif name.endswith('s'):
            return name + 'es'
        else:
            return name + 's'


class SoftDeleteMixin:
    """软删除混入类"""
    is_deleted = Column(Boolean, default=False, nullable=False, comment="是否删除")
    deleted_at = Column(DateTime(timezone=True), comment="删除时间")
    deleted_by = Column(BigInteger, comment="删除人ID")


class AuditMixin:
    """审计混入类"""
    created_by = Column(BigInteger, comment="创建人ID")
    updated_by = Column(BigInteger, comment="更新人ID")

2.2 基础Schema (app/schemas/base.py)

"""
Pydantic模型基类
"""
from datetime import datetime
from typing import Optional, Generic, TypeVar, List, Any

from pydantic import BaseModel, Field


# 配置所有模型的通用行为
class BaseSchema(BaseModel):
    """Schema基类"""
    
    class Config:
        # 允许从ORM模型创建
        orm_mode = True
        # 使用枚举的值而不是名称
        use_enum_values = True
        # 验证赋值
        validate_assignment = True
        # 允许population by field name
        allow_population_by_field_name = True
        # JSON编码器
        json_encoders = {
            datetime: lambda v: v.strftime("%Y-%m-%d %H:%M:%S")
        }


class TimestampMixin(BaseModel):
    """时间戳混入类"""
    created_at: datetime = Field(..., description="创建时间")
    updated_at: datetime = Field(..., description="更新时间")


class IDMixin(BaseModel):
    """ID混入类"""
    id: int = Field(..., description="ID")


# 通用响应模型
T = TypeVar('T')

class ResponseModel(BaseModel, Generic[T]):
    """统一响应模型"""
    code: int = Field(200, description="响应码")
    message: str = Field("success", description="响应消息")
    data: Optional[T] = Field(None, description="响应数据")
    request_id: Optional[str] = Field(None, description="请求ID")


class PageModel(BaseModel, Generic[T]):
    """分页响应模型"""
    items: List[T] = Field(..., description="数据列表")
    total: int = Field(..., description="总数")
    page: int = Field(..., description="当前页")
    size: int = Field(..., description="每页大小")
    pages: int = Field(..., description="总页数")


class ErrorResponse(BaseModel):
    """错误响应模型"""
    code: int = Field(..., description="错误码")
    message: str = Field(..., description="错误消息")
    details: Optional[str] = Field(None, description="错误详情")
    request_id: Optional[str] = Field(None, description="请求ID")


# 通用查询参数
class QueryParams(BaseModel):
    """通用查询参数"""
    page: int = Field(1, ge=1, description="页码")
    size: int = Field(10, ge=1, le=100, description="每页大小")
    order_by: Optional[str] = Field(None, description="排序字段")
    order: Optional[str] = Field("desc", regex="^(asc|desc)$", description="排序方向")
    
    @property
    def skip(self) -> int:
        """计算跳过的记录数"""
        return (self.page - 1) * self.size

2.3 认证依赖 (app/api/deps.py)

"""
API依赖注入
"""
from typing import Optional, Generator

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession

from app.config.database import get_db
from app.config.settings import settings
from app.core.security import verify_token
from app.models.user import User
from app.services.user_service import UserService

# OAuth2 scheme
oauth2_scheme = OAuth2PasswordBearer(
    tokenUrl=f"{settings.API_V1_STR}/auth/login"
)

# 可选的OAuth2 scheme
optional_oauth2_scheme = OAuth2PasswordBearer(
    tokenUrl=f"{settings.API_V1_STR}/auth/login",
    auto_error=False
)


async def get_current_user(
    db: AsyncSession = Depends(get_db),
    token: str = Depends(oauth2_scheme)
) -> User:
    """获取当前登录用户"""
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="无效的认证凭据",
        headers={"WWW-Authenticate": "Bearer"},
    )
    
    try:
        payload = verify_token(token)
        user_id: int = payload.get("sub")
        if user_id is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    
    user_service = UserService(db)
    user = await user_service.get_by_id(user_id)
    if user is None:
        raise credentials_exception
    
    if not user.is_active:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="用户已被禁用"
        )
    
    return user


async def get_current_active_user(
    current_user: User = Depends(get_current_user)
) -> User:
    """获取当前活跃用户"""
    if not current_user.is_active:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail="用户已被禁用"
        )
    return current_user


async def get_optional_current_user(
    db: AsyncSession = Depends(get_db),
    token: Optional[str] = Depends(optional_oauth2_scheme)
) -> Optional[User]:
    """获取可选的当前用户(用于公开接口的用户识别)"""
    if not token:
        return None
    
    try:
        payload = verify_token(token)
        user_id: int = payload.get("sub")
        if user_id is None:
            return None
        
        user_service = UserService(db)
        user = await user_service.get_by_id(user_id)
        return user if user and user.is_active else None
    except JWTError:
        return None


class RoleChecker:
    """角色检查依赖类"""
    
    def __init__(self, allowed_roles: List[str]):
        self.allowed_roles = allowed_roles
    
    def __call__(self, user: User = Depends(get_current_active_user)) -> User:
        if user.role not in self.allowed_roles:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail="权限不足"
            )
        return user


# 常用的角色检查器
require_admin = RoleChecker(["admin"])
require_manager = RoleChecker(["admin", "manager"])
require_trainer = RoleChecker(["admin", "manager", "trainer"])

2.4 统一异常处理 (app/core/exceptions.py)

"""
统一异常定义
"""
from typing import Optional, Dict, Any


class BaseAPIException(Exception):
    """API异常基类"""
    
    def __init__(
        self,
        code: int,
        message: str,
        details: Optional[str] = None
    ):
        self.code = code
        self.message = message
        self.details = details
        super().__init__(message)


class BadRequestError(BaseAPIException):
    """请求错误 400"""
    
    def __init__(self, message: str = "请求参数错误", details: str = None):
        super().__init__(400, message, details)


class UnauthorizedError(BaseAPIException):
    """未授权错误 401"""
    
    def __init__(self, message: str = "未授权访问", details: str = None):
        super().__init__(401, message, details)


class ForbiddenError(BaseAPIException):
    """禁止访问错误 403"""
    
    def __init__(self, message: str = "禁止访问", details: str = None):
        super().__init__(403, message, details)


class NotFoundError(BaseAPIException):
    """资源不存在错误 404"""
    
    def __init__(self, resource: str = "资源", details: str = None):
        super().__init__(404, f"{resource}不存在", details)


class ConflictError(BaseAPIException):
    """冲突错误 409"""
    
    def __init__(self, message: str = "资源冲突", details: str = None):
        super().__init__(409, message, details)


class ValidationError(BaseAPIException):
    """验证错误 422"""
    
    def __init__(self, message: str = "数据验证失败", details: str = None):
        super().__init__(422, message, details)


class InternalServerError(BaseAPIException):
    """服务器内部错误 500"""
    
    def __init__(self, message: str = "服务器内部错误", details: str = None):
        super().__init__(500, message, details)


class ExternalServiceError(BaseAPIException):
    """外部服务错误 502"""
    
    def __init__(self, service: str, details: str = None):
        super().__init__(502, f"{service}服务异常", details)


class ServiceUnavailableError(BaseAPIException):
    """服务不可用错误 503"""
    
    def __init__(self, message: str = "服务暂时不可用", details: str = None):
        super().__init__(503, message, details)

2.5 日志配置 (app/core/logger.py)

"""
统一日志配置
"""
import logging
import sys
from pathlib import Path
from typing import Optional

import structlog
from structlog.stdlib import LoggerFactory

from app.config.settings import settings


def setup_logging(
    log_level: str = "INFO",
    log_dir: Optional[Path] = None
) -> structlog.BoundLogger:
    """
    配置结构化日志
    
    Args:
        log_level: 日志级别
        log_dir: 日志目录
    
    Returns:
        配置好的日志器
    """
    # 设置时间戳格式
    timestamper = structlog.processors.TimeStamper(fmt="iso")
    
    # 配置处理器链
    shared_processors = [
        structlog.stdlib.add_logger_name,
        structlog.stdlib.add_log_level,
        structlog.stdlib.PositionalArgumentsFormatter(),
        timestamper,
        structlog.processors.StackInfoRenderer(),
        structlog.processors.format_exc_info,
        structlog.processors.UnicodeDecoder(),
        structlog.contextvars.merge_contextvars,
        structlog.processors.CallsiteParameterAdder(
            parameters=[
                structlog.processors.CallsiteParameter.FILENAME,
                structlog.processors.CallsiteParameter.LINENO,
                structlog.processors.CallsiteParameter.FUNC_NAME,
            ]
        ),
    ]
    
    # 配置structlog
    structlog.configure(
        processors=shared_processors + [
            structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
        ],
        logger_factory=LoggerFactory(),
        cache_logger_on_first_use=True,
    )
    
    # 配置标准库日志
    formatter = structlog.stdlib.ProcessorFormatter(
        processors=[
            structlog.stdlib.ProcessorFormatter.remove_processors_meta,
            structlog.dev.ConsoleRenderer() if settings.DEBUG else structlog.processors.JSONRenderer(),
        ],
    )
    
    # 控制台处理器
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    console_handler.setLevel(getattr(logging, log_level.upper()))
    
    # 配置根日志器
    root_logger = logging.getLogger()
    root_logger.handlers.clear()
    root_logger.addHandler(console_handler)
    root_logger.setLevel(getattr(logging, log_level.upper()))
    
    # 文件处理器(如果指定了日志目录)
    if log_dir and settings.LOG_DIR:
        from logging.handlers import RotatingFileHandler
        
        log_file = settings.LOG_DIR / "app.log"
        file_handler = RotatingFileHandler(
            log_file,
            maxBytes=10 * 1024 * 1024,  # 10MB
            backupCount=10
        )
        file_handler.setFormatter(
            structlog.stdlib.ProcessorFormatter(
                processors=[
                    structlog.stdlib.ProcessorFormatter.remove_processors_meta,
                    structlog.processors.JSONRenderer(),
                ],
            )
        )
        file_handler.setLevel(logging.INFO)
        root_logger.addHandler(file_handler)
    
    # 设置第三方库的日志级别
    logging.getLogger("uvicorn").setLevel(logging.INFO)
    logging.getLogger("sqlalchemy").setLevel(logging.WARNING)
    logging.getLogger("httpx").setLevel(logging.WARNING)
    
    return structlog.get_logger()


# 创建全局日志器
logger = setup_logging(log_level=settings.LOG_LEVEL, log_dir=settings.LOG_DIR)

3. 业务服务基类

3.1 基础服务类 (app/services/base_service.py)

"""
业务服务基类
"""
from typing import TypeVar, Generic, Type, Optional, List, Dict, Any
from datetime import datetime

from sqlalchemy import select, func, and_, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from app.models.base import BaseModel
from app.schemas.base import BaseSchema, QueryParams, PageModel
from app.core.exceptions import NotFoundError, BadRequestError
from app.core.logger import logger

ModelType = TypeVar("ModelType", bound=BaseModel)
CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseSchema)
UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseSchema)


class BaseService(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
    """
    基础CRUD服务类
    提供通用的增删改查功能
    """
    
    def __init__(
        self,
        model: Type[ModelType],
        db: AsyncSession
    ):
        self.model = model
        self.db = db
        self.logger = logger.bind(service=self.__class__.__name__)
    
    async def get(self, id: int) -> Optional[ModelType]:
        """根据ID获取单个对象"""
        query = select(self.model).where(self.model.id == id)
        result = await self.db.execute(query)
        return result.scalar_one_or_none()
    
    async def get_or_404(self, id: int) -> ModelType:
        """根据ID获取单个对象不存在则抛出404错误"""
        obj = await self.get(id)
        if not obj:
            raise NotFoundError(f"{self.model.__name__}")
        return obj
    
    async def get_multi(
        self,
        *,
        skip: int = 0,
        limit: int = 100,
        order_by: Optional[str] = None,
        order: str = "desc"
    ) -> List[ModelType]:
        """获取多个对象"""
        query = select(self.model)
        
        # 排序
        if order_by and hasattr(self.model, order_by):
            order_column = getattr(self.model, order_by)
            if order == "asc":
                query = query.order_by(order_column.asc())
            else:
                query = query.order_by(order_column.desc())
        else:
            query = query.order_by(self.model.id.desc())
        
        # 分页
        query = query.offset(skip).limit(limit)
        
        result = await self.db.execute(query)
        return result.scalars().all()
    
    async def get_page(
        self,
        params: QueryParams,
        filters: Optional[List] = None
    ) -> PageModel[ModelType]:
        """分页查询"""
        # 构建查询
        query = select(self.model)
        count_query = select(func.count()).select_from(self.model)
        
        # 添加过滤条件
        if filters:
            for filter_condition in filters:
                query = query.where(filter_condition)
                count_query = count_query.where(filter_condition)
        
        # 获取总数
        total_result = await self.db.execute(count_query)
        total = total_result.scalar()
        
        # 排序
        if params.order_by and hasattr(self.model, params.order_by):
            order_column = getattr(self.model, params.order_by)
            if params.order == "asc":
                query = query.order_by(order_column.asc())
            else:
                query = query.order_by(order_column.desc())
        else:
            query = query.order_by(self.model.id.desc())
        
        # 分页
        query = query.offset(params.skip).limit(params.size)
        
        # 执行查询
        result = await self.db.execute(query)
        items = result.scalars().all()
        
        # 计算总页数
        pages = (total + params.size - 1) // params.size
        
        return PageModel(
            items=items,
            total=total,
            page=params.page,
            size=params.size,
            pages=pages
        )
    
    async def create(
        self,
        *,
        obj_in: CreateSchemaType,
        **extra_fields
    ) -> ModelType:
        """创建对象"""
        # 转换为字典
        obj_data = obj_in.dict()
        obj_data.update(extra_fields)
        
        # 创建模型实例
        db_obj = self.model(**obj_data)
        
        # 保存到数据库
        self.db.add(db_obj)
        await self.db.commit()
        await self.db.refresh(db_obj)
        
        self.logger.info(
            "创建对象",
            model=self.model.__name__,
            id=db_obj.id
        )
        
        return db_obj
    
    async def update(
        self,
        *,
        db_obj: ModelType,
        obj_in: UpdateSchemaType,
        **extra_fields
    ) -> ModelType:
        """更新对象"""
        # 转换为字典,排除未设置的值
        update_data = obj_in.dict(exclude_unset=True)
        update_data.update(extra_fields)
        
        # 更新字段
        for field, value in update_data.items():
            if hasattr(db_obj, field):
                setattr(db_obj, field, value)
        
        # 更新时间
        if hasattr(db_obj, "updated_at"):
            db_obj.updated_at = datetime.utcnow()
        
        # 保存到数据库
        self.db.add(db_obj)
        await self.db.commit()
        await self.db.refresh(db_obj)
        
        self.logger.info(
            "更新对象",
            model=self.model.__name__,
            id=db_obj.id
        )
        
        return db_obj
    
    async def delete(self, *, id: int) -> bool:
        """删除对象(物理删除)"""
        db_obj = await self.get_or_404(id)
        
        await self.db.delete(db_obj)
        await self.db.commit()
        
        self.logger.info(
            "删除对象",
            model=self.model.__name__,
            id=id
        )
        
        return True
    
    async def soft_delete(
        self,
        *,
        id: int,
        deleted_by: Optional[int] = None
    ) -> ModelType:
        """软删除对象"""
        db_obj = await self.get_or_404(id)
        
        if hasattr(db_obj, "is_deleted"):
            db_obj.is_deleted = True
            if hasattr(db_obj, "deleted_at"):
                db_obj.deleted_at = datetime.utcnow()
            if hasattr(db_obj, "deleted_by") and deleted_by:
                db_obj.deleted_by = deleted_by
            
            await self.db.commit()
            await self.db.refresh(db_obj)
            
            self.logger.info(
                "软删除对象",
                model=self.model.__name__,
                id=id,
                deleted_by=deleted_by
            )
            
            return db_obj
        else:
            raise BadRequestError("该对象不支持软删除")
    
    async def count(self, filters: Optional[List] = None) -> int:
        """统计数量"""
        query = select(func.count()).select_from(self.model)
        
        if filters:
            for filter_condition in filters:
                query = query.where(filter_condition)
        
        result = await self.db.execute(query)
        return result.scalar()
    
    async def exists(self, **kwargs) -> bool:
        """检查是否存在"""
        query = select(self.model)
        
        for key, value in kwargs.items():
            if hasattr(self.model, key):
                query = query.where(getattr(self.model, key) == value)
        
        query = query.limit(1)
        result = await self.db.execute(query)
        return result.scalar_one_or_none() is not None

4. 常用工具函数

4.1 安全工具 (app/core/security.py)

"""
安全相关工具函数
"""
from datetime import datetime, timedelta
from typing import Optional, Dict, Any

from jose import JWTError, jwt
from passlib.context import CryptContext

from app.config.settings import settings

# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


def create_access_token(
    subject: int,
    expires_delta: Optional[timedelta] = None,
    **extra_claims
) -> str:
    """
    创建访问令牌
    
    Args:
        subject: 用户ID
        expires_delta: 过期时间
        extra_claims: 额外的声明
    
    Returns:
        JWT令牌
    """
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(
            minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
        )
    
    to_encode = {
        "exp": expire,
        "sub": str(subject),
        "type": "access",
        **extra_claims
    }
    
    encoded_jwt = jwt.encode(
        to_encode,
        settings.SECRET_KEY,
        algorithm=settings.ALGORITHM
    )
    return encoded_jwt


def create_refresh_token(
    subject: int,
    expires_delta: Optional[timedelta] = None
) -> str:
    """创建刷新令牌"""
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(
            days=settings.REFRESH_TOKEN_EXPIRE_DAYS
        )
    
    to_encode = {
        "exp": expire,
        "sub": str(subject),
        "type": "refresh"
    }
    
    encoded_jwt = jwt.encode(
        to_encode,
        settings.SECRET_KEY,
        algorithm=settings.ALGORITHM
    )
    return encoded_jwt


def verify_token(token: str) -> Dict[str, Any]:
    """
    验证令牌
    
    Args:
        token: JWT令牌
    
    Returns:
        令牌payload
    
    Raises:
        JWTError: 令牌无效
    """
    payload = jwt.decode(
        token,
        settings.SECRET_KEY,
        algorithms=[settings.ALGORITHM]
    )
    return payload


def verify_password(plain_password: str, hashed_password: str) -> bool:
    """验证密码"""
    return pwd_context.verify(plain_password, hashed_password)


def get_password_hash(password: str) -> str:
    """获取密码哈希值"""
    return pwd_context.hash(password)


def generate_random_password(length: int = 12) -> str:
    """生成随机密码"""
    import secrets
    import string
    
    alphabet = string.ascii_letters + string.digits + string.punctuation
    password = ''.join(secrets.choice(alphabet) for i in range(length))
    return password

4.2 验证工具 (app/utils/validators.py)

"""
通用验证工具
"""
import re
from typing import Optional

from app.core.exceptions import ValidationError


def validate_email(email: str) -> str:
    """验证邮箱格式"""
    pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
    if not re.match(pattern, email):
        raise ValidationError("邮箱格式不正确")
    return email.lower()


def validate_phone(phone: str) -> str:
    """验证手机号格式(中国大陆)"""
    pattern = r'^1[3-9]\d{9}$'
    if not re.match(pattern, phone):
        raise ValidationError("手机号格式不正确")
    return phone


def validate_username(username: str) -> str:
    """验证用户名格式"""
    if len(username) < 3 or len(username) > 20:
        raise ValidationError("用户名长度必须在3-20个字符之间")
    
    pattern = r'^[a-zA-Z0-9_-]+$'
    if not re.match(pattern, username):
        raise ValidationError("用户名只能包含字母、数字、下划线和连字符")
    
    return username


def validate_password(password: str) -> str:
    """验证密码强度"""
    if len(password) < 8:
        raise ValidationError("密码长度至少8个字符")
    
    if not re.search(r'[A-Z]', password):
        raise ValidationError("密码必须包含至少一个大写字母")
    
    if not re.search(r'[a-z]', password):
        raise ValidationError("密码必须包含至少一个小写字母")
    
    if not re.search(r'\d', password):
        raise ValidationError("密码必须包含至少一个数字")
    
    return password


def validate_id_card(id_card: str) -> str:
    """验证身份证号(中国大陆)"""
    if len(id_card) != 18:
        raise ValidationError("身份证号长度必须为18位")
    
    # 验证前17位是否都是数字
    if not id_card[:17].isdigit():
        raise ValidationError("身份证号前17位必须是数字")
    
    # 验证最后一位
    if not (id_card[-1].isdigit() or id_card[-1].upper() == 'X'):
        raise ValidationError("身份证号最后一位必须是数字或X")
    
    # 验证校验码
    factors = [7, 9, 10, 5, 8, 4, 2, 1, 6, 3, 7, 9, 10, 5, 8, 4, 2]
    check_codes = ['1', '0', 'X', '9', '8', '7', '6', '5', '4', '3', '2']
    
    sum_value = sum(int(id_card[i]) * factors[i] for i in range(17))
    check_code = check_codes[sum_value % 11]
    
    if id_card[-1].upper() != check_code:
        raise ValidationError("身份证号校验码错误")
    
    return id_card.upper()


def sanitize_filename(filename: str) -> str:
    """清理文件名,移除不安全字符"""
    # 移除路径分隔符和其他特殊字符
    filename = re.sub(r'[/\\:<>"|?*]', '_', filename)
    # 移除前后空格
    filename = filename.strip()
    # 限制长度
    if len(filename) > 255:
        name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
        max_name_length = 255 - len(ext) - 1
        filename = f"{name[:max_name_length]}.{ext}" if ext else name[:255]
    
    return filename

5. 快速开始命令

5.1 项目初始化脚本

#!/bin/bash
# scripts/init_project.sh

echo "初始化考培练系统后端项目..."

# 创建虚拟环境
python3 -m venv venv
source venv/bin/activate

# 安装依赖
pip install -r requirements/dev.txt

# 复制环境变量文件
cp .env.example .env
echo "请编辑 .env 文件配置必要的环境变量"

# 初始化数据库
alembic init migrations
echo "数据库迁移已初始化"

# 创建必要的目录
mkdir -p logs uploads

# 运行代码质量检查
make lint
make type-check

echo "项目初始化完成!"
echo "下一步:"
echo "1. 编辑 .env 文件"
echo "2. 运行 'alembic revision --autogenerate -m \"initial\"' 创建初始迁移"
echo "3. 运行 'alembic upgrade head' 应用迁移"
echo "4. 运行 'make run-dev' 启动开发服务器"

5.2 开发环境启动脚本

#!/bin/bash
# scripts/start_dev.sh

# 激活虚拟环境
source venv/bin/activate

# 检查环境变量
if [ ! -f .env ]; then
    echo "错误:.env 文件不存在"
    exit 1
fi

# 启动依赖服务
docker-compose up -d mysql redis

# 等待服务启动
echo "等待数据库启动..."
sleep 10

# 运行数据库迁移
alembic upgrade head

# 启动开发服务器
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000