# 数据集迭代器
def __next__(self):
if self.residue and self.index == self.n_batches:
batches = self.dataset[self.index * self.batch_size : len(self.dataset)]
self.index += 1
batches = self._to_tensor(batches)
return batches
elif self.index > self.n_batches:
self.index = 0
raise StopIteration
else:
batches = self.dataset[self.index * self.batch_size : (self.index + 1) * self.batch_size]
self.index += 1
batches = self._to_tensor(batches)
return batches
这段代码是一个迭代器的__next__()
方法实现。在这个方法中,首先判断是否存在剩余的数据(self.residue
)且当前批次的索引等于总批次数(self.index == self.n_batches
)。如果满足条件,则表示当前是最后一个批次,且存在剩余数据无法构成完整的批次。此时,执行与之前相同的操作,将剩余数据转换为张量格式,并作为最后一个批次返回。
接下来,判断当前批次的索引是否大于总批次数(self.index > self.n_batches
)。如果满足条件,则表示已经迭代完所有批次,需要抛出StopIteration
异常来终止迭代器。
最后,如果以上两个条件都不满足,说明仍然有完整的批次可以迭代。在这种情况下,获取当前批次的数据(batches = self.dataset[self.index * self.batch_size : (self.index + 1) * self.batch_size]
),将索引加1(self.index += 1
),并将获取的数据转换为张量格式(batches = self._to_tensor(batches)
)。最后,返回这个批次的数据。
总之,这段代码实现了对数据集进行批次迭代的功能。当存在剩余数据时,将剩余数据作为最后一个批次返回;当迭代完所有批次时,抛出StopIteration
异常;否则,返回下一个批次的数据。