介绍
pytorch中,我们可以使用torch.utils.data.DataLoader
和torch.utils.data.Dataset
加载数据集,具体来说,可以简单理解为Dataset是数据集,他提供数据与索引之间的映射,同时也要有标签。而DataLoader是将Dataset中的数据迭代提取出来,从而能够提供给模型。
所以,具体流程是,我们应该先按照要求先建立一个Dataset,之后再建立一个DataLoader,然后就可以用了。
pytorch中有很多现成的数据集,我们下载就可以使用。但是更多时候我们要建立自己的数据集,我也是入门,所以先建立一个带标签的图像数据集。
参考
- DATASETS & DATALOADERS
- 两文读懂PyTorch中Dataset与DataLoader(一)打造自己的数据集
- 从0开始撸代码–手把手教你搭建AlexNet网络模型训练自己的数据集(猫狗分类
建立Dataset
我们可以继承torch.utils.data.Dataset
类,必须要重写__init__
, __len__
, 和 __getitem__
这三个函数。其中 __len__
能够返回我们数据集中的数据个数,__getitem__
能够根据索引返回数据。
前提
我们有一个文件夹,里面有很多猫、狗和汽车的照片,此外有一个csv文件,里面是每张照片对应的类别,也就是标签。我们根据这个照片文件夹和csv文件,来建立我们的带标签数据集。
- 对于图片文件夹:0——29张图片为猫,30——59张图片为狗,其他为汽车。
- 对于标签csv文件,每一行中首先是图片名,然后是类别。其中0代表猫,1代表狗,2代表汽车。如下图:
具体代码
import os
from torchvision.io import read_image
import pandas as pd
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np
class myImageDataset(Dataset):
def __init__(self, img_dir, img_label_dir, transform=None):
super().__init__()
self.img_dir = img_dir
self.img_labels = pd.read_csv(img_label_dir) # 这是一个dataframe,0是文件名,1是类别
self.transform = transform
def __len__(self):
return len(self.img_labels) # 数据集长度
def __getitem__(self, index):
# 拼接得到图片文件路径
# 例如img_dir为'D:/curriculum/2022learning/learnning_dataset/data/'
# img_labels.iloc[index, 0]为5.jpg
# 那么img_path为'D:/curriculum/2022learning/learnning_dataset/data/5.jpg'
img_path = os.path.join(self.img_dir + self.img_labels.iloc[index, 0])
image = read_image(img_path) # tensor类型
label = self.img_labels.iloc[index, 1]
if self.transform is not None:
image = self.transform(image) # 对图片进行某些变换
return image, label
代码中都有注释。
__init__()
类的初始化函数,其中img_dir
为图片文件夹的根目录,img_label_dir
为标签文件路径,transform
为对数据项进行的变换。
__len__()
返回数据集长度。
__getitem__()
根据index
,返回其在数据集中对应的数据和标签。
验证
通过如下代码,我们具体输出一张图片:
# 把图片对应的tensor调整维度,并显示
def tensorToimg(img_tensor):
img = img_tensor.numpy()
img = np.transpose(img_tensor, [1, 2, 0])
plt.imshow(img)
label_dic = {0: 'cat', 1: 'dog', 2: 'car'}
label_path = 'D:/curriculum/2022learning/learnning_dataset/labels.csv'
img_root_path = 'D:/curriculum/2022learning/learnning_dataset/data/'
dataset = myImageDataset(img_root_path, label_path)
image, label = dataset.__getitem__(33)
print(image.shape)
print(label_dic[label])
tensorToimg(image)
可以看到,数据集中,图片变为tensor,维度为[通道数,长,宽]。
DataLoader
之后就可以使用DataLoader对刚刚创建的数据集不断取出样本了。不再赘述。
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)
这样,我们就建立了一个dataLoader。接下来我们输出一下看看:
for imgs, labels in dataloader:
print(imgs.shape)
print(labels)
break
但是这里报错:stack expects each tensor to be equal size, but got [3, 268, 320] at entry 0 and [3, 480, 370] at ...
,查询得知是数据集中图片大小不一,而这时Dataset中定义的参数transfom就派上了用场。我们让每张图片的大小都是224*224
。
from torch.utils.data import DataLoader
from torchvision import transforms
transform = transforms.Resize((224, 224))
dataset = myImageDataset(img_root_path, label_path, transform)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)
for imgs, labels in dataloader:
print(imgs.shape)
print(labels)
break
结果为:
torch.Size([5, 3, 224, 224])
tensor([0, 2, 2, 2, 1])
由于batch_size
是5,而每个图片的形状为[3, 224, 224]
,因此一个batch的数据形状为:[5, 3, 224, 224]
。
其他使用DataLoader的方法
for index, (imgs, labels) in enumerate(dataloader):
print(index)
print(imgs.shape)
print(labels)
break
结果为:
0
torch.Size([5, 3, 224, 224])
tensor([1, 2, 0, 1, 1])
imgs, label = next(iter(dataloader))
print(imgs.shape)
print(labels)
结果为:
torch.Size([5, 3, 224, 224])
tensor([1, 2, 0, 1, 1])
得到了一批的图片和对应的标签,我们就能将其输入到模型中,并使用标签和预测结果计算损失。