ORM(Object Relational Mapping,对象关系映射),是一种程序设计技术,用于实现面向对象编程语言里不同类型系统的数据之间的转换。从效果上来说,它其实创建了一个可在编程语言里使用的“虚拟对象数据库”。
上面是维基百科的解释,但是为什么要用ORM这种编程技术呢?
就这个实战作业来看:
博客——标题、摘要、内容、评论、作者、创作时间
评论——内容、评论人、评论文章、评论时间
用户——姓名、邮箱、口令、权限
上述信息,都需要有组织的存储在数据库中。数据库方面很简单,只需要维护三张表,如果想要加强各表间的联系,可以使用外键。但是,Python该如何组织这些信息呢?每篇博客有不同标题、摘要、内容…,每篇评论和用户信息也个不相同。像C这种面向过程的语言必须创建高级数据结构,而Python这种面向对象的语言真是有天然的优势,我们只需要把每篇博客、评论或者用户看作对象,使用属性表示其蕴含的信息。最后,我们还要解决一个问题:Python和数据库如何高效有组织的交换数据呢?数据库表是由一条条记录组成的,每条记录又包含不同字段。记录和字段,对象和属性…看起来两者关系类不类似?这就是我们的思路——
将数据库表的每条记录映射为对象,每条记录的字段和对象的属性相应;同时透过对象方法执行SQL命令。
我们编写的ORM框架就是实现上述想法。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#Program:
# This is a ORM for MySQL.
#History:
#2017/06/29 smile First release
import logging
import asyncio
import aiomysql
def log(sql,args=()):
logging.info('SQL: %s' % sql)
#Close pool
async def destory_pool():
global __pool
__pool.close()
await __pool.wait_closed()
#Create connect pool
#Parameter: host,port,user,password,db,charset,autocommit
# maxsize,minsize,loop
async def create_pool(loop,**kw):
logging.info('Create database connection pool...')
global __pool
__pool = await aiomysql.create_pool(
host = kw.get('host', 'localhost'),
port = kw.get('port', 3306),
user = kw['user'],
password = kw['password'],
db = kw['db'],
charset = kw.get('charset', 'utf8'),
autocommit = kw.get('autocommit', 'True'),
maxsize = kw.get('maxsize', 10),
minsize = kw.get('minsize', 1),
loop = loop
)
#Package SELECT function that can execute SELECT command.
#Setup 1:acquire connection from connection pool.
#Setup 2:create a cursor to execute MySQL command.
#Setup 3:execute MySQL command with cursor.
#Setup 4:return query result.
async def select(sql,args,size=None):
log(sql,args)
global __pool
async with __pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(sql.replace('?','%s'),args or ())
if size:
rs = await cur.fetchmany(size)
else:
rs = await cur.fetchall()
logging.info('rows returned: %s' % len(rs))
return rs
#Package execute function that can execute INSERT,UPDATE and DELETE command
async def execute(sql,args,autocommit=True):
global __pool
#acquire connection from connection pool
async with __pool.acquire() as conn:
#如果MySQL禁止隐式提交,则标记事务开始
if not autocommit:
await conn.begin()
try:
#create cursor to execute MySQL command
async with conn.cursor(aiomysql.DictCursor) as cur:
await cur.execute(sql.replace('?','%s'),args or ())
affectrow = cur.rowcount
#如果MySQL禁止隐式提交,手动提交事务
if not autocommit:
await cur.commit()
#如果事务处理出现错误,则回退
except BaseException as e:
await conn.rollback()
raise
#return number of affected rows
return affectrow
#Create placeholder with '?'
def create_args_string(num):
L = []
for i in range(num):
L.append('?')
return ', '.join(L)
#A base class about Field
#描述字段的字段名,数据类型,键信息,默认值
class Field(object):
def __init__(self,name,column_type,primary_key,default):
self.name = name
self.column_type = column_type
self.primary_key = primary_key
self.default = default
def __str__(self):
return '<%s,%s:%s>' % (self.__class__.__name__,self.column_type,self.name)
#String Field
class StringField(Field):
def __init__(self,name=None,ddl='varchar(100)',default=None,primary_key=False):
super(StringField,self).__init__(name,ddl,primary_key,default)
#Bool Fileed
class BooleanField(Field):
def __init__(self,name=None,ddl='boolean',default=False,primary_key=False):
super(BooleanField,self).__init__(name,ddl,primary_key,default)
#Integer Field
class IntegerField(Field):
def __init__(self,name=None,ddl='bigint',default=None,primary_key=None):
super(IntegerField,self).__init__(name,ddl,primary_key,default)
#Float Field
class FloatField(Field):
def __init__(self,name=None,ddl='real',default=None,primary_key=None):
super(FloatField,self).__init__(name,ddl,primary_key,default)
#Text Field
class TextField(Field):
def __init__(self,name=None,ddl='text',default=None,primary_key=None):
super(TextField,self).__init__(name,ddl,primary_key,default)
#Meatclass about ORM
#作用:
#首先,拦截类的创建
#然后,修改类
#最后,返回修改后的类
class ModelMetaclass(type):
#采集应用元类的子类属性信息
#将采集的信息作为参数传入__new__方法
#应用__new__方法修改类
def __new__(cls,name,bases,attrs):
#不对Model类应用元类
if name == 'Model':
return type.__new__(cls,name,bases,attrs)
#获取数据库表名。若__table__为None,则取用类名
tablename = attrs.get('__table__',None) or name
logging.info('Found model: %s (table: %s)' % (name,tablename))
#存储映射表类的属性(键-值)
mappings = dict()
#存储映射表类的非主键属性(仅键)
fields = []
#主键对应字段
primarykey = None
for k,v in attrs.items():
if isinstance(v,Field):
logging.info('Found mapping: %s ==> %s' % (k,v))
mappings[k] = v
if v.primary_key:
logging.info('Found primary key')
if primarykey:
raise Exception('Duplicate primary key for field:%s' % k)
primarykey = k
else:
fields.append(k)
#如果没有主键抛出异常
if not primarykey:
raise Exception('Primary key not found')
#删除映射表类的属性,以便应用新的属性
for i in mappings.keys():
attrs.pop(i)
#使用反单引号" ` "区别MySQL保留字,提高兼容性
escaped_fields = list(map(lambda f:'`%s`' % f,fields))
#重写属性
attrs['__mappings__'] = mappings
attrs['__table__'] = tablename
attrs['__primary_key__'] = primarykey
attrs['__fields__'] = fields
attrs['__select__'] = 'SELECT `%s`, %s FROM `%s`' % (primarykey,','.join(escaped_fields),tablename)
attrs['__insert__'] = 'INSERT `%s` (%s,`%s`) VALUES (%s)' % (tablename,','.join(escaped_fields),primarykey,create_args_string(len(escaped_fields) + 1))
attrs['__update__'] = 'UPDATE `%s` SET %s WHERE `%s` = ?' % (tablename,','.join(map(lambda f:'`%s` = ?' % (mappings.get(f).name or f),fields)),primarykey)
attrs['__delete__'] = 'DELETE FROM `%s` WHERE `%s` = ?' % (tablename,primarykey)
#返回修改后的类
return type.__new__(cls,name,bases,attrs)
#A base class about Model
#继承dict类特性
#附加方法:
# 以属性形式获取值
# 拦截私设属性
class Model(dict,metaclass=ModelMetaclass):
def __init__(self,**kw):
super(Model,self).__init__(**kw)
def __getattr__(self,key):
try:
return self[key]
except KeyError:
raise AttributeError(r"'Model' object has no attribute '%s'" % key)
def __setattr__(self,key,value):
self[key] = value
def getValue(self,key):
return getattr(self,key,None)
def getValueorDefault(self,key):
value = getattr(self,key,None)
if value is None:
field = self.__mappings__[key]
if field.default is not None:
value = field.default() if callable(field.default) else field.default
logging.debug('using default value for %s: %s' % (key,str(value)))
setattr(self,key,value)
return value
#ORM框架下,每条记录作为对象返回
#@classmethod定义类方法,类对象cls便可完成某些操作
@classmethod
async def findAll(cls,where=None,args=None,**kw):
sql = [cls.__select__]
#添加WHERE子句
if where:
sql.append('WHERE')
sql.append(where)
if args is None:
args = []
orderby = kw.get('orderby',None)
#添加ORDER BY子句
if orderby:
sql.append('ORDER BY')
sql.append(orderby)
limit = kw.get('limit',None)
#添加LIMIT子句
if limit:
sql.append('LIMIT')
if isinstance(limit,int):
sql.append('?')
args.append(limit)
elif isinstance(limit,tuple):
sql.append('?, ?')
args.extend(limit)
else:
raise ValueError('Invalid limit value: %s' % str(limit))
#execute SQL
rs = await select(' '.join(sql),args)
#将每条记录作为对象返回
return [cls(**r) for r in rs]
#过滤结果数量
@classmethod
async def findNumber(cls,selectField,where=None,args=None):
sql = ['SELECT %s _num_ from `%s`' % (selectField,cls.__table__)]
#添加WHERE子句
if where:
sql.append('WHERE')
sql.append(where)
rs = await select(' '.join(sql),args)
if len(rs) == 0:
return None
return rs[0]['_num_']
#返回主键的一条记录
@classmethod
async def find(cls,pk):
rs = await select('%s WHERE `%s` = ?' % (cls.__select__,cls.__primary_key__),[pk],1)
if len(rs) == 0:
return None
return cls(**rs[0])
#INSERT command
async def save(self):
args = list(map(self.getValueorDefault,self.__fields__))
args.append(self.getValueorDefault(self.__primary_key__))
rows = await execute(self.__insert__,args)
if rows != 1:
logging.warn('Faield to insert record:affected rows: %s' % rows)
#UPDATE command
async def update(self):
args = list(map(self.getValue,self.__fields__))
args.append(self.getValue(self.__primary_key__))
rows = await execute(self.__update__,args)
if rows != 1:
logging.warn('Faield to update by primary_key:affectesd rows: %s' % rows)
#DELETE command
async def remove(self):
args = [self.getValue(self.__primary_key__)]
rows = await execute(self.__delete__,args)
if rows != 1:
logging.warn('Faield to remove by primary key:affected: %s' % rows)