机器学习中的小批次数据生成器,上代码以备以后使用:
#inputs, targets就是 X 和 y
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
assert len(inputs) == len(targets)
if shuffle:
indices = np.arange(len(inputs))
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
else:
excerpt = slice(start_idx, start_idx + batchsize)
yield inputs[excerpt], targets[excerpt]
用法:
>>>X = np.random.randn(6,4)
>>>y = np.random.randint(0,2,6)
>>>gen = iterate_minibatches(X, y, 10, shuffle=False)
>>>next(gen)
Out[113]:
(array([[-1.29894375, 1.64519902, 0.449616 , 0.2820526 ],
[ 0.49223679, 0.4552031 , 0.40525815, 0.22591931]]), array([0, 0]))
>>>next(gen)
Out[114]:
(array([[ 1.46478028, 0.26404314, 0.79638976, 0.223665 ],
[ 0.08203239, 1.52025888, 1.2913706 , 0.38070632]]), array([0, 0]))
>>>next(gen)
Out[115]:
(array([[-0.53161078, -0.3045918 , 0.63480717, -0.12489774],
[-0.41806273, 0.09459128, -0.4725088 , -0.83306486]]), array([1, 0]))
>>>next(gen)
Traceback (most recent call last):
File "<ipython-input-116-b2c61ce5e131>", line 1, in <module>
gen.next()
StopIteration
意思就是把所有数据都迭代完后,迭代器就作废了,如果想要实现循环迭代,得再加一句while True。
循环迭代版:
def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
assert len(inputs) == len(targets)
while True:
if shuffle:
indices = np.arange(len(inputs))
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
else:
excerpt = slice(start_idx, start_idx + batchsize)
yield inputs[excerpt], targets[excerpt]