最近在用mxnet重写代码,因为之前的代码因为集群问题没法跑了,想方设法找到一个可以在集群跑的框架,只有mxnet。于是
就开始了填坑之旅。想着开篇文章记录一下这个过程,用于以后查阅,反省。
关于symbol.infer_shape
infer_shape可以用于测试自己写的symbol,而这里有些参数说明。文档中有说明但是不够具体,说白了就一句话,里面的
参数就是构建symbol时的入口variable的name,当然大部分情况name是默认data,label;有些时候我们自定义DataIter后,这里
name就是我们自定义的name.来个例子说一下吧:
audio_data = mx.sym.Variable('audio_data')
face_data = mx.sym.Variable('face_data')
softmax_label = mx.sym.Variable('softmax_label')
# LSTM1, LSTM2
rnn1 = mx.rnn.LSTMCell(num_hidden=256, prefix='lstm1_')
rnn1_outputs, rnn1_states = rnn1.unroll(length=49, inputs=audio_data, merge_outputs=False)
rnn2 = mx.rnn.LSTMCell(num_hidden=256, prefix='lstm2_')
rnn2_outputs, rnn2_states = rnn2.unroll(length=49, inputs=rnn1_outputs, merge_outputs=False)
rnn1_last_out = mx.sym.Reshape(data=rnn1_outputs[-1], shape=(-1, 256), name='rnn1_last_out_reshape')
rnn2_last_out = mx.sym.Reshape(data=rnn2_outputs[-1], shape=(-1, 256), name='rnn2_last_out_reshape')
rnn_outputs = mx.sym.Concat(rnn1_last_out, rnn2_last_out, num_args=2, dim=1, name='rnn_outputs')
audio_data_shape = (128, 49, 75)
face_data_shape = (128, 512, 14, 14)
label_data_shape = (128, 6)
rnn_outputs_shape = rnn_outputs.infer_shape(audio_data=audio_data_shape)
这里rnn_outputs_shape
如下所示:
INFO:root:rnn_outputs_shape:([(128L, 49L, 75L), (1024L, 75L), (1024L,), (1024L, 256L), (1024L,), (1024L, 256L), (1024L,), (1024L, 256L), (1024L,)], [(128L, 512L)], [])
分别指代input_shape, output_shape, aux_shape,为了获得对应的shape,只需要取对应的index即可。
关于自定义data_names, label_names
在自定义DataIter后,对于数据的入口可能由默认的data, label变成了自定义的形式,因此需要在mod.fit
之前声明,通常格式
为:
Module(label_names=('new_label',), data_names=('new_data_name',)...)
这里需要注意的是,元祖若只有一个元素,需要有逗号:(‘new_label’,),否则会报错: