Files
012-kaopeilian/backend/app/core/deps.py
111 998211c483 feat: 初始化考培练系统项目
- 从服务器拉取完整代码
- 按框架规范整理项目结构
- 配置 Drone CI 测试环境部署
- 包含后端(FastAPI)、前端(Vue3)、管理端

技术栈: Vue3 + TypeScript + FastAPI + MySQL
2026-01-24 19:33:28 +08:00

167 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""依赖注入模块"""
from typing import AsyncGenerator, Optional
from sqlalchemy import select
import redis.asyncio as redis
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import AsyncSessionLocal
from app.core.config import get_settings
from app.models.user import User
# JWT Bearer认证
security = HTTPBearer()
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
获取数据库会话
"""
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
) -> User:
"""
获取当前用户基于JWT
- 从 Authorization Bearer Token 中解析用户ID
- 查询数据库返回完整的 User 对象
- 失败时抛出 401 未授权
"""
from app.core.security import decode_token # 延迟导入避免循环依赖
if not credentials or not credentials.scheme or not credentials.credentials:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="未提供认证信息")
if credentials.scheme.lower() != "bearer":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="认证方式不支持")
token = credentials.credentials
try:
payload = decode_token(token)
user_id = int(payload.get("sub"))
except Exception:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的令牌")
result = await db.execute(
select(User).where(User.id == user_id, User.is_deleted == False)
)
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在或已被禁用"
)
return user
async def require_admin(current_user: User = Depends(get_current_user)) -> User:
"""
需要管理员权限
"""
if getattr(current_user, "role", None) != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理员权限")
return current_user
async def require_admin_or_manager(current_user: User = Depends(get_current_user)) -> User:
"""
需要管理者或管理员权限
"""
if getattr(current_user, "role", None) not in ("admin", "manager"):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="需要管理者或管理员权限")
return current_user
async def get_optional_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
db: AsyncSession = Depends(get_db),
) -> Optional[User]:
"""
获取可选的当前用户(不强制登录)
"""
if not credentials:
return None
try:
return await get_current_user(credentials, db)
except:
return None
async def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
"""
获取当前活跃用户
"""
# TODO: 检查用户是否被禁用
return current_user
async def verify_scrm_api_key(
credentials: HTTPAuthorizationCredentials = Depends(security),
) -> bool:
"""
验证 SCRM 系统 API Key
用于内部服务间调用认证SCRM 系统通过固定 API Key 访问考陪练数据查询接口
请求头格式: Authorization: Bearer {SCRM_API_KEY}
"""
settings = get_settings()
if not credentials or not credentials.credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证信息"
)
if credentials.scheme.lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="认证方式不支持,需要 Bearer Token"
)
if credentials.credentials != settings.SCRM_API_KEY:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的 API Key"
)
return True
# Redis 连接池
_redis_pool: Optional[redis.ConnectionPool] = None
async def get_redis() -> AsyncGenerator[redis.Redis, None]:
"""
获取 Redis 连接
"""
global _redis_pool
if _redis_pool is None:
settings = get_settings()
_redis_pool = redis.ConnectionPool.from_url(
settings.REDIS_URL, encoding="utf-8", decode_responses=True
)
client = redis.Redis(connection_pool=_redis_pool)
try:
yield client
finally:
await client.close()