本文以pytorch1.10进行解读:torch — PyTorch 1.10 documentation
文本的操作在github上都有Shirley-Xie/pytorch_exercise · GitHub,且有运行结果。
1.Dataset和DataLoder介绍
1.1 Dataset
torch.utils.data.
Dataset
(*args, **kwds)
所有表示从键到数据样本映射的数据集都应该将其子类化。所有子类都应该覆盖__getitem__(),支持为给定的键获取数据样本。子类还可以选择性地覆盖__len__(),许多Sampler实现和DataLoader的默认选项都期望它返回数据集的大小。
Dataset定义数据集的内容,类似于列表的数据结构,长度确定,能够用索引获取数据集中的元素。
Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()
这个类方法,作用是接收一个索引, 返回一个样本。
1.2 DataLoader
DataLoader定义了按batch加载数据集的方法,它是一个实现了`__iter__`方法的可迭代对象,每次迭代输出一个batch的数据。Dataset数据集相当于一种列表结构不同,IterableDataset相当于一种迭代器结构。 它更加复杂,一般较少使用。
能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。
在绝大部分情况下,用户只需实现Dataset的`__len__`方法和`__getitem__`方法,就可以轻松构建自己的数据集,并用默认数据管道进行加载。
函数签名如下:
torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
)
常用dataset, batch_size, shuffle, num_workers,pin_memory, drop_last这六个参数。
- - dataset : 数据集
- - batch_size: 批次大小
- - shuffle: 是否乱序
- - sampler: 样本采样函数,一般无需设置。
- - batch_sampler: 批次采样函数,一般无需设置。
- - num_workers: 使用多进程读取数据,设置的进程数。
- - collate_fn: 整理一个批次数据的函数。
- - pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
- - drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
- - timeout: 加载一个数据批次的最长等待时间,一般无需设置。
- - worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。
一般实现的代码如下:
ds = TensorDataset(torch.randn(1000,3),
torch.randint(low=0,high=2,size=(1000,)).float())
dl = DataLoader(ds,batch_size=4,drop_last = False)
features,labels = next(iter(dl))
print("features = ",features )
print("labels = ",labels )
结果:
features = tensor([[-0.3192, -1.7329, -1.7346],
[-0.7792, 1.2145, -0.5208],
[ 0.5105, -1.4158, 1.0757],
[-1.3785, -1.3909, -0.7086]])
labels = tensor([0., 0., 0., 1.])
2. Dataset和DataLoader操作步骤
获取一个batch数据的步骤
假定数据集的特征和标签分别表示为张量X和Y,数据集可以表示为(X,Y), 假定batch大小为m 。
- 首先我们要确定数据集的长度n。比如:n = 1000。确定数据集的长度由Dataset__len__方法实现的。数据是元组列表,也就是特征和标签。
- 然后我们从0到n-1的范围中抽样出m个数(batch大小)。假定m=4, 拿到的结果是一个索引列表,类似:indices = [1,4,8,9]。从n个中抽出m个数方法由DataLoader的 sampler和 batch_sampler参数指定的。也就是shuffle和drop_last两个参数影响。
- 接着我们从数据集中去取这m个数对应下标的元素。拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]。根据下标取数据集中的元素 是由 Dataset的 __getitem__方法实现的。
- 最后我们将结果整理成两个张量作为输出。拿到的结果是两个张量,类似batch = (features,labels) , 其中 features = torch.stack([X[1],X[4],X[8],X[9]])。labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])。DataLoader的参数collate_fn指定。
总而言之,在一个确定数据集中,按照batch的大小确定索引,然后根据索引取出对应的数据。最后整理成特征和标签在一起的样子。
具体内部方法拆解如下:
# step1: 确定数据集长度 (Dataset的 __len__ 方法实现)
ds = TensorDataset(torch.randn(1000,3),
torch.randint(low=0,high=2,size=(1000,)).float())
print("n = ", len(ds)) # len(ds)等价于 ds.__len__()
# step2: 确定抽样indices (DataLoader中的 Sampler和BatchSampler实现)
sampler = RandomSampler(data_source = ds)
batch_sampler = BatchSampler(sampler = sampler,
batch_size = 4, drop_last = False)
for idxs in batch_sampler:
indices = idxs
break
print("indices = ",indices)
# step3: 取出一批样本batch (Dataset的 __getitem__ 方法实现)
batch = [ds[i] for i in indices] # ds[i] 等价于 ds.__getitem__(i)
print("batch = ", batch)
# step4: 整理成features和labels (DataLoader 的 collate_fn 方法实现)
def collate_fn(batch):
features = torch.stack([sample[0] for sample in batch])
labels = torch.stack([sample[1] for sample in batch])
return features,labels
features,labels = collate_fn(batch)
print("features = ",features)
print("labels = ",labels)
结果:
n = 1000
indices = [426, 137, 471, 292]
batch = [(tensor([1.5614, 0.6875, 1.7250]), tensor(1.)), (tensor([ 0.2853, -1.4416, -0.5672]), tensor(1.)), (tensor([ 0.1800, 0.2652, -0.5301]), tensor(0.)), (tensor([-0.9303, 0.7461, 0.2575]), tensor(1.))]
features = tensor([[ 1.5614, 0.6875, 1.7250],
[ 0.2853, -1.4416, -0.5672],
[ 0.1800, 0.2652, -0.5301],
[-0.9303, 0.7461, 0.2575]])
labels = tensor([1., 1., 0., 1.])
3. 使用Dataset创建数据集
Dataset创建数据集常用的方法有:
- 使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。
- 使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。
- 继承 torch.utils.data.Dataset 创建自定义数据集。
此外,还可以通过
- torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。
- 调用Dataset的加法运算符(+)将多个数据集合并成一个数据集。
此处代码是常见的自定义方法:
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label)))
return data_info
其中get_img_info做的是拿到数据的位置和标签,也就是元祖列表格式。 有了这个list,然后又给了data_info一个index, data_info[index] 就取出了某个(样本i_loc, label_i)。
__getitem__()
这个方法, 是不是很容易理解了, 第一行我们拿到了一个样本的图片路径和标签。然后第二行就是去找到图片,然后转成RGB数值。 第三行就是做了图片的数据预处理,最后返回了这张图片的张量形式和它的标签。
参考文章:
torch — PyTorch 1.10 documentation
GitHub - lyhue1991/eat_pytorch_in_20_days: Pytorch🍊🍉 is delicious, just eat it! 😋😋系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)_dataloader 输入两个变量-CSDN博客