网络构建
数据加载
* 引入函数库
import torch
import torchvision
import torchvision.transforms as transforms
*将读入的数据进行转化:
transform = transforms.Compose(
[transforms.ToTensor(), ***range [0, 255] -> [0.0,1.0]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) *数据分布归一化到[-1,1]
*利用torch自带的CIFAR10数据集加载训练集
trainset = torchvision.datasets.CIFAR10(root=’./data’, train=True,
download=True, transform=transform)
*生成batch,其中:
*参数:
dataset:Dataset类型,从其中加载数据
batch_size:int,可选。每个batch加载多少样本
shuffle:bool,可选。为True时表示每个epoch都对数据进行洗牌
sampler:Sampler,可选。从数据集中采样样本的方法。
num_workers:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。
collate_fn:callable,可选。
pin_memory:bool,可选
drop_last:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
*加载测试集
testset = torchvision.datasets.CIFAR10(root=’./data’, train=False,
download=True, transform=transform)
*测试集batch
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
*定义类别
classes = (‘plane’, ‘car’, ‘bird’, ‘cat’,
‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’)
*显示一些训练集中的图片与标签
import matplotlib.pyplot as plt