绘制网络模型
import alexnet
import mxnet as mx
alexnet = alexnet.get_symbol(10)
mx.viz.plot_network(alexnet,title='alexnet',save_format='jpg',hide_weights=True).view()
显示各层网络参数
import mxnet as mx
import alexnet
alexnet = alexnet.get_symbol(10)
arg_shape, out_shape, aux_shape = alexnet.infer_shape(data=(32,3,128,128))
print(out_shape)
print(dict(zip(alexnet.list_arguments(), arg_shape)))
输出:
[(32, 10)]
{‘data’: (32, 3, 128, 128), ‘conv1_weight’: (96, 3, 11, 11), ‘conv1_bias’: (96,), ‘conv2_weight’: (256, 96, 5, 5), ‘conv2_bias’: (256,), ‘conv3_weight’: (384, 256, 3, 3), ‘conv3_bias’: (384,), ‘conv4_weight’: (384, 384, 3, 3), ‘conv4_bias’: (384,), ‘conv5_weight’: (256, 384, 3, 3), ‘conv5_bias’: (256,), ‘fc1_weight’: (4096, 1024), ‘fc1_bias’: (4096,), ‘fc2_weight’: (4096, 4096), ‘fc2_bias’: (4096,), ‘fc3_weight’: (10, 4096), ‘fc3_bias’: (10,), ‘softmax_label’: (32,)}