以下代码即可实现全连接层网络的可视化:
# 引用模块
from pylab import mpl #matplotlib使用中文
# 自编函数
def ANN_ksh(number_input,number_hidden,number_output):
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
mpl.rcParams['font.sans-serif']=['SimHei'] #matplotlib使用中文,SimHei为黑体
# number_input为输入层节点个数,number_hidden为隐藏层各层节点个数,number_output为输出层节点个数
ceng_hidden=len(number_hidden) #隐藏层层数
G=nx.DiGraph()
# 节点
vertex_input_list=['v'+str(i) for i in range(1,number_input+1)] #输入层
vertex_hidden_list=[]
start=number_input+1
end=number_input+number_hidden[0]+1
vertex_hidden_list.append(['v'+str(i) for i in range(start,end)]) #隐藏层
for j in range(1,ceng_hidden):
start=end
end=start+number_hidden[j]
vertex_hidden_list.append(['v'+str(i) for i in range(start,end)]) #隐藏层
vertex_output_list=['v'+str(i) for i in range(end,end+number_output)] #输出层
vertex_list=[]
vertex_list.extend(vertex_input_list)
list(map(lambda i:vertex_list.extend(vertex_hidden_list[i]),range(ceng_hidden)))
vertex_list.extend(vertex_output_list)
G.add_nodes_from(vertex_list)
# 连接
edge_input_hidden_list=[]
edge_input_hidden_list.extend([(i,j) for i in vertex_input_list for j in vertex_hidden_list[0]]) #输入层-隐藏层
edge_list=[]
edge_list.extend(edge_input_hidden_list)
edge_hidden_hidden_list=[]
if ceng_hidden>1:
for k in range(ceng_hidden-1):
edge_hidden_hidden_list.extend([(i,j) for i in vertex_hidden_list[k] for j in vertex_hidden_list[k+1]]) #隐藏层-隐藏层
edge_list.extend(edge_hidden_hidden_list)
edge_hidden_output_list=[]
edge_hidden_output_list.extend([(i,j) for i in vertex_hidden_list[len(vertex_hidden_list)-1] for j in vertex_output_list]) #隐藏层-输出层
edge_list.extend(edge_hidden_output_list)
G.add_edges_from(edge_list)
# 位置
pos={}
ceng_pos_x=np.linspace(-(ceng_hidden+2)/2,(ceng_hidden+2)/2,num=ceng_hidden+2)
list(map(lambda i:pos.update({vertex_input_list[int(np.where(np.arange(
-number_input/2*1+1/2,number_input/2*1+1/2,1)==i)[0])]:(ceng_pos_x[0],i)}),
np.arange(-number_input/2*1+1/2,number_input/2*1+1/2,1))) #输入层
list(map(lambda j:list(map(lambda i:pos.update({vertex_hidden_list[j][int(np.where(np.arange(
-number_hidden[j]/2*1+1/2,number_hidden[j]/2*1+1/2,1)==i)[0])]:(ceng_pos_x[j+1],i)}),
np.arange(-number_hidden[j]/2*1+1/2,number_hidden[j]/2*1+1/2,1))),range(ceng_hidden))) #隐藏层
list(map(lambda i:pos.update({vertex_output_list[int(np.where(np.arange(
-number_output/2*1+1/2,number_output/2*1+1/2,1)==i)[0])]:(ceng_pos_x[len(ceng_pos_x)-1],i)}),
np.arange(-number_output/2*1+1/2,number_output/2*1+1/2,1))) #输出层
fig=plt.figure(figsize=(8,5),dpi=300)
plt.xlim(ceng_pos_x[0]-1,ceng_pos_x[len(ceng_pos_x)-1]+1)
plt.ylim(-max(number_input,max(number_hidden),number_output)/2*1,
max(number_input,max(number_hidden),number_output)/2*1+1/2)
nx.draw(
G,
pos=pos,
node_color='red',
edge_color='black',
with_labels=False,
font_size=10,
node_size=300,
)
fig.savefig('全连接层网络可视化.png')
函数参数说明:
number_input 为输入层的节点个数,number_hidden 为隐藏层各层的节点个数,number_output 为输出层的节点个数。
调用函数示例:
ANN_ksh(8,[8,5,2],2)
结果: