Python 数据库对比工具,导出SQL脚本
最近项目上线经常需要更新脚本,数据库自带工具不好用,还是自己写一个脚本吧
import datetime
import time
from typing import List
import mysql.connector
"""数据库对比脚本"""
# 新增/修改字段SQL
def addOrModifyColumnSql(isAdd, tableName, columnName, columnType, nullable, defaultValue, comment):
nullable = " null " if nullable == "YES" else " not null "
addSql = "alter table {} add column {} {} {}".format(tableName, columnName, columnType, nullable)
if defaultValue is not None and len(defaultValue) > 0:
addSql += " default " + defaultValue
if comment is not None and len(comment) > 0:
addSql += " comment '{}'".format(comment)
if isAdd:
addSql = addSql.replace("add column", "modify column")
addSql += " ;"
return addSql
# 字节码转字符串
def decodeObj(s):
if type(s) == bytes or type(s) == bytearray:
return s.decode()
return s
class DatabaseCompare:
def __init__(self):
self.__sourceDb = None
self.__targetDb = None
# 获取数据库表,可以设置过滤前缀
self.__getCurrentDbTableSql = "select table_name from information_schema.tables where " \
"table_schema=database() "
self.__getTableColumnSql = "select column_name,column_type,is_nullable,column_default,column_comment " \
"from information_schema.columns where table_schema=database() " \
"and column_name != 'id' and table_name = '{}'"
def sourceDbConnect(self, **args):
config = args.copy()
self.__sourceDb = mysql.connector.connect(
host=config["host"], # 数据库主机地址
port=config["port"],
auth_plugin=config["auth_plugin"],
user=config["user"], # 数据库用户名
passwd=config["passwd"], # 数据库密码
database=config["database"] # 数据库
)
def targetDbConnect(self, **args):
config = args.copy()
self.__targetDb = mysql.connector.connect(
host=config["host"], # 数据库主机地址
port=config["port"],
auth_plugin=config["auth_plugin"],
user=config["user"], # 数据库用户名
passwd=config["passwd"], # 数据库密码
database=config["database"] # 数据库
)
# 指定表前缀过滤
def filterTablePrefix(self, tuples: List[str]):
if tuples is not None and len(tuples) > 0:
sql = " and ("
for v in tuples:
sql += " table_name like '{}%' or".format(v)
sql = sql[0: len(sql) - 3]
sql += ")"
self.__getCurrentDbTableSql += sql
print("print filter table prefix sql: {}".format(self.__getCurrentDbTableSql))
# 指定表
def filterTableName(self, tuples: List[str]):
if tuples is not None and len(tuples) > 0:
sql = " and table_name in("
for v in tuples:
sql += "'{}', ".format(v)
sql = sql[0: len(sql) - 2]
sql += ")"
self.__getCurrentDbTableSql += sql
print("print filter table sql: {}".format(self.__getCurrentDbTableSql))
# 源数据库连接
def bindSourceDbConnect(self):
return self.__sourceDb
# 目标数据库连接
def bindTargetDbConnect(self):
return self.__targetDb
def dbClose(self):
self.bindSourceDbConnect().close()
self.bindTargetDbConnect().close()
# 源数据库所有表名
def getSourceDbAllTableName(self):
try:
cursor = self.bindSourceDbConnect().cursor()
cursor.execute(self.__getCurrentDbTableSql)
resultList = cursor.fetchall()
return resultList
except Exception as e:
print("发生异常 ", e)
# 目标数据库所有表名
def getTargetDbAllTableName(self):
try:
cursor = self.bindTargetDbConnect().cursor()
cursor.execute(self.__getCurrentDbTableSql)
resultList = cursor.fetchall()
return resultList
except Exception as e:
print("发生异常 ", e)
# 获取差异表
def getDiffTableName(self):
aList = self.getSourceDbAllTableName()
bList = self.getTargetDbAllTableName()
return list(set(aList).difference(set(bList)))
# 获取差异建表SQL
def getDiffTableCreateSql(self):
diffTableList = self.getDiffTableName()
resultList = []
if len(diffTableList) > 0:
try:
cursor = self.bindSourceDbConnect().cursor()
for t in diffTableList:
cursor.execute("show create table " + t[0])
res = cursor.fetchone()
resultList.append(str(res[1]).replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS") + ";")
except Exception as e:
print("发生异常 ", e)
print("diff table count {}".format(len(resultList)))
return resultList
# 获取相同表字段差异SQL
def getDiffTableColumnSql(self):
aList = self.getSourceDbAllTableName()
bList = self.getTargetDbAllTableName()
intersectionList = set(aList).intersection(set(bList))
resultList = []
if len(intersectionList) > 0:
try:
aCursor = self.bindSourceDbConnect().cursor()
bCursor = self.bindTargetDbConnect().cursor()
# 1、循环获取表字段属性
for t in intersectionList:
# 源库
aCursor.execute(self.__getTableColumnSql.format(t[0]))
aColumns = aCursor.fetchall()
# 目标库
bCursor.execute(self.__getTableColumnSql.format(t[0]))
bColumns = bCursor.fetchall()
# 循环校验字段属性
for column in aColumns:
bColumn = None
columnName = decodeObj(column[0])
columnType = decodeObj(column[1])
nullable = decodeObj(column[2])
defaultValue = decodeObj(column[3])
comment = decodeObj(column[4])
for field in bColumns:
if columnName == field[0]:
bColumn = field
break
if bColumn is not None: # 3、判断相同字段属性不同
if columnType != decodeObj(bColumn[1]) or columnType != decodeObj(bColumn[1]) \
or nullable != decodeObj(bColumn[2]) or defaultValue != decodeObj(bColumn[3]) \
or comment != decodeObj(bColumn[4]):
resultList.append(addOrModifyColumnSql(True, t[0], columnName, columnType,
nullable, defaultValue, comment))
else: # 2、不存在的字段
resultList.append(addOrModifyColumnSql(False, t[0], columnName, columnType,
nullable, defaultValue, comment))
except Exception as e:
print("发生异常 ", e)
print("diff column count {}".format(len(resultList)))
return resultList
# 导出SQL脚本:exportType:1-全部;2-表;3-字段
def exportDiffSqlFile(self, aPath, bPath, exportType):
if exportType is not None and exportType == 1:
print("开始导出差异表和表字段SQL ==========>>>> ")
elif exportType is not None and exportType == 2:
print("开始导出差异表SQL ==========>>>> ")
elif exportType is not None and exportType == 3:
print("开始导出差异表字段SQL ==========>>>> ")
startTime = datetime.datetime.now()
try:
if exportType is not None and (exportType == 1 or exportType == 2):
result = self.getDiffTableCreateSql()
fo = open(aPath, "a")
if result is not None:
for x in result:
fo.write(x + "\n")
fo.close()
print("Export Success: table file: {}".format(aPath))
if exportType is not None and (exportType == 1 or exportType == 3):
result = self.getDiffTableColumnSql()
if result is not None:
fo = open(bPath, "a")
for x in result:
try:
fo.write(x + "\n")
except Exception as e:
print(e)
fo.close()
print("Export Success: column file:{}".format(bPath))
finally:
self.dbClose()
print("总耗时:{}".format((datetime.datetime.now() - startTime).total_seconds()))
if __name__ == "__main__":
today = time.strftime("%Y%m%d%H%M", time.localtime(time.time()))
createTableSqlFilePath = "G:\\data\\pytest\\createTableSql-{}.sql".format(today)
tableColumnSqlFilePath = "G:\\data\\pytest\\tableColumnSql-{}.sql".format(today)
tool = DatabaseCompare()
# 过滤指定表前缀
tool.filterTablePrefix(["act_"])
# 过滤指定表
# tool.filterTableName(["rbp_user_column_language"])
tool.sourceDbConnect(
host="localhost", # 数据库主机地址
port=3306,
auth_plugin="mysql_native_password",
user="root",
passwd="123456",
database="db01"
)
tool.targetDbConnect(
host="localhost",
port=3306,
auth_plugin="mysql_native_password",
user="root",
passwd="123456",
database="db02"
)
# 导出差异表:1-全部;2-表;3-字段
tool.exportDiffSqlFile(createTableSqlFilePath, tableColumnSqlFilePath, 1)
总结
简单实现两个数据对比表、表字段结构并导出,后续增加导出差异数据