参考廖雪峰的Python教程,实现Python3 MySQL ORM使用
数据表user结构如下:
mysql> desc user;
+-------+-------------+------+-----+---------+-------+
| Field | Type | Null | Key | Default | Extra |
+-------+-------------+------+-----+---------+-------+
| id | varchar(20) | NO | PRI | NULL | |
| name | varchar(20) | YES | | NULL | |
| age | smallint(6) | YES | | NULL | |
+-------+-------------+------+-----+---------+-------+
#!/usr/bin/python3
# coding: utf-8
import aiomysql
import asyncio
import logging
import re
import time
logging.basicConfig(level=logging.INFO, format="\033[1;35m%(asctime)s %(levelname)s: %(message)s\033[0m", datefmt="%Y-%m-%d %H:%M:%S")
def log(sql, args=None):
logging.info("SQL: %s %s" %(sql, args))
@asyncio.coroutine
def create_pool(loop, **kw):
logging.info("create database connection pool")
global __pool
__pool = yield from aiomysql.create_pool(
host = kw.get("host", "localhost"),
port = kw.get("port", 3306),
user = kw["user"],
password = kw["password"],
db = kw["db"],
charset = kw.get("charset", "utf8"),
autocommit = kw.get("autocommit", True),
maxsize = kw.get("maxsize", 10),
minsize = kw.get("minsize", 1),
loop = loop
)
@asyncio.coroutine
def destroy_pool():
global __pool
if __pool:
__pool.close()
yield from __pool.wait_closed()
@asyncio.coroutine
def select(sql, args, size=None):
log(sql, args)
global __pool
with (yield from __pool) as conn:
cursor = yield from conn.cursor(aiomysql.DictCursor)
yield from cursor.execute(sql.replace("?", "%s"), args or ())
if size:
ret = yield from cursor.fetchmany(size)
else:
ret = yield from cursor.fetchall()
yield from cursor.close()
logging.info("rows returned: %s" %(len(ret)))
return ret
@asyncio.coroutine
def execute(sql, args):
log(sql, args)
with (yield from __pool) as conn:
try:
cursor = yield from conn.cursor()
yield from cursor.execute(sql.replace("?", "%s"), args or ())
affected = cursor.rowcount
yield from cursor.close()
except BaseException as e:
raise
return affected
def create_args_string(num):
L = []
for n in range(num):
L.append("?")
return ", ".join(L)
class ModelMetaclass(type):
def __new__(class_type, class_name, bases, attrs):
if class_name == "Model":
return type.__new__(class_type, class_name, bases, attrs)
table_name = attrs.get("__table__", None) or class_name
logging.info("found model: %s(table: %s)" %(class_name, table_name))
mappings = dict()
fields = []
primary_key = None
for k, v in attrs.items():
if isinstance(v, Field):
logging.info("found mapping: %s: %s" %(k, v))
mappings[k] = v
if v.primary_key:
if primary_key:
raise RuntimeError("duplicate primary key for field: %s" %(k))
primary_key = k
else:
fields.append(k)
if not primary_key:
raise RuntimeError("primary key not found")
for key in mappings.keys():
attrs.pop(key)
escaped_fields = list(map(lambda f: "`%s`" %f, fields))
attrs["__mappings__"] = mappings
attrs["__table__"] == table_name
attrs["__primary_key__"] = primary_key
attrs["__fields__"] = fields
attrs["__select__"] = "select `%s`, %s from `%s`" %(primary_key, ", ".join(escaped_fields), table_name)
attrs["__insert__"] = "insert into `%s` (`%s`, %s) values(%s)" %(table_name, primary_key, ", ".join(escaped_fields), create_args_string(len(escaped_fields) + 1))
attrs["__delete__"] = "delete from `%s` where `%s`=?" %(table_name, primary_key)
attrs["__update__"] = "update `%s` set %s where `%s`=?" %(table_name, ", ".join(map(lambda f: "`%s`=?" %(mappings.get(f).name or f), fields)), primary_key)
for k, v in attrs.items():
logging.info("%s: %s" %(k, v))
return type.__new__(class_type, class_name, bases, attrs)
class Model(dict, metaclass=ModelMetaclass):
def __init__(self, **kw):
super(Model, self).__init__(kw)
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError(r"Model's object has no attribute %s" %(key))
def __setattr__(self, key, value):
self[key] = value
def getValue(self, key):
return getattr(self, key, None)
def getValueOrDefault(self, key):
value = getattr(self, key, None)
if value is None:
field = self.__mappings__[key]
if field.default is not None:
value = field.default() if callable(field.default) else field.default
logging.info("using default value for %s: %s" %(key, str(value)))
setattr(self, key, value)
return value
@asyncio.coroutine
def find(class_type, where=None, args=None, **kw):
sql = [class_type.__select__]
if where:
sql.append(" where ")
sql.append(where)
if args is None:
args = []
orderBy = kw.get("orderBy", None)
if orderBy:
sql.append(" order by ")
sql.append(orderBy)
limit = kw.get("limit", None)
if limit:
sql.append(" limit ")
if isinstance(limit, int):
sql.append("?")
args.append(limit)
elif isinstance(limit, tuple) and len(limit) == 2:
sql.append("?, ?")
args.extend(limit)
ret = yield from select("".join(sql), args)
return [class_type(**r) for r in ret]
@asyncio.coroutine
def save(self):
args = []
args.append(self.getValueOrDefault(self.__primary_key__))
args += list(map(self.getValueOrDefault, self.__fields__))
rows = yield from execute(self.__insert__, args)
if rows == 1:
logging.info("insert %s rows" %(rows))
else:
logging.warning("failed to insert, affected rows: %s" %(rows))
@asyncio.coroutine
def update(self):
args = list(map(self.getValueOrDefault, self.__fields__))
args.append(self.getValueOrDefault(self.__primary_key__))
rows = yield from execute(self.__update__, args)
if rows == 1:
logging.info("update %s rows" %(rows))
else:
logging.warning("failed to update, affected rows: %s" %(rows))
@asyncio.coroutine
def delete(self):
args = []
args.append(self.getValueOrDefault(self.__primary_key__))
rows = yield from execute(self.__delete__, args)
if rows == 1:
logging.info("delete %s rows" %(rows))
else:
logging.warning("failed to delete, affected rows: %s" %(rows))
class Field(object):
def __init__(self, name, column_type, primary_key, default):
self.name = name
self.column_type = column_type
self.primary_key = primary_key
self.default = default
def __str__(self):
return "<%s %s: %s>" %(self.__class__.__name__, self.column_type, self.name)
class StringField(Field):
def __init__(self, name=None, primary_key=False, default=None, ddl="varchar(20)"):
super().__init__(name, ddl, primary_key, default)
class IntegerField(Field):
def __init__(self, name=None, primary_key=False, default=None, ddl="smallint"):
super().__init__(name, ddl, primary_key, default)
class User(Model):
__table__ = "user"
id = StringField(primary_key=True)
name = StringField()
age = IntegerField(default=99)
@asyncio.coroutine
def test(loop):
yield from create_pool(loop=loop, host="localhost", user="xxx", password="xxx", db="python")
ret = yield from User.find(User)
print_ret(ret)
ret = yield from User.find(User, where="id > 3")
print_ret(ret)
ret = yield from User.find(User, limit=3, orderBy="name desc")
print_ret(ret)
ret = yield from User.find(User, limit=(3, 3), orderBy="id desc")
print_ret(ret)
id = re.sub(r"\.", "", str(time.time())[4:])
user = User(id=id, name="user%s" %(id))
yield from user.save()
ret = yield from User.find(User, where="id = %s" %(id))
print_ret(ret)
user = ret[0]
user = User(id=user["id"], name="update_user", age=23)
yield from user.update()
ret = yield from User.find(User, where="id = %s" %(id))
print_ret(ret)
user = ret[0]
yield from user.delete()
ret = yield from User.find(User, where="name like 'user%%'")
print_ret(ret)
for user in ret:
yield from user.delete()
ret = yield from User.find(User)
print_ret(ret)
yield from destroy_pool()
def print_ret(ret):
for data in ret:
print("%s %s %s" %(data["id"], data["name"], data["age"]))
loop = asyncio.get_event_loop()
loop.run_until_complete(test(loop))
loop.close()