Python 数据库对比工具,导出SQL脚本

Python 数据库对比工具,导出SQL脚本

最近项目上线经常需要更新脚本,数据库自带工具不好用,还是自己写一个脚本吧

import datetime
import time
from typing import List

import mysql.connector

"""数据库对比脚本"""


# 新增/修改字段SQL
def addOrModifyColumnSql(isAdd, tableName, columnName, columnType, nullable, defaultValue, comment):
    nullable = " null " if nullable == "YES" else " not null "
    addSql = "alter table {} add column {} {} {}".format(tableName, columnName, columnType, nullable)
    if defaultValue is not None and len(defaultValue) > 0:
        addSql += " default " + defaultValue
    if comment is not None and len(comment) > 0:
        addSql += " comment '{}'".format(comment)
    if isAdd:
        addSql = addSql.replace("add column", "modify column")
    addSql += " ;"
    return addSql


# 字节码转字符串
def decodeObj(s):
    if type(s) == bytes or type(s) == bytearray:
        return s.decode()
    return s


class DatabaseCompare:
    def __init__(self):
        self.__sourceDb = None
        self.__targetDb = None
        # 获取数据库表,可以设置过滤前缀
        self.__getCurrentDbTableSql = "select table_name from information_schema.tables where " \
                                      "table_schema=database() "

        self.__getTableColumnSql = "select column_name,column_type,is_nullable,column_default,column_comment  " \
                                   "from information_schema.columns where table_schema=database() " \
                                   "and column_name != 'id' and table_name = '{}'"

    def sourceDbConnect(self, **args):
        config = args.copy()
        self.__sourceDb = mysql.connector.connect(
            host=config["host"],  # 数据库主机地址
            port=config["port"],
            auth_plugin=config["auth_plugin"],
            user=config["user"],  # 数据库用户名
            passwd=config["passwd"],  # 数据库密码
            database=config["database"]  # 数据库
        )

    def targetDbConnect(self, **args):
        config = args.copy()
        self.__targetDb = mysql.connector.connect(
            host=config["host"],  # 数据库主机地址
            port=config["port"],
            auth_plugin=config["auth_plugin"],
            user=config["user"],  # 数据库用户名
            passwd=config["passwd"],  # 数据库密码
            database=config["database"]  # 数据库
        )

    # 指定表前缀过滤
    def filterTablePrefix(self, tuples: List[str]):
        if tuples is not None and len(tuples) > 0:
            sql = " and ("
            for v in tuples:
                sql += " table_name  like '{}%' or".format(v)
            sql = sql[0: len(sql) - 3]
            sql += ")"
            self.__getCurrentDbTableSql += sql
        print("print filter table prefix sql: {}".format(self.__getCurrentDbTableSql))

    # 指定表
    def filterTableName(self, tuples: List[str]):
        if tuples is not None and len(tuples) > 0:
            sql = " and table_name in("
            for v in tuples:
                sql += "'{}', ".format(v)
            sql = sql[0: len(sql) - 2]
            sql += ")"
            self.__getCurrentDbTableSql += sql
        print("print filter table sql: {}".format(self.__getCurrentDbTableSql))

    # 源数据库连接
    def bindSourceDbConnect(self):
        return self.__sourceDb

    # 目标数据库连接
    def bindTargetDbConnect(self):
        return self.__targetDb

    def dbClose(self):
        self.bindSourceDbConnect().close()
        self.bindTargetDbConnect().close()

    # 源数据库所有表名
    def getSourceDbAllTableName(self):
        try:
            cursor = self.bindSourceDbConnect().cursor()
            cursor.execute(self.__getCurrentDbTableSql)
            resultList = cursor.fetchall()
            return resultList
        except Exception as e:
            print("发生异常 ", e)

    # 目标数据库所有表名
    def getTargetDbAllTableName(self):
        try:
            cursor = self.bindTargetDbConnect().cursor()
            cursor.execute(self.__getCurrentDbTableSql)
            resultList = cursor.fetchall()
            return resultList
        except Exception as e:
            print("发生异常 ", e)

    # 获取差异表
    def getDiffTableName(self):
        aList = self.getSourceDbAllTableName()
        bList = self.getTargetDbAllTableName()
        return list(set(aList).difference(set(bList)))

    # 获取差异建表SQL
    def getDiffTableCreateSql(self):
        diffTableList = self.getDiffTableName()
        resultList = []
        if len(diffTableList) > 0:
            try:
                cursor = self.bindSourceDbConnect().cursor()
                for t in diffTableList:
                    cursor.execute("show create table " + t[0])
                    res = cursor.fetchone()
                    resultList.append(str(res[1]).replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS") + ";")
            except Exception as e:
                print("发生异常 ", e)
        print("diff table count {}".format(len(resultList)))
        return resultList

    # 获取相同表字段差异SQL
    def getDiffTableColumnSql(self):
        aList = self.getSourceDbAllTableName()
        bList = self.getTargetDbAllTableName()
        intersectionList = set(aList).intersection(set(bList))
        resultList = []
        if len(intersectionList) > 0:
            try:
                aCursor = self.bindSourceDbConnect().cursor()
                bCursor = self.bindTargetDbConnect().cursor()
                # 1、循环获取表字段属性
                for t in intersectionList:
                    # 源库
                    aCursor.execute(self.__getTableColumnSql.format(t[0]))
                    aColumns = aCursor.fetchall()
                    # 目标库
                    bCursor.execute(self.__getTableColumnSql.format(t[0]))
                    bColumns = bCursor.fetchall()
                    # 循环校验字段属性
                    for column in aColumns:
                        bColumn = None
                        columnName = decodeObj(column[0])
                        columnType = decodeObj(column[1])
                        nullable = decodeObj(column[2])
                        defaultValue = decodeObj(column[3])
                        comment = decodeObj(column[4])
                        for field in bColumns:
                            if columnName == field[0]:
                                bColumn = field
                                break
                        if bColumn is not None:  # 3、判断相同字段属性不同
                            if columnType != decodeObj(bColumn[1]) or columnType != decodeObj(bColumn[1]) \
                                    or nullable != decodeObj(bColumn[2]) or defaultValue != decodeObj(bColumn[3]) \
                                    or comment != decodeObj(bColumn[4]):
                                resultList.append(addOrModifyColumnSql(True, t[0], columnName, columnType,
                                                                       nullable, defaultValue, comment))
                        else:  # 2、不存在的字段
                            resultList.append(addOrModifyColumnSql(False, t[0], columnName, columnType,
                                                                   nullable, defaultValue, comment))
            except Exception as e:
                print("发生异常 ", e)
            print("diff column count {}".format(len(resultList)))
            return resultList

    # 导出SQL脚本:exportType:1-全部;2-表;3-字段
    def exportDiffSqlFile(self, aPath, bPath, exportType):
        if exportType is not None and exportType == 1:
            print("开始导出差异表和表字段SQL ==========>>>> ")
        elif exportType is not None and exportType == 2:
            print("开始导出差异表SQL ==========>>>> ")
        elif exportType is not None and exportType == 3:
            print("开始导出差异表字段SQL ==========>>>> ")
        startTime = datetime.datetime.now()
        try:
            if exportType is not None and (exportType == 1 or exportType == 2):
                result = self.getDiffTableCreateSql()
                fo = open(aPath, "a")
                if result is not None:
                    for x in result:
                        fo.write(x + "\n")
                    fo.close()
                    print("Export Success: table file: {}".format(aPath))

            if exportType is not None and (exportType == 1 or exportType == 3):
                result = self.getDiffTableColumnSql()
                if result is not None:
                    fo = open(bPath, "a")
                    for x in result:
                        try:
                            fo.write(x + "\n")
                        except Exception as e:
                            print(e)
                    fo.close()
                    print("Export Success: column file:{}".format(bPath))
        finally:
            self.dbClose()
            print("总耗时:{}".format((datetime.datetime.now() - startTime).total_seconds()))


if __name__ == "__main__":
    today = time.strftime("%Y%m%d%H%M", time.localtime(time.time()))
    createTableSqlFilePath = "G:\\data\\pytest\\createTableSql-{}.sql".format(today)
    tableColumnSqlFilePath = "G:\\data\\pytest\\tableColumnSql-{}.sql".format(today)

    tool = DatabaseCompare()
    # 过滤指定表前缀
    tool.filterTablePrefix(["act_"])
    # 过滤指定表
    # tool.filterTableName(["rbp_user_column_language"])
    tool.sourceDbConnect(
        host="localhost",  # 数据库主机地址
        port=3306,
        auth_plugin="mysql_native_password",
        user="root",
        passwd="123456",
        database="db01"
    )
    tool.targetDbConnect(
        host="localhost",
        port=3306,
        auth_plugin="mysql_native_password",
        user="root",
        passwd="123456",
        database="db02"
    )
    # 导出差异表:1-全部;2-表;3-字段
    tool.exportDiffSqlFile(createTableSqlFilePath, tableColumnSqlFilePath, 1)

总结

简单实现两个数据对比表、表字段结构并导出,后续增加导出差异数据

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值