深度学习笔记9_Softmax回归_图像分类(李沐,pytorch)

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()
SVG是一种无损格式 – 意味着它在压缩时不会丢失任何数据,可以呈现无限数量的颜色。
d2l.use_svg_display() 
意思是使用svg来显示图片,这样清晰度高一些。
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()   
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

trans = transforms.ToTensor()  暂且按下不表。看名字大概就是把数据转换成tensor。

torchvision.datasets.FashionMNIST是PyTorch自带的读取MNIST的库,其关键字如下:

  • root (string) – Root directory of dataset where FashionMNIST/raw/train-images-idx3-ubyte and FashionMNIST/raw/t10k-images-idx3-ubyte exist.

用于储存训练数据(以上两个文件)的目录

这里如果直接从资源管理器复制目录会出现以下错误:

SyntaxError: (unicode error) 'unicodeescape' codec can't decode bytes in position 2-3: truncated \xXX escape

百度可知,是因为字符串里面的\右斜杠被识别为转义字符

解决方法有两种:

1.在字符串前,增加r,保持字符串的原始含义

root=r"../data"

2.把右斜杠\改成左斜杠/ 即可

  • train (booloptional) – If True, creates dataset from train-images-idx3-ubyte, otherwise from t10k-images-idx3-ubyte.

这里是用布尔值指定是训练集(train-images-idx3-ubyte)还是测试集(t10k-images-idx3-ubyte

  • download (booloptional) – If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

如果实现准备好了数据集,可以把这个布尔值指定为False,否则就自动下载一份数据集到目录

这里本地下载文件非常麻烦,而且可直接下载,无需特别手段,所以可以指定好目录直接下载。

若非要下载,首先文件会存储在...\data\FashionMNIST\raw目录下,而且必须有八个文件存在(四个数据包和他们的解压文件)才能通过检测。

  • transform (callableoptional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

与上文的trans = transforms.ToTensor() 对上了

  • target_transform (callableoptional) – A function/transform that takes in the target and transforms it.

print(len(mnist_train), len(mnist_test))
print(mnist_train[0][0].shape)

这里下面的输出为

torch.Size([1, 28, 28])

这表明这第一张图片是一个黑白图片,channel数为1,长宽都是28px

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

get_fashion_mnist_labels(labels)函数

可以输入labels为一个数组,其中的每一项都是0-9的数字,对应标签中的10项,然后输出一个数组,其中每一项都是一个表示类别的字符串

show_images(imgs, num_rows, num_cols, titles=None, scale=1.5)函数

这个函数是用来输出个体对应的图片。

img是输入的图片;

num_rows, num_cols是输出图片时,想把图片输出的行列数,如num_rows=2, num_cols=9就是输出2行9列图片;

titles是显示在图片上方的标签,用于把数据集的标签显示出来;

scale是缩放比例,可以把输入的图片(28*28太小了,放大一些会比较好)进行放缩在输出。

函数内部的流程:

1.先把输出图像的尺寸定下来,也就是num_cols * scale, num_rows * scale,输出给figsize

2.创建画图用的画布。这里注意_, axes 是两个变量!但是_变量是作为占位,不被使用。

这里使用了subplots这个命令,用法:

fig, ax = plt.subplots()这样使用,其中要把图和锚(规定图像的坐标轴大小、图片数量等)一起传输进去。

.flatten是把二维数组变成一维数组,不清楚为什么要这么干

询问了一下Bing,这是为了遍历数组中每个图的坐标。

在下面的这个循环中,i是每个子图的索引,后面的(ax,img)则是每个子图的内容。为了实现这个功能,使用了zip和enumerate两个函数,他们的用法如下:

zip(a,b)可以把两个数组打包为一个数组,例如:a=[a,b,c,d],b=[1,2,3,4],那么list(zip(a,b))=[(a,1),(b,2),(c,3),(d,4)]

enumerate是为了一次遍历两个元素,可以把i作为索引,而后面的(ax,img)是每项的内容。

my_list = ['a', 'b', 'c']
for i, value in enumerate(my_list):
    print(f'Index: {i}, Value: {value}')

输出结果为:

Index: 0, Value: a
Index: 1, Value: b
Index: 2, Value: c

 然后,就把图片展示出来:

在接下来的循环中,判断图片img是否是张量,如果是,则将其转换为 NumPy 数组并使用 imshow 函数绘制;否则,直接使用 imshow 函数绘制。

接下来分别获取了axes的x、y轴,并将其设置为不可见。

最后函数返回axes的值

简单记录一下自己对DataLoader的理解:

DataLoader是torch.utils.data模块的一个类,它的用法是:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

其中比较重要的:

dataset:就是输入的数据集 
batch_size:顾名思义,一次读取的数据个数
shuffle:(数据类型 bool)是否打乱数据?设置为True就会打乱
 

timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

这里d2l.Timer只是一个计时器,可以使用timer.start()和timer.stop()来记录这两个命令时间代码运行的时间。

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值