MySQL简单技巧(四):教你轻松用information_schema表+python实现表结构同步(下)——实战篇

前文

  紧接着上篇MySQL简单技巧(三):教你轻松用information_schema表+python实现表结构同步(上)——理论篇,这篇进入实战,也就是基于Python来实现一套表结构同步!
  代码放于github,链接:https://github.com/sandwu/db_diff.git

代码说明

  代码是基于前端代码修改的,为了在前端显示友好,所以显示格式是:

已选择table远程table差异sql差异sql说明
table_nametable_name(无则为空)alter或者create语句提醒是否有drop语句等

  回顾下上文的表结构同步原理六步:
在这里插入图片描述
  首先需要基于pymysql实现一个类,用于从数据库取元数据,代码实现如下:

class EasyPyMySql:
    def __init__(self, config):
        self.conn = None
        self.cursor = None
        self.result = None
        self.errMsg = None #用于记录错误信息
        self.config = {
            "host": None,
            "port": None,
            "user": None,
            "password": None,
            "database": None,
            "charset": "utf8",
        }
        for k, v in config.items():
            if k in self.config:
                self.config[k] = v

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.py_close()

    def getConnection(self):
        config = self.config
        try:
            self.conn = pymysql.connect(**config)
            self.conn.autocommit = True
        except Exception as err:
            print(err)
            self.errMsg = "Connection error, please contact the Administrator !"
            raise
        return self.conn

    # 获取普通游标
    def getCursor(self):
        if self.conn is None:
            self.conn = self.getConnection()
        return self.conn.cursor(cursor=pymysql.cursors.DictCursor)

    # 执行sql语句
    def py_execute(self, sql, param):
        if self.cursor is None:
            self.cursor = self.getCursor()
        try:
            self.cursor.execute(sql, param)
            self.rows = self.cursor.fetchall()
        except Exception as err:
            self.errMsg = str(err)
            print(err)
            raise
        return self.rows

    # 关闭连接
    def py_close(self):
        if self.cursor is not None:
            self.cursor.close()
        if self.conn is not None:
            self.conn.close()

  然后根据上篇,依次取出我们需要的字段:

def get_table_info(table_list, conf):
    """
    根据本地的table_list,依次找到远程数据库的表、字段、索引等信息,通过in table_list只查询存在的表
    :param table_list:
    :return:
    """
    db_name = conf["database"]

    if len(table_list) >1:
        sql_columns = "select TABLE_NAME,COLUMN_NAME,ORDINAL_POSITION,COLUMN_DEFAULT,IS_NULLABLE,DATA_TYPE," \
                      "COLUMN_TYPE,EXTRA,COLUMN_COMMENT " \
                      "from information_schema.columns where table_schema=%s and table_name in %s"
        #生产环境 暂去除索引注释比较:index_comment
        sql_statistics = "select TABLE_NAME,NON_UNIQUE,INDEX_NAME,COLUMN_NAME,NULLABLE,INDEX_TYPE,SUB_PART " \
                     "from information_schema.statistics where table_schema=%s and table_name in %s"

        #单独取出参数,防sql注入
        sql_columns_params = (db_name, table_list)
        sql_statistics_params = (db_name, table_list)
    else:
        sql_columns = "select TABLE_NAME,COLUMN_NAME,ORDINAL_POSITION,COLUMN_DEFAULT,IS_NULLABLE,DATA_TYPE," \
                      "COLUMN_TYPE,EXTRA,COLUMN_COMMENT " \
                      "from information_schema.columns where table_schema=%s and table_name=%s"
        # 生产环境 暂去除索引注释比较:index_comment
        sql_statistics = "select TABLE_NAME,NON_UNIQUE,INDEX_NAME,COLUMN_NAME,NULLABLE,INDEX_TYPE,SUB_PART " \
                         "from information_schema.statistics where table_schema=%s and table_name=%s"

        sql_columns_params = (db_name, table_list[0])
        sql_statistics_params = (db_name, table_list[0])

    try:
        db = EasyPyMySql(conf)
        res_columns = db.py_execute(sql_columns,sql_columns_params)  #防SQL注入
        res_statistics = db.py_execute(sql_statistics,sql_statistics_params)
        db.py_close()
        return res_columns, res_statistics
    except Exception as e:
        print(e)
    return None,None #错误返回

  在取得元数据的时候,可以根据公司的sql规范来取出所需要比对的字段。比如columns表的ORDINAL_POSITION是服务于定位字段位置的,虽然我也取出了,不过在对比的时候我是直接跳过的。再比如statistics表的index_comment字段,因为不care索引注释,所以这个字段我直接不取。所以可以灵活更改需要取的字段,而关于pymysql尽量使用防SQL注入的形式来取数据
  在取出数据后,数据结构是列表嵌套字典,需要转换为{table:{}}这种字典嵌套字典的形式,这样可以通过dict.keys()获取本地、远程所有的table,还可以通过table快速定位到数据,进行对应的比较。转换如下:

def table_columns_to_dict(table_strucs):
    """
    将数据库取出的表结构转换为{table_name:{column_name:{}}}格式
    :param table_strucs:
    :return:
    """
    tmp = {}
    for table_struc in table_strucs:
        if table_struc["TABLE_NAME"] in tmp:
            tmp[table_struc["TABLE_NAME"]][table_struc["COLUMN_NAME"]] = table_struc
        else:
            tmp[table_struc["TABLE_NAME"]] = {}
            tmp[table_struc["TABLE_NAME"]][table_struc["COLUMN_NAME"]] = table_struc
    return tmp

def table_statistics_to_dict(table_strucs):
    """
    将数据库取出的表结构转换为{table_name:{index_name:{}}}格式
    :param table_strucs:
    :return:
    """
    tmp = {}
    for table_struc in table_strucs:
        if table_struc["TABLE_NAME"] in tmp:
            if table_struc["INDEX_NAME"] in tmp[table_struc["TABLE_NAME"]]:
                # 如果是联合索引,则将联合索引的字段合并在一起
                tmp[table_struc["TABLE_NAME"]][table_struc["INDEX_NAME"]]["COLUMN_NAME"] += ",%s" % table_struc["COLUMN_NAME"]
            else:
                tmp[table_struc["TABLE_NAME"]][table_struc["INDEX_NAME"]] = table_struc
        else:
            tmp[table_struc["TABLE_NAME"]] = {}
            tmp[table_struc["TABLE_NAME"]][table_struc["INDEX_NAME"]] = table_struc
    return tmp

  然后进入比对阶段,先比较表是否存在,通过table in remote_dict.keys(),如果不存在则生成create语句,通过show create table table_name获取,这里要注意auto_increment需要舍弃,因为这里是标记当前的自增值到了哪,作为新表是不需要的。如果存在,则进入比较。在比较这块是通过set的对称差集来实现的,原先是想要用hash来实现,即对所有字典通过hash来获取一个值,直接比较该值是否相等就能确认两张表是否有差异(这里利用到了hash的数据标识的功能,关于hash的七大功能,待下篇分享)。但没找到合适的hash函数,所以直接基于set的对称差集来实现,通过set1 ^ set2来获取结果,代码如下:

    for select_column_name in select_column_dict[select_table].keys():
        if select_column_name in remote_column_dict[select_table].keys():
            #执行对比表字段,获取对应table的字段数据和索引数据
            select_column_name_dict = select_column_dict[select_table][select_column_name]
            remote_column_name_dict = remote_column_dict[select_table][select_column_name]
            differ = set(select_column_name_dict.items()) ^ set(remote_column_name_dict.items())

            if not differ:
                continue
            else:
                # mac自带的mysql没有该字段,所以跳过,后续上线可以废除该判断!˚
                if differ == {('GENERATION_EXPRESSION', '')}:
                    continue
                #用于判断diff=ORDINAL_POSITION的情况,该参数服务于列所在的位置,暂不支持跳过
                for diff in differ:
                    if diff[0] =='ORDINAL_POSITION' :
                        ORDINAL_POSITION_flag = 1
                        break
                    else:
                        ORDINAL_POSITION_flag = 0
                if ORDINAL_POSITION_flag:
                    continue
                row = select_column_name_dict
                change_sql += " modify `%(COLUMN_NAME)s` %(COLUMN_TYPE)s" % row
                change_sql += table_diff_create_sql(row)
                alter_modify_column_msg += "`%(COLUMN_NAME)s`, " % row
        else:
            #没有该字段,生成该sql语句
            row = select_column_dict[select_table][select_column_name]
            change_sql += " add column `%(COLUMN_NAME)s` %(COLUMN_TYPE)s" % row
            change_sql += table_diff_create_sql(row)
            alter_add_column_msg += "`%(COLUMN_NAME)s`, " % row

  有差异就是生成modify的结果,无差异则生成add的结果;至于statistics也是一样的道理,详细的可以通过代码查阅。因为MySQL的字段类型过多,所以不同的字段处理方式也是不同的,这里定义了一个函数来处理类型,代码如下:

def table_diff_create_sql(row):
    """
    将表对比的结果生成对应的sql语句
    :param row:
    :return:
    """
    sql = ""
    if "unsigned" in row["COLUMN_TYPE"]:
        sql += " unsigned"
    if "zerofill" in row["COLUMN_TYPE"]:
        sql += " zerofill"
    if row["IS_NULLABLE"] == "NO":
        sql += " not null"
    if row["COLUMN_DEFAULT"]:
        if row.get("DATA_TYPE") in ["char", "varchar", "datetime", "date", "timestamp", "text", "longtext"]:
            if 'CURRENT_TIMESTAMP' in row["COLUMN_DEFAULT"] and row['EXTRA'] == 'on update CURRENT_TIMESTAMP':
                sql += " default CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP "
            elif 'CURRENT_TIMESTAMP' in row["COLUMN_DEFAULT"]:
                sql += " default %(COLUMN_DEFAULT)s" % (row)
            else:
                sql += " default '%(COLUMN_DEFAULT)s'" % (row)
        else:

            sql += " default %(COLUMN_DEFAULT)s" % (row)
    else:
        if row.get("DATA_TYPE") in ["char", "varchar", "datetime", "date", "timestamp", "text", "longtext"]:
            if row.get("COLUMN_DEFAULT") == '""' or row.get("COLUMN_DEFAULT") == "''" or row.get(
                    "COLUMN_DEFAULT") == "" or row.get("COLUMN_DEFAULT") == '':
                sql += " default '' "
            # 注意这里无法区分是否有default值,因为设置default null和不设置default值读取COLUMN_DEFAULT都是None
            # 所以默认添加default Null
            else:
                # timestamp默认值不能是NULL,必须要加上NULL在前
                if row.get("DATA_TYPE") == "timestamp":
                    sql += " NULL default NULL "
                else:
                    sql += " default NULL "
        # 判断是主键的情况或者自增
        elif row.get("COLUMN_KEY") == "PRI" or "auto_increment" in row["EXTRA"]:
            pass
        else:
            sql += " default NULL "
    if "auto_increment" in row["EXTRA"]:
        sql += " auto_increment"
    if row["COLUMN_COMMENT"]:
        sql += " comment '%(COLUMN_COMMENT)s'" % (row)
    sql += ","
    return sql

  基于此就可以实现一套基本的表结构对比~

总结

  代码实现的相对粗糙,加上注释大概400行左右,有所建议,可以互相探讨~

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值