shuffle中文意思是洗牌,在各大机器学习框架中都有他的身影。这里主要讲在mxnet利用shuffle参数进行相邻批量采样和随机批量采样。
相邻批量采样:一般用在对数据顺序比较依赖的情况,如RNN等网络。
随机批量采样:用在对顺序无所谓的情况下,能够加快训练速度,如CNN,DNN等网络。
1,shuffle=False
,表示不洗牌,即进行相邻批量采样,一般都是默认False。例子如下:
from mxnet import gluon
import mxnet.ndarray as nd
x = nd.reshape(nd.arange(6), shape=(6, 1))
y = nd.reshape(nd.arange(6), shape=(6, 1))
train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(x, y), batch_size=2, shuffle=False)
for j, (data, label) in enumerate(train_data):
print(data, label)
out: # 可见是顺利按相邻批量遍历数据集
[[ 0.]
[ 1.]]
<NDArray 2x1 @cpu(0)>
[[ 0.]
[ 1.]]
<NDArray 2x1 @cpu(0)>
[[ 2.]
[ 3.]]
<NDArray 2x1 @cpu(0)>
[[ 2.]
[ 3.]]
<NDArray 2x1 @cpu(0)>
[[ 4.]
[ 5.]]
<NDArray 2x1 @cpu(0)>
[[ 4.]
[ 5.]]
<NDArray 2x1 @cpu(0)>
1,shuffle=True
,表示洗牌,即进行随机批量采样。例子如下:
from mxnet import gluon
import mxnet.ndarray as nd
x = nd.reshape(nd.arange(6), shape=(6, 1))
y = nd.reshape(nd.arange(6), shape=(6, 1))
train_data = gluon.data.DataLoader(gluon.data.ArrayDataset(x, y), batch_size=2, shuffle=True)
for j, (data, label) in enumerate(train_data):
print(data, label)
out: # 可见是随机按批量遍历数据集
[[ 3.]
[ 2.]]
<NDArray 2x1 @cpu(0)>
[[ 3.]
[ 2.]]
<NDArray 2x1 @cpu(0)>
[[ 4.]
[ 0.]]
<NDArray 2x1 @cpu(0)>
[[ 4.]
[ 0.]]
<NDArray 2x1 @cpu(0)>
[[ 1.]
[ 5.]]
<NDArray 2x1 @cpu(0)>
[[ 1.]
[ 5.]]
<NDArray 2x1 @cpu(0)>