Python廖雪峰实战web开发(Day3-编写ORM)

在一个Web App中,所有数据,包括用户信息、发布的日志、评论等,都存储在数据库中。
由于Web框架使用了基于asyncioaiohttp,这是基于协程的异步模型。在协程中,不能调用普通的同步IO操作,因为所有用户都是由一个线程服务的,协程的执行速度必须非常快,才能处理大量用户的请求。而耗时的IO操作不能在协程中以同步的方式调用,否则,等待一个IO操作时,系统无法响应任何其他用户。
这就是异步编程的一个原则:一旦决定使用异步,则系统每一层都必须是异步

编写ORM模块

ORMObject-Relational Mapping),作用是:把关系数据库的表结构映射到对象上。

这次使用的是MySQL作为数据库。并且运用aiomysqlMySQL数据库提供了异步IO的驱动。建立一个web访问的ORM,使每一个web请求被连接之后都要接入数据库进行操作。

1. 创建思路分析

1.1 参考网站:

关于MySQL教程
关于MySQL官方文档
关于创建数据库连接池
关于元类教程
关于元类解释
关于sql语句

1.2 自我思路整理

记得在廖雪峰网站刚学习元类的时候,那边就有讲到ORM是一个典型的,需要通过metaclass修改类定义的一个例子。那时候完全不理解元类的意义及用处,直到刚才从头敲一遍编写ORM的代码,才感觉,好像,有点,理解元类这块硬骨头了。
在制作ORM模块,要使用元类的原因,本人认为是:可以使用OOP编程(Obejct Oriented Programming)

比方说设计一个ORM框架,使用者如果使用这个ORM框架,想定义一个User类来操作对应的数据库表User,我们期待他写出这样的代码:

# 创建实例:
user = User(id=123, name='Michael')
# 存入数据库:
user.insert()
# 查询所有User对象:
users = User.findAll()

仔细想想其中它们究竟干了些什么。大概就是:收集数据;对这些数据进行分类,识别(相对应数据库),生成SQL语句;最后,连接数据库,并执行SQL语句进行操作。
嗯,这样,貌似就可以使用OOP编程,把问题拆分开来了。
比方说:

  • 这刚定义的User类负责收集数据,并尝试归类出这些数据对应数据库表的映射关系,类如对应表的字段(包含名字、类型、是否为表的主键、默认值)等;
  • 它的基类负责执行操作,比如数据库的存储、读取,查找等操作;
  • 它的元类负责分类、整理收集的数据并以此创建一些类属性(如SQL语句)供基类作为参数。

这么想来,在创建User类之前,最好就得先封装数据库表中的每一列的属性(包含名字、类型、是否为表的主键、默认值),以便调用,这里的做法是定义一个Field类来保存每一列的属性。
而且,最好也把操作数据库的操作函数(SELECT、INSERT、UPDATE、DELETE)封装起来,以便调用。
至此,还需要创建一个
全局的连接池
,每个HTTP请求都可以从连接池中直接获取数据库连接。使用连接池的好处是不必频繁地打开和关闭数据库连接,而是能复用就尽量复用。

2. 代码编写

以下是代码:

#!usr/bin/env python3
# -*- coding: utf-8 -*-

__author__="Seiei"

'''
编写orm模块
'''

import asyncio
import logging
import aiomysql

def log(sql):
    logging.info("SQL: %s"%(sql))

#创建一个全局的连接池,每个HTTP请求都从池中获得数据库连接
#连接池由全局变量__pool存储,缺省情况下将编码设置为utf8,自动提交事务
async def create_pool(loop,**kw):#charset参数是utf8
    logging.info('create database connectiong pool...')
    global __pool #全局变量
    __pool = await aiomysql.create_pool(
        host = kw.get('host','localhost'),
        port = kw.get('port',3306),
        user = kw['user'],
        db = kw['db'],
        password = kw['password'],
        charset = kw.get('charset','utf8'),
        autocommit = kw.get('autocommit',True),
        maxsize = kw.get('maxsize',10),
        minsize = kw.get('minsize',1),
        loop=loop
    )#创建连接所需要的参数

#用于输出元类中创建sql_insert语句中的占位符
def create_args_string(num):
    L=[]
    for x in range(num):
        L.append('?')
    return ','.join(L)

#单独封装select,其他insert,update,delete一并封装,理由如下:
#使用Cursor对象执行insert,update,delete语句时,执行结果由rowcount返回影响的行数,就可以拿到执行结果。
#使用Cursor对象执行select语句时,通过featchall()可以拿到结果集。结果集是一个list,每个元素都是一个tuple,对应一行记录。
async def select(sql,args,size=None):
    log(sql)
    global __pool #引入全局变量
    with await __pool as conn: #打开pool的方法,或-->async with __pool.get() as conn:
        cur = await conn.cursor(aiomysql.DictCursor) #创建游标,aiomysql.DictCursor的作用使生成结果是一个dict
        await cur.execute(sql.replace('?',"%s"),args or ()) #执行sql语句,sql语句的占位符是'?',而Mysql的占位符是'%s'
        if size:
            rs = await cur.fetchmany(size)
        else:
            rs = await cur.fetchall()
    await cur.close()
    logging.info('rows returned: %s'%len(rs))
    return rs

#封装INSTERT,UPDATE,DELETE,我一开始的做法
#async def execute(sql,args):
#    log(sql)
#    global __pool
#    with await __pool as conn:
#        try:
#            cur = await conn.cursor()
#            await cur.execute(sql.replace('?','%s'),args)
#            affectline = cur.rowcount
#            await cur.close()
#        except BaseException as e:
#            raise
#        return affectline

#封装INSTERT,UPDATE,DELETE,老师的做法
async def execute(sql, args, autocommit=True):
    log(sql)
    with await __pool as conn:
        if not autocommit:
            await conn.begin()
        try:
            async with conn.cursor(aiomysql.DictCursor) as cur:
                await cur.execute(sql.replace('?', '%s'), args)
                affected = cur.rowcount
            if not autocommit:
                await conn.commit()
        except BaseException as e:
            if not autocommit:
                await conn.rollback()
            raise
        return affected


#定义Field
class Field(object):
    def __init__(self,name,colum_type,primary_key,default):
        self.name = name
        self.colum_type = colum_type
        self.primary_key = primary_key
        self.default = default
    def __str__(self):
        return '<%s,%s:%s>'%(self.__class__.__name__,self.colum_type,self.name)

class StringField(Field):
    def __init__(self,name=None,ddl='varchar(100)',primary_key=False,default=None):
        super(StringField,self).__init__(name,ddl,primary_key,default)

class BooleanField(Field):
    def __init__(self,name=None,ddl='boolean',primary_key=False,default=None):
        super(BooleanField,self).__init__(name,ddl,primary_key,default)

class IntegerField(Field):
    def __init__(self,name=None,ddl='bigint',primary_key=False,default=0):
        super(IntegerField,self).__init__(name,ddl,primary_key,default)

class FloatField(Field):
    def __init__(self,name=None,ddl='real',primary_key=False,default=0.0):
        super(FloatField,self).__init__(name,ddl,primary_key,default)

class TextField(Field):
    def __init__(self,name=None,ddl='Text',primary_key=False,default=None):
        super(TextField,self).__init__(name,ddl,primary_key,default)

#元类
class ModelMetaclass(type):

    def __new__(cls,name,bases,attrs):#当前准备创建的类的对象;类的名字;类继承的父类集合;类的方法集合。
        if name == 'Model': #排除掉对Model类的修改;
            return type.__new__(cls,name,bases,attrs)
        tableName = attrs.get('__table__',None) or name
        logging.info('found a model: %s (table: %s)'%(name,tableName))
        # 获取所有的Field和主键名:
        mappings = dict() #保存映射关系
        fields = [] #保存除主键外的属性
        primarykey = None
        for k,v in attrs.items():
            if isinstance(v,Field):
                logging.info('Found mapping: %s ==> %s'%(k,v))
                mappings[k] = v
                if v.primary_key: #找到主键名
                    if primarykey:
                        raise StandardError('Duplicate primary key for field: k'%k)
                    primarykey = k #此列设为列表的主键
                else:
                    fields.append(k) #保存除主键外的属性
        if not primarykey:
            raise StandardError('primary key not found.')
        for k in mappings.keys():
            attrs.pop(k) #从类属性中删除Field属性,否则,容易造成运行时错误(实例的属性会遮盖类的同名属性)
        escaped_fields = list(map(lambda f: "`%s`"%f,fields))#转换为sql语法
        #创建供Model类使用属性
        attrs['__mappings__'] = mappings # 保存属性和列的映射关系
        attrs['__table__'] = tableName #表的名字
        attrs['__primary_key__'] = primarykey # 主键属性名
        attrs['__fields__'] = fields # 除主键外的属性名
        attrs['__select__'] = 'select `%s`,%s from %s'%(primarykey,','.join(escaped_fields),tableName)
        attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primarykey, create_args_string(len(escaped_fields) + 1))#占位符
        attrs['__update__'] = 'update `%s` set %s where `%s`=?'%(tableName,','.join(map(lambda f: '`%s`=?'%(mappings.get(f).name or f),fields)),primarykey)#查询列的名字,也看一下在Field定义上有没有定义名字,默认None
        attrs['__delete__'] = 'delete from `%s` where `%s`=?'%(tableName,primarykey)
        return type.__new__(cls,name,bases,attrs)

#基类Model
class Model(dict,metaclass=ModelMetaclass):
    
    def __init__(self,**kw):
        super(Model,self).__init__(**kw)

    def __getattr__(self,key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError('"Model" object has no atttibute: %s'%key)

    def __setattr__(self,key,value):
        self[key] = value

    def getValue(self,key):
        return getattr(self,key,None) #直接调回内置函数,注意这里没有下划符,注意这里None的用处,是为了当user没有赋值数据时,返回None,调用于update

    def getValueOrDefault(self,key):
        value = getattr(self,key,None) #第三个参数None,可以在没有返回数值时,返回None,调用于save
        if not value:
            field = self.__mappings__[key]
            if field.default is not None:
                value = field.default() if callable(field.default) else field.default
                logging.debug('using default value for %s: %s'%(key, str(value)))
        return value

    @classmethod
    async def findall(cls,where=None,args=None,**kw):
        sql = [cls.__select__]
        if where:
            sql.append('where')
            sql.append(where)
        if args is None:
            args = []
        orderBy = kw.get('orderBy',None)
        if orderBy:
            sql.append('order by')
            sql.append(orderBy)
        limit = kw.get('limit',None)
        if limit:
            sql.append('limit')
            if isinstance(limit,int):
                sql.append('?')
                args.append(limit)
            elif isinstance(limit,tuple) and len(limit) == 2:
                sql.append('?, ?')
                args.extend(limit) #tuple融入list
            else:
                raise ValueError('Invalid limit value: %s'%str(limit))
        rs = await select(' '.join(sql),args)
        return [cls(**r) for r in rs] #调试的时候尝试了一下return rs,输出结果一样

    @classmethod
    async def findnumber(cls,selectField,where=None,args=None):
        sql = ['select %s __num__ from `%s`'%(selectField,cls.__table__)]
        if where:
            sql.append('where')
            sql.append(where)
        rs = await select(' '.join(sql),args,1)
        if len(rs) == 0:
            return None
        return rs[0]['__num__']

    @classmethod
    async def find(cls,primarykey):
        sql = '%s where `%s`=?'%(cls.__select__,cls.__primary_key__)
        rs = await select(sql,[primarykey],1)
        if len(rs) == 0:
            return None
        return cls(**rs[0])

    async def save(self):
        args = list(map(self.getValueOrDefault, self.__fields__))
        args.append(self.getValueOrDefault(self.__primary_key__))
        rows = await execute(self.__insert__, args)
        if rows != 1:
            logging.warn('failed to insert record: affected rows: %s' % rows)

    async def update(self):
        args = list(map(self.getValue,self.__fields__))
        args.append(self.getValue(self.__primary_key__))
        rows = await execute(self.__update__,args)
        if rows != 1:
            logging.warn('faild to update by primary key: affected rows: %s'%rows)

    async def remove(self):
        args = [self.getValue(self.__primary_key__)]#这里不能使用list()-->'int' object is not iterable
        rows = await execute(self.__delete__,args)
        if rows != 1:
            logging.warn('faild to remove by primary key: affected rows: %s'%rows)

调用时注意,看清是不是协程,注意使用yield from

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值