利用Pytorch对CIFAR10数据集分类(一)

  1. 使用torchvision加载并预处理CIFAR-10数据集
  2. 定义网络(卷及神经)
  3. 定义损失函数和优化器
  4. 训练网络并更新网络参数
  5. 测试网络

Pytorch库中有许多与深度学习有关的代码块,在进行学习时可以直接调用,十分有利于新手学习和使用。本次深度学习我就是采用pytorch库进行变成实现对CIFAR10数据集的分类处理

使用torchvision加载并预处理CIFAR10数据集

直接上python代码(编译器为jupyter)

import torch
import torchvision
import torchvision.transforms as transforms
transform1 = transforms.Compose(
    [
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root=r'C:\Users\dell\Desktop\Python', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                          shuffle=True, num_workers=1)

testset = torchvision.datasets.CIFAR10(root='test_batch', train=False,
                                       download=True, transform=transform1)
testloader = torch.utils.data.DataLoader(testset, batch_size=50,
                                         shuffle=False, num_workers=1)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
transforms.ToTensor();

ToTensor()将shape为(H, W, C)的nump.ndarray或img转为shape为(C, H, W)的tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单,直接除以255即可。

transforms.Normalize

class torchvision.transforms.Normalize(mean, std)
给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化。即:Normalized_image=(image-mean)/std。

1.神经网络中归一化的原因:

归一化是为了加快训练网络的收敛性,可以不进行归一化处理;

归一化的具体作用是归纳统一样本的统计分布性。归一化在0-1之间是统计的概率分布,归一化在-1–+1之间是统计的坐标分布。归一化有同一、统一和合一的意思。无论是为了建模还是为了计算,首先基本度量单位要同一,神经网络是以样本在事件中的统计分别几率来进行训练(概率计算)和预测的,归一化是同一在0-1之间的统计概率分布;当所有样本的输入信号都为正值时,与第一隐含层神经元相连的权值只能同时增加或减小,从而导致学习速度很慢。为了避免出现这种情况,加快网络学习速度,可以对输入信号进行归一化,使得所有样本的输入信号其均值接近于0或与其均方差相比很小。

torchvision.datasets.CIFAR10(root=r'C:\Users\dell\Desktop\Python', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                          shuffle=True, num_workers=1)

pytorch中自带加载CIFAR10数据集的模块使用torchvision.datasets.CIFAR10加载CIFAR10数据集root是保存数据集的路径,train=true为训练集download为true要进行联网下载数据集如果路径文件夹中含有已经下载好的数据集则不用下载,transform为归一化处理在前面已经操作过了。
trainloadar作为一个容器在程序运行时装载数据集中的数据,trainset作为训练集,batch_size = 10为minibatch的数据量为100,shuffle = True 表明提取数据时,随机打乱顺序,因为我们都是基于随机梯度下降的方式进行训练优化,但测试的时候因为不需要更新参数,所以就无须打乱顺序了。
num_workers = 2 指定了工作线程的数量。
接下来的testset和testlodar为测试集。

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

classes规定数据集中的种类名称

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值