Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

1、写在前面的话

这一篇博文的内容主要来自于pytorch的官方tutorial,然后根据自己的理解把cifar10这个示例讲一遍,权当自己做笔记。因为这个cifar10是官方example,所以适合我们拿来先练手,至少能保证代码的正确性。
之所以第一篇pytorch的博文(其实之前还写了篇如何安装pytorch)就用cifar10做例子,是我个人觉得先从宏观上了解一个例子的样貌是什么样的,然后我们再来针对性的学习相关的知识点,这样可能效率快一点。 所以我不光是在讲cifar10这个例子,而是在剖析这个例子,说明这些知识点属于哪个模块,该去哪儿找后续也会写相关博客进行一些细节性的讲解。
 

2.大致流程

一般来说,使用深度学习框架我们会经过下面几个流程:

模型定义(包括损失函数的选择) --->数据处理和加载 ---> 训练(可能包含训练过程可视化) ---> 测试

所以我们在自己写代码的时候也基本上就按照这四个大模块四步走就ok了

官方给的这个例子呢,是先进行的第二步数据处理和加载,然后定义网络,这其实没什么关系。

所以本篇博文讲解的是  数据处理和加载 这一步的内容,当然会接着在后续博文写其他步骤。

此例的步骤:
A、Load and normalizing the CIFAR10 training and test datasets using torchvision
B、Define a Convolution Neural Network
C、Define a loss function
D、Train the network on the training data
E、Test the network on the test data

下面我就直接上程序,并且添加我自己的一些注解,觉得有问题的欢迎提出,希望和大家多交流。

 

 

3、代码分析

首先使用 torchvision加载和归一化我们的训练数据和测试数据。
a、torchvision这个东西,实现了常用的一些深度学习的相关的图像数据的加载功能,比如cifar10、Imagenet、Mnist等等的,保存在torchvision.datasets模块中。
b、同时,也封装了一些处理数据的方法。保存在torchvision.transforms模块中
c、还封装了一些模型和工具封装在相应模型中。可以从下图一窥大貌:

 

细节见torchvision的官方文档链接:http://pytorch.org/docs/0.3.0/torchvision/index.html 

 

#  首先当然肯定要导入torch和torchvision,至于第三个是用于进行数据预处理的模块
import torch
import torchvision
import torchvision.transforms as transforms

#  **由于torchvision的datasets的输出是[0,1]的PILImage,所以我们先先归一化为[-1,1]的Tensor**
    #  首先定义了一个变换transform,利用的是上面提到的transforms模块中的Compose( )
    #  把多个变换组合在一起,可以看到这里面组合了ToTensor和Normalize这两个变换
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 

    # 定义了我们的训练集,名字就叫trainset,至于后面这一堆,其实就是一个类:
    # torchvision.datasets.CIFAR10( )也是封装好了的,就在我前面提到的torchvision.datasets
    # 模块中,不必深究,如果想深究就看我这段代码后面贴的图1,其实就是在下载数据
    #(不翻墙可能会慢一点吧)然后进行变换,可以看到transform就是我们上面定义的transform
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
    # trainloader其实是一个比较重要的东西,我们后面就是通过trainloader把数据传入网
    # 络,当然这里的trainloader其实是个变量名,可以随便取,重点是他是由后面的
    # torch.utils.data.DataLoader()定义的,这个东西来源于torch.utils.data模块,
    #  网页链接http://pytorch.org/docs/0.3.0/data.html,这个类可见我后面图2
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)
    # 对于测试集的操作和训练集一样,我就不赘述了
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)
    # 类别信息也是需要我们给定的
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
---------- update 2019/11/28 --------
经常有读者问上面的代码中的 Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 这个是什么意思
 
前面的(0.5,0.5,0.5) 是 R G B 三个通道上的均值, 后面(0.5, 0.5, 0.5) 是三个通道的标准差
注意通道顺序是 R G B ,用过opencv的同学应该知道openCV读出来的图像是 BRG顺序。
这两个tuple数据是用来对RGB 图像做归一化的,如其名称 Normalize 所示
这里都取0.5只是一个近似的操作,实际上其均值和方差并不是这么多,但是就这个示例而言 影响可不计。
 精确值是通过分别计算R,G,B三个通道的数据算出来的,
比如你有2张图片,都是100*100大小的,那么两图片的像素点共有2*100*100 = 20 000 个;
 那么这两张图片的mean求法为:
mean_R: 这20000个像素点的R值加起来,除以像素点的总数,这里是20000;
mean_G 和mean_B 两个通道 的计算方法 一样的。
 
标准差求法:
首先标准差就是开了方的方差,所以其实就是求方差,方差公式就是我们数学上的那个求方差的公式:
也是3个通道分开算,
比如算R通道的, 这里X就为20000个像素点 各自的R值,再减去R均值,上面已经算好了;
然后平方;
然后20000个像素点相加,然后求平均除以20000,
得到R的方差,再开方得标准差。
注意!!!
如果你用的是自己创建的数据集,从头训练,那最好还是要自己统计自己数据集的这两个量
 
如果①你加载的的是 pytorch上的预训练模型,自己只是微调模型;
②或者你用了 常见的数据集比如VOC或者COCO之类的,但是用的是自己的网络结构,即pytorch上没有可选的预训练模型
那么可以使用一个 pytorch上给的通用的统计值
mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
参考官网链接:https://pytorch.org/docs/stable/torchvision/models.html   用ctrl+f 搜索mean 第一个位置就有说明

 

 

4、图片

 

图1

root表示存放dataset的位置,本例就是' ./data'

train,如果为True,就创建的是trainning set,可以看到我们的trainset调用它时用的是True

           而testset调用它时,参数里填的是False

transform,这个transform是形参名,由于我们定义的变换也叫transform,所以就有transform = transform,

看起来可能有点怪,其实我们的之前的变换可以随便命名

download,如果为True,就从网上下载,如果已经有下载好的数据就不会重复下载了

------------------------------------------------------------------------------------------------------------------------------------------

图2

数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。

 dataset:就是数据的来源,比如训练集就添入我们定义的trainset

batch_size:每批次进入多少数据,本例中填的是4

shuffle:如果为真,就打乱数据的顺序,本例为True

num_workers:用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)

本例中为2。这个值是什么意思呢,就是数据读入的速度到底有多快,你选的用来加载数据的

子进程越多,那么显然数据读的就越快,这样的话消耗CPU的资源也就越多,所以这个值在自己

跑实验的时候,可以自己试一试,既不要让花在加载数据上的时间太多,也不要占用太多电脑资源

 

 

所以这第一步----数据加载和处理,要注意的就是这些内容,如果程序运行完毕,会显示:

 

这里我提个小建议,就是下载数据的那个root参数,官网代码给的是'./data',这个其实可以改成自己的位置

而且,建议改成 绝对路径 要好一点。然后由于代码可以直接从官网复制粘帖,所以这部分程序的运行的快慢,

基本就取决于下载数据的网速了,建议翻墙,可能翻墙也不见得很快,不如晚上睡觉前开始下,说不定第二天

醒来就下好了呢.......

  • 101
    点赞
  • 413
    收藏
    觉得还不错? 一键收藏
  • 34
    评论
评论 34
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值