feat: 初始化考培练系统项目

- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
This commit is contained in:
111
2026-01-24 19:33:28 +08:00
commit 998211c483
1197 changed files with 228429 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
"""
核心功能模块
"""

323
backend/app/core/config.py Normal file
View File

@@ -0,0 +1,323 @@
"""
系统配置
支持两种配置来源:
1. 环境变量 / .env 文件(传统方式,向后兼容)
2. 数据库 tenant_configs 表(新方式,支持热更新)
配置优先级:数据库 > 环境变量 > 默认值
"""
import os
import json
from functools import lru_cache
from typing import Optional, Any
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""系统配置"""
# 应用基础配置
APP_NAME: str = "KaoPeiLian"
APP_VERSION: str = "1.0.0"
DEBUG: bool = Field(default=True)
# 租户配置(用于多租户部署)
TENANT_CODE: str = Field(default="demo", description="租户编码,如 hua, yy, hl")
# 服务器配置
HOST: str = Field(default="0.0.0.0")
PORT: int = Field(default=8000)
# 数据库配置
DATABASE_URL: Optional[str] = Field(default=None)
MYSQL_HOST: str = Field(default="localhost")
MYSQL_PORT: int = Field(default=3306)
MYSQL_USER: str = Field(default="root")
MYSQL_PASSWORD: str = Field(default="password")
MYSQL_DATABASE: str = Field(default="kaopeilian")
@property
def database_url(self) -> str:
"""构建数据库连接URL"""
if self.DATABASE_URL:
return self.DATABASE_URL
# 使用urllib.parse.quote_plus来正确编码特殊字符
import urllib.parse
password = urllib.parse.quote_plus(self.MYSQL_PASSWORD)
return f"mysql+aiomysql://{self.MYSQL_USER}:{password}@{self.MYSQL_HOST}:{self.MYSQL_PORT}/{self.MYSQL_DATABASE}?charset=utf8mb4"
# Redis配置
REDIS_URL: str = Field(default="redis://localhost:6379/0")
# JWT配置
SECRET_KEY: str = Field(default="your-secret-key-here")
ALGORITHM: str = Field(default="HS256")
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=30)
REFRESH_TOKEN_EXPIRE_DAYS: int = Field(default=7)
# 跨域配置
CORS_ORIGINS: list[str] = Field(
default=[
"http://localhost:3000",
"http://localhost:3001",
"http://localhost:5173",
"http://127.0.0.1:3000",
"http://127.0.0.1:3001",
"http://127.0.0.1:5173",
]
)
@field_validator('CORS_ORIGINS', mode='before')
@classmethod
def parse_cors_origins(cls, v):
"""解析 CORS_ORIGINS 环境变量(支持 JSON 格式字符串)"""
if isinstance(v, str):
try:
return json.loads(v)
except json.JSONDecodeError:
# 如果不是 JSON 格式,尝试按逗号分割
return [origin.strip() for origin in v.split(',')]
return v
# 日志配置
LOG_LEVEL: str = Field(default="INFO")
LOG_FORMAT: str = Field(default="text") # text 或 json
LOG_DIR: str = Field(default="logs")
# 上传配置
UPLOAD_DIR: str = Field(default="uploads")
MAX_UPLOAD_SIZE: int = Field(default=15 * 1024 * 1024) # 15MB
@property
def UPLOAD_PATH(self) -> str:
"""获取上传文件的完整路径"""
import os
return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.UPLOAD_DIR)
# Coze 平台配置(陪练对话、播课等)
COZE_API_BASE: Optional[str] = Field(default="https://api.coze.cn")
COZE_WORKSPACE_ID: Optional[str] = Field(default=None)
COZE_API_TOKEN: Optional[str] = Field(default="pat_Sa5OiuUl0gDflnKstQTToIz0sSMshBV06diX0owOeuI1ZK1xDLH5YZH9fSeuKLIi")
COZE_TRAINING_BOT_ID: Optional[str] = Field(default=None)
COZE_CHAT_BOT_ID: Optional[str] = Field(default=None)
COZE_PRACTICE_BOT_ID: Optional[str] = Field(default="7560643598174683145") # 陪练专用Bot ID
# 播课工作流配置(多租户需在环境变量中覆盖,参见:应用配置清单.md
COZE_BROADCAST_WORKFLOW_ID: str = Field(default="7577983042284486666") # 默认:演示版播课工作流
COZE_BROADCAST_SPACE_ID: str = Field(default="7474971491470688296") # 播课工作流空间ID
COZE_BROADCAST_BOT_ID: Optional[str] = Field(default=None) # 播课工作流专用Bot ID
# OAuth配置可选
COZE_OAUTH_CLIENT_ID: Optional[str] = Field(default=None)
COZE_OAUTH_PUBLIC_KEY_ID: Optional[str] = Field(default=None)
COZE_OAUTH_PRIVATE_KEY_PATH: Optional[str] = Field(default=None)
# WebSocket语音配置
COZE_WS_BASE_URL: str = Field(default="wss://ws.coze.cn")
COZE_AUDIO_FORMAT: str = Field(default="pcm") # 音频格式
COZE_SAMPLE_RATE: int = Field(default=16000) # 采样率Hz
COZE_AUDIO_CHANNELS: int = Field(default=1) # 声道数(单声道)
COZE_AUDIO_BIT_DEPTH: int = Field(default=16) # 位深度
# 服务器公开访问域名
PUBLIC_DOMAIN: str = Field(default="http://aiedu.ireborn.com.cn")
# 言迹智能工牌API配置
YANJI_API_BASE: str = Field(default="https://open.yanjiai.com") # 正式环境
YANJI_CLIENT_ID: str = Field(default="1Fld4LCWt2vpJNG5")
YANJI_CLIENT_SECRET: str = Field(default="XE8w413qNtJBOdWc2aCezV0yMIHpUuTZ")
YANJI_TENANT_ID: str = Field(default="516799409476866048")
YANJI_ESTATE_ID: str = Field(default="516799468310364162")
# SCRM 系统对接 API Key用于内部服务间调用
SCRM_API_KEY: str = Field(default="scrm-kpl-api-key-2026-ruixiaomei")
# AI 服务配置(知识点分析 V2 使用)
# 首选服务商4sapi.com国内优化
AI_PRIMARY_API_KEY: str = Field(default="sk-9yMCXjRGANbacz20kJY8doSNy6Rf446aYwmgGIuIXQ7DAyBw") # 测试阶段 Key
AI_PRIMARY_BASE_URL: str = Field(default="https://4sapi.com/v1")
# 备选服务商OpenRouter模型全稳定性好
AI_FALLBACK_API_KEY: str = Field(default="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") # 测试阶段 Key
AI_FALLBACK_BASE_URL: str = Field(default="https://openrouter.ai/api/v1")
# 默认模型
AI_DEFAULT_MODEL: str = Field(default="gemini-3-flash-preview")
# 请求超时(秒)
AI_TIMEOUT: float = Field(default=120.0)
model_config = {
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": True,
"extra": "allow", # 允许额外的环境变量
}
@lru_cache()
def get_settings() -> Settings:
"""获取系统配置(缓存)"""
return Settings()
settings = get_settings()
# ============================================
# 动态配置获取(支持从数据库读取)
# ============================================
class DynamicConfig:
"""
动态配置管理器
用于在运行时从数据库获取配置,支持热更新。
向后兼容:如果数据库不可用,回退到环境变量配置。
"""
_tenant_loader = None
_initialized = False
@classmethod
async def init(cls, redis_url: Optional[str] = None):
"""
初始化动态配置管理器
Args:
redis_url: Redis URL可选用于缓存
"""
if cls._initialized:
return
try:
from app.core.tenant_config import TenantConfigManager
if redis_url:
await TenantConfigManager.init_redis(redis_url)
cls._initialized = True
except Exception as e:
import logging
logging.getLogger(__name__).warning(f"动态配置初始化失败: {e}")
@classmethod
async def get(cls, key: str, default: Any = None, tenant_code: Optional[str] = None) -> Any:
"""
获取配置值
Args:
key: 配置键(如 AI_PRIMARY_API_KEY
default: 默认值
tenant_code: 租户编码(可选,默认使用环境变量中的 TENANT_CODE
Returns:
配置值
"""
# 确定租户编码
if tenant_code is None:
tenant_code = settings.TENANT_CODE
# 配置键到分组的映射
config_mapping = {
# 数据库
"MYSQL_HOST": ("database", "MYSQL_HOST"),
"MYSQL_PORT": ("database", "MYSQL_PORT"),
"MYSQL_USER": ("database", "MYSQL_USER"),
"MYSQL_PASSWORD": ("database", "MYSQL_PASSWORD"),
"MYSQL_DATABASE": ("database", "MYSQL_DATABASE"),
# Redis
"REDIS_HOST": ("redis", "REDIS_HOST"),
"REDIS_PORT": ("redis", "REDIS_PORT"),
"REDIS_DB": ("redis", "REDIS_DB"),
# 安全
"SECRET_KEY": ("security", "SECRET_KEY"),
"CORS_ORIGINS": ("security", "CORS_ORIGINS"),
# Coze
"COZE_PRACTICE_BOT_ID": ("coze", "COZE_PRACTICE_BOT_ID"),
"COZE_BROADCAST_WORKFLOW_ID": ("coze", "COZE_BROADCAST_WORKFLOW_ID"),
"COZE_BROADCAST_SPACE_ID": ("coze", "COZE_BROADCAST_SPACE_ID"),
"COZE_OAUTH_CLIENT_ID": ("coze", "COZE_OAUTH_CLIENT_ID"),
"COZE_OAUTH_PUBLIC_KEY_ID": ("coze", "COZE_OAUTH_PUBLIC_KEY_ID"),
# AI
"AI_PRIMARY_API_KEY": ("ai", "AI_PRIMARY_API_KEY"),
"AI_PRIMARY_BASE_URL": ("ai", "AI_PRIMARY_BASE_URL"),
"AI_FALLBACK_API_KEY": ("ai", "AI_FALLBACK_API_KEY"),
"AI_FALLBACK_BASE_URL": ("ai", "AI_FALLBACK_BASE_URL"),
"AI_DEFAULT_MODEL": ("ai", "AI_DEFAULT_MODEL"),
"AI_TIMEOUT": ("ai", "AI_TIMEOUT"),
# 言迹
"YANJI_CLIENT_ID": ("yanji", "YANJI_CLIENT_ID"),
"YANJI_CLIENT_SECRET": ("yanji", "YANJI_CLIENT_SECRET"),
"YANJI_TENANT_ID": ("yanji", "YANJI_TENANT_ID"),
"YANJI_ESTATE_ID": ("yanji", "YANJI_ESTATE_ID"),
}
# 尝试从数据库获取
if cls._initialized and key in config_mapping:
try:
from app.core.tenant_config import TenantConfigManager
config_group, config_key = config_mapping[key]
loader = TenantConfigManager.get_loader(tenant_code)
value = await loader.get_config(config_group, config_key)
if value is not None:
return value
except Exception:
pass
# 回退到环境变量 / Settings
env_value = getattr(settings, key, None)
if env_value is not None:
return env_value
return default
@classmethod
async def is_feature_enabled(cls, feature_code: str, tenant_code: Optional[str] = None) -> bool:
"""
检查功能是否启用
Args:
feature_code: 功能编码
tenant_code: 租户编码
Returns:
是否启用
"""
if tenant_code is None:
tenant_code = settings.TENANT_CODE
if cls._initialized:
try:
from app.core.tenant_config import TenantConfigManager
loader = TenantConfigManager.get_loader(tenant_code)
return await loader.is_feature_enabled(feature_code)
except Exception:
pass
return True # 默认启用
@classmethod
async def refresh_cache(cls, tenant_code: Optional[str] = None):
"""
刷新配置缓存
Args:
tenant_code: 租户编码(为空则刷新所有)
"""
if not cls._initialized:
return
try:
from app.core.tenant_config import TenantConfigManager
if tenant_code:
await TenantConfigManager.refresh_tenant_cache(tenant_code)
else:
await TenantConfigManager.refresh_all_cache()
except Exception:
pass

View File

@@ -0,0 +1,31 @@
"""
数据库配置
"""
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from .config import settings
# 创建异步引擎
engine = create_async_engine(
settings.database_url,
echo=settings.DEBUG,
pool_pre_ping=True,
pool_size=10,
max_overflow=20,
# 确保 MySQL 连接使用 UTF-8 字符集
connect_args={
"charset": "utf8mb4",
"use_unicode": True,
"autocommit": False,
"init_command": "SET character_set_client=utf8mb4, character_set_connection=utf8mb4, character_set_results=utf8mb4, collation_connection=utf8mb4_unicode_ci",
} if "mysql" in settings.database_url else {},
)
# 创建异步会话工厂
AsyncSessionLocal = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)

166
backend/app/core/deps.py Normal file
View File

@@ -0,0 +1,166 @@
"""依赖注入模块"""
from typing import AsyncGenerator, Optional
from sqlalchemy import select
import redis.asyncio as redis
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import AsyncSessionLocal
from app.core.config import get_settings
from app.models.user import User
# JWT Bearer认证
security = HTTPBearer()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
获取数据库会话
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
) -> User:
"""
获取当前用户基于JWT
- 从 Authorization Bearer Token 中解析用户ID
- 查询数据库返回完整的 User 对象
- 失败时抛出 401 未授权
"""
from app.core.security import decode_token # 延迟导入避免循环依赖
if not credentials or not credentials.scheme or not credentials.credentials:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证信息")
if credentials.scheme.lower() != "bearer":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="认证方式不支持")
token = credentials.credentials
try:
payload = decode_token(token)
user_id = int(payload.get("sub"))
except Exception:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的令牌")
result = await db.execute(
select(User).where(User.id == user_id, User.is_deleted == False)
)
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已被禁用"
)
return user
async def require_admin(current_user: User = Depends(get_current_user)) -> User:
"""
需要管理员权限
"""
if getattr(current_user, "role", None) != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
return current_user
async def require_admin_or_manager(current_user: User = Depends(get_current_user)) -> User:
"""
需要管理者或管理员权限
"""
if getattr(current_user, "role", None) not in ("admin", "manager"):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理者或管理员权限")
return current_user
async def get_optional_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
db: AsyncSession = Depends(get_db),
) -> Optional[User]:
"""
获取可选的当前用户(不强制登录)
"""
if not credentials:
return None
try:
return await get_current_user(credentials, db)
except:
return None
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
"""
获取当前活跃用户
"""
# TODO: 检查用户是否被禁用
return current_user
async def verify_scrm_api_key(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> bool:
"""
验证 SCRM 系统 API Key
用于内部服务间调用认证SCRM 系统通过固定 API Key 访问考陪练数据查询接口
请求头格式: Authorization: Bearer {SCRM_API_KEY}
"""
settings = get_settings()
if not credentials or not credentials.credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证信息"
)
if credentials.scheme.lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="认证方式不支持,需要 Bearer Token"
)
if credentials.credentials != settings.SCRM_API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的 API Key"
)
return True
# Redis 连接池
_redis_pool: Optional[redis.ConnectionPool] = None
async def get_redis() -> AsyncGenerator[redis.Redis, None]:
"""
获取 Redis 连接
"""
global _redis_pool
if _redis_pool is None:
settings = get_settings()
_redis_pool = redis.ConnectionPool.from_url(
settings.REDIS_URL, encoding="utf-8", decode_responses=True
)
client = redis.Redis(connection_pool=_redis_pool)
try:
yield client
finally:
await client.close()

View File

@@ -0,0 +1,28 @@
"""
应用生命周期事件处理
"""
from app.core.logger import logger
async def startup_handler():
"""应用启动时执行的任务"""
logger.info("执行启动任务...")
# TODO: 初始化数据库连接池
# TODO: 初始化Redis连接
# TODO: 初始化AI平台客户端
# TODO: 加载缓存数据
logger.info("启动任务完成")
async def shutdown_handler():
"""应用关闭时执行的任务"""
logger.info("执行关闭任务...")
# TODO: 关闭数据库连接池
# TODO: 关闭Redis连接
# TODO: 清理临时文件
# TODO: 保存应用状态
logger.info("关闭任务完成")

View File

@@ -0,0 +1,89 @@
"""统一异常定义"""
from typing import Optional, Dict, Any
from fastapi import HTTPException, status
class BusinessError(HTTPException):
"""业务异常基类"""
def __init__(
self,
message: str,
code: int = status.HTTP_400_BAD_REQUEST,
error_code: Optional[str] = None,
detail: Optional[Dict[str, Any]] = None,
):
super().__init__(
status_code=code,
detail={
"message": message,
"error_code": error_code or f"ERR_{code}",
"detail": detail,
},
)
self.message = message
self.code = code
self.error_code = error_code
class BadRequestError(BusinessError):
"""400 错误请求"""
def __init__(self, message: str = "错误的请求", **kwargs):
super().__init__(message, status.HTTP_400_BAD_REQUEST, **kwargs)
class UnauthorizedError(BusinessError):
"""401 未授权"""
def __init__(self, message: str = "未授权", **kwargs):
super().__init__(message, status.HTTP_401_UNAUTHORIZED, **kwargs)
class ForbiddenError(BusinessError):
"""403 禁止访问"""
def __init__(self, message: str = "禁止访问", **kwargs):
super().__init__(message, status.HTTP_403_FORBIDDEN, **kwargs)
class NotFoundError(BusinessError):
"""404 未找到"""
def __init__(self, message: str = "资源未找到", **kwargs):
super().__init__(message, status.HTTP_404_NOT_FOUND, **kwargs)
class ConflictError(BusinessError):
"""409 冲突"""
def __init__(self, message: str = "资源冲突", **kwargs):
super().__init__(message, status.HTTP_409_CONFLICT, **kwargs)
class ValidationError(BusinessError):
"""422 验证错误"""
def __init__(self, message: str = "验证失败", **kwargs):
super().__init__(message, status.HTTP_422_UNPROCESSABLE_ENTITY, **kwargs)
class InternalServerError(BusinessError):
"""500 内部服务器错误"""
def __init__(self, message: str = "内部服务器错误", **kwargs):
super().__init__(message, status.HTTP_500_INTERNAL_SERVER_ERROR, **kwargs)
class InsufficientPermissionsError(ForbiddenError):
"""权限不足"""
def __init__(self, message: str = "权限不足", **kwargs):
super().__init__(message, error_code="INSUFFICIENT_PERMISSIONS", **kwargs)
class ExternalServiceError(BusinessError):
"""外部服务错误"""
def __init__(self, message: str = "外部服务异常", **kwargs):
super().__init__(message, status.HTTP_502_BAD_GATEWAY, error_code="EXTERNAL_SERVICE_ERROR", **kwargs)

View File

@@ -0,0 +1,76 @@
"""
日志配置
"""
import logging
import sys
from typing import Any
import structlog
from structlog.stdlib import LoggerFactory
from app.core.config import get_settings
settings = get_settings()
def setup_logging():
"""
配置日志系统
"""
# 设置日志级别
log_level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)
# 配置标准库日志
logging.basicConfig(
format="%(message)s",
stream=sys.stdout,
level=log_level,
)
# 配置处理器
processors = [
structlog.stdlib.filter_by_level,
structlog.stdlib.add_logger_name,
structlog.stdlib.add_log_level,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
structlog.processors.UnicodeDecoder(),
]
# 根据配置选择输出格式
if getattr(settings, "LOG_FORMAT", "text") == "json":
processors.append(structlog.processors.JSONRenderer())
else:
processors.append(structlog.dev.ConsoleRenderer())
# 配置 structlog
structlog.configure(
processors=processors,
context_class=dict,
logger_factory=LoggerFactory(),
cache_logger_on_first_use=True,
)
# 设置日志
setup_logging()
# 获取日志器
def get_logger(name: str = __name__) -> Any:
"""
获取日志器
Args:
name: 日志器名称
Returns:
日志器实例
"""
return structlog.get_logger(name)
# 默认日志器
logger = get_logger("app")

View File

@@ -0,0 +1,64 @@
"""
中间件定义
"""
import time
import uuid
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.logger import logger
class RequestIDMiddleware(BaseHTTPMiddleware):
"""请求ID中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 生成请求ID
request_id = str(uuid.uuid4())
# 将请求ID添加到request状态
request.state.request_id = request_id
# 记录请求开始
start_time = time.time()
# 处理请求
response = await call_next(request)
# 计算处理时间
process_time = time.time() - start_time
# 添加响应头
response.headers["X-Request-ID"] = request_id
response.headers["X-Process-Time"] = str(process_time)
# 记录请求日志
logger.info(
"HTTP请求",
method=request.method,
url=str(request.url),
status_code=response.status_code,
process_time=process_time,
request_id=request_id,
)
return response
class GlobalContextMiddleware(BaseHTTPMiddleware):
"""全局上下文中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 设置追踪ID用于分布式追踪
trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4()))
request.state.trace_id = trace_id
# 处理请求
response = await call_next(request)
# 添加追踪ID到响应头
response.headers["X-Trace-ID"] = trace_id
return response

44
backend/app/core/redis.py Normal file
View File

@@ -0,0 +1,44 @@
"""
Redis连接管理
"""
from typing import Optional
from redis import asyncio as aioredis
from app.core.config import settings
from app.core.logger import logger
# 全局Redis连接实例
redis_client: Optional[aioredis.Redis] = None
async def init_redis() -> aioredis.Redis:
"""初始化Redis连接"""
global redis_client
try:
redis_client = await aioredis.from_url(
settings.REDIS_URL, encoding="utf-8", decode_responses=True
)
# 测试连接
await redis_client.ping()
logger.info("Redis连接成功", url=settings.REDIS_URL)
return redis_client
except Exception as e:
logger.error("Redis连接失败", error=str(e), url=settings.REDIS_URL)
raise
async def close_redis():
"""关闭Redis连接"""
global redis_client
if redis_client:
await redis_client.close()
logger.info("Redis连接已关闭")
redis_client = None
def get_redis_client() -> aioredis.Redis:
"""获取Redis客户端实例"""
if not redis_client:
raise RuntimeError("Redis client not initialized")
return redis_client

View File

@@ -0,0 +1,72 @@
"""
安全相关功能
"""
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Union
import bcrypt
from jose import JWTError, jwt
from .config import settings
def create_access_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
) -> str:
"""创建访问令牌"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def create_refresh_token(
subject: Union[str, Any],
expires_delta: Optional[timedelta] = None,
) -> str:
"""创建刷新令牌"""
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
encoded_jwt = jwt.encode(
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
)
return encoded_jwt
def decode_token(token: str) -> Dict[str, Any]:
"""解码令牌"""
try:
payload = jwt.decode(
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
)
return payload
except JWTError:
raise ValueError("Invalid token")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return bcrypt.checkpw(
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
)
def get_password_hash(password: str) -> str:
"""生成密码哈希"""
salt = bcrypt.gensalt()
hashed_password = bcrypt.hashpw(password.encode("utf-8"), salt)
return hashed_password.decode("utf-8")

View File

@@ -0,0 +1,81 @@
"""
简化认证中间件 - 支持 API Key 和长期 Token
用于内部服务间调用
"""
from typing import Optional
from fastapi import HTTPException, Header, status
from app.models.user import User
# 配置 API Keys用于内部服务调用
API_KEYS = {
"internal-service-2025-kaopeilian": {
"service": "internal",
"user_id": 1,
"username": "internal_service",
"role": "admin"
}
}
# 长期有效的 Token用于内部服务调用
LONG_TERM_TOKENS = {
"permanent-token-for-internal-2025": {
"service": "internal",
"user_id": 1,
"username": "internal_service",
"role": "admin"
}
}
def get_current_user_by_api_key(
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None)
) -> Optional[User]:
"""
通过 API Key 或长期 Token 获取用户
支持两种方式:
1. X-API-Key: internal-service-2025-kaopeilian
2. Authorization: Bearer permanent-token-for-internal-2025
"""
# 方式1检查 API Key
if x_api_key and x_api_key in API_KEYS:
api_key_info = API_KEYS[x_api_key]
# 创建一个虚拟用户对象
user = User()
user.id = api_key_info["user_id"]
user.username = api_key_info["username"]
user.role = api_key_info["role"]
return user
# 方式2检查长期 Token
if authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if token in LONG_TERM_TOKENS:
token_info = LONG_TERM_TOKENS[token]
user = User()
user.id = token_info["user_id"]
user.username = token_info["username"]
user.role = token_info["role"]
return user
return None
def get_current_user_simple(
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None)
) -> User:
"""
简化的用户认证依赖项
"""
# 尝试 API Key 或长期 Token 认证
user = get_current_user_by_api_key(x_api_key, authorization)
if user:
return user
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)

View File

@@ -0,0 +1,421 @@
"""
租户配置加载器
功能:
1. 从数据库 tenant_configs 表加载租户配置
2. 支持 Redis 缓存
3. 数据库不可用时回退到环境变量
4. 支持配置热更新
"""
import os
import json
import logging
from typing import Optional, Dict, Any
from functools import lru_cache
import aiomysql
import redis.asyncio as redis
logger = logging.getLogger(__name__)
# ============================================
# 平台管理库连接配置
#
# 注意:敏感信息必须通过环境变量传递,禁止硬编码
# 参考:瑞小美系统技术栈标准与字符标准.md - 敏感信息管理
# ============================================
ADMIN_DB_CONFIG = {
"host": os.getenv("ADMIN_DB_HOST", "prod-mysql"),
"port": int(os.getenv("ADMIN_DB_PORT", "3306")),
"user": os.getenv("ADMIN_DB_USER", "root"),
"password": os.getenv("ADMIN_DB_PASSWORD"), # 必须从环境变量获取
"db": os.getenv("ADMIN_DB_NAME", "kaopeilian_admin"),
"charset": "utf8mb4",
}
# 校验必填环境变量
if not ADMIN_DB_CONFIG["password"]:
logger.warning(
"ADMIN_DB_PASSWORD 环境变量未设置,租户配置加载功能将不可用。"
"请在 .env.admin 文件中配置此变量。"
)
# Redis 缓存配置
CACHE_PREFIX = "tenant_config:"
CACHE_TTL = 300 # 5分钟缓存
class TenantConfigLoader:
"""租户配置加载器"""
def __init__(self, tenant_code: str, redis_client: Optional[redis.Redis] = None):
"""
初始化租户配置加载器
Args:
tenant_code: 租户编码(如 hua, yy, hl
redis_client: Redis 客户端(可选)
"""
self.tenant_code = tenant_code
self.redis_client = redis_client
self._config_cache: Dict[str, Any] = {}
self._tenant_id: Optional[int] = None
async def get_config(self, config_group: str, config_key: str, default: Any = None) -> Any:
"""
获取配置项
优先级:
1. 内存缓存
2. Redis 缓存
3. 数据库
4. 环境变量
5. 默认值
Args:
config_group: 配置分组database, redis, coze, ai, yanji, security
config_key: 配置键
default: 默认值
Returns:
配置值
"""
cache_key = f"{config_group}.{config_key}"
# 1. 内存缓存
if cache_key in self._config_cache:
return self._config_cache[cache_key]
# 2. Redis 缓存
if self.redis_client:
try:
redis_key = f"{CACHE_PREFIX}{self.tenant_code}:{cache_key}"
cached_value = await self.redis_client.get(redis_key)
if cached_value:
value = json.loads(cached_value)
self._config_cache[cache_key] = value
return value
except Exception as e:
logger.warning(f"Redis 缓存读取失败: {e}")
# 3. 数据库
try:
value = await self._get_from_database(config_group, config_key)
if value is not None:
self._config_cache[cache_key] = value
# 写入 Redis 缓存
if self.redis_client:
try:
redis_key = f"{CACHE_PREFIX}{self.tenant_code}:{cache_key}"
await self.redis_client.setex(
redis_key,
CACHE_TTL,
json.dumps(value)
)
except Exception as e:
logger.warning(f"Redis 缓存写入失败: {e}")
return value
except Exception as e:
logger.warning(f"数据库配置读取失败: {e}")
# 4. 环境变量
env_value = os.getenv(config_key)
if env_value is not None:
return env_value
# 5. 默认值
return default
async def _get_from_database(self, config_group: str, config_key: str) -> Optional[Any]:
"""从数据库获取配置"""
conn = None
try:
conn = await aiomysql.connect(**ADMIN_DB_CONFIG)
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取租户 ID
if self._tenant_id is None:
await cursor.execute(
"SELECT id FROM tenants WHERE code = %s AND status = 'active'",
(self.tenant_code,)
)
row = await cursor.fetchone()
if row:
self._tenant_id = row['id']
else:
return None
# 获取配置值
await cursor.execute(
"""
SELECT config_value, value_type, is_encrypted
FROM tenant_configs
WHERE tenant_id = %s AND config_group = %s AND config_key = %s
""",
(self._tenant_id, config_group, config_key)
)
row = await cursor.fetchone()
if row:
return self._parse_value(row['config_value'], row['value_type'], row['is_encrypted'])
# 如果租户没有配置,获取默认值
await cursor.execute(
"""
SELECT default_value, value_type
FROM config_templates
WHERE config_group = %s AND config_key = %s
""",
(config_group, config_key)
)
row = await cursor.fetchone()
if row and row['default_value']:
return self._parse_value(row['default_value'], row['value_type'], False)
return None
finally:
if conn:
conn.close()
def _parse_value(self, value: str, value_type: str, is_encrypted: bool) -> Any:
"""解析配置值"""
if value is None:
return None
# TODO: 如果是加密值,先解密
if is_encrypted:
# 这里可以实现解密逻辑
pass
if value_type == 'int':
return int(value)
elif value_type == 'bool':
return value.lower() in ('true', '1', 'yes')
elif value_type == 'json':
return json.loads(value)
elif value_type == 'float':
return float(value)
else:
return value
async def get_all_configs(self) -> Dict[str, Any]:
"""获取租户的所有配置"""
configs = {}
conn = None
try:
conn = await aiomysql.connect(**ADMIN_DB_CONFIG)
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取租户 ID
await cursor.execute(
"SELECT id FROM tenants WHERE code = %s AND status = 'active'",
(self.tenant_code,)
)
row = await cursor.fetchone()
if not row:
return configs
tenant_id = row['id']
# 获取所有配置
await cursor.execute(
"""
SELECT config_group, config_key, config_value, value_type, is_encrypted
FROM tenant_configs
WHERE tenant_id = %s
""",
(tenant_id,)
)
rows = await cursor.fetchall()
for row in rows:
key = f"{row['config_group']}.{row['config_key']}"
configs[key] = self._parse_value(
row['config_value'],
row['value_type'],
row['is_encrypted']
)
return configs
finally:
if conn:
conn.close()
async def refresh_cache(self):
"""刷新缓存"""
self._config_cache.clear()
if self.redis_client:
try:
# 删除该租户的所有缓存
pattern = f"{CACHE_PREFIX}{self.tenant_code}:*"
cursor = 0
while True:
cursor, keys = await self.redis_client.scan(cursor, match=pattern, count=100)
if keys:
await self.redis_client.delete(*keys)
if cursor == 0:
break
except Exception as e:
logger.warning(f"Redis 缓存刷新失败: {e}")
async def is_feature_enabled(self, feature_code: str) -> bool:
"""
检查功能是否启用
Args:
feature_code: 功能编码
Returns:
是否启用
"""
conn = None
try:
conn = await aiomysql.connect(**ADMIN_DB_CONFIG)
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取租户 ID
if self._tenant_id is None:
await cursor.execute(
"SELECT id FROM tenants WHERE code = %s AND status = 'active'",
(self.tenant_code,)
)
row = await cursor.fetchone()
if row:
self._tenant_id = row['id']
# 先查租户级别的配置
if self._tenant_id:
await cursor.execute(
"""
SELECT is_enabled FROM feature_switches
WHERE tenant_id = %s AND feature_code = %s
""",
(self._tenant_id, feature_code)
)
row = await cursor.fetchone()
if row:
return bool(row['is_enabled'])
# 再查全局默认配置
await cursor.execute(
"""
SELECT is_enabled FROM feature_switches
WHERE tenant_id IS NULL AND feature_code = %s
""",
(feature_code,)
)
row = await cursor.fetchone()
if row:
return bool(row['is_enabled'])
return True # 默认启用
except Exception as e:
logger.warning(f"功能开关查询失败: {e}, 默认启用")
return True
finally:
if conn:
conn.close()
class TenantConfigManager:
"""租户配置管理器(单例)"""
_instances: Dict[str, TenantConfigLoader] = {}
_redis_client: Optional[redis.Redis] = None
@classmethod
async def init_redis(cls, redis_url: str):
"""初始化 Redis 连接"""
try:
cls._redis_client = redis.from_url(redis_url)
await cls._redis_client.ping()
logger.info("TenantConfigManager Redis 连接成功")
except Exception as e:
logger.warning(f"TenantConfigManager Redis 连接失败: {e}")
cls._redis_client = None
@classmethod
def get_loader(cls, tenant_code: str) -> TenantConfigLoader:
"""获取租户配置加载器"""
if tenant_code not in cls._instances:
cls._instances[tenant_code] = TenantConfigLoader(
tenant_code,
cls._redis_client
)
return cls._instances[tenant_code]
@classmethod
async def refresh_tenant_cache(cls, tenant_code: str):
"""刷新指定租户的缓存"""
if tenant_code in cls._instances:
await cls._instances[tenant_code].refresh_cache()
@classmethod
async def refresh_all_cache(cls):
"""刷新所有租户的缓存"""
for loader in cls._instances.values():
await loader.refresh_cache()
# ============================================
# 辅助函数
# ============================================
def get_tenant_code_from_domain(domain: str) -> str:
"""
从域名提取租户编码
Examples:
hua.ireborn.com.cn -> hua
yy.ireborn.com.cn -> yy
aiedu.ireborn.com.cn -> demo
"""
if not domain:
return "demo"
# 移除 https:// 或 http://
domain = domain.replace("https://", "").replace("http://", "")
# 获取子域名
parts = domain.split(".")
if len(parts) >= 3:
subdomain = parts[0]
# 特殊处理
if subdomain == "aiedu":
return "demo"
return subdomain
return "demo"
async def get_tenant_config(tenant_code: str, config_group: str, config_key: str, default: Any = None) -> Any:
"""
快捷函数:获取租户配置
Args:
tenant_code: 租户编码
config_group: 配置分组
config_key: 配置键
default: 默认值
Returns:
配置值
"""
loader = TenantConfigManager.get_loader(tenant_code)
return await loader.get_config(config_group, config_key, default)
async def is_tenant_feature_enabled(tenant_code: str, feature_code: str) -> bool:
"""
快捷函数:检查租户功能是否启用
Args:
tenant_code: 租户编码
feature_code: 功能编码
Returns:
是否启用
"""
loader = TenantConfigManager.get_loader(tenant_code)
return await loader.is_feature_enabled(feature_code)