85 lines
2.0 KiB
Python
85 lines
2.0 KiB
Python
"""数据库连接模块
|
||
|
||
使用 SQLAlchemy 异步引擎
|
||
遵循瑞小美系统技术栈标准:MySQL 8.0, utf8mb4, utf8mb4_unicode_ci
|
||
"""
|
||
|
||
from typing import AsyncGenerator
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||
from sqlalchemy.orm import DeclarativeBase
|
||
from sqlalchemy import MetaData
|
||
|
||
from app.config import settings
|
||
|
||
|
||
# 命名约定,便于数据库迁移
|
||
convention = {
|
||
"ix": "ix_%(column_0_label)s",
|
||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||
"pk": "pk_%(table_name)s"
|
||
}
|
||
|
||
metadata = MetaData(naming_convention=convention)
|
||
|
||
|
||
class Base(DeclarativeBase):
|
||
"""SQLAlchemy 模型基类"""
|
||
metadata = metadata
|
||
|
||
|
||
# 创建异步引擎
|
||
# SQLite 不支持连接池参数,需要区分处理
|
||
_engine_kwargs = {
|
||
"echo": settings.DEBUG,
|
||
}
|
||
|
||
# 仅在非 SQLite 环境下添加连接池参数
|
||
if not settings.DATABASE_URL.startswith("sqlite"):
|
||
_engine_kwargs.update({
|
||
"pool_size": settings.DB_POOL_SIZE,
|
||
"max_overflow": settings.DB_MAX_OVERFLOW,
|
||
"pool_recycle": settings.DB_POOL_RECYCLE,
|
||
"pool_pre_ping": True,
|
||
})
|
||
|
||
engine = create_async_engine(
|
||
settings.DATABASE_URL,
|
||
**_engine_kwargs,
|
||
)
|
||
|
||
# 创建异步会话工厂
|
||
async_session_maker = async_sessionmaker(
|
||
engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autocommit=False,
|
||
autoflush=False,
|
||
)
|
||
|
||
|
||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||
"""获取数据库会话依赖"""
|
||
async with async_session_maker() as session:
|
||
try:
|
||
yield session
|
||
await session.commit()
|
||
except Exception:
|
||
await session.rollback()
|
||
raise
|
||
finally:
|
||
await session.close()
|
||
|
||
|
||
async def init_db():
|
||
"""初始化数据库表"""
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
|
||
|
||
async def close_db():
|
||
"""关闭数据库连接"""
|
||
await engine.dispose()
|