在torch中获取训练数据常用到迭代类,有以下两种实现方式,self.info相当于处理好的所有数据,一般是一个list。
**
1. 类中使用__getitem__方法
**
class Person:
def __init__(self):
self.info = [[i] for i in range(10)]
def __getitem__(self, index):
print(index) # index会自动加
return self.info[index]
p = Person()
for item in p:
print(item, 666)
执行结果
**
2. 类中使用__iter__和__next__方法
**
class Person:
def __init__(self):
self.info = [[i] for i in range(10)]
self.index = -1
def __iter__(self):
print('iter')
return self
def __next__(self):
self.index += 1
if self.index < len(self.info):
return self.info[self.index]
else:
raise StopIteration('end')
p = Person()
print(888)
for item in p:
print(item, 999)
执行结果
从执行结果可以看出执行迭代之前会调用一次__iter__,之后才开始迭代;另外,不同于第一种方式,需要自定义索引