Pytorch学习(二)使用 torchvision

训练图像分类器

官方教程
我们将按顺序执行以下步骤:
1.使用使用 torchvision
2.定义卷积神经网络
3.定义损耗函数
4.根据培训数据对网络进行训练
5.在测试数据上测试网络
这篇博文为第一步

准备数据集

在训练神经网络前,必须有数据。可以使用以下几个数据提供源。
准备图片数据集
一、CIFAR-10CIFAR-10
二、ImageNetImageNet
三、ImageFolderImageFolder
四、LSUN ClassificationLSUN Classification
五、COCO (Captioning and Detection)COCO

torchvision

为了方便加载以上五种数据库的数据,torchvision 是 PyTorch 中专门用来处理图像的库。使用torchvision就可以轻松实现数据的加载和预处理。
我们以使用CIFAR10为例:
在这里插入图片描述

一. 导入torchvision的库

import torchvision
import torchvision.transforms as transforms   # transforms用于数据预处理

二. 使用datasets.CIFAR10()函数加载数据库

CIFAR10有60000张图片,其中50000张是训练集,10000张是测试集。

#训练集,将相对目录./data下的cifar-10-batches-py文件夹中的全部数据
#(50000张图片作为训练数据)加载到内存中,若download为True时,会自动从网上下载数据并解压
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
download=False, transform=None)

下面简单讲解root、train、download、transform这四个参数
1.root,表示cifar10数据的加载的相对目录
2.train,表示是否加载数据库的训练集,false的时候加载测试集
3.download,表示是否自动下载cifar数据集
4.transform,表示是否需要对数据进行预处理,none为不进行预处理
由于美帝路途遥远,靠命令台进程下载100多M的数据速度很慢,所以我们可以自己去到cifar10的官网上把CIFAR-10 python version下载下来,然后解压为cifar-10-batches-py文件夹,并复制到相对目录./data下。(若设置download=True,则程序会自动从网上下载cifar10数据到相对目录./data下,但这样小伙伴们可能要等一个世纪了),并对训练集进行加载(train=True)
在脚本文件下建一个data文件夹,然后把数据集文件夹丢到里面去就好了,注意cifar-10-batches-py文件夹名字不能自己任意改。
接下来看一下trainset的大小

print len(trainset)
#结果:50000

在这里插入图片描述

三. DataLoader用多进程加速batch data的处理

我们在训练神经网络时,使用的是mini-batch(一次输入多张图片),所以我们在使用一个叫DataLoader的工具为我们将50000张图分成每四张图一分,一共12500份的数据包。

#将训练集的50000张图片划分成12500份,每份4张图,用于mini-batch输入。
#shffule=True在表示不同批次的数据遍历时,打乱顺序(这个需要在训练神经网络时再来讲)。
#num_workers=2表示使用两个子进程来加载数据

import torch
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
shuffle=False, num_workers=2)

四. 对数据预处理

接下来需要对数据进行预处理,预处理会帮助我们加快神经网络的训练。
预处理用到了transforms函数:

transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

compose函数会将多个transforms包在一起。
1.ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C) 从0到255的值映射到0到1的范围内,并转化成Tensor格式。
2.Normalize(mean,std)是通过下面公式实现数据归一化
transforms.Normalize( mean = (0.5,0.5,0.5), std = (0.5,0.5,0.5) )并不是指将张量的均值和标准差设为0.5,而是做这么一个运算:输入的每个channel做 ( [0, 1] - mean(0.5) )/ std(0.5)= [-1, 1] 的运算,所以这一句的实际结果是将[0,1]的张量归一化到[-1, 1]上。

channel=(channel-mean)/std 

经过上面两个操作,我们的数据中的每个值就变成了[-1,1]的数了。

完整代码

1.先引入库

import torch
import torchvision
import torchvision.transforms as transforms

2.数据预处理,归化为tensor数据,并归一化为[-1, 1]

# torchvision输出的是PILImage,值的范围是[0, 1].
# 我们将其转化为tensor数据,并归一化为[-1, 1]。
transform=transforms.Compose([transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                             ])

3.加载训练集数据

#训练集,将相对目录./data下的cifar-10-batches-py文件夹中的全部数据
#(50000张图片作为训练数据)加载到内存中,若download为True时,会自动从网上下载数据并解压
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
download=False, transform=transform)

4.划分数据,加速处理

#将训练集的50000张图片划分成12500份,每份4张图,用于mini-batch输入。
#shffule=True在表示不同批次的数据遍历时,打乱顺序。num_workers=2表示使用两个子进程来加载数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
shuffle=False, num_workers=2)

5.展示训练图像,直观感受
对代码中一部分的理解,参考plt.imshow(np.transpose(npimg, (1, 2, 0)))在pytorch中,读入图片并进行显示的方式有两种,见博文。

#添加CIFAR10的标签
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize 非标准的反归一化
    #因为归一话的时候是先减去平均值0.5 ,然后再除以标准偏差0.5那么反归一化就是先乘以0.5,再加0.5。
    npimg = img.numpy()#Image互转化为numpy,将torch.FloatTensor 转换为numpy
    #因为在plt.imshow现在实际输入的是(imagesize,imagesize,channels),
    #而定义中参数img的格式为(channels,imagesize,imagesize),这两者的格式不一致
    #我们需要调用np.transpose函数,即np.transpose(npimg,(1,2,0))
    #将npimg的数据格式由(channels,imagesize,imagesize)转化为(imagesize,imagesize,channels),进行格式的转换后方可进行显示。
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s'%classes[labels[j]] for j in range(4)))

输出图片:
在这里插入图片描述
输出标签:
在这里插入图片描述

6.运行时出现的问题
在这里插入图片描述
原因:多进程需要在main函数中运行
解决办法:参考博客
1.加main函数,在main中调用
2.num_workers改为0,单进程加载

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值