mxnet进阶 - mx.io.NDArrayIter 源码分析

介绍

想知道mxnet在训练过程或者验证过程中,如何通过iterator提供数据

几个问题:

  • 如何构造iterator?
  • 训练或者测试时从iterator中获取数据,data_batch = next(iterator).getdata(),输出的data_batch是什么?又是怎么获得的?
  • provide_data,provide_label如何设计以及如何应用于module或executor初始化?

分析mxnet自带的mx.io.NDArrayIter,看如何把一个NDArray转化为一个可以用于module.fit() 的 iterator

用于测试的代码,使用一个MLP学习mnist的例子

'''
Loading Data
'''
import mxnet as mx
from collections import OrderedDict
from mxnet.ndarray import array
mnist = mx.test_utils.get_mnist()# dict
#'train_data' ndarray ,shape<class 'tuple'> (60000,1,28,28)
#'train_label' ndarray ,shape<class 'tuple'> (60000,)
#'test_data' ndarray ,shape<class 'tuple'> (10000,1,28,28)
#'test_label' ndarray ,shape<class 'tuple'> (10000,)
# Fix the seed
mx.random.seed(42)

# Set the compute context, GPU is available otherwise CPU
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()


batch_size = 100
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

'''
Training
'''

'''
这里的名字'data'不能改,对应于mx.io.NDArrayIter的defaltname参数就是'data',往后看就明白了
也可以改着看看bug信息
'''
data = mx.sym.var('data')
# Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height)
data = mx.sym.flatten(data=data)


# The first fully-connected layer and the corresponding activation function
fc1  = mx.sym.FullyConnected(data=data, num_hidden=128)
act1 = mx.sym.Activation(data=fc1, act_type="relu")

# The second fully-connected layer and the corresponding activation function
fc2  = mx.sym.FullyConnected(data=act1, num_hidden = 64)
act2 = mx.sym.Activation(data=fc2, act_type="relu")
# MNIST has 10 classes
fc3  = mx.sym.FullyConnected(data=act2, num_hidden=10)
# Softmax with cross entropy loss
mlp  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

import logging
logging.getLogger().setLevel(logging.DEBUG)  # logging to stdout
# create a trainable module on compute context
mlp_model = mx.mod.Module(symbol=mlp, context=ctx)
mlp_model.fit(train_iter,  # train data
              eval_data=val_iter,  # validation data
              optimizer='sgd',  # use SGD to train
              optimizer_params={'learning_rate':0.1},  # use fixed learning rate
              eval_metric='acc',  # report accuracy during training
              batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
              num_epoch=10)  # train for at most 10 dataset passes

看 mx.io.NDArrayIter.__init__()

 def __init__(self, data, label=None, batch_size=1, shuffle=False,
                 last_batch_handle='pad', data_name='data',
                 label_name='softmax_label'):
        super(NDArrayIter, self).__init__(batch_size)
        '''统一输入的格式为list(tuple(key,val),tuple(key,val)……)'''
        '''划重点!!!这个key和executor里的symbol对应的'''
        self.data = _init_data(data, allow_empty=False, default_name=data_name)
        self.label = _init_data(l
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值