Python3 MySQL ORM使用

参考廖雪峰的Python教程,实现Python3 MySQL ORM使用

数据表user结构如下:

mysql> desc user;
+-------+-------------+------+-----+---------+-------+
| Field | Type        | Null | Key | Default | Extra |
+-------+-------------+------+-----+---------+-------+
| id    | varchar(20) | NO   | PRI | NULL    |       |
| name  | varchar(20) | YES  |     | NULL    |       |
| age   | smallint(6) | YES  |     | NULL    |       |
+-------+-------------+------+-----+---------+-------+

#!/usr/bin/python3
# coding: utf-8

import aiomysql 
import asyncio
import logging
import re 
import time 

logging.basicConfig(level=logging.INFO, format="\033[1;35m%(asctime)s %(levelname)s: %(message)s\033[0m", datefmt="%Y-%m-%d %H:%M:%S")

def log(sql, args=None):
    logging.info("SQL: %s %s" %(sql, args))

@asyncio.coroutine
def create_pool(loop, **kw):
    logging.info("create database connection pool")
    global __pool
    __pool = yield from aiomysql.create_pool(
            host = kw.get("host", "localhost"),
            port = kw.get("port", 3306),
            user = kw["user"],
            password = kw["password"],
            db = kw["db"],
            charset = kw.get("charset", "utf8"),
            autocommit = kw.get("autocommit", True),
            maxsize = kw.get("maxsize", 10),
            minsize = kw.get("minsize", 1),
            loop = loop 
            )

@asyncio.coroutine 
def destroy_pool():
    global __pool 
    if __pool:
        __pool.close()
        yield from __pool.wait_closed()

@asyncio.coroutine 
def select(sql, args, size=None):
    log(sql, args)
    global __pool 
    with (yield from __pool) as conn:
        cursor = yield from conn.cursor(aiomysql.DictCursor)
        yield from cursor.execute(sql.replace("?", "%s"), args or ())
        if size:
            ret = yield from cursor.fetchmany(size)
        else:
            ret = yield from cursor.fetchall()
        yield from cursor.close()
        logging.info("rows returned: %s" %(len(ret)))
        return ret 

@asyncio.coroutine 
def execute(sql, args):
    log(sql, args)
    with (yield from __pool) as conn:
        try:
            cursor = yield from conn.cursor()
            yield from cursor.execute(sql.replace("?", "%s"), args or ())
            affected = cursor.rowcount
            yield from cursor.close()
        except BaseException as e:
            raise 
        return affected 

def create_args_string(num):
    L = []
    for n in range(num):
        L.append("?")
    return ", ".join(L)

class ModelMetaclass(type):
    def __new__(class_type, class_name, bases, attrs):
        if class_name == "Model":
            return type.__new__(class_type, class_name, bases, attrs)
        table_name = attrs.get("__table__", None) or class_name 
        logging.info("found model: %s(table: %s)" %(class_name, table_name))
        mappings = dict()
        fields = []
        primary_key = None 
        
        for k, v in attrs.items():
            if isinstance(v, Field):
                logging.info("found mapping: %s: %s" %(k, v))
                mappings[k] = v
                if v.primary_key:
                    if primary_key:
                        raise RuntimeError("duplicate primary key for field: %s" %(k))
                    primary_key = k
                else:
                    fields.append(k)

        if not primary_key:
            raise RuntimeError("primary key not found")
        
        for key in mappings.keys():
            attrs.pop(key)

        escaped_fields = list(map(lambda f: "`%s`" %f, fields))
        attrs["__mappings__"] = mappings 
        attrs["__table__"] == table_name 
        attrs["__primary_key__"] = primary_key 
        attrs["__fields__"] = fields 
        attrs["__select__"] = "select `%s`, %s from `%s`" %(primary_key, ", ".join(escaped_fields), table_name)
        attrs["__insert__"] = "insert into `%s` (`%s`, %s) values(%s)" %(table_name, primary_key, ", ".join(escaped_fields), create_args_string(len(escaped_fields) + 1))
        attrs["__delete__"] = "delete from `%s` where `%s`=?" %(table_name, primary_key)
        attrs["__update__"] = "update `%s` set %s where `%s`=?" %(table_name, ", ".join(map(lambda f: "`%s`=?" %(mappings.get(f).name or f), fields)), primary_key)
        for k, v in attrs.items():
            logging.info("%s: %s" %(k, v))
        return type.__new__(class_type, class_name, bases, attrs)

class Model(dict, metaclass=ModelMetaclass):
    def __init__(self, **kw):
        super(Model, self).__init__(kw)

    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(r"Model's object has no attribute %s" %(key))

    def __setattr__(self, key, value):
        self[key] = value 

    def getValue(self, key):
        return getattr(self, key, None)

    def getValueOrDefault(self, key):
        value = getattr(self, key, None)
        if value is None:
            field = self.__mappings__[key]
            if field.default is not None:
                value = field.default() if callable(field.default) else field.default 
                logging.info("using default value for %s: %s" %(key, str(value)))
                setattr(self, key, value)
        return value 

    @asyncio.coroutine 
    def find(class_type, where=None, args=None, **kw):
        sql = [class_type.__select__]
        if where:
            sql.append(" where ")
            sql.append(where)

        if args is None:
            args = []

        orderBy = kw.get("orderBy", None)
        if orderBy:
            sql.append(" order by ")
            sql.append(orderBy)

        limit = kw.get("limit", None)
        if limit:
            sql.append(" limit ")
            if isinstance(limit, int):
                sql.append("?")
                args.append(limit)
            elif isinstance(limit, tuple) and len(limit) == 2:
                sql.append("?, ?")
                args.extend(limit)

        ret = yield from select("".join(sql), args)
        return [class_type(**r) for r in ret]

    @asyncio.coroutine 
    def save(self):
        args = []
        args.append(self.getValueOrDefault(self.__primary_key__))
        args += list(map(self.getValueOrDefault, self.__fields__))
        rows = yield from execute(self.__insert__, args)
        if rows == 1:
            logging.info("insert %s rows" %(rows))
        else:
            logging.warning("failed to insert, affected rows: %s" %(rows))

    @asyncio.coroutine 
    def update(self):
        args = list(map(self.getValueOrDefault, self.__fields__))
        args.append(self.getValueOrDefault(self.__primary_key__))
        rows = yield from execute(self.__update__, args)
        if rows == 1:
            logging.info("update %s rows" %(rows))
        else:
            logging.warning("failed to update, affected rows: %s" %(rows))

    @asyncio.coroutine 
    def delete(self):
        args = []
        args.append(self.getValueOrDefault(self.__primary_key__))
        rows = yield from execute(self.__delete__, args)
        if rows == 1:
            logging.info("delete %s rows" %(rows))
        else:
            logging.warning("failed to delete, affected rows: %s" %(rows))

class Field(object):
    def __init__(self, name, column_type, primary_key, default):
        self.name = name 
        self.column_type = column_type 
        self.primary_key = primary_key 
        self.default = default 

    def __str__(self):
        return "<%s %s: %s>" %(self.__class__.__name__, self.column_type, self.name)

class StringField(Field):
    def __init__(self, name=None, primary_key=False, default=None, ddl="varchar(20)"):
        super().__init__(name, ddl, primary_key, default)

class IntegerField(Field):
    def __init__(self, name=None, primary_key=False, default=None, ddl="smallint"):
        super().__init__(name, ddl, primary_key, default)

class User(Model):
    __table__ = "user"
    id = StringField(primary_key=True)
    name = StringField()
    age = IntegerField(default=99)

@asyncio.coroutine 
def test(loop):
    yield from create_pool(loop=loop, host="localhost", user="xxx", password="xxx", db="python")
    
    ret = yield from User.find(User)
    print_ret(ret)
    
    ret = yield from User.find(User, where="id > 3")
    print_ret(ret)
     
    ret = yield from User.find(User, limit=3, orderBy="name desc")
    print_ret(ret)
      
    ret = yield from User.find(User, limit=(3, 3), orderBy="id desc")
    print_ret(ret)

    id = re.sub(r"\.", "", str(time.time())[4:])
    user = User(id=id, name="user%s" %(id))
    yield from user.save()
    
    ret = yield from User.find(User, where="id = %s" %(id))
    print_ret(ret)
    user = ret[0]

    user = User(id=user["id"], name="update_user", age=23)
    yield from user.update()

    ret = yield from User.find(User, where="id = %s" %(id))
    print_ret(ret)
    user = ret[0]
    yield from user.delete()

    ret = yield from User.find(User, where="name like 'user%%'")
    print_ret(ret)
    for user in ret:
        yield from user.delete()

    ret = yield from User.find(User)
    print_ret(ret)
 
    yield from destroy_pool()

def print_ret(ret):
    for data in ret:
        print("%s %s %s" %(data["id"], data["name"], data["age"])) 

loop = asyncio.get_event_loop()
loop.run_until_complete(test(loop))
loop.close()

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值