关于dataset, datalloader的使用,以及图像展示和相关参数说明
1. 图片展示
这是pytorch官网上的一个例子,获取到FashionMNIST的一个例子
[1]https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
# 获取dataset
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
# 随机从训练集中选出图像数据进行展示
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
2.自定义数据集
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
3. Preparing your data for training with DataLoaders
2中的dataset只能一个一个的获取样本,然而在训练的时候往往需要多个训练样本组成 batch,以及对顺序进行打乱避免过拟合。
The DataLoader is an iterable that abstracts this complexity for us in an easy API. We use the Dataloader, we need to set the following paraments:
- data the training data that will be used to train the model; and test data to evaluate the model
- batch size the number of records to be processed in each batch
- shuffle the randoms sample of the data by indices
这样就生成了一个迭代器,每次获取的样本数量有batch_size决定,顺序是否打乱由shuffle参数决定
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
可以通过迭代器展示 样本:
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
return
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 5
4. DataLoaders中的一些参数
除了batch_size, shuffle 外,常常使用的参数还有
- num_workers (int, optional) 进程数量
– how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) - collate_fn参数的使用
这个参数传入的是一个函数,这个函数主要是对每个batch进行处理,最终输出一个batch的返回值,换句话说collate_fn函数的返回值,就是遍历DataLoader的时候每个“batch”的返回值了。
def mycollate(item):
train_features, train_labels = item
return {'feature':train_features,'lable':train_labels}
from torch.utils.data import DataLoader
myDataloader = DataLoader(dataset, shuffle=True, batch_size=8, collate_fn=mycollate)
这样再经过遍历dataloader的时候,返回的将是mycollate函数
for batch in myDataloader:
print(batch)
得到的不再是train_features, train_labels
而是 {'feature':train_features, 'lable':train_labels}