徒手撸一个ORM


mysql_single.py文件如下

import pymysql

class Mysql:
    __instance=None
    def __init__(self):
        self.conn=pymysql.connect(
            host='127.0.0.1',
            port=3306,
            user='root',
            password='123',
            database='youku',
            charset='utf8',
            autocommit=True
        )
        self.cursor=self.conn.cursor(cursor=pymysql.cursors.DictCursor)

    def select(self,sql,args=None):
        self.cursor.execute(sql,args)
        re=self.cursor.fetchall()
        return re

    def execute(self,sql,args):
        try:
            self.cursor.execute(sql,args)
            affact_line=self.cursor.rowcount
        except BaseException as e:
            print(e)

        return affact_line


    def close_db(self):
        self.cursor.close()
        self.conn.close()

    @classmethod
    def singleton(cls):
        if not cls.__instance:
            cls.__instance=cls()
        return cls.__instance

orm.py文件如下:

import mysql_single

class BaseField:
    def __init__(self,name,column_type,primary_key,default):
        self.name=name
        self.column_type=column_type
        self.primary_key=primary_key
        self.default=default


class StringField(BaseField):
    def __init__(self,name,column_type='varchar(200)',primary_key=False,default=None):
        super().__init__(name,column_type,primary_key,default)

class IntegerField(BaseField):
    def __init__(self,name,column_type='int',primary_key=False,default=0):
        super().__init__(name, column_type, primary_key, default)


class ModelMeta(type):
    def __new__(cls,name,bases,attr):
        if name=='Model':
            return type.__new__(cls,name,bases,attr)

        table_name=attr.get('table_name',None)
        if not table_name:
            table_name=name

        primary_key=None
        mappings=dict()

        for k,v in attr.items():
            if isinstance(v,BaseField):
                mappings[k]=v
                if v.primary_key:
                    if primary_key:
                        raise TypeError('主键重复')
                    primary_key=k
        for k in mappings.keys():
            attr.pop(k)
        if not primary_key:
            raise TypeError('没有主键')

        attr['table_name']=table_name
        attr['primary_key']=primary_key
        attr['mappings']=mappings

        return type.__new__(cls,name,bases,attr)

class Model(dict,metaclass=ModelMeta):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)

    def __getattr__(self, item):
        try:
            return self[item]
        except TypeError:
            print('没有这个属性')

    def __setattr__(self, key, value):
        self[key]=value

    @classmethod
    def select_one(cls,**kwargs):
        key=list(kwargs.keys())[0]
        value=kwargs[key]
        sql='select * from %s where %s =?'%(cls.table_name,key)
        sql=sql.replace('?','%s')
        ms=mysql_single.Mysql().singleton()
        re=ms.select(sql,value)
        if re:
            return cls(**re[0])
        else:
            return
    @classmethod
    def select_many(cls,**kwargs):
        ms = mysql_single.Mysql().singleton()
        if kwargs:
            key = list(kwargs.keys())[0]
            value = kwargs[key]
            sql = 'select * from %s where %s =?' % (cls.table_name, key)
            sql = sql.replace('?', '%s')
            re = ms.select(sql, value)
        else:
            sql='select * from %s'%(cls.table_name)
            re = ms.select(sql)

        if re:
            return list(cls(**r) for r in re )
        else:
            return


    def update(self):
        ms=mysql_single.Mysql().singleton()

        field_list=[]
        primary_key_value=None
        field_value=[]

        for k,v in self.mappings.items():
            if v.primary_key:
                primary_key_value=getattr(self,v.name,None)
            else:
                field_list.append(v.name+'=?')
                field_value.append(getattr(self,v.name,v.default))
        sql='update %s set %s where %s =%s '%(self.table_name,','.join(field_list),self.primary_key,primary_key_value)
        sql=sql.replace('?','%s')
        ms.execute(sql,field_value)

    def save(self):
        ms=mysql_single.Mysql().singleton()
        #insert into %s (name,password) values(?,?)
        field_list=[]
        values_list=[]
        field_list_value=[]
        for k,v in self.mappings.items():
            if not v.primary_key:
                field_list.append(v.name)
                values_list.append('?')
                field_list_value.append(getattr(self,v.name,v.default))
        sql='insert into %s (%s) values(%s)'%(self.table_name,','.join(field_list),','.join(values_list))
        sql=sql.replace('?','%s')
        ms.execute(sql,field_list_value)


class User(Model):
    table_name='user'
    id=IntegerField('id',primary_key=True)
    name=StringField('name')
    password=StringField('password')

if __name__ == '__main__':
    # user1=User.select_one(id=1)
    # print(user1)
    # user1.name='ccc'
    # user1.update()

    # user2=User.select_many()
    # print(user2)
    user=User(name='monicx',password='4333')
    user.save()







阅读更多

没有更多推荐了,返回首页