orm,即Object Relational Mapping,全称对象关系映射。简单来说就是将数据库中的表,字段,行,与我们的面向对象编程中的类,类的属性,以及对象建立一一对应的映射关系,这样我们便可以避免直接操作数据库,而只要调用相应的方法就可以了。
1.创建数据库连接池
#异步协程:创建数据库连接池
@asyncio.coroutine
def create_pool(loop,**kw):
logging.info('start creating database connection pool')
global __pool
#yield from 调用协程函数并返回结果
__pool = yield from aiomysql.create_pool(
#kw.get(key,default):通过key在kw中查找对应的value,如果没有则返回默认值default
host = kw.get('host','localhost'),
port = kw.get('port',3306),
user = kw['user'],
password = kw['password'],
db = kw['db'],
charset = kw.get('charset','utf8'),
autocommit = kw.get('autocommit',True),
maxsize = kw.get('maxsize',10),
minsize = kw.get('minsize',1),
loop = loop
)
2.封装select方法
@asyncio.coroutine
def select(sql,args,size=None):
log(sql,args)
global __pool
#yield from从连接池返回一个连接
with (yield from __pool) as conn:
cur = yield from conn.cursor(aiomysql.DictCursor)
#执行sql语句前,先将sql语句中的占位符?换成mysql中采用的占位符%s
yield from cur.execute(sql.replace('?','%s'),args)
#size表示要返回的结果数,若为None则返回全部查询结果
if size:
rs = yield from cur.fetchmany(size)
else:
rs = yield from cur.fetchall()
yield from cur.close()
logging.info('%s rows have returned' % len(rs))
return rs
3.封装execute方法(update,insert,delete语句操作参数一样,返回影响的行数,直接封装成一个通用的执行函数)
async def execute(sql, args, autocommit=True):
log(sql)
async with __pool.get() as conn:
if not autocommit:
await conn.begin()
try:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(sql.replace('?', '%s'), args)
affected = cur.rowcount
if not autocommit:
await conn.commit()
except BaseException as e:
if not autocommit:
await conn.rollback()
raise
return affected
4.字段类的实现
class Field(object):
def __init__(self, name, column_type, primary_key, default):
self.name = name
self.column_type = column_type
self.primary_key = primary_key
self.default = default
def __str__(self):
return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)
class StringField(Field):
def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
super().__init__(name, ddl, primary_key, default)
实现的具体类User、Blog等里面,主要就是属性名与字段类之间的对应关系,具体的属性值还需要使用dict实现
5.元类的使用
学了廖雪峰老师的教程,觉得在编写orm过程中比较难理解的还是元类的使用:
类是对象的模板,元类是类的模板。我们的User类继承自Model类,而Model类的模板是元类ModelMetaClass,所以当使用者实例化一个User对象的时候,User会根据Model去创建,而Model则根据ModelMetaClass动态创建,所以user对象间接的根据ModelMetaClass创建。
#定义Model的metaclass元类
#所有的元类都继承自type
#ModelMetaclass元类定义了所有Model基类(继承ModelMetaclass)的子类实现的操作
# -*-ModelMetaclass:为一个数据库表映射成一个封装的类做准备
# 读取具体子类(eg:user)的映射信息
#创造类的时候,排除对Model类的修改
#在当前类中查找所有的类属性(attrs),如果找到Field属性,就保存在__mappings__的dict里,
#同时从类属性中删除Field(防止实例属性覆盖类的同名属性)
#__table__保存数据库表名
class ModelMetaClass(type):
# 元类必须实现__new__方法,当一个类指定通过某元类来创建,那么就会调用该元类的__new__方法
# 该方法接收4个参数
# cls为当前准备创建的类的对象
# name为类的名字,创建User类,则name便是User
# bases类继承的父类集合,创建User类,则base便是Model
# attrs为类的属性/方法集合,创建User类,则attrs便是一个包含User类属性的dict
def __new__(cls,name,bases,attrs):
# 因为Model类是基类,所以排除掉,如果你print(name)的话,会依次打印出Model,User,Blog,即
# 所有的Model子类,因为这些子类通过Model间接继承元类
if name=="Model":
return type.__new__(cls,name,bases,attrs)
# 取出表名,默认与类的名字相同
tableName=attrs.get('__table__',None) or name
logging.info('found model: %s (table: %s)' % (name, tableName))
# 用于存储所有的字段,以及字段值
mappings=dict()
# 仅用来存储非主键意外的其它字段,而且只存key
fields=[]
# 仅保存主键的key
primaryKey=None
# 注意这里attrs的key是字段名,value是字段实例,不是字段的具体值
# 比如User类的id=StringField(...) 这个value就是这个StringField的一个实例,而不是实例化
# 的时候传进去的具体id值
for k,v in attrs.items():
# attrs同时还会拿到一些其它系统提供的类属性,我们只处理自定义的类属性,所以判断一下
# isinstance 方法用于判断v是否是一个Field
if isinstance(v,Field):
mappings[k]=v
if v.primary_key:
if primaryKey:
raise RuntimeError("Douplicate primary key for field :%s" % key)
primaryKey=k
else:
fields.append(k)
# 保证了必须有一个主键
if not primaryKey:
raise RuntimeError("Primary key not found")
# 记录到了mappings,fields,等变量里,而我们实例化的时候,如
# user=User(id='10001') ,为了防止这个实例变量与类属性冲突,所以将其去掉
for k in mappings.keys():
attrs.pop(k)
#保存非主键属性为字符串列表形式
#将非主键属性变成`id`,`name`这种形式(带反引号)
#repr函数和反引号:取得对象的规范字符串表示
escaped_fields = list(map(lambda f:'`%s`' %f,fields))
# 以下都是要返回的东西了,刚刚记录下的东西,如果不返回给这个类,又谈得上什么动态创建呢?
# 到此,动态创建便比较清晰了,各个子类根据自己的字段名不同,动态创建了自己
# 下面通过attrs返回的东西,在子类里都能通过实例拿到,如self
attrs['__mappings__']=mappings
attrs['__table__']=tableName
attrs['__primaryKey__']=primaryKey
attrs['__fields__']=fields
# 只是为了Model编写方便,放在元类里和放在Model里都可以
attrs['__select__']="select %s ,%s from %s " % (primaryKey,','.join(map(lambda f: '%s' % (mappings.get(f).name or f ),fields )),tableName)
attrs['__update__']="update %s set %s where %s=?" % (tableName,', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)),primaryKey)
attrs['__insert__']="insert into %s (%s,%s) values (%s);" % (tableName,primaryKey,','.join(map(lambda f: '%s' % (mappings.get(f).name or f),fields)),create_args_string(len(fields)+1))
attrs['__delete__']="delete from %s where %s= ? ;" % (tableName,primaryKey)
return type.__new__(cls,name,bases,attrs)
6.定义Models类
#定义ORM所有映射的基类:Model
#Model类的任意子类可以映射一个数据库表
#Model类可以看做是对所有数据库表操作的基本定义的映射
#基于字典查询形式
#Model从dict继承,拥有字典的所有功能,同时实现特殊方法__getattr__和__setattr__,能够实现属性操作
#实现数据库操作的所有方法,定义为class方法,所有继承自Model都具有数据库操作方法
class Model(dict, metaclass=ModelMetaclass):
def __init__(self, **kw):
super(Model, self).__init__(**kw)
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Model' object has no attribute '%s'" % key)
def __setattr__(self, key, value):
self[key] = value
def getValue(self, key):
return getattr(self, key, None)
def getValueOrDefault(self, key):
value = getattr(self, key, None)
if value is None:
field = self.__mappings__[key]
if field.default is not None:
value = field.default() if callable(field.default) else field.default
logging.debug('using default value for %s: %s' % (key, str(value)))
setattr(self, key, value)
return value
@classmethod
async def findAll(cls, where=None, args=None, **kw):
' find objects by where clause. '
sql = [cls.__select__]
if where:
sql.append('where')
sql.append(where)
if args is None:
args = []
orderBy = kw.get('orderBy', None)
if orderBy:
sql.append('order by')
sql.append(orderBy)
limit = kw.get('limit', None)
if limit is not None:
sql.append('limit')
if isinstance(limit, int):
sql.append('?')
args.append(limit)
elif isinstance(limit, tuple) and len(limit) == 2:
sql.append('?, ?')
args.extend(limit)
else:
raise ValueError('Invalid limit value: %s' % str(limit))
rs = await select(' '.join(sql), args)
return [cls(**r) for r in rs]
@classmethod
async def findNumber(cls, selectField, where=None, args=None):
' find number by select and where. '
sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]
if where:
sql.append('where')
sql.append(where)
rs = await select(' '.join(sql), args, 1)
if len(rs) == 0:
return None
return rs[0]['_num_']
@classmethod
async def find(cls, pk):
' find object by primary key. '
rs = await select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)
if len(rs) == 0:
return None
return cls(**rs[0])
async def save(self):
args = list(map(self.getValueOrDefault, self.__fields__))
args.append(self.getValueOrDefault(self.__primary_key__))
rows = await execute(self.__insert__, args)
if rows != 1:
logging.warn('failed to insert record: affected rows: %s' % rows)
async def update(self):
args = list(map(self.getValue, self.__fields__))
args.append(self.getValue(self.__primary_key__))
rows = await execute(self.__update__, args)
if rows != 1:
logging.warn('failed to update by primary key: affected rows: %s' % rows)
async def remove(self):
args = [self.getValue(self.__primary_key__)]
rows = await execute(self.__delete__, args)
if rows != 1:
logging.warn('failed to remove by primary key: affected rows: %s' % rows)