参数说明:
传入训练数据和标签,设定batch_size数值大小,shuffle表示是否打乱顺序
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
assert len(inputs) == len(targets)
if shuffle:
indices = np.arange(len(inputs))
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
if shuffle:
excerpt = indices[start_idx:start_idx + batch_size]
else:
excerpt = slice(start_idx, start_idx + batch_size)
yield inputs[excerpt], targets[excerpt] # 提取相应的样本数据和标签数据
yield解释:
yield
是一个类似return
的关键字,只是这个函数返回的是个生成器- 当你调用这个函数的时候,函数内部的代码并不立马执行 ,这个函数只是返回一个生成器对象
- 当你使用for进行迭代的时候,函数中的代码才会执行
- 例如:
>>> def createGenerator() : ... mylist = range(3) ... for i in mylist : ... yield i*i ... >>> mygenerator = createGenerator() # create a generator >>> print(mygenerator) # mygenerator is an object! <generator object createGenerator at 0xb7555c34> >>> for i in mygenerator: ... print(i) 0 1 4