数据预处理DataLoader及各参数详解
pytorch关于数据处理的功能模块均在torch.utils.data 中,pytorch输入数据PipeLine一般遵循一个“三步走”的策略,操作顺序是这样的:
① 继承Dataset类,自定义数据处理类。必须重载实现len()、getitem()这两个方法。
其中__len__
返回数据集样本的数量,而__getitem__
应该编写支持数据集索引的函数,例如通过dataset[i]
可以得到数据集中的第i+1
个数据。在实现自定义类时,一般需要对图像数据做增强处理,和标签处理,__getitem__返回图像和对应label,图像增强的方法可以使用pytorch自带的torchvision.transforms内模块,也可以使用自定义或者其他第三方增强库。
② 导入 DataLoader类,传入参数(上面自定义类的对象) 创建一个DataLoader对象。
③ 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练
dataset = MyDataset() # 第一步:构造Dataset对象
dataloader = DataLoader(dataset)# 第二步:通过DataLoader来构造迭代对象
num_epoches = 100
for epoch in range(num_epoches):# 第三步:逐步迭代数据
for img, label in dataloader:
# 训练代码
pytorch内部默认的数据处理类有如下:
class Dataset(object):
class IterableDataset(Dataset):
class TensorDataset(Dataset): # 封装成tensor的数据集,每一个样本都通过索引张量来获得。
class ConcatDataset(Dataset): # 连接不同的数据集以构成更大的新数据集
class Subset(Dataset): # 获取指定一个索引序列对应的子数据集
class ChainDataset(IterableDataset):
一般能用到的是ConcatDataset, Subset,其他不常用。
可迭代对象的创建方式:
-
方法一:在python中凡是具有
__iter__
的方法的类,都是可迭代的类。可迭代类创建的对象实现了__iter__
方法,因此就是可迭代对象。from collections import Iterable, Iterator class Student(object): def __init__(self, score): self.score = score def __iter__(self): return iter(self.score) # return 返回的是一个迭代器, test = Student([80, 90, 95]) print(isinstance(test, Iterable)) print(isinstance(test, Iterator)) for i in test: # test可迭代对象,但不是迭代器,所以不能next(test) print(i) for i in test: # 重复遍历试试看,是否有结果 print(i) print("============") test=iter(test) # 对可迭代对象使用内建函数iter(), 使之成为为迭代器,此时可以next(test) print(isinstance(test, Iterable)) print(isinstance(test, Iterator)) for i in test: print(i) for i in test: # 对迭代器重复遍历试试看,有结果过没, print(i) # 没有结果 """ True False 80 90 95 80 90 95 ============ True True 80 90 95 """
从本代码可看出Student类创建的对象是可迭代对象,但不是迭代器(因为没有实现
__next__
方法),且可以实现重复遍历,而迭代器是无法重复遍历的! -
方法二:用list、tuple等容器创建的对象,也都是可迭代对象。如:test=[1,2,3], test就是可迭代对象
迭代器的创建方式:迭代器对象必须同时实现__iter__和__next__方法才是迭代器
-
方法一:自定义类实现__iter__和__next__方法,对于迭代器来说,
__iter__
返回的是它自身 self,__next__
则是返回迭代器中的下一个值。class Student(object): def __init__(self, score): self.score = score def __iter__(self): return self # 对于迭代器来说,__iter__ 返回的是它自身self,也就是返回迭代器。 def __next__(self): if self.score < 100: self.score += 1