fastapi-sqlaclchemy2.0-async

1.alembic迁移工具

  • 配置异步连接数据库,修改env文件
import asyncio
from logging.config import fileConfig

from sqlalchemy import engine_from_config
from sqlalchemy import pool

from alembic import context
from sqlalchemy.ext.asyncio import async_engine_from_config, AsyncEngine

from database.configuration import Base
from models import address_mapping
from models import users


from core.setting import *
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config

# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
    fileConfig(config.config_file_name)

# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata


# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.


def run_migrations_offline():
    """Run migrations in 'offline' mode.

    This configures the context with just a URL
    and not an Engine, though an Engine is acceptable
    here as well.  By skipping the Engine creation
    we don't even need a DBAPI to be available.

    Calls to context.execute() here emit the given string to the
    script output.

    """
    url = f'mysql+asyncmy://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DB}?charset=utf8'
    context.configure(
        url=url,
        target_metadata=target_metadata,
        literal_binds=True,
        dialect_opts={"paramstyle": "named"},
    )

    with context.begin_transaction():
        context.run_migrations()


def do_run_migrations(connection):
    context.configure(connection=connection, target_metadata=target_metadata)

    with context.begin_transaction():
        context.run_migrations()


async def run_migrations_online():
    """Run migrations in 'online' mode.

    In this scenario we need to create an Engine
    and associate a connection with the context.

    """
    # varies between live and test migrations
    DATABASE_URL = f'mysql+asyncmy://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DB}?charset=utf8'

    connectable = context.config.attributes.get("connection", None)
    config.set_main_option("sqlalchemy.url", DATABASE_URL)
    if connectable is None:
        connectable = AsyncEngine(
            engine_from_config(
                context.config.get_section(context.config.config_ini_section),
                prefix="sqlalchemy.",
                poolclass=pool.NullPool,
                # future=True
            )
        )

    async with connectable.connect() as connection:
        await connection.run_sync(do_run_migrations)

    await connectable.dispose()


if context.is_offline_mode():
    run_migrations_offline()
else:
    asyncio.run(run_migrations_online())

  • 迁移修改
alembic revision --autogenerate -m "alembic mapping table init"

alembic upgrade head

2.数据库配置

DATABASE_ASYNC_URL = f'mysql+asyncmy://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}/{MYSQL_DB}'
print(DATABASE_ASYNC_URL)
async_engine = create_async_engine(DATABASE_ASYNC_URL, echo=True)


class Base(AsyncAttrs, DeclarativeBase):
    pass


async_session = async_sessionmaker(
    async_engine, class_=AsyncSession, expire_on_commit=False
)


async def get_session() -> AsyncSession:
    async with async_session() as session:
        yield session

3.增删改查语句

db: AsyncSession = Depends(get_session)

async def add(user_id: int, report_name: str, db: AsyncSession):
    new_mapping = DashVerify(user_id=user_id, report_name=report_name)
    db.add(new_mapping)

    try:
        await db.commit()
        return new_mapping
    except IntegrityError as ex:
        return None


async def read(user_id: int, db: AsyncSession):
    sql = select(DashVerify).where(DashVerify.user_id == user_id).order_by(DashVerify.id)
    result = await db.execute(sql)
    mapping = result.scalars().all()

    result = []
    for i in mapping:
        result.append(i)
    return result


async def update(verify_id: int, report_name: str, db: AsyncSession):
    sql = update(DashVerify).where(DashVerify.id == verify_id).values(report_name=report_name)
    try:
        await db.execute(sql)
        return {"message": "修改成功"}

    except IntegrityError as ex:
        # await session.rollback()
        return None


async def delete(verify_id: int, db: AsyncSession):
    sql = select(DashVerify).where(DashVerify.id == verify_id)
    result = await db.execute(sql)
    mapping = result.scalars().one()

    if mapping is None:
        return "未查找到"

    sql = delete(DashVerify).where(DashVerify.id == verify_id)
    await db.execute(sql)
    return "删除成功"

4.树形结构: 数据库配置与增删改查

#model.py
class MappingTree(Base):
    __tablename__ = 'mapping_tree'

    id = Column(Integer, primary_key=True, index=True)
    parent_id = Column(Integer, ForeignKey('mapping_tree.id'))
    report_name = Column(String(200), index=True)
    url = Column(String(255), index=True, nullable=True)
    children = relationship("MappingTree", backref="parent", lazy="selectin", remote_side=[id])



#service.py
async def create_node(session, parent_id, report_name: str, url: str):
    new_node = MappingTree(parent_id=parent_id, report_name=report_name, url=url)
    session.add(new_node)
    await session.commit()
    return new_node.id


async def get_tree_name(session, report_name: str):
    result = await session.execute(
        select(MappingTree).where(MappingTree.report_name == report_name)
    )
    node = result.scalars().one()
    return node


async def get_tree_parent(session, parent_id):
    result = await session.execute(
        select(MappingTree).options(joinedload(MappingTree.children)).where(MappingTree.parent_id == parent_id)
    )
    nodes = result.unique().scalars().all()
    return nodes


async def find_node_by_name(session: AsyncSession, report_name: str):
    result = await session.execute(
        select(MappingTree).where(MappingTree.report_name == report_name)
    )
    node = result.scalar()  # 获取单个节点对象
    if node:
        return node.id, node.parent_id
    return None  # 如果没有找到匹配的节点


async def update_node(session, report_name, new_report_name, new_url):
    result = await session.execute(
        select(MappingTree).where(MappingTree.report_name == report_name)
    )
    node = result.scalars().one_or_none()
    if node:
        node.report_name = new_report_name
        node.url = new_url
        await session.commit()
        return True
    return False


async def delete_node_and_children(session, node_id: int):
    # 递归删除所有子记录
    children = await session.execute(
        select(MappingTree).where(MappingTree.__table__.c.parent_id == node_id)
    )

    for child in children:
        child = child[0]
        print(child)
        print(child.report_name)
        print(child.id)
        await delete_node_and_children(session, child.id)

    # 删除当前节点
    await session.execute(delete(MappingTree).where(MappingTree.__table__.c.id == node_id))
    await session.commit()
    return True


#获取树形结构以树形json的格式返回
async def build_tree_data(nodes, db, contain_id):
    tree_data = []
    if nodes is not None:
        for node in nodes:
            if node.url:
                node_data = {
                    "name": node.report_name,
                    "url": node.url
                }
            else:
                node_data = {
                    "name": node.report_name,
                }
            if contain_id:
                node_data["id"] = node.id
            node_children = await get_tree_parent(db, node.id)
            if node_children:
                node_data["children"] = await build_tree_data(node_children, db, contain_id)
            tree_data.append(node_data)

    return tree_data
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值