参考修改了廖雪峰老师的代码
只实现了最基本的功能
数据库连接模块
import pymysql
class BaseDB:
def __init__(self, user, password, database='', host='127.0.0.1', port=3306, charset='utf8', cursor_class=pymysql.cursors.DictCursor):
self.user = user
self.password = password
self.host = host
self.database = database
self.port = port
self.charset = charset
self.cursor_class = cursor_class
self.conn = self.connect()
def connect(self):
return pymysql.connect(host=self.host, user=self.user, port=self.port,passwd=self.password, db=self.database,charset=self.charset,cursorclass=self.cursor_class)
def execute(self, sql, params=None):
with self.conn as cursor:
rows = cursor.execute(sql, params)
result = cursor.fetchall()
return rows, result
field模块
class Field(object):
def __init__(self, name, column_type, primary_key, default):
self.name = name
self.column_type = column_type
self.primary_key = primary_key
self.default = default
def __str__(self):
return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)
class StringField(Field):
def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):
super().__init__(name, ddl, primary_key, default)
model部分:核心
class ModelMetaclass(type):
def __new__(cls, name, bases, attrs):
# skip base Model class:
if name == 'Model':
return type.__new__(cls, name, bases, attrs)
tableName = attrs.get('__table__', None) or name
mappings = dict()
fields = []
primaryKey = None
for k, v in attrs.items():
if isinstance(v, Field):
mappings[k] = v
if v.primary_key:
# 找到主键:
if primaryKey:
raise RuntimeError('Duplicate primary key for field: %s' % k)
primaryKey = k
else:
fields.append(k)
if not primaryKey:
raise RuntimeError('Primary key not found.')
for k in mappings.keys():
attrs.pop(k)
escaped_fields = list(map(lambda f: '`%s`' % f, fields))
attrs['__mappings__'] = mappings # 保存属性和列的映射关系
attrs['__table__'] = tableName
attrs['__primary_key__'] = primaryKey # 主键属性名
attrs['__fields__'] = fields # 除主键外的属性名
attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)
attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))
attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)
attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)
if not '__database__' in attrs:
attrs['__database__'] = 'default'
attrs['db'] = BaseDB('root','123456',database = attrs['__database__'])
return type.__new__(cls, name, bases, attrs)
class Model(dict,metaclass=ModelMetaclass):
def __init__(self, **kw):
super(Model, self).__init__(**kw)
def __getattr__(self, key):
try:
return self[key]
except KeyError:
raise AttributeError("'%s' instance has no attribute '%s'" % (self.__class__.__name__, key))
def __setattr__(self, key, value):
self[key]=value
def getValue(self, key):
return getattr(self, key, None)
def getValueOrDefault(self, key):
value = getattr(self, key, None)
if value is None:
field = self.__mappings__[key]
if field.default is not None:
value = field.default() if callable(field.default) else field.default
setattr(self, key, value)
return value
@classmethod
def filter(cls,where='', *args):
sql='select * from %s %s' % (cls.__table__,'where %s'%where if where else '')
print(sql)
L,res=cls.db.execute(sql)
print(res)
return res
def save(self):
args = list(map(self.getValueOrDefault, self.__fields__))
args.append(self.getValueOrDefault(self.__primary_key__))
print(self.__insert__,args)
sql = self.__insert__.replace("?", "%s")
rows,result = self.db.execute(sql,args)
return rows
测试
sql脚本新建数据库和表
可通过mysql -u root -p < schema.sql
执行
drop database if exists test3;
create database test3;
use test3;
create table users (
`id` varchar(50) not null,
`passwd` varchar(50) not null,
`name` varchar(50) not null,
primary key (`id`)
) engine=innodb default charset=utf8;
py脚本
from orm import Model,StringField
class User(Model):
__table__ = 'users'
__database__='test3'
id = StringField(primary_key=True, ddl='varchar(50)')
passwd = StringField(ddl='varchar(50)')
name = StringField(ddl='varchar(50)')
u = User(id=0,passwd='123456',name='xiaoming')
u.save()
a = User.filter()
效果截图: