4_cifar10_tutorial

pytorch-handbook
一篇比较详细的学习博客

训练一个分类器

关于数据?

一般情况下处理图像、文本、音频和视频数据时,可以使用标准的Python包来加载数据到一个numpy数组中。 然后把这个数组转换成 torch.*Tensor

  • 图像可以使用 Pillow, OpenCV
  • 音频可以使用 scipy, librosa
  • 文本可以使用原始Python和Cython来加载,或者使用 NLTK或 SpaCy 处理
    特别的,对于图像任务,我们创建了一个包 torchvision,它包含了处理一些基本图像数据集的方法。这些数据集包括 Imagenet, CIFAR10, MNIST 等。除了数据加载以外,torchvision 还包含了图像转换器, torchvision.datasetstorch.utils.data.DataLoader

torchvision包不仅提供了巨大的便利,也避免了代码的重复。

在这个教程中,我们使用CIFAR10数据集,它有如下10个类别 :‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。CIFAR-10的图像都是 3x32x32大小的,即,3颜色通道,32x32像素。
在这里插入图片描述

训练一个图像分类器

依次按照下列顺序进行:

  1. 使用torchvision加载和归一化CIFAR10训练集和测试集
  2. 定义一个卷积神经网络
  3. 定义损失函数
  4. 在训练集上训练网络
  5. 在测试集上测试网络

1. 读取和归一化 CIFAR10

使用torchvision可以非常容易地加载CIFAR10。

import torch
import torchvision
import torchvision.transforms as transforms

torchvision的输出是[0,1]的PILImage图像,我们把它转换为归一化范围为[-1, 1]的张量。

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 这里我直接离线下载了,链接下载真是太慢了。
# Using downloaded and verified file: ./data\cifar-10-python.tar.gz
# Extracting ./data\cifar-10-python.tar.gz to ./data
# Files already downloaded and verified

我们展示一些训练图像。

import matplotlib.pyplot as plt
import numpy as np

# 展示图像的函数

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# 获取随机数据
dataiter = iter(trainloader)
images, labels = dataiter.next()

# 展示图像
imshow(torchvision.utils.make_grid(images))
plt.show()
# 显示图像标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

报错1:import matplotlib.pyplot as plt
解决方案1:conda activate pytorchconda install matplotlib,重启pycharm

报错2:

RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.
        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:
            if __name__ == '__main__':
                freeze_support()
                ...
        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

解决方案2:将代码放在if __name__ == '__main__':

运行代码如下:

import matplotlib.pyplot as plt
import numpy as np

# 展示图像的函数
def imshow(img):
    img = img / 2 + 0.5         # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

if __name__ == '__main__':
    
    # 获取随机数据
    dataiter = iter(trainloader)
    images, labels = dataiter.next()

    # 展示图像
    imshow(torchvision.utils
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值