1、查看mxnet的版本
import mxnet as mx
mx.__version__
2、扩展nd的维度
image_data = mx.random.normal(shape = (3, 112, 112))
image_data.expand_dims(axis = 0)
3、装载网络结构
import mxnet as mx
train_net = mx.gluon.model_zoo.get_model("ResNet34_V1",pretrianed = False)
4、修改预装载的网络结构
from mxnet.gluon import nn
train_net = ....(参考3)
conv_0 = nn.Conv2d(channels = 64, kernel_size = (3, 3), strides =(1,1), padding = (1, 1), use_bais = False)
conv_0.collect_params().initialize(mx.init.Xavier())
train_net.features.register_child(conv_0, "0")
train_net.features.hybridize()
5、画出网络结构
x =mx.sym.var('data')
mx.viz.plot_network(train_net(x), shape = {'data':(1, 3, 112, 112)} ).view()
6、上下文
num_gpus = 0
ctx =