3. The Image Classification Dataset
import torch
import torchvision
from torchvision import transforms
from torch.utils import data
import matplotlib.pyplot as plt
%matplotlib inline
3.1 Reading the Dataset
Download and read the Fashion-MNIST dataset into memory via the build-in functions in the framework:
# Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor
trans = transforms.ToTensor()
minst_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)
minst_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)
Fashion-MNIST consists of images from 10 categories,
each represented by 6000 images in the training dataset and by 1000 in the test dataset.
The height and width of each input image are both 28 pixels.
print(len(minst_train), len(minst_test))
print(minst_train[0][0].shape)
60000 10000
torch.Size([1, 28, 28])
Convert between numeric label indices and their names in text:
def get_fashion_mnist_labels(labels):
"""Return text labels for the Fashion-MNIST dataset."""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
Visualize examples:
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# Tensor Image
ax.imshow(img.numpy())
else:
# PIL Image
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
# return axes
plt.show()
X, y = next(iter(data.DataLoader(minst_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
![](https://i-blog.csdnimg.cn/blog_migrate/94c642e4bcee00797ca56e4157247200.png)
3.2 Reading a Minibatch
Use built-in data iterator to read training and test sets:
def get_dataloader_workers():
return 4
batch_size = 256
train_iter = data.DataLoader(minst_train, batch_size=batch_size, shuffle=True, num_workers=get_dataloader_workers())
3.3 Putting All Things Together
def get_dataloader_workers():
return 4
def load_data_fashion_mnist(batch_size, resize=None):
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
minst_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)
minst_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)
return (data.DataLoader(minst_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
data.DataLoader(minst_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64