feat: 初始化考培练系统项目
- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
This commit is contained in:
3
backend/app/core/__init__.py
Normal file
3
backend/app/core/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
核心功能模块
|
||||
"""
|
||||
323
backend/app/core/config.py
Normal file
323
backend/app/core/config.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
系统配置
|
||||
|
||||
支持两种配置来源:
|
||||
1. 环境变量 / .env 文件(传统方式,向后兼容)
|
||||
2. 数据库 tenant_configs 表(新方式,支持热更新)
|
||||
|
||||
配置优先级:数据库 > 环境变量 > 默认值
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""系统配置"""
|
||||
|
||||
# 应用基础配置
|
||||
APP_NAME: str = "KaoPeiLian"
|
||||
APP_VERSION: str = "1.0.0"
|
||||
DEBUG: bool = Field(default=True)
|
||||
|
||||
# 租户配置(用于多租户部署)
|
||||
TENANT_CODE: str = Field(default="demo", description="租户编码,如 hua, yy, hl")
|
||||
|
||||
# 服务器配置
|
||||
HOST: str = Field(default="0.0.0.0")
|
||||
PORT: int = Field(default=8000)
|
||||
|
||||
# 数据库配置
|
||||
DATABASE_URL: Optional[str] = Field(default=None)
|
||||
MYSQL_HOST: str = Field(default="localhost")
|
||||
MYSQL_PORT: int = Field(default=3306)
|
||||
MYSQL_USER: str = Field(default="root")
|
||||
MYSQL_PASSWORD: str = Field(default="password")
|
||||
MYSQL_DATABASE: str = Field(default="kaopeilian")
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""构建数据库连接URL"""
|
||||
if self.DATABASE_URL:
|
||||
return self.DATABASE_URL
|
||||
|
||||
# 使用urllib.parse.quote_plus来正确编码特殊字符
|
||||
import urllib.parse
|
||||
password = urllib.parse.quote_plus(self.MYSQL_PASSWORD)
|
||||
|
||||
return f"mysql+aiomysql://{self.MYSQL_USER}:{password}@{self.MYSQL_HOST}:{self.MYSQL_PORT}/{self.MYSQL_DATABASE}?charset=utf8mb4"
|
||||
|
||||
# Redis配置
|
||||
REDIS_URL: str = Field(default="redis://localhost:6379/0")
|
||||
|
||||
# JWT配置
|
||||
SECRET_KEY: str = Field(default="your-secret-key-here")
|
||||
ALGORITHM: str = Field(default="HS256")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(default=30)
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = Field(default=7)
|
||||
|
||||
# 跨域配置
|
||||
CORS_ORIGINS: list[str] = Field(
|
||||
default=[
|
||||
"http://localhost:3000",
|
||||
"http://localhost:3001",
|
||||
"http://localhost:5173",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://127.0.0.1:3001",
|
||||
"http://127.0.0.1:5173",
|
||||
]
|
||||
)
|
||||
|
||||
@field_validator('CORS_ORIGINS', mode='before')
|
||||
@classmethod
|
||||
def parse_cors_origins(cls, v):
|
||||
"""解析 CORS_ORIGINS 环境变量(支持 JSON 格式字符串)"""
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return json.loads(v)
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是 JSON 格式,尝试按逗号分割
|
||||
return [origin.strip() for origin in v.split(',')]
|
||||
return v
|
||||
|
||||
# 日志配置
|
||||
LOG_LEVEL: str = Field(default="INFO")
|
||||
LOG_FORMAT: str = Field(default="text") # text 或 json
|
||||
LOG_DIR: str = Field(default="logs")
|
||||
|
||||
# 上传配置
|
||||
UPLOAD_DIR: str = Field(default="uploads")
|
||||
MAX_UPLOAD_SIZE: int = Field(default=15 * 1024 * 1024) # 15MB
|
||||
|
||||
@property
|
||||
def UPLOAD_PATH(self) -> str:
|
||||
"""获取上传文件的完整路径"""
|
||||
import os
|
||||
return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), self.UPLOAD_DIR)
|
||||
|
||||
# Coze 平台配置(陪练对话、播课等)
|
||||
COZE_API_BASE: Optional[str] = Field(default="https://api.coze.cn")
|
||||
COZE_WORKSPACE_ID: Optional[str] = Field(default=None)
|
||||
COZE_API_TOKEN: Optional[str] = Field(default="pat_Sa5OiuUl0gDflnKstQTToIz0sSMshBV06diX0owOeuI1ZK1xDLH5YZH9fSeuKLIi")
|
||||
COZE_TRAINING_BOT_ID: Optional[str] = Field(default=None)
|
||||
COZE_CHAT_BOT_ID: Optional[str] = Field(default=None)
|
||||
COZE_PRACTICE_BOT_ID: Optional[str] = Field(default="7560643598174683145") # 陪练专用Bot ID
|
||||
# 播课工作流配置(多租户需在环境变量中覆盖,参见:应用配置清单.md)
|
||||
COZE_BROADCAST_WORKFLOW_ID: str = Field(default="7577983042284486666") # 默认:演示版播课工作流
|
||||
COZE_BROADCAST_SPACE_ID: str = Field(default="7474971491470688296") # 播课工作流空间ID
|
||||
COZE_BROADCAST_BOT_ID: Optional[str] = Field(default=None) # 播课工作流专用Bot ID
|
||||
# OAuth配置(可选)
|
||||
COZE_OAUTH_CLIENT_ID: Optional[str] = Field(default=None)
|
||||
COZE_OAUTH_PUBLIC_KEY_ID: Optional[str] = Field(default=None)
|
||||
COZE_OAUTH_PRIVATE_KEY_PATH: Optional[str] = Field(default=None)
|
||||
|
||||
# WebSocket语音配置
|
||||
COZE_WS_BASE_URL: str = Field(default="wss://ws.coze.cn")
|
||||
COZE_AUDIO_FORMAT: str = Field(default="pcm") # 音频格式
|
||||
COZE_SAMPLE_RATE: int = Field(default=16000) # 采样率(Hz)
|
||||
COZE_AUDIO_CHANNELS: int = Field(default=1) # 声道数(单声道)
|
||||
COZE_AUDIO_BIT_DEPTH: int = Field(default=16) # 位深度
|
||||
|
||||
# 服务器公开访问域名
|
||||
PUBLIC_DOMAIN: str = Field(default="http://aiedu.ireborn.com.cn")
|
||||
|
||||
# 言迹智能工牌API配置
|
||||
YANJI_API_BASE: str = Field(default="https://open.yanjiai.com") # 正式环境
|
||||
YANJI_CLIENT_ID: str = Field(default="1Fld4LCWt2vpJNG5")
|
||||
YANJI_CLIENT_SECRET: str = Field(default="XE8w413qNtJBOdWc2aCezV0yMIHpUuTZ")
|
||||
YANJI_TENANT_ID: str = Field(default="516799409476866048")
|
||||
YANJI_ESTATE_ID: str = Field(default="516799468310364162")
|
||||
|
||||
# SCRM 系统对接 API Key(用于内部服务间调用)
|
||||
SCRM_API_KEY: str = Field(default="scrm-kpl-api-key-2026-ruixiaomei")
|
||||
|
||||
# AI 服务配置(知识点分析 V2 使用)
|
||||
# 首选服务商:4sapi.com(国内优化)
|
||||
AI_PRIMARY_API_KEY: str = Field(default="sk-9yMCXjRGANbacz20kJY8doSNy6Rf446aYwmgGIuIXQ7DAyBw") # 测试阶段 Key
|
||||
AI_PRIMARY_BASE_URL: str = Field(default="https://4sapi.com/v1")
|
||||
# 备选服务商:OpenRouter(模型全,稳定性好)
|
||||
AI_FALLBACK_API_KEY: str = Field(default="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") # 测试阶段 Key
|
||||
AI_FALLBACK_BASE_URL: str = Field(default="https://openrouter.ai/api/v1")
|
||||
# 默认模型
|
||||
AI_DEFAULT_MODEL: str = Field(default="gemini-3-flash-preview")
|
||||
# 请求超时(秒)
|
||||
AI_TIMEOUT: float = Field(default=120.0)
|
||||
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": True,
|
||||
"extra": "allow", # 允许额外的环境变量
|
||||
}
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""获取系统配置(缓存)"""
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
# ============================================
|
||||
# 动态配置获取(支持从数据库读取)
|
||||
# ============================================
|
||||
|
||||
class DynamicConfig:
|
||||
"""
|
||||
动态配置管理器
|
||||
|
||||
用于在运行时从数据库获取配置,支持热更新。
|
||||
向后兼容:如果数据库不可用,回退到环境变量配置。
|
||||
"""
|
||||
|
||||
_tenant_loader = None
|
||||
_initialized = False
|
||||
|
||||
@classmethod
|
||||
async def init(cls, redis_url: Optional[str] = None):
|
||||
"""
|
||||
初始化动态配置管理器
|
||||
|
||||
Args:
|
||||
redis_url: Redis URL(可选,用于缓存)
|
||||
"""
|
||||
if cls._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
from app.core.tenant_config import TenantConfigManager
|
||||
|
||||
if redis_url:
|
||||
await TenantConfigManager.init_redis(redis_url)
|
||||
|
||||
cls._initialized = True
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"动态配置初始化失败: {e}")
|
||||
|
||||
@classmethod
|
||||
async def get(cls, key: str, default: Any = None, tenant_code: Optional[str] = None) -> Any:
|
||||
"""
|
||||
获取配置值
|
||||
|
||||
Args:
|
||||
key: 配置键(如 AI_PRIMARY_API_KEY)
|
||||
default: 默认值
|
||||
tenant_code: 租户编码(可选,默认使用环境变量中的 TENANT_CODE)
|
||||
|
||||
Returns:
|
||||
配置值
|
||||
"""
|
||||
# 确定租户编码
|
||||
if tenant_code is None:
|
||||
tenant_code = settings.TENANT_CODE
|
||||
|
||||
# 配置键到分组的映射
|
||||
config_mapping = {
|
||||
# 数据库
|
||||
"MYSQL_HOST": ("database", "MYSQL_HOST"),
|
||||
"MYSQL_PORT": ("database", "MYSQL_PORT"),
|
||||
"MYSQL_USER": ("database", "MYSQL_USER"),
|
||||
"MYSQL_PASSWORD": ("database", "MYSQL_PASSWORD"),
|
||||
"MYSQL_DATABASE": ("database", "MYSQL_DATABASE"),
|
||||
# Redis
|
||||
"REDIS_HOST": ("redis", "REDIS_HOST"),
|
||||
"REDIS_PORT": ("redis", "REDIS_PORT"),
|
||||
"REDIS_DB": ("redis", "REDIS_DB"),
|
||||
# 安全
|
||||
"SECRET_KEY": ("security", "SECRET_KEY"),
|
||||
"CORS_ORIGINS": ("security", "CORS_ORIGINS"),
|
||||
# Coze
|
||||
"COZE_PRACTICE_BOT_ID": ("coze", "COZE_PRACTICE_BOT_ID"),
|
||||
"COZE_BROADCAST_WORKFLOW_ID": ("coze", "COZE_BROADCAST_WORKFLOW_ID"),
|
||||
"COZE_BROADCAST_SPACE_ID": ("coze", "COZE_BROADCAST_SPACE_ID"),
|
||||
"COZE_OAUTH_CLIENT_ID": ("coze", "COZE_OAUTH_CLIENT_ID"),
|
||||
"COZE_OAUTH_PUBLIC_KEY_ID": ("coze", "COZE_OAUTH_PUBLIC_KEY_ID"),
|
||||
# AI
|
||||
"AI_PRIMARY_API_KEY": ("ai", "AI_PRIMARY_API_KEY"),
|
||||
"AI_PRIMARY_BASE_URL": ("ai", "AI_PRIMARY_BASE_URL"),
|
||||
"AI_FALLBACK_API_KEY": ("ai", "AI_FALLBACK_API_KEY"),
|
||||
"AI_FALLBACK_BASE_URL": ("ai", "AI_FALLBACK_BASE_URL"),
|
||||
"AI_DEFAULT_MODEL": ("ai", "AI_DEFAULT_MODEL"),
|
||||
"AI_TIMEOUT": ("ai", "AI_TIMEOUT"),
|
||||
# 言迹
|
||||
"YANJI_CLIENT_ID": ("yanji", "YANJI_CLIENT_ID"),
|
||||
"YANJI_CLIENT_SECRET": ("yanji", "YANJI_CLIENT_SECRET"),
|
||||
"YANJI_TENANT_ID": ("yanji", "YANJI_TENANT_ID"),
|
||||
"YANJI_ESTATE_ID": ("yanji", "YANJI_ESTATE_ID"),
|
||||
}
|
||||
|
||||
# 尝试从数据库获取
|
||||
if cls._initialized and key in config_mapping:
|
||||
try:
|
||||
from app.core.tenant_config import TenantConfigManager
|
||||
|
||||
config_group, config_key = config_mapping[key]
|
||||
loader = TenantConfigManager.get_loader(tenant_code)
|
||||
value = await loader.get_config(config_group, config_key)
|
||||
|
||||
if value is not None:
|
||||
return value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退到环境变量 / Settings
|
||||
env_value = getattr(settings, key, None)
|
||||
if env_value is not None:
|
||||
return env_value
|
||||
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
async def is_feature_enabled(cls, feature_code: str, tenant_code: Optional[str] = None) -> bool:
|
||||
"""
|
||||
检查功能是否启用
|
||||
|
||||
Args:
|
||||
feature_code: 功能编码
|
||||
tenant_code: 租户编码
|
||||
|
||||
Returns:
|
||||
是否启用
|
||||
"""
|
||||
if tenant_code is None:
|
||||
tenant_code = settings.TENANT_CODE
|
||||
|
||||
if cls._initialized:
|
||||
try:
|
||||
from app.core.tenant_config import TenantConfigManager
|
||||
|
||||
loader = TenantConfigManager.get_loader(tenant_code)
|
||||
return await loader.is_feature_enabled(feature_code)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True # 默认启用
|
||||
|
||||
@classmethod
|
||||
async def refresh_cache(cls, tenant_code: Optional[str] = None):
|
||||
"""
|
||||
刷新配置缓存
|
||||
|
||||
Args:
|
||||
tenant_code: 租户编码(为空则刷新所有)
|
||||
"""
|
||||
if not cls._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
from app.core.tenant_config import TenantConfigManager
|
||||
|
||||
if tenant_code:
|
||||
await TenantConfigManager.refresh_tenant_cache(tenant_code)
|
||||
else:
|
||||
await TenantConfigManager.refresh_all_cache()
|
||||
except Exception:
|
||||
pass
|
||||
31
backend/app/core/database.py
Normal file
31
backend/app/core/database.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
数据库配置
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from .config import settings
|
||||
|
||||
# 创建异步引擎
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
# 确保 MySQL 连接使用 UTF-8 字符集
|
||||
connect_args={
|
||||
"charset": "utf8mb4",
|
||||
"use_unicode": True,
|
||||
"autocommit": False,
|
||||
"init_command": "SET character_set_client=utf8mb4, character_set_connection=utf8mb4, character_set_results=utf8mb4, collation_connection=utf8mb4_unicode_ci",
|
||||
} if "mysql" in settings.database_url else {},
|
||||
)
|
||||
|
||||
# 创建异步会话工厂
|
||||
AsyncSessionLocal = sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
166
backend/app/core/deps.py
Normal file
166
backend/app/core/deps.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""依赖注入模块"""
|
||||
from typing import AsyncGenerator, Optional
|
||||
from sqlalchemy import select
|
||||
import redis.asyncio as redis
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import AsyncSessionLocal
|
||||
from app.core.config import get_settings
|
||||
from app.models.user import User
|
||||
|
||||
# JWT Bearer认证
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
获取数据库会话
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""
|
||||
获取当前用户(基于JWT)
|
||||
|
||||
- 从 Authorization Bearer Token 中解析用户ID
|
||||
- 查询数据库返回完整的 User 对象
|
||||
- 失败时抛出 401 未授权
|
||||
"""
|
||||
from app.core.security import decode_token # 延迟导入避免循环依赖
|
||||
|
||||
if not credentials or not credentials.scheme or not credentials.credentials:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证信息")
|
||||
|
||||
if credentials.scheme.lower() != "bearer":
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="认证方式不支持")
|
||||
|
||||
token = credentials.credentials
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
user_id = int(payload.get("sub"))
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的令牌")
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user_id, User.is_deleted == False)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已被禁用"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def require_admin(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
需要管理员权限
|
||||
"""
|
||||
if getattr(current_user, "role", None) != "admin":
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
|
||||
return current_user
|
||||
|
||||
|
||||
async def require_admin_or_manager(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
需要管理者或管理员权限
|
||||
"""
|
||||
if getattr(current_user, "role", None) not in ("admin", "manager"):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理者或管理员权限")
|
||||
return current_user
|
||||
|
||||
async def get_optional_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Optional[User]:
|
||||
"""
|
||||
获取可选的当前用户(不强制登录)
|
||||
"""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await get_current_user(credentials, db)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> User:
|
||||
"""
|
||||
获取当前活跃用户
|
||||
"""
|
||||
# TODO: 检查用户是否被禁用
|
||||
return current_user
|
||||
|
||||
|
||||
async def verify_scrm_api_key(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
) -> bool:
|
||||
"""
|
||||
验证 SCRM 系统 API Key
|
||||
|
||||
用于内部服务间调用认证,SCRM 系统通过固定 API Key 访问考陪练数据查询接口
|
||||
请求头格式: Authorization: Bearer {SCRM_API_KEY}
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
if not credentials or not credentials.credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未提供认证信息"
|
||||
)
|
||||
|
||||
if credentials.scheme.lower() != "bearer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="认证方式不支持,需要 Bearer Token"
|
||||
)
|
||||
|
||||
if credentials.credentials != settings.SCRM_API_KEY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的 API Key"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Redis 连接池
|
||||
_redis_pool: Optional[redis.ConnectionPool] = None
|
||||
|
||||
|
||||
async def get_redis() -> AsyncGenerator[redis.Redis, None]:
|
||||
"""
|
||||
获取 Redis 连接
|
||||
"""
|
||||
global _redis_pool
|
||||
|
||||
if _redis_pool is None:
|
||||
settings = get_settings()
|
||||
_redis_pool = redis.ConnectionPool.from_url(
|
||||
settings.REDIS_URL, encoding="utf-8", decode_responses=True
|
||||
)
|
||||
|
||||
client = redis.Redis(connection_pool=_redis_pool)
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
await client.close()
|
||||
28
backend/app/core/events.py
Normal file
28
backend/app/core/events.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
应用生命周期事件处理
|
||||
"""
|
||||
from app.core.logger import logger
|
||||
|
||||
|
||||
async def startup_handler():
|
||||
"""应用启动时执行的任务"""
|
||||
logger.info("执行启动任务...")
|
||||
|
||||
# TODO: 初始化数据库连接池
|
||||
# TODO: 初始化Redis连接
|
||||
# TODO: 初始化AI平台客户端
|
||||
# TODO: 加载缓存数据
|
||||
|
||||
logger.info("启动任务完成")
|
||||
|
||||
|
||||
async def shutdown_handler():
|
||||
"""应用关闭时执行的任务"""
|
||||
logger.info("执行关闭任务...")
|
||||
|
||||
# TODO: 关闭数据库连接池
|
||||
# TODO: 关闭Redis连接
|
||||
# TODO: 清理临时文件
|
||||
# TODO: 保存应用状态
|
||||
|
||||
logger.info("关闭任务完成")
|
||||
89
backend/app/core/exceptions.py
Normal file
89
backend/app/core/exceptions.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""统一异常定义"""
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
class BusinessError(HTTPException):
|
||||
"""业务异常基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: int = status.HTTP_400_BAD_REQUEST,
|
||||
error_code: Optional[str] = None,
|
||||
detail: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=code,
|
||||
detail={
|
||||
"message": message,
|
||||
"error_code": error_code or f"ERR_{code}",
|
||||
"detail": detail,
|
||||
},
|
||||
)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class BadRequestError(BusinessError):
|
||||
"""400 错误请求"""
|
||||
|
||||
def __init__(self, message: str = "错误的请求", **kwargs):
|
||||
super().__init__(message, status.HTTP_400_BAD_REQUEST, **kwargs)
|
||||
|
||||
|
||||
class UnauthorizedError(BusinessError):
|
||||
"""401 未授权"""
|
||||
|
||||
def __init__(self, message: str = "未授权", **kwargs):
|
||||
super().__init__(message, status.HTTP_401_UNAUTHORIZED, **kwargs)
|
||||
|
||||
|
||||
class ForbiddenError(BusinessError):
|
||||
"""403 禁止访问"""
|
||||
|
||||
def __init__(self, message: str = "禁止访问", **kwargs):
|
||||
super().__init__(message, status.HTTP_403_FORBIDDEN, **kwargs)
|
||||
|
||||
|
||||
class NotFoundError(BusinessError):
|
||||
"""404 未找到"""
|
||||
|
||||
def __init__(self, message: str = "资源未找到", **kwargs):
|
||||
super().__init__(message, status.HTTP_404_NOT_FOUND, **kwargs)
|
||||
|
||||
|
||||
class ConflictError(BusinessError):
|
||||
"""409 冲突"""
|
||||
|
||||
def __init__(self, message: str = "资源冲突", **kwargs):
|
||||
super().__init__(message, status.HTTP_409_CONFLICT, **kwargs)
|
||||
|
||||
|
||||
class ValidationError(BusinessError):
|
||||
"""422 验证错误"""
|
||||
|
||||
def __init__(self, message: str = "验证失败", **kwargs):
|
||||
super().__init__(message, status.HTTP_422_UNPROCESSABLE_ENTITY, **kwargs)
|
||||
|
||||
|
||||
class InternalServerError(BusinessError):
|
||||
"""500 内部服务器错误"""
|
||||
|
||||
def __init__(self, message: str = "内部服务器错误", **kwargs):
|
||||
super().__init__(message, status.HTTP_500_INTERNAL_SERVER_ERROR, **kwargs)
|
||||
|
||||
|
||||
class InsufficientPermissionsError(ForbiddenError):
|
||||
"""权限不足"""
|
||||
|
||||
def __init__(self, message: str = "权限不足", **kwargs):
|
||||
super().__init__(message, error_code="INSUFFICIENT_PERMISSIONS", **kwargs)
|
||||
|
||||
|
||||
class ExternalServiceError(BusinessError):
|
||||
"""外部服务错误"""
|
||||
|
||||
def __init__(self, message: str = "外部服务异常", **kwargs):
|
||||
super().__init__(message, status.HTTP_502_BAD_GATEWAY, error_code="EXTERNAL_SERVICE_ERROR", **kwargs)
|
||||
76
backend/app/core/logger.py
Normal file
76
backend/app/core/logger.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
日志配置
|
||||
"""
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from structlog.stdlib import LoggerFactory
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
def setup_logging():
|
||||
"""
|
||||
配置日志系统
|
||||
"""
|
||||
# 设置日志级别
|
||||
log_level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)
|
||||
|
||||
# 配置标准库日志
|
||||
logging.basicConfig(
|
||||
format="%(message)s",
|
||||
stream=sys.stdout,
|
||||
level=log_level,
|
||||
)
|
||||
|
||||
# 配置处理器
|
||||
processors = [
|
||||
structlog.stdlib.filter_by_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
]
|
||||
|
||||
# 根据配置选择输出格式
|
||||
if getattr(settings, "LOG_FORMAT", "text") == "json":
|
||||
processors.append(structlog.processors.JSONRenderer())
|
||||
else:
|
||||
processors.append(structlog.dev.ConsoleRenderer())
|
||||
|
||||
# 配置 structlog
|
||||
structlog.configure(
|
||||
processors=processors,
|
||||
context_class=dict,
|
||||
logger_factory=LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
|
||||
# 设置日志
|
||||
setup_logging()
|
||||
|
||||
|
||||
# 获取日志器
|
||||
def get_logger(name: str = __name__) -> Any:
|
||||
"""
|
||||
获取日志器
|
||||
|
||||
Args:
|
||||
name: 日志器名称
|
||||
|
||||
Returns:
|
||||
日志器实例
|
||||
"""
|
||||
return structlog.get_logger(name)
|
||||
|
||||
|
||||
# 默认日志器
|
||||
logger = get_logger("app")
|
||||
64
backend/app/core/middleware.py
Normal file
64
backend/app/core/middleware.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
中间件定义
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.core.logger import logger
|
||||
|
||||
|
||||
class RequestIDMiddleware(BaseHTTPMiddleware):
|
||||
"""请求ID中间件"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 生成请求ID
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# 将请求ID添加到request状态
|
||||
request.state.request_id = request_id
|
||||
|
||||
# 记录请求开始
|
||||
start_time = time.time()
|
||||
|
||||
# 处理请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 计算处理时间
|
||||
process_time = time.time() - start_time
|
||||
|
||||
# 添加响应头
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
# 记录请求日志
|
||||
logger.info(
|
||||
"HTTP请求",
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
status_code=response.status_code,
|
||||
process_time=process_time,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class GlobalContextMiddleware(BaseHTTPMiddleware):
|
||||
"""全局上下文中间件"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 设置追踪ID(用于分布式追踪)
|
||||
trace_id = request.headers.get("X-Trace-ID", str(uuid.uuid4()))
|
||||
request.state.trace_id = trace_id
|
||||
|
||||
# 处理请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 添加追踪ID到响应头
|
||||
response.headers["X-Trace-ID"] = trace_id
|
||||
|
||||
return response
|
||||
44
backend/app/core/redis.py
Normal file
44
backend/app/core/redis.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Redis连接管理
|
||||
"""
|
||||
from typing import Optional
|
||||
from redis import asyncio as aioredis
|
||||
from app.core.config import settings
|
||||
from app.core.logger import logger
|
||||
|
||||
# 全局Redis连接实例
|
||||
redis_client: Optional[aioredis.Redis] = None
|
||||
|
||||
|
||||
async def init_redis() -> aioredis.Redis:
|
||||
"""初始化Redis连接"""
|
||||
global redis_client
|
||||
|
||||
try:
|
||||
redis_client = await aioredis.from_url(
|
||||
settings.REDIS_URL, encoding="utf-8", decode_responses=True
|
||||
)
|
||||
# 测试连接
|
||||
await redis_client.ping()
|
||||
logger.info("Redis连接成功", url=settings.REDIS_URL)
|
||||
return redis_client
|
||||
except Exception as e:
|
||||
logger.error("Redis连接失败", error=str(e), url=settings.REDIS_URL)
|
||||
raise
|
||||
|
||||
|
||||
async def close_redis():
|
||||
"""关闭Redis连接"""
|
||||
global redis_client
|
||||
|
||||
if redis_client:
|
||||
await redis_client.close()
|
||||
logger.info("Redis连接已关闭")
|
||||
redis_client = None
|
||||
|
||||
|
||||
def get_redis_client() -> aioredis.Redis:
|
||||
"""获取Redis客户端实例"""
|
||||
if not redis_client:
|
||||
raise RuntimeError("Redis client not initialized")
|
||||
return redis_client
|
||||
72
backend/app/core/security.py
Normal file
72
backend/app/core/security.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
安全相关功能
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import bcrypt
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
) -> str:
|
||||
"""创建访问令牌"""
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
|
||||
to_encode = {"exp": expire, "sub": str(subject), "type": "access"}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: Union[str, Any],
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
) -> str:
|
||||
"""创建刷新令牌"""
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str) -> Dict[str, Any]:
|
||||
"""解码令牌"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
|
||||
)
|
||||
return payload
|
||||
except JWTError:
|
||||
raise ValueError("Invalid token")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""生成密码哈希"""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_password = bcrypt.hashpw(password.encode("utf-8"), salt)
|
||||
return hashed_password.decode("utf-8")
|
||||
81
backend/app/core/simple_auth.py
Normal file
81
backend/app/core/simple_auth.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
简化认证中间件 - 支持 API Key 和长期 Token
|
||||
用于内部服务间调用
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Header, status
|
||||
from app.models.user import User
|
||||
|
||||
# 配置 API Keys(用于内部服务调用)
|
||||
API_KEYS = {
|
||||
"internal-service-2025-kaopeilian": {
|
||||
"service": "internal",
|
||||
"user_id": 1,
|
||||
"username": "internal_service",
|
||||
"role": "admin"
|
||||
}
|
||||
}
|
||||
|
||||
# 长期有效的 Token(用于内部服务调用)
|
||||
LONG_TERM_TOKENS = {
|
||||
"permanent-token-for-internal-2025": {
|
||||
"service": "internal",
|
||||
"user_id": 1,
|
||||
"username": "internal_service",
|
||||
"role": "admin"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_current_user_by_api_key(
|
||||
x_api_key: Optional[str] = Header(None),
|
||||
authorization: Optional[str] = Header(None)
|
||||
) -> Optional[User]:
|
||||
"""
|
||||
通过 API Key 或长期 Token 获取用户
|
||||
支持两种方式:
|
||||
1. X-API-Key: internal-service-2025-kaopeilian
|
||||
2. Authorization: Bearer permanent-token-for-internal-2025
|
||||
"""
|
||||
|
||||
# 方式1:检查 API Key
|
||||
if x_api_key and x_api_key in API_KEYS:
|
||||
api_key_info = API_KEYS[x_api_key]
|
||||
# 创建一个虚拟用户对象
|
||||
user = User()
|
||||
user.id = api_key_info["user_id"]
|
||||
user.username = api_key_info["username"]
|
||||
user.role = api_key_info["role"]
|
||||
return user
|
||||
|
||||
# 方式2:检查长期 Token
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
if token in LONG_TERM_TOKENS:
|
||||
token_info = LONG_TERM_TOKENS[token]
|
||||
user = User()
|
||||
user.id = token_info["user_id"]
|
||||
user.username = token_info["username"]
|
||||
user.role = token_info["role"]
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_current_user_simple(
|
||||
x_api_key: Optional[str] = Header(None),
|
||||
authorization: Optional[str] = Header(None)
|
||||
) -> User:
|
||||
"""
|
||||
简化的用户认证依赖项
|
||||
"""
|
||||
# 尝试 API Key 或长期 Token 认证
|
||||
user = get_current_user_by_api_key(x_api_key, authorization)
|
||||
if user:
|
||||
return user
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing authentication credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
421
backend/app/core/tenant_config.py
Normal file
421
backend/app/core/tenant_config.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
租户配置加载器
|
||||
|
||||
功能:
|
||||
1. 从数据库 tenant_configs 表加载租户配置
|
||||
2. 支持 Redis 缓存
|
||||
3. 数据库不可用时回退到环境变量
|
||||
4. 支持配置热更新
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from functools import lru_cache
|
||||
|
||||
import aiomysql
|
||||
import redis.asyncio as redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================
|
||||
# 平台管理库连接配置
|
||||
#
|
||||
# 注意:敏感信息必须通过环境变量传递,禁止硬编码
|
||||
# 参考:瑞小美系统技术栈标准与字符标准.md - 敏感信息管理
|
||||
# ============================================
|
||||
ADMIN_DB_CONFIG = {
|
||||
"host": os.getenv("ADMIN_DB_HOST", "prod-mysql"),
|
||||
"port": int(os.getenv("ADMIN_DB_PORT", "3306")),
|
||||
"user": os.getenv("ADMIN_DB_USER", "root"),
|
||||
"password": os.getenv("ADMIN_DB_PASSWORD"), # 必须从环境变量获取
|
||||
"db": os.getenv("ADMIN_DB_NAME", "kaopeilian_admin"),
|
||||
"charset": "utf8mb4",
|
||||
}
|
||||
|
||||
# 校验必填环境变量
|
||||
if not ADMIN_DB_CONFIG["password"]:
|
||||
logger.warning(
|
||||
"ADMIN_DB_PASSWORD 环境变量未设置,租户配置加载功能将不可用。"
|
||||
"请在 .env.admin 文件中配置此变量。"
|
||||
)
|
||||
|
||||
# Redis 缓存配置
|
||||
CACHE_PREFIX = "tenant_config:"
|
||||
CACHE_TTL = 300 # 5分钟缓存
|
||||
|
||||
|
||||
class TenantConfigLoader:
|
||||
"""租户配置加载器"""
|
||||
|
||||
def __init__(self, tenant_code: str, redis_client: Optional[redis.Redis] = None):
|
||||
"""
|
||||
初始化租户配置加载器
|
||||
|
||||
Args:
|
||||
tenant_code: 租户编码(如 hua, yy, hl)
|
||||
redis_client: Redis 客户端(可选)
|
||||
"""
|
||||
self.tenant_code = tenant_code
|
||||
self.redis_client = redis_client
|
||||
self._config_cache: Dict[str, Any] = {}
|
||||
self._tenant_id: Optional[int] = None
|
||||
|
||||
async def get_config(self, config_group: str, config_key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
获取配置项
|
||||
|
||||
优先级:
|
||||
1. 内存缓存
|
||||
2. Redis 缓存
|
||||
3. 数据库
|
||||
4. 环境变量
|
||||
5. 默认值
|
||||
|
||||
Args:
|
||||
config_group: 配置分组(database, redis, coze, ai, yanji, security)
|
||||
config_key: 配置键
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
配置值
|
||||
"""
|
||||
cache_key = f"{config_group}.{config_key}"
|
||||
|
||||
# 1. 内存缓存
|
||||
if cache_key in self._config_cache:
|
||||
return self._config_cache[cache_key]
|
||||
|
||||
# 2. Redis 缓存
|
||||
if self.redis_client:
|
||||
try:
|
||||
redis_key = f"{CACHE_PREFIX}{self.tenant_code}:{cache_key}"
|
||||
cached_value = await self.redis_client.get(redis_key)
|
||||
if cached_value:
|
||||
value = json.loads(cached_value)
|
||||
self._config_cache[cache_key] = value
|
||||
return value
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 缓存读取失败: {e}")
|
||||
|
||||
# 3. 数据库
|
||||
try:
|
||||
value = await self._get_from_database(config_group, config_key)
|
||||
if value is not None:
|
||||
self._config_cache[cache_key] = value
|
||||
# 写入 Redis 缓存
|
||||
if self.redis_client:
|
||||
try:
|
||||
redis_key = f"{CACHE_PREFIX}{self.tenant_code}:{cache_key}"
|
||||
await self.redis_client.setex(
|
||||
redis_key,
|
||||
CACHE_TTL,
|
||||
json.dumps(value)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 缓存写入失败: {e}")
|
||||
return value
|
||||
except Exception as e:
|
||||
logger.warning(f"数据库配置读取失败: {e}")
|
||||
|
||||
# 4. 环境变量
|
||||
env_value = os.getenv(config_key)
|
||||
if env_value is not None:
|
||||
return env_value
|
||||
|
||||
# 5. 默认值
|
||||
return default
|
||||
|
||||
async def _get_from_database(self, config_group: str, config_key: str) -> Optional[Any]:
|
||||
"""从数据库获取配置"""
|
||||
conn = None
|
||||
try:
|
||||
conn = await aiomysql.connect(**ADMIN_DB_CONFIG)
|
||||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||
# 获取租户 ID
|
||||
if self._tenant_id is None:
|
||||
await cursor.execute(
|
||||
"SELECT id FROM tenants WHERE code = %s AND status = 'active'",
|
||||
(self.tenant_code,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
self._tenant_id = row['id']
|
||||
else:
|
||||
return None
|
||||
|
||||
# 获取配置值
|
||||
await cursor.execute(
|
||||
"""
|
||||
SELECT config_value, value_type, is_encrypted
|
||||
FROM tenant_configs
|
||||
WHERE tenant_id = %s AND config_group = %s AND config_key = %s
|
||||
""",
|
||||
(self._tenant_id, config_group, config_key)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return self._parse_value(row['config_value'], row['value_type'], row['is_encrypted'])
|
||||
|
||||
# 如果租户没有配置,获取默认值
|
||||
await cursor.execute(
|
||||
"""
|
||||
SELECT default_value, value_type
|
||||
FROM config_templates
|
||||
WHERE config_group = %s AND config_key = %s
|
||||
""",
|
||||
(config_group, config_key)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row and row['default_value']:
|
||||
return self._parse_value(row['default_value'], row['value_type'], False)
|
||||
|
||||
return None
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def _parse_value(self, value: str, value_type: str, is_encrypted: bool) -> Any:
|
||||
"""解析配置值"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# TODO: 如果是加密值,先解密
|
||||
if is_encrypted:
|
||||
# 这里可以实现解密逻辑
|
||||
pass
|
||||
|
||||
if value_type == 'int':
|
||||
return int(value)
|
||||
elif value_type == 'bool':
|
||||
return value.lower() in ('true', '1', 'yes')
|
||||
elif value_type == 'json':
|
||||
return json.loads(value)
|
||||
elif value_type == 'float':
|
||||
return float(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
async def get_all_configs(self) -> Dict[str, Any]:
|
||||
"""获取租户的所有配置"""
|
||||
configs = {}
|
||||
conn = None
|
||||
try:
|
||||
conn = await aiomysql.connect(**ADMIN_DB_CONFIG)
|
||||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||
# 获取租户 ID
|
||||
await cursor.execute(
|
||||
"SELECT id FROM tenants WHERE code = %s AND status = 'active'",
|
||||
(self.tenant_code,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if not row:
|
||||
return configs
|
||||
|
||||
tenant_id = row['id']
|
||||
|
||||
# 获取所有配置
|
||||
await cursor.execute(
|
||||
"""
|
||||
SELECT config_group, config_key, config_value, value_type, is_encrypted
|
||||
FROM tenant_configs
|
||||
WHERE tenant_id = %s
|
||||
""",
|
||||
(tenant_id,)
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
for row in rows:
|
||||
key = f"{row['config_group']}.{row['config_key']}"
|
||||
configs[key] = self._parse_value(
|
||||
row['config_value'],
|
||||
row['value_type'],
|
||||
row['is_encrypted']
|
||||
)
|
||||
|
||||
return configs
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
async def refresh_cache(self):
|
||||
"""刷新缓存"""
|
||||
self._config_cache.clear()
|
||||
|
||||
if self.redis_client:
|
||||
try:
|
||||
# 删除该租户的所有缓存
|
||||
pattern = f"{CACHE_PREFIX}{self.tenant_code}:*"
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await self.redis_client.scan(cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
await self.redis_client.delete(*keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 缓存刷新失败: {e}")
|
||||
|
||||
async def is_feature_enabled(self, feature_code: str) -> bool:
|
||||
"""
|
||||
检查功能是否启用
|
||||
|
||||
Args:
|
||||
feature_code: 功能编码
|
||||
|
||||
Returns:
|
||||
是否启用
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = await aiomysql.connect(**ADMIN_DB_CONFIG)
|
||||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||||
# 获取租户 ID
|
||||
if self._tenant_id is None:
|
||||
await cursor.execute(
|
||||
"SELECT id FROM tenants WHERE code = %s AND status = 'active'",
|
||||
(self.tenant_code,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
self._tenant_id = row['id']
|
||||
|
||||
# 先查租户级别的配置
|
||||
if self._tenant_id:
|
||||
await cursor.execute(
|
||||
"""
|
||||
SELECT is_enabled FROM feature_switches
|
||||
WHERE tenant_id = %s AND feature_code = %s
|
||||
""",
|
||||
(self._tenant_id, feature_code)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return bool(row['is_enabled'])
|
||||
|
||||
# 再查全局默认配置
|
||||
await cursor.execute(
|
||||
"""
|
||||
SELECT is_enabled FROM feature_switches
|
||||
WHERE tenant_id IS NULL AND feature_code = %s
|
||||
""",
|
||||
(feature_code,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return bool(row['is_enabled'])
|
||||
|
||||
return True # 默认启用
|
||||
except Exception as e:
|
||||
logger.warning(f"功能开关查询失败: {e}, 默认启用")
|
||||
return True
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
class TenantConfigManager:
|
||||
"""租户配置管理器(单例)"""
|
||||
|
||||
_instances: Dict[str, TenantConfigLoader] = {}
|
||||
_redis_client: Optional[redis.Redis] = None
|
||||
|
||||
@classmethod
|
||||
async def init_redis(cls, redis_url: str):
|
||||
"""初始化 Redis 连接"""
|
||||
try:
|
||||
cls._redis_client = redis.from_url(redis_url)
|
||||
await cls._redis_client.ping()
|
||||
logger.info("TenantConfigManager Redis 连接成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"TenantConfigManager Redis 连接失败: {e}")
|
||||
cls._redis_client = None
|
||||
|
||||
@classmethod
|
||||
def get_loader(cls, tenant_code: str) -> TenantConfigLoader:
|
||||
"""获取租户配置加载器"""
|
||||
if tenant_code not in cls._instances:
|
||||
cls._instances[tenant_code] = TenantConfigLoader(
|
||||
tenant_code,
|
||||
cls._redis_client
|
||||
)
|
||||
return cls._instances[tenant_code]
|
||||
|
||||
@classmethod
|
||||
async def refresh_tenant_cache(cls, tenant_code: str):
|
||||
"""刷新指定租户的缓存"""
|
||||
if tenant_code in cls._instances:
|
||||
await cls._instances[tenant_code].refresh_cache()
|
||||
|
||||
@classmethod
|
||||
async def refresh_all_cache(cls):
|
||||
"""刷新所有租户的缓存"""
|
||||
for loader in cls._instances.values():
|
||||
await loader.refresh_cache()
|
||||
|
||||
|
||||
# ============================================
|
||||
# 辅助函数
|
||||
# ============================================
|
||||
|
||||
def get_tenant_code_from_domain(domain: str) -> str:
|
||||
"""
|
||||
从域名提取租户编码
|
||||
|
||||
Examples:
|
||||
hua.ireborn.com.cn -> hua
|
||||
yy.ireborn.com.cn -> yy
|
||||
aiedu.ireborn.com.cn -> demo
|
||||
"""
|
||||
if not domain:
|
||||
return "demo"
|
||||
|
||||
# 移除 https:// 或 http://
|
||||
domain = domain.replace("https://", "").replace("http://", "")
|
||||
|
||||
# 获取子域名
|
||||
parts = domain.split(".")
|
||||
if len(parts) >= 3:
|
||||
subdomain = parts[0]
|
||||
# 特殊处理
|
||||
if subdomain == "aiedu":
|
||||
return "demo"
|
||||
return subdomain
|
||||
|
||||
return "demo"
|
||||
|
||||
|
||||
async def get_tenant_config(tenant_code: str, config_group: str, config_key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
快捷函数:获取租户配置
|
||||
|
||||
Args:
|
||||
tenant_code: 租户编码
|
||||
config_group: 配置分组
|
||||
config_key: 配置键
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
配置值
|
||||
"""
|
||||
loader = TenantConfigManager.get_loader(tenant_code)
|
||||
return await loader.get_config(config_group, config_key, default)
|
||||
|
||||
|
||||
async def is_tenant_feature_enabled(tenant_code: str, feature_code: str) -> bool:
|
||||
"""
|
||||
快捷函数:检查租户功能是否启用
|
||||
|
||||
Args:
|
||||
tenant_code: 租户编码
|
||||
feature_code: 功能编码
|
||||
|
||||
Returns:
|
||||
是否启用
|
||||
"""
|
||||
loader = TenantConfigManager.get_loader(tenant_code)
|
||||
return await loader.is_feature_enabled(feature_code)
|
||||
|
||||
Reference in New Issue
Block a user