pyTorch——训练第一个分类器要点解读

网络构建

数据加载

* 引入函数库
import torch
import torchvision
import torchvision.transforms as transforms

*将读入的数据进行转化:
transform = transforms.Compose(
[transforms.ToTensor(), ***range [0, 255] -> [0.0,1.0]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) *数据分布归一化到[-1,1]

*利用torch自带的CIFAR10数据集加载训练集
trainset = torchvision.datasets.CIFAR10(root=’./data’, train=True,
download=True, transform=transform)

*生成batch,其中:
*参数:
dataset:Dataset类型,从其中加载数据
batch_size:int,可选。每个batch加载多少样本
shuffle:bool,可选。为True时表示每个epoch都对数据进行洗牌
sampler:Sampler,可选。从数据集中采样样本的方法。
num_workers:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。
collate_fn:callable,可选。
pin_memory:bool,可选
drop_last:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)

*加载测试集
testset = torchvision.datasets.CIFAR10(root=’./data’, train=False,
download=True, transform=transform)

*测试集batch
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)

*定义类别
classes = (‘plane’, ‘car’, ‘bird’, ‘cat’,
‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’)

*显示一些训练集中的图片与标签
import matplotlib.pyplot as plt

  • 2
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值