方案一:定义迭代器类
自定义数据迭代器类,主要需要申明如下方法:
(1)__init__
:初始化方法,可用于传入数据路径或导入数据;
(2)__iter__
:定义返回的可迭代对象,在自定义数据迭代器类时,直接返回本对象self即可;
(3)__next__
:定义迭代器对象,用于next()
方法或for
方法,逐个读出数据;
(4)__len__
:定义可迭代器长度
下面是一个简单示例:
class DataLoader(object):
"""
batch_size: 批量大小
drop_last:最后一个batch的不完整数据是否需要裁剪
"""
def __init__(self, data, batch_size, drop_last=False):
self.data = data
self.batch_size = batch_size
self.drop_last = drop_last
self.index = 0
# 根据参数计算实际的batch数
if len(self.data) % self.batch_size != 0 and not self.drop_last:
self.batch = len(self.data) // self.batch_size + 1
else:
self.batch = len(self.data) // self.batch_size
# 返回自己,作为可迭代对象
def __iter__(self):
return self
# 每次迭代的数据
def __next__(self):
if self.index < self.batch:
data = self.data[self.index*self.batch_size:(self.index+1)*self.batch_size]
self.index += 1
return data
else:
raise StopIteration
def __len__(self):
return self.batch
验证上述数据迭代器类的可行性
data = list(range(10))
dataloader1 = DataLoader(data, 5)
for i in dataloader1:
print(i)
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
dataloader2 = DataLoader(data, 4)
for i in dataloader2:
print(i)
# [0, 1, 2, 3]
# [4, 5, 6, 7]
# [8, 9]
方案二:直接利用yield函数
def DataLoader(data, batch_size, drop_last=False):
minibatch, size_sofar =[], 0
for i in data:
minibatch.append(i)
size_sofar += 1
if size_sofar == batch_size:
yield minibatch
minibatch, size_sofar =[], 0
if not drop_last and minibatch:
yield minibatch
验证该函数的可行性
data = list(range(10))
for i in DataLoader(data, batch_size=5):
print(i)
# [0, 1, 2, 3, 4]
# [5, 6, 7, 8, 9]
for i in DataLoader(data, batch_size=4, drop_last=True):
print(i)
# [0, 1, 2, 3]
# [4, 5, 6, 7]
for i in DataLoader(data, batch_size=4, drop_last=False):
print(i)
# [0, 1, 2, 3]
# [4, 5, 6, 7]
# [8, 9]