mxnet网络是链式结构,pytorch可以是列表结构
引发的问题:mxnet symbol如何打印特征维度?
mxnet设计网络是,不用输入网络输入channel,
pytorch需要输入通道数。
mxnet:
num_classes = config.emb_size
bn_mom = config.bn_mom
workspace = config.workspace
data = mx.symbol.Variable(name="data") # 224
data = data - 127.5
data = data * 0.0078125
fc_type = config.net_output
bf = int(32 * config.net_multiplier)
if config.net_input == 0:
conv_1 = Conv(data, num_filter=bf, kernel=(3, 3), pad=(1, 1), stride=(2, 2), name="conv_1") # 224/112
else:
conv_1 = Conv(data, num_filter=bf, kernel=