- 从服务器拉取完整代码 - 按框架规范整理项目结构 - 配置 Drone CI 测试环境部署 - 包含后端(FastAPI)、前端(Vue3)、管理端 技术栈: Vue3 + TypeScript + FastAPI + MySQL
167 lines
4.8 KiB
Python
167 lines
4.8 KiB
Python
"""依赖注入模块"""
|
||
from typing import AsyncGenerator, Optional
|
||
from sqlalchemy import select
|
||
import redis.asyncio as redis
|
||
|
||
from fastapi import Depends, HTTPException, status
|
||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.database import AsyncSessionLocal
|
||
from app.core.config import get_settings
|
||
from app.models.user import User
|
||
|
||
# JWT Bearer认证
|
||
security = HTTPBearer()
|
||
|
||
|
||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||
"""
|
||
获取数据库会话
|
||
"""
|
||
async with AsyncSessionLocal() as session:
|
||
try:
|
||
yield session
|
||
await session.commit()
|
||
except Exception:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
await session.close()
|
||
|
||
|
||
async def get_current_user(
|
||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> User:
|
||
"""
|
||
获取当前用户(基于JWT)
|
||
|
||
- 从 Authorization Bearer Token 中解析用户ID
|
||
- 查询数据库返回完整的 User 对象
|
||
- 失败时抛出 401 未授权
|
||
"""
|
||
from app.core.security import decode_token # 延迟导入避免循环依赖
|
||
|
||
if not credentials or not credentials.scheme or not credentials.credentials:
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证信息")
|
||
|
||
if credentials.scheme.lower() != "bearer":
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="认证方式不支持")
|
||
|
||
token = credentials.credentials
|
||
try:
|
||
payload = decode_token(token)
|
||
user_id = int(payload.get("sub"))
|
||
except Exception:
|
||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的令牌")
|
||
|
||
result = await db.execute(
|
||
select(User).where(User.id == user_id, User.is_deleted == False)
|
||
)
|
||
user = result.scalar_one_or_none()
|
||
if not user or not user.is_active:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已被禁用"
|
||
)
|
||
|
||
return user
|
||
|
||
|
||
async def require_admin(current_user: User = Depends(get_current_user)) -> User:
|
||
"""
|
||
需要管理员权限
|
||
"""
|
||
if getattr(current_user, "role", None) != "admin":
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
|
||
return current_user
|
||
|
||
|
||
async def require_admin_or_manager(current_user: User = Depends(get_current_user)) -> User:
|
||
"""
|
||
需要管理者或管理员权限
|
||
"""
|
||
if getattr(current_user, "role", None) not in ("admin", "manager"):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理者或管理员权限")
|
||
return current_user
|
||
|
||
async def get_optional_user(
|
||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||
db: AsyncSession = Depends(get_db),
|
||
) -> Optional[User]:
|
||
"""
|
||
获取可选的当前用户(不强制登录)
|
||
"""
|
||
if not credentials:
|
||
return None
|
||
|
||
try:
|
||
return await get_current_user(credentials, db)
|
||
except:
|
||
return None
|
||
|
||
|
||
async def get_current_active_user(
|
||
current_user: User = Depends(get_current_user),
|
||
) -> User:
|
||
"""
|
||
获取当前活跃用户
|
||
"""
|
||
# TODO: 检查用户是否被禁用
|
||
return current_user
|
||
|
||
|
||
async def verify_scrm_api_key(
|
||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||
) -> bool:
|
||
"""
|
||
验证 SCRM 系统 API Key
|
||
|
||
用于内部服务间调用认证,SCRM 系统通过固定 API Key 访问考陪练数据查询接口
|
||
请求头格式: Authorization: Bearer {SCRM_API_KEY}
|
||
"""
|
||
settings = get_settings()
|
||
|
||
if not credentials or not credentials.credentials:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="未提供认证信息"
|
||
)
|
||
|
||
if credentials.scheme.lower() != "bearer":
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="认证方式不支持,需要 Bearer Token"
|
||
)
|
||
|
||
if credentials.credentials != settings.SCRM_API_KEY:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的 API Key"
|
||
)
|
||
|
||
return True
|
||
|
||
|
||
# Redis 连接池
|
||
_redis_pool: Optional[redis.ConnectionPool] = None
|
||
|
||
|
||
async def get_redis() -> AsyncGenerator[redis.Redis, None]:
|
||
"""
|
||
获取 Redis 连接
|
||
"""
|
||
global _redis_pool
|
||
|
||
if _redis_pool is None:
|
||
settings = get_settings()
|
||
_redis_pool = redis.ConnectionPool.from_url(
|
||
settings.REDIS_URL, encoding="utf-8", decode_responses=True
|
||
)
|
||
|
||
client = redis.Redis(connection_pool=_redis_pool)
|
||
try:
|
||
yield client
|
||
finally:
|
||
await client.close()
|