同样是跟着Tutorial学的,博客主要是给自己看笔记。其他人首次学习可能还是直接看Tutorials效果更好一点。
Pytorch官方Totorial Datasets & DataLoaders
数据集
Pytorch提供了两个数据基元(不知道这样翻译准不准确,原文是data primitives
)分别是torch.utils.data.DataLoader
和torch.utils.data.Dataset
,这两个基元允许你使用(pytorch)预先加载好的数据和你自己的数据。其中,后者Dataset
存储着样本和对应的标签,DataLoader
在Dataset
外封装一个可迭代对象,使我们方便获取样本。
另外Pytorch还提供了一些继承自torch.utils.data.Dataset
的预加载好的数据(如FashionMNIST
),这些数据本质上就是那个XXX.Datase
t的子类,而且有很多方法。这些数据可以用来训练和测试我们的模型。
上面说的可能有点抽象,而且解释得不是很清楚,给个实在点的例子。
从TorchVision中加载Fasion-MNIST数据集:
Fasion-MNIST里有六万个训练样本和一万个训练样本,每个都是28*28灰度图像,共被分成十类。
加载FashionMNIST Dataset需要以下参数:
root:我们所训练/测试数据的路径
train:指定训练集
download=True:如果root中没有我们想要的数据,则从互联网上下载数据集。
transform和target_transform:指定标签和数据的转换。(这里transform可能有点模糊,在下一章中有transform的介绍)
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
这样之后数据就会被下载到同一个目录下的"data"文件中了。
数据的可视化
我们也可以将这些数据集可视化:
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
这串代码使接在上一串代码下面的,简单解释一下它在干嘛。
label_map
是用来索引的。
sample_idx
是随机选中训练集中所有照片中的其中一个,后面的.item()方法在pytorch官方tutorial中的tensor板块有介绍,但在我的上一份博客中没有讲,简单来说就是把一个1*1大小的tensor格式的数据转化成一个python类型的数据,如float32,float16,int16啥的。在这段代码中,转变成python的数据后就可以用来索引了。
training_data[0]
是图片的像素信息,training_data[1]
是图片的类别,用它通过label_map进行索引得到label
稍微解释一下再看一下代码就好理解了。
最后得到的图片也是随机的,随便放两个:
自定义数据集
本质上就是自定义类,但由于作者非cs科班学生,pyfthon并没有仔细学过类,自学C++最后由于时间原因中道崩殂,这里可能无法讲得太仔细。
自定义的数据集必须包含三个类方法: __init__
,__len__
和__getitem__
另外在定义类之前,需要把图像信息放在img_dir
文件中,数据情况放在annotations_file
中
先上代码:
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
下面一个一个看过去,首先是init方法:
每次实例化一个数据集对象,__init__
方法都会被调用。它初始化了图像所在文件img_dir,标签文件annotations_file和转换方式(下一节中有详细介绍)
len方法:
len方法比较简单,这还看不懂可以直接入土了。
getitem方法:
getitem方法如其名,作用就是通过给的idx,加载和返回图片和标签。第一句img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
通过os库确定索引的图片所在的路径。其中iloc是panda库里的一个索引方法。self.img_labels.iloc[idx,0]返回的是idx图像的文件名。利用os库的join把img_dir和后面得到的图像文件名给链接起来,就得到了该图像的路径。
第二句image = read_image(img_path)
获得图像信息。
第三句label = self.img_labels.iloc[idx, 1]
得到图像的标签,同样是用的pandas库里的方法。
下面几句就不解释啦。
DataLoaders
Dataset能一次性获取所有图像的特征(就是图像数据,以后都叫特征啦,features)和标签。但在训练时,我们常常需要把样本切分成几个batch分堆送去训练,对每一次epoch都会对数据重新洗牌,来防止过拟合。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
现在把数据装进train_dataloader和test_dataloader之后,需要把他们用迭代的方式取出。
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
最终打印的结果为:
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 3
因为上面设置中shuffle=True所以我们每次得到的label和图像都是不一样的。