(二)Linear Neural Networks -- 3. The Image Classification Dataset

3. The Image Classification Dataset

import torch
import torchvision
from torchvision import transforms
from torch.utils import data

import matplotlib.pyplot as plt
%matplotlib inline

3.1 Reading the Dataset

Download and read the Fashion-MNIST dataset into memory via the build-in functions in the framework:

# Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor
trans = transforms.ToTensor()

minst_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)
minst_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)

Fashion-MNIST consists of images from 10 categories,
each represented by 6000 images in the training dataset and by 1000 in the test dataset.

The height and width of each input image are both 28 pixels.

print(len(minst_train), len(minst_test))
print(minst_train[0][0].shape)
60000 10000
torch.Size([1, 28, 28])

Convert between numeric label indices and their names in text:

def get_fashion_mnist_labels(labels):
    """Return text labels for the Fashion-MNIST dataset."""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

Visualize examples:

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # Tensor Image
            ax.imshow(img.numpy())
        else:
            # PIL Image
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    # return axes
    plt.show()
X, y = next(iter(data.DataLoader(minst_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

3.2 Reading a Minibatch

Use built-in data iterator to read training and test sets:

def get_dataloader_workers():
    return 4

batch_size = 256
train_iter = data.DataLoader(minst_train, batch_size=batch_size, shuffle=True, num_workers=get_dataloader_workers())

3.3 Putting All Things Together

def get_dataloader_workers():
    return 4

def load_data_fashion_mnist(batch_size, resize=None):
    trans = [transforms.ToTensor()]
    
    if resize:
        trans.insert(0, transforms.Resize(resize))

    trans = transforms.Compose(trans)

    minst_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)
    minst_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)

    return (data.DataLoader(minst_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()), 
    data.DataLoader(minst_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64




References

Linear Neural Networks – The Image Classification Dataset

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值