编写ORM框架

124 篇文章 0 订阅

ORM:Object-Relational Mapping,把关系数据库中的表结构映射到对象上。然后操作数据库就不需要构造SQL语句,而是直接调用相应的方法。ORM框架可以方便的完成这些转换,然后,数据库表中的一行记录就对应着python中的一个对象,就不需要使用SQL语句,可以调用方法直接操作数据库。


Python中有名的ORM框架是SQLAlchemySQLAlchemy的用法:

from sqlalchemy import Column, String, create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base

# 创建对象的基类:
Base = declarative_base()

# 定义User对象,继承上面的基类。对应着数据库中的一张表
class User(Base): 
    __tablename__ = 'user' # 表的名字:

    # 表的结构:
    id = Column(String(20), primary_key=True)
    name = Column(String(20))

# 初始化数据库连接:
engine = create_engine('mysql+mysqlconnector://root:password@localhost:3306/test')
# 创建DBSession类型:
DBSession = sessionmaker(bind=engine)

往数据库中写入数据:
session = DBSession()  # 创建session对象(连接数据库)
new_user = User(id='5', name='Bob')  # 创建新User对象
session.add(new_user) # 添加到session
session.commit() # 提交即保存到数据库
session.close() # # 关闭session

查询数据:
session = DBSession() # 创建Session
user = session.query(User).filter(User.id=='5').one()  # 创建Query查询,filter是where条件,最后调用one()返回唯一行,如果调用all()则返回所有行
print('type:', type(user)) # 打印类型和对象的name属性
print('name:', user.name)
session.close() # 关闭Session

廖雪峰老师Python教程实战部分使用异步,另外写了一个ORM框架.可以参考SQLAlchemy框架使用来编写新的ORM框架。


orm.py代码总共两百多行,整体结构如下:

import asyncio, logging

import aiomysql

async def create_pool(loop, **kw):  # 创建全局的连接池,每个HTTP请求都能从连接池中直接获取数据库连接,不必频繁地打开、关闭数据库连接。
    ...

async def select(sql, args, size=None):  # 用select()函数来执行SELECT语句,需要传入SQL语句和SQL参数。
    ...

async def execute(sql, args, autocommit=True): # UPDATE,INSERT,DELETE不需要详细的查询结果,封装在一个execute()函数中。
    ...


class Field(object):  # 定义字段基类
    ...

class StringField(Field): #继承Field定义不同类型字段类
    ...
...

class ModelMetaclass(type):  # 定义一个元类。每个表(Model对象)需要不同的继承模板,这里通过元类动态创建类。
    ...

class Model(dict, metaclass=ModelMetaclass): # ORM映射的基类。
    ...

然后,完成了一个简单ORM框架,使用时参照sqlalchemy框架。

具体细节:

数据库连接池

百度百科中关于数据库连接池的解释:

数据库连接是一种关键的、有限的、昂贵的资源,这一点在多用户的网页应用程序中体现得尤为突出。
数据库连接池在初始化时将创建一定数量的数据库连接放到连接池中,这些数据库连接的数量是由最小数据库连接数制约。无论这些数据库连接是否被使用,连接池都将一直保证至少拥有这么多的连接数量。连接池的最大数据库连接数量限定了这个连接池能占有的最大连接数,当应用程序向连接池请求的连接数超过最大连接数量时,这些请求将被加入到等待队列中。
连接池基本的思想是在系统初始化的时候,将数据库连接作为对象存储在内存中,当用户需要访问数据库时,并非建立一个新的连接,而是从连接池中取出一个已建立的空闲连接对象。使用完毕后,用户也并非将连接关闭,而是将连接放回连接池中,以供下一个请求访问使用。而连接的建立、断开都由连接池自身来管理。同时,还可以通过设置连接池的参数来控制连接池中的初始连接数、连接的上下限数以及每个连接的最大使用次数、最大空闲时间等等。也可以通过其自身的管理机制来监视数据库连接的数量、使用情况等。

python中的aiomysql为MySQL提供了异步IO的驱动。
aiomysql中有create_pool()方法,这里有create_pool()的源码。前辈们准备好了工具,现在先学会使用再说。

@asyncio.coroutine  # 表明create_pool()为协程
def create_pool(loop, **kw):
    logging.info('create database connection pool...')
    global __pool   # 全局变量__pool来存储连接池。
    __pool = yield from aiomysql.create_pool(
        host=kw.get('host', 'localhost'),
        port=kw.get('port', 3306),
        user=kw['user'],  # 从参数中获取
        password=kw['password'],
        db=kw['db'],
        charset=kw.get('charset', 'utf8'),
        autocommit=kw.get('autocommit', True),  # 自动连接
        maxsize=kw.get('maxsize', 10),  # 最多10个连接对象
        minsize=kw.get('minsize', 1),  # 最少1个
        loop=loop
    )

封装SELECT方法:

查找是数据库最重要的一部分。这里写了select()来执行查找语句。

@asyncio.coroutine
def select(sql, args, size=None):  # sql指SQL语句,传递参数指定查找什么,size规定查找几条,默认None,会查找所有数据
    log(sql, args)  #记录日志
    global __pool
    with (yield from __pool) as conn:
        cur = yield from conn.cursor(aiomysql.DictCursor)  #创建游标来操作数据库。
        yield from cur.execute(sql.replace('?', '%s'), args or ())  #SQL的占位符为?MySQL占位符为%s,然后执行SQL语句。
        if size:
            rs = yield from cur.fetchmany(size)  # yield
            from,协程中调用另一个协程
        else:
            rs = yield from cur.fetchall()
        yield from cur.close()  #关闭游标
        logging.info('rows returned: %s' % len(rs))  # 再记录一下
        return rs  # 返回查找结果

Insert, Update, Delete

这三个方法,Cursor操作完数据库,不用返回详细结果,封装在了一个execute()函数中。

@asyncio.coroutine
def execute(sql, args):
    log(sql)
    with (yield from __pool) as conn:
        try:
            cur = yield from conn.cursor()
            yield from cur.execute(sql.replace('?', '%s'), args)  # 这里执行数据库操作
            affected = cur.rowcount
            yield from cur.close()
        except BaseException as e:
            raise
        return affected  # 只返回影响数据库结果数

Field

有了直接操作数据库的方法,还需要定义数据库表中对应的字段。数据库中一张表有任意行,固定列,每一列的字段类型可能不同。

首先定义Field:

class Field(object):

    def __init__(self, name, column_type, primary_key, default):
        self.name = name   # 对应着数据库表中的字段名
        self.column_type = column_type  #字段数据类型
        self.primary_key = primary_key # 是否为主键
        self.default = default # 有无默认值

    def __str__(self):  #返回对象的字符串形式
        return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)

Field子类:

class StringField(Field):

    def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
        super().__init__(name, ddl, primary_key, default) #  初始化self。

Model

开始定义所有ORM映射的基类Model

首先要有这样熟悉的功能:

>>> user['id']
123
>>> user.id
123

然后要有find(),findAll(),remove(),update(),save(),这些方便的方法。

class Model(dict, metaclass=ModelMetaclass):  # #拥有dict的功能,同时继承自元类`ModelMetaclass`动态生成Model对象。 

    def __init__(self, **kw):
        super(Model, self).__init__(**kw)

    def __getattr__(self, key): #从对象中读取某个属性
        try:
            return self[key]
        except KeyError:
            raise AttributeError(r"'Model' object has no attribute '%s'" % key)

    def __setattr__(self, key, value): #设置对象的属性
        self[key] = value

    def getValue(self, key):
        return getattr(self, key, None)

    def getValueOrDefault(self, key):  # 取默认值,定义字段类设置了默认值属性,默认值也可以是函数
        value = getattr(self, key, None)
        if value is None:
            field = self.__mappings__[key]
            if field.default is not None:
                value = field.default() if callable(field.default) else field.default
                logging.debug('using default value for %s: %s' % (key, str(value)))
                setattr(self, key, value)
        return value

## 然后,find(),findAll(),remove(),update(),save()等好记又好用的方法。
    @classmethod  #将方法变成属性
    @asyncio.coroutine  # 这些方法都要是协程
    def findAll(cls, where=None, args=None, **kw):
        ' find objects by where clause. '
        sql = [cls.__select__]  # cls for clause,每个表名都不相同,这里的__select__方法是动态生成的。
        if where:
            sql.append('where')  # 以下都是为了得到完整的SQL查询语句。
            sql.append(where)
        if args is None:
            args = []
        orderBy = kw.get('orderBy', None)
        if orderBy:
            sql.append('order by')
            sql.append(orderBy)
        limit = kw.get('limit', None)
        if limit is not None:
            sql.append('limit')
            if isinstance(limit, int):
                sql.append('?')
                args.append(limit)
            elif isinstance(limit, tuple) and len(limit) == 2:
                sql.append('?, ?')
                args.extend(limit)
            else:
                raise ValueError('Invalid limit value: %s' % str(limit))
        rs = yield from select(' '.join(sql), args)  # 调用一开始定义的select()查询记录。
        return [cls(**r) for r in rs]  # 将所有结果以列表形式返回。

    @classmethod
    @asyncio.coroutine
    def findNumber(cls, selectField, where=None, args=None):
        ' find number by select and where. '
        sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]  # 这个__table__也各不相同。
        if where:
            sql.append('where')
            sql.append(where)
        rs = yield from select(' '.join(sql), args, 1)
        if len(rs) == 0:
            return None
        return rs[0]['_num_']

    @classmethod
    @asyncio.coroutine
    def find(cls, pk):
        ' find object by primary key. '
        rs = yield from select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
        if len(rs) == 0:
            return None
        return cls(**rs[0])  #返回一个实例对象引用

    @asyncio.coroutine
    def save(self):
        args = list(map(self.getValueOrDefault, self.__fields__))  # 需要传递到SQL语句中的参数
        args.append(self.getValueOrDefault(self.__primary_key__))
        rows = yield from execute(self.__insert__, args)  # 调用上面定义的execute()方法,返回影响数
        if rows != 1:
            logging.warn('failed to insert record: affected rows: %s' % rows)

    @asyncio.coroutine
    def update(self):
        args = list(map(self.getValue, self.__fields__))
        args.append(self.getValue(self.__primary_key__))
        rows = yield from execute(self.__update__, args)
        if rows != 1:
            logging.warn('failed to update by primary key: affected rows: %s' % rows)

    @asyncio.coroutine
    def remove(self):
        args = [self.getValue(self.__primary_key__)]
        rows = yield from execute(self.__delete__, args)
        if rows != 1:
            logging.warn('failed to remove by primary key: affected rows: %s' % rows)

元类ModelMetaclass

创建一个元类让Model继承,这样,对象需要不同的继承模板。使用元类,通过继承Model就能继承ModelMetaclass,就能动态生成一个对象。

class ModelMetaclass(type):  # 类是对象的模板,元类是类的模板。type看成类工厂,制造各种类。

    def __new__(cls, name, bases, attrs):  # 当一个类指定通过莫元类来创建,会调用该元类的__new__方法。
    # cls 参数为当前准备创建类的对象 name 为类的名字, bases为继承的父类集合, attrs为类的属性/方法集合。
    # 创建User=Model(),name就是User, bases就是Model, attrs就是一个包含User类属性的dict

        if name=='Model': # Model是基类,要排除掉
            return type.__new__(cls, name, bases, attrs) # 直接返回就行
        # 获取table名称:
        tableName = attrs.get('__table__', None) or name
        logging.info('found model: %s (table: %s)' % (name, tableName))
        mappings = dict()  # 用于存储所有的字段名和字段的映射
        fields = []  # 用于存储非主键以外的其他字段,而且只存key
        primaryKey = None
        # 这里k for key, 是字段名, v for vale, 是字段实例,例如StringField
        for k, v in attrs.items():
            if isinstance(v, Field):
                logging.info('  found mapping: %s ==> %s' % (k, v))
                mappings[k] = v  # 储存到mappings字典中。
                if v.primary_key: # 创建字段会设置primary_key=True
                    # 找到主键:
                    if primaryKey:
                        raise StandardError('Duplicate primary key for field: %s' % k)
                    primaryKey = k # 上述条件成立,把这个字段名赋值给primaryKey变量。
                else:
                    fields.append(k)  # 非主键保存再fields中。
        if not primaryKey:  # 一个主键都没有,报错
            raise StandardError('Primary key not found.')
        for k in mappings.keys():
            attrs.pop(k)  # 去除掉不需要的字段名,返回下面的属性。
        escaped_fields = list(map(lambda f: '`%s`' % f, fields))

        #通过attrs返回的东西,子类中都能通过实例获取,例如self.__table__
        # 这样,任何继承自Model的类(比如User),会自动通过ModelMetaclass扫描映射关系,并存储到自身的类属性如__table__、__mappings__中。
        attrs['__mappings__'] = mappings # 保存属性和列的映射关系
        attrs['__table__'] = tableName
        attrs['__primary_key__'] = primaryKey # 主键属性名
        attrs['__fields__'] = fields # 除主键外的属性名
        # 在这里定义这些属性,Model看起来更简单些
        attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
        attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
        attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
        attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
        return type.__new__(cls, name, bases, attrs)

以上,完成了一个简单的ORM 框架。
廖雪峰老师教程中的源码:orm.py


这篇文章很详细:跟着廖大学python之orm框架实现

深刻理解Python中的元类(metaclass)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值