最近做了一个挺有意思的小工具,就是用Python读取数据库,然后绘制出一个ER图,可以体现出来主外键等信息。
主要用到的两个库:pymysql和graphviz
注意,graphviz在anaconda的虚拟环境里是不能直接用conda install指令安装的,需要在官网下载安装包,配置环境变量,然后利用pip install指令安装。
接下来是代码部分:
首先是创建3个类:数据库类、表类和属性类:
class database:
def __init__(self,name,tables,tables_name,foreign_keys):
self.name=name
self.tables=tables
self.tables_name=tables_name
self.foreign_keys=foreign_keys
class table:
def __init__(self,name):
self.name=name
self.columns=[]
def printinf(self):
strr='表名:'+self.name+':'+'\n'
for i in range(len(self.columns)):
strr+=self.columns[i].printinf()+'\n'
print(strr)
return strr
class column:
def __init__(self,data):
self.name=data[0]
self.columntype=data[1]
self.extra=data[2]
def printinf(self):
strr=str(self.name)+' '+str(self.columntype)+' '+self.extra
return strr
定义访问数据库的函数:
def getconnection(username,password,dbname):
#初始化几个数组存储对象
tables_name=[]
tables=[]
foregin_keys=[]
#打开数据库连接
conn = pymysql.connect(user = username,passwd = password,database = dbname)
#创建游标进行数据库操作
cursor=conn.cursor()
#获取所有表名
sqlstring_gettablename="select table_name from information_schema.tables where table_schema='test_database'"
cursor.execute(sqlstring_gettablename)
while 1:
res=cursor.fetchone()
if res is None:
#表示已经取完结果集
break
tables_name.append(res[0])
for i in tables_name:
temp_table=table(i)
tables.append(temp_table)
#获取属性
for i in range(len(tables)):
#print(tables[i].name)
sqlstring_getcolumn='SELECT COLUMN_NAME "字段名称",COLUMN_TYPE "字段类型长度",IF(EXTRA="auto_increment",CONCAT(COLUMN_KEY,"(", IF(EXTRA="auto_increment","自增长",EXTRA),")"),COLUMN_KEY) "主外键" FROM information_schema. COLUMNS WHERE TABLE_SCHEMA = "test_database" AND TABLE_NAME = "'+tables[i].name+'"';
cursor.execute(sqlstring_getcolumn)
while 1:
res=cursor.fetchone()
if res is None:
#表示已经取完结果集
break
temp_column=column(res)
tables[i].columns.append(temp_column)
#获取外键
for i in range(len(tables)):
#print(tables[i].name)
sqlstring_getfk='SELECT C.REFERENCED_TABLE_NAME 父表名称 ,C.TABLE_NAME 子表名称 FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE C JOIN INFORMATION_SCHEMA. TABLES T ON T.TABLE_NAME = C.TABLE_NAME JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS R ON R.TABLE_NAME = C.TABLE_NAME AND R.CONSTRAINT_NAME = C.CONSTRAINT_NAME AND R.REFERENCED_TABLE_NAME = C.REFERENCED_TABLE_NAME WHERE C.REFERENCED_TABLE_NAME IS NOT NULL AND C.REFERENCED_TABLE_NAME="'+tables[i].name+'"'
cursor.execute(sqlstring_getfk)
while 1:
res=cursor.fetchone()
if res is None:
#表示已经取完结果集
break
foregin_keys.append(res)
cursor.close()
conn.close()
db=database('MyDB',tables,tables_name,foregin_keys)
return db
接下来就是用graphviz绘制ER图:
def draw_ER(db):
tables_name=db.tables_name
tables=db.tables
foreign_keys=db.foreign_keys
#创建一个空白图
dot = gz.Graph()
for i in range(len(tables)):
dot.node(str(i+1),tables[i].printinf(),shape='rectangle')
edges=[]
for i in range(len(foreign_keys)):
temp_edge=str(tables_name.index(foreign_keys[i][0])+1)+str(tables_name.index(foreign_keys[i][1])+1)
edges.append(temp_edge)
#创建边(表示外键关系)
dot.edge(str(tables_name.index(foreign_keys[i][0])+1),str(tables_name.index(foreign_keys[i][1])+1))
dot.view()
dot.save('before_xml.dot')
最后绘制出来的效果是这样的:
乱码是中文的“表名”,应该是因为编码问题显示有问题。
这个不如自己用专门的ER图软件走出来精美,但是挺方便的。