异步sqlalchemy ORM session使用总结

声明Base和提供异步session

声明Base

Base = declarative_base()

模型类需要集成该Base, 建议所有模型类都统一集成同一个Base, 这样在对模型类的创建和修改统一管理。

sqlalchemy 使用异步ORM, 需要使用到异步的session:

提供异步session

通过装饰器提供异步session, 这样就不需要在操作数据库的方法中每次实例化一个异步session, 需要的地方装饰一下就行了。

database.py:

import contextlib
from typing import Callable
from asyncio import current_task
from functools import wraps

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base

# 将连接数据库的URI 写在配置文件中读取
from dqc import SQLALCHEMY_URI

# 所有模型都要统一继承该Base
Base = declarative_base()


class DatabaseManager:
    """
    连接元数据库的类,在__init__中进行初始化
    """
    def __init__(self):
        self._engine = create_async_engine(
        SQLALCHEMY_URI, echo=False,
    )
        self.session = None
        self.initialize()

    def initialize(self, scope_function: Callable = None):
        """
        Configure class for creating scoped sessions.
        """
        # Create the session factory classes
        async_session_factory  = sessionmaker(
            self._engine, expire_on_commit=False, class_=AsyncSession
        )
        self.session = async_scoped_session(async_session_factory, scopefunc=current_task)

    def cleanup(self):
        """
        Cleans up the database connection pool.
        """
        if self._engine is not None:
            self._engine.dispose()


@contextlib.asynccontextmanager
async def create_session():
    """
    Contextmanager that will create and teardown a session.
    """
    db = DatabaseManager()
    session = db.session
    try:
        yield session
        await session.commit()
    except Exception:
        await session.rollback()
        raise
    finally:
        await session.close()

def provide_session(func):
    """
    Function decorator that provides a session if it isn't provided.
    If you want to reuse a session or run the function as part of a
    database transaction, you pass it to the function, if not this wrapper
    will create one and close it for you.
    """
    @wraps(func)
    async def wrapper(*args, **kwargs):
        arg_session = 'session'

        func_params = func.__code__.co_varnames
        session_in_args = arg_session in func_params and \
            func_params.index(arg_session) < len(args)
        session_in_kwargs = arg_session in kwargs

        if session_in_kwargs or session_in_args:
            return await func(*args, **kwargs)
        else:
            async with create_session() as session:
                kwargs[arg_session] = session
                return await func(*args, **kwargs)

    return wrapper

使用异步session案例table:

一个规则模板对应多个规则,三张表:规则模板表,规则表,关联关系表

规则模板表

class RuleTemplate(Base):
    __tablename__ = 'd_rule_template'

    id = Column(Integer, primary_key=True, comment='规则模板id')
    template_name = Column(String(250), nullable=False, comment='规则模板名称')
    rules = relationship('RuleInfo', lazy='joined', secondary=template_rules, backref='template')
    deletion = Column(Integer, nullable=False, comment='是否删除')
    
    @classmethod
    @provide_session
    async def add_template(cls, session=None, **kwargs):
        pass
    
    @classmethod
    @provide_session
    async def delete_template(cls, id, session=None):
        pass
    
    @classmethod
    @provide_session
    async def get_rule_template_by_id(cls, id, session=None):
        pass
    
    @classmethod
    @provide_session
    async def update_template(cls, session=None, **kwargs):
        pass

规则表

from database import Base, provide_session


class RuleInfo(Base):
    __tablename__ = 'd_rule_info'
    id = Column(Integer, primary_key=True)
    rule_name = Column(String(250), comment='规则名称')
    rule_type = Column(String(100), nullable=True, comment="规则类型")
    deletion = Column(Integer, nullable=False, comment='是否删除')
   	
    @classmethod
    @provide_session
    async def add_rule(cls, session=None, **kwargs):
        pass
    
    @classmethod
    @provide_session
    async def delete_rule(cls, session=None, **kwargs):
        pass
    
    @classmethod
    @provide_session
    async def get_rule_by_id(cls, id, session=None):
        pass
    
    @classmethod
    @provide_session
    async def update_rule_info(cls, session=None, **kwargs):
        pass

关联关系表

template_rules = Table(
    'd_template_rule_relation',
    Base.metadata,
    Column('id', Integer, primary_key=True),
    Column('rule_template_id', Integer, ForeignKey('d_rule_template.id')),
    Column('rule_info_id', Integer, ForeignKey('d_rule_info.id')),
    UniqueConstraint('rule_template_id', 'rule_info_id')
)

增加

	@classmethod
    @provide_session
    async def add_template(cls, session=None, **kwargs):
        """
        add rule template
        :param template_name: 模板名称
        :return: rule_template.id
        """
        async with session() as session:
            rule_template = RuleTemplate()
            session.add(rule_template)
            for k, v in kwargs.items():
                if v is None:
                    continue
                setattr(rule_template, k, v)
            await session.commit()
        return rule_template.id
    
    
    @classmethod
    @provide_session
    async def add_rule(cls, session=None, **kwargs):
        """
        add rule
        :param rule_name: 规则名称
        :return: rule.id
        """
        async with session() as session:
            rule = RuleInfo()
            session.add(rule)
            for k, v in kwargs.items():
                if v is None:
                    continue
                setattr(rule, k, v)
            await session.commit()
        return rule.id

删除

    @classmethod
    @provide_session
    async def delete_template(cls, id, session=None):
        """
        软删除模板
        :param id: 模板id
        """
        async with session() as session:
            await session.execute(
                update(RuleTemplate).where(RuleTemplate.id == id).values(deletion=1)
            )
            # 直接删除数据
            # await session.execute(
            #    delete(RuleTemplate).where(RuleTemplate.id == id)
            # )
            await session.commit()
            
            
    @classmethod
    @provide_session
    async def delete_rule(cls, session=None, **kwargs):
        """
        删除规则
        :param id: 规则id
        :return:
        """
        id = kwargs.get("id")

        def fetch_and_update_objects(session2):
            # 这里使用闭包函数来写同步方法
            result = session2.execute(select(RuleInfo).where(RuleInfo.id == id))
            for rule in result.scalars():
                rule.deletion = 1
                try:
                    # 如果该规则引用了规则模板,删除规则的同时需要删除关联表中的数据
                    rule.template[0].rules.remove(rule)
                except IndexError:
                    pass

        async with session() as session2:
            # 删除后将模板与规则桥表删除对应数据
            # 通过rule.template 来获取规则关联的模板对象为同步代码,在异步session 中执行同步代码
            # 需要使用session.run_sync(fetch_and_update_objects) 方法
            # fetch_and_update_objects 为同步代码方法名
            await session2.run_sync(fetch_and_update_objects)
            await session2.commit()

查询

	@classmethod
	@provide_session
	async def get_rule_template_by_id(cls, id, session=None):
    	"""根据模板id查询模板对象"""
    	results = await session.execute(select(RuleTemplate).where(RuleTemplate.id == id))
    	data = results.scalars().first()
    	return data

    @classmethod
    @provide_session
    async def get_rule_by_id(cls, id, session=None):
        """根据规则id查询规则对象"""
        rule_info = None
        def get_template(session2):
            nonlocal rule_info
            results = session2.execute(select(RuleInfo).where(RuleInfo.id == id))
            rows = results.fetchall()
            for row in rows:
                # 由于rule_template 表rules = relationship('RuleInfo', lazy='joined', secondary=template_rules, backref='template')
                # 为了提高查询效率这里关联查询的关系为lazy='joined', 
                # 会导致查询规则对象时不会主动将该规则绑定的模板对象加载出来,
                # 需要通过使用同步代码rule.template 主动加载模板对象,
                # 否在在session 结束后,获取的规则对象将没有模板对象信息
                rule = row.RuleInfo
                template = rule.template
                rule_info = rule
                rule_info.template = template

        async with session() as session2:
            # 加载rule.template 属于同步代码,需要使用session2.run_sync() 方法
            await session2.run_sync(get_template)
        return rule_info

join关联查询

增加一张模型表:规则校验结果表

class CheckResult(Base):
    __tablename__ = 'd_check_result'
    id = Column(Integer, primary_key=True)
    rule_id = Column(Integer, nullable=False, comment='规则id')
    plan_execution_date = Column(Integer, nullable=False, comment='计划执行时间,时间戳格式')
    real_execution_date = Column(Integer, nullable=True, comment='实际执行时间,时间戳格式')
    time_duration = Column(Integer, nullable=False, comment='执行时长')
    check_result = Column(Integer, nullable=True, comment="校验结果,1 通过 2异常 3等待结果")
@classmethod
@provide_session
async def get_rule_result_list(cls, session=None, **kwargs):
    	"""	
        获取结果列表信息
        :param page_num: 页码
        :param page_size: 页面大下
        :param query: 模糊查询条件
        :param plan_execution_date: 过滤条件:计划执行时间
        :param check_result: 过滤条件:校验结果类型
        :return:
        """
        # 获取模糊查询条件
        query = kwargs.get("query")
        # 获取准确过滤条件
        plan_execution_date = kwargs.get('plan_execution_date')
        check_results = kwargs.get('check_results')
        # 分页条件
        page_num = kwargs.get("page_num", 1)
        page_size = kwargs.get("page_size", 40)
        
        # and 拼接准确查询SQL ORM
        base_filter = and_(
            CheckResult.plan_execution_date <= plan_execution_date if plan_execution_date else True,
            CheckResult.check_result.in_(check_results) if state else True,
            RuleInfo.deletion != 1
        )
        # or 拼接模糊查询SQL ORM
        query_rule = or_(
                RuleInfo.rule_name.like('%{}%'.format(query)),
                RuleInfo.rule_type.like('%{}%'.format(query)),
            )
        
        async with session() as session:
        	# func.count() 查询总数
            total = await session.execute(select([func.count()]).select_from(CheckResult, RuleInfo).outerjoin_from(CheckResult, RuleInfo, RuleInfo.id == CheckResult.rule_id).filter(base_filter, query_rule))
            
            total = total.scalar()
            
            # 左连接关联查询
            base_join_select = select(
                CheckResult.rule_id.label('rule_id'),
                CheckResult.check_result.label('check_result')
            ).outerjoin_from(CheckResult, RuleInfo, RuleInfo.id == CheckResult.rule_id)\
            .outerjoin_from(RuleInfo, DataSourceType, RuleInfo.data_source_type_id == DataSourceType.id)\
            .filter(base_filter, query_rule)
            
            # 聚合分组查询不同校验结果数量
            # 子查询, 需要通过sub.c 获取父查询中的label
            sub = base_join_select.subquery()
            group_by_select = select(sub.c.check_result, func.count(sub.c.check_result))\
                .group_by(sub.c.check_result)
            
            state_count_results = await session.execute(group_by_select)
            # join关联查询需要使用fetchall()方法获取所有查询内容
            state_count_rows = state_count_results.fetchall()
            state_info = {}
            for row in state_count_rows:
                state_info[row[0]] = row[1]
                
            # join关联并分页查询
            join_pagination_select = select(CheckResult.check_result,
                                            RuleInfo.id,
                                            RuleInfo.rule_name
                                           )\
            .outerjoin_from(CheckResult, RuleInfo, RuleInfo.id == CheckResult.rule_id)
            .filter(base_filter, query_rule)\
            .limit(page_size).offset((int(page_num) - 1) * page_size )
            
            results = await session.execute(join_pagination_select)
            all_rows = results.fetchall()
            list_info = []
            for row in all_rows:
                result_info = {}
                result_info['check_result'] = row[0]
                result_info['id'] = row[1]
                result_info['rule_name'] = row[2]
                list_info.append(result_info)
		return total, list_info, state_info

修改

    @classmethod
    @provide_session
    async def update_template(cls, session=None, **kwargs):
        """编辑规则模板"""
        id = kwargs.get('id')
        template_name = kwargs.get('template_name')

        async with session() as session:
            # 查询模板对象使用的与接下来修改使用的是同一个session 中
            # 保证一致性
            rule_template = await cls.get_rule_template_by_id(id, session=session)
            for k, v in kwargs.items():
                if v is None or k == 'id':
                    continue
                setattr(rule_template, k, v)
            await session.commit()
  • 7
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一切如来心秘密

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值