用 Python 实现全连接层网络可视化

以下代码即可实现全连接层网络的可视化:

# 引用模块
from pylab import mpl #matplotlib使用中文

# 自编函数
def ANN_ksh(number_input,number_hidden,number_output):
    import numpy as np
    import networkx as nx
    import matplotlib.pyplot as plt
    
    mpl.rcParams['font.sans-serif']=['SimHei'] #matplotlib使用中文,SimHei为黑体
    
    # number_input为输入层节点个数,number_hidden为隐藏层各层节点个数,number_output为输出层节点个数
    ceng_hidden=len(number_hidden) #隐藏层层数
    G=nx.DiGraph()
    
    # 节点
    vertex_input_list=['v'+str(i) for i in range(1,number_input+1)] #输入层
    vertex_hidden_list=[]
    start=number_input+1
    end=number_input+number_hidden[0]+1
    vertex_hidden_list.append(['v'+str(i) for i in range(start,end)]) #隐藏层
    for j in range(1,ceng_hidden):
        start=end
        end=start+number_hidden[j]
        vertex_hidden_list.append(['v'+str(i) for i in range(start,end)]) #隐藏层
    vertex_output_list=['v'+str(i) for i in range(end,end+number_output)] #输出层
    vertex_list=[]
    vertex_list.extend(vertex_input_list)
    list(map(lambda i:vertex_list.extend(vertex_hidden_list[i]),range(ceng_hidden)))
    vertex_list.extend(vertex_output_list)
    G.add_nodes_from(vertex_list)
    
    # 连接
    edge_input_hidden_list=[]
    edge_input_hidden_list.extend([(i,j) for i in vertex_input_list for j in vertex_hidden_list[0]]) #输入层-隐藏层
    edge_list=[]
    edge_list.extend(edge_input_hidden_list)
    edge_hidden_hidden_list=[]
    if ceng_hidden>1:
        for k in range(ceng_hidden-1):
            edge_hidden_hidden_list.extend([(i,j) for i in vertex_hidden_list[k] for j in vertex_hidden_list[k+1]]) #隐藏层-隐藏层
        edge_list.extend(edge_hidden_hidden_list)
    edge_hidden_output_list=[]
    edge_hidden_output_list.extend([(i,j) for i in vertex_hidden_list[len(vertex_hidden_list)-1] for j in vertex_output_list]) #隐藏层-输出层
    edge_list.extend(edge_hidden_output_list)
    G.add_edges_from(edge_list)
    
    # 位置
    pos={}
    ceng_pos_x=np.linspace(-(ceng_hidden+2)/2,(ceng_hidden+2)/2,num=ceng_hidden+2)
    list(map(lambda i:pos.update({vertex_input_list[int(np.where(np.arange(
        -number_input/2*1+1/2,number_input/2*1+1/2,1)==i)[0])]:(ceng_pos_x[0],i)}),
        np.arange(-number_input/2*1+1/2,number_input/2*1+1/2,1))) #输入层
    list(map(lambda j:list(map(lambda i:pos.update({vertex_hidden_list[j][int(np.where(np.arange(
            -number_hidden[j]/2*1+1/2,number_hidden[j]/2*1+1/2,1)==i)[0])]:(ceng_pos_x[j+1],i)}),
            np.arange(-number_hidden[j]/2*1+1/2,number_hidden[j]/2*1+1/2,1))),range(ceng_hidden))) #隐藏层
    list(map(lambda i:pos.update({vertex_output_list[int(np.where(np.arange(
        -number_output/2*1+1/2,number_output/2*1+1/2,1)==i)[0])]:(ceng_pos_x[len(ceng_pos_x)-1],i)}),
        np.arange(-number_output/2*1+1/2,number_output/2*1+1/2,1))) #输出层
    
    fig=plt.figure(figsize=(8,5),dpi=300)
    plt.xlim(ceng_pos_x[0]-1,ceng_pos_x[len(ceng_pos_x)-1]+1)
    plt.ylim(-max(number_input,max(number_hidden),number_output)/2*1,
             max(number_input,max(number_hidden),number_output)/2*1+1/2)
    
    nx.draw(
            G,
            pos=pos,
            node_color='red',
            edge_color='black',
            with_labels=False,
            font_size=10,
            node_size=300,
           )
    fig.savefig('全连接层网络可视化.png')

函数参数说明:

  number_input 为输入层的节点个数,number_hidden 为隐藏层各层的节点个数,number_output 为输出层的节点个数。

调用函数示例:

ANN_ksh(8,[8,5,2],2)

结果:

图 1 全连接层网络可视化
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值