Pytorch——CNN Image Preparation Code Project - Learn to Extract, Transform, Load -------ETL

The project (Bird’s-eye view)

There are four general steps that we’ll be following as we move through this project:
1.Prepare the data
2.Build the model
3.Train the model
4.Analyze the model’s results

The ETL process

  • Extract data from a data source
  • Transform data into a desirable format
  • Load data into a suitable structure

PyTorch imports

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

在这里插入图片描述
The next imports are standard packages used for data science in Python:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix
#from plotcm import plot_confusion_matrix

import pdb

torch.set_printoptions(linewidth=120)

Note that pdb is the Python debugger and the commented import is a local file that we’ll introduce in future posts for plotting the confusion matrix, and the last line sets the print options for PyTorch print statements.

Preparing our data using PyTorch

Extract – Get the Fashion-MNIST image data from the source.
Transform – Put our data into tensor form.
Load – Put our data into an object to make it easily accessible.

在这里插入图片描述

PyTorch Dataset class

train_set = torchvision.datasets.FashionMNIST(
    root='./data'  # 下载到当前文件夹的data文件夹内,若没有则会创建data文件夹
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

Note that the root argument used to be ‘./data/FashionMNIST’, however, it has since changed due to torchvision updates.
在这里插入图片描述
Since we want our images to be transformed into tensors, we use the built-in transforms.ToTensor() transformation, and since this dataset is going to be used for training, we’ll name the instance train_set.
When we run this code for the first time, the Fashion-MNIST dataset will be downloaded locally. Subsequent calls check for the data before downloading it. Thus, we don’t have to worry about double downloads or repeated network calls.

PyTorch DataLoader class

train_loader = torch.utils.data.DataLoader(train_set
    ,batch_size=1000
    ,shuffle=True
)

batch_size (1000 in our case)
shuffle (True in our case)
num_workers (Default is 0 which means the main process will be used)

Exploring the data

To see how many images are in our training set, we can check the length of the dataset using the Python len() function:

> len(train_set)
60000

This 60000 number makes sense based on what we learned in the post on the Fashion-MNIST dataset.
Suppose we want to see the labels for each image. This can be done like so:

> train_set.targets
tensor([9, 0, 0, ..., 3, 0, 5])

The first image is a 9 and the next two are zeros. Remember from posts past, these values encode the actual class name or label. The 9 for example is an ankle boot while the 0 is a t-shirt.
If we want to see how many of each label exists in the dataset, we can use the PyTorch bincount() function like so:

> train_set.targets.bincount()
tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000])
Class imbalance: Balanced and unbalanced datasets

This shows us that the Fashion-MNIST dataset is uniform 平均分布的with respect to the number of samples in each class. This means we have 6000 samples for each class. As a result, this dataset is said to be balanced. ( Fashion-MNIST dataset 是一个均衡数据集)If the classes had a varying number of samples, we would call the set an unbalanced dataset.
Class imbalance is a common problem, but in our case, we have just seen that the Fashion-MNIST dataset is indeed balanced, so we need not worry about that for our project.

Accessing data in the training set

To access an individual element from the training set, we first pass the train_set object to Python’s iter() built-in function, which returns an object representing a stream of data.要访问训练集中的单个元素,我们首先将训练集对象传递给Python的iter()内置函数,该函数返回一个表示数据流的对象。
With the stream of data, we can use Python built-in next() function to get the next data element in the stream of data. From this we are expecting to get a single sample, so we’ll name the result accordingly:

> sample = next(iter(train_set))
> len(sample)
2

After passing the sample to the len() function, we can see that the sample contains two items, and this is because the dataset contains image-label pairs.
Each sample we retrieve from the training set contains the image data as a tensor and the corresponding label as a tensor.
Since the sample is a sequence type, we can use sequence unpacking to assigned the image and the label. We will now check the type of the image and the label and see they are both torch.Tensor objects:

> type(image)
torch.Tensor

# Before torchvision 0.2.2
> type(label)
torch.Tensor
# Starting at torchvision 0.2.2
> type(label)
int

We’ll check the shape to see that the image is a 1 x 28 x 28 tensor while the label is a scalar valued tensor:

> image.shape
torch.Size([1, 28, 28]) 

> torch.tensor(label).shape
torch.Size([])

> image.squeeze().shape
torch.Size([28, 28])

注意:tensor shape--------torch.Size([1]) and scalar shape--------- torch.Size([])
torch.Size([0]) means a tensor of this size is 1-dimensional but has no elements.
Contrast this to a tensor of size torch.Size([1]), which means it is 1 dimensional and has one element.

Let’s plot the image now, and we’ll see why we squeezed the tensor in the first place. We first squeeze the tensor and then pass it to the imshow() function.
如果不squeeze的话,传入imshow的图像参数过多会报错

> plt.imshow(image.squeeze(), cmap="gray")
> torch.tensor(label)
tensor(9)

PyTorch DataLoader: Working with batches of data

We’ll start by creating a new data loader with a smaller batch size of 10 so it’s easy to demonstrate what’s going on:

> display_loader = torch.utils.data.DataLoader(
    train_set, batch_size=10
)

There is one thing to notice when working with the data loader. If shuffle=True, then the batch will be different each time a call to next occurs.
With shuffle=True, the first samples in the training set will be returned on the first call to next.
The shuffle functionality is turned off by default.
在使用数据加载器时需要注意一件事。如果shuffle=True,则每次发生对next的调用时批处理将不同。使用shuffle=True,训练集中的第一个样本将在第一次调用next时返回。默认情况下,shuffle功能是关闭的。

# note that each batch will be different when shuffle=True
> batch = next(iter(display_loader))
> print('len:', len(batch))
len: 2

Let’s unpack the batch and take a look at the two tensors and their shapes:

> images, labels = batch

> print('types:', type(images), type(labels))
> print('shapes:', images.shape, labels.shape)
types: <class 'torch.Tensor'> <class 'torch.Tensor'>
shapes: torch.Size([10, 1, 28, 28]) torch.Size([10])

Since batch_size=10, we know we are dealing with a batch of 10 images and 10 corresponding labels.
The size of each dimension in the tensor that contains the image data is defined by each of the following values:(batch size, number of color channels, image height, image width)

> images[0].shape
torch.Size([1, 28, 28])

> labels[0]
tensor(9)

To plot a batch of images, we can use the torchvision.utils.make_grid() function to create a grid that can be plotted like so:

> grid = torchvision.utils.make_grid(images, nrow=10)

> plt.figure(figsize=(15,15))
> plt.imshow(np.transpose(grid, (1,2,0)))
# 此处要还原为载入图像时基础的shape,所以应把顺序变为[height, width, channel]
> print('labels:', labels)
labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

在这里插入图片描述
Another way to do this:

> grid = torchvision.utils.make_grid(images, nrow=10)

> plt.figure(figsize=(15,15))
> plt.imshow(grid.permute(1,2,0))

> print('labels:', labels)
labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

在这里插入图片描述

How to Plot Images Using PyTorch DataLoader

Here is another was to plot the images using the PyTorch DataLoader.

how_many_to_plot = 20

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=1, shuffle=True
)

mapping = {
    0:'Top', 1:'Trousers', 2:'Pullover', 3:'Dress', 4:'Coat'
    ,5:'Sandal', 6:'Shirt', 7:'Sneaker', 8:'Bag', 9:'Ankle Boot'
}

plt.figure(figsize=(50,50))
for i, batch in enumerate(train_loader, start=1):
    image, label = batch
    plt.subplot(10,10,i)
    fig = plt.imshow(image.reshape(28,28), cmap='gray')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(mapping[label.item()], fontsize=28)
    if (i >= how_many_to_plot): break
plt.show()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TonyHsuM

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值