参考资料:《深度学习框架PyTorch:入门与实践》
目录
一、简单numpy例子观察DataLoader
创建数据,显示它的shape:
import numpy as np
data = np.array([[1,1,1,1],
[2,2,2,2],
[3,3,3,3],
[4,4,4,4],
[5,5,5,5],
[6,6,6,6]])
print(data)
print(data.shape)
输出:
[[1 1 1 1]
[2 2 2 2]
[3 3 3 3]
[4 4 4 4]
[5 5 5 5]
[6 6 6 6]]
(6, 4)
然后使用DataLoader处理这份数据,注意这份数据中有“6个”数据:
import torch as t
from tqdm import tqdm
# 每个batch有2个数据,shuffle=False是禁止打乱
dataloader = t.utils.data.DataLoader(data,
batch_size=2,
shuffle=False,
num_workers=1)
for i, data_ in enumerate(tqdm(dataloader)):
print(i)
print(data_)
逐个batch输出数据,就会显示如下:
100%|██████████| 3/3 [00:00<00:00, 12.95it/s]
0
tensor([[1, 1, 1, 1],
[2, 2, 2, 2]])
1
tensor([[3, 3, 3, 3],
[4, 4, 4, 4]])
2
tensor([[5, 5, 5, 5],
[6, 6, 6, 6]])
二、两种方式加载CIFAR-10数据
方式1,用torchvision自动下载CIFAR-10
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()
# 第一次运行torchvision会自动下载CIFAR-10数据集,大约100MB,
# 如果已经有,可以通过root参数指定
# 定义对数据的预处理
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
# 训练集
# 注意download=True即可
trainset = tv.datasets.CIFAR10(
root='填写想存放数据的文件夹',
train=True,
download=True,
transform=transform)
trainloader = t.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2)
# 测试集
testset = tv.datasets.CIFAR10(
root='填写想存放数据的文件夹',
train=False,
download=True,
transform=transform)
testloader = t.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
如果上述代码报错:
URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1056)>
那么只需要在开头加入:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
但是这种方式下载速度非常慢,所以我采用方式2
方式2,自行下载CIFAR-10
上述代码是在https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
中下载数据的,我将链接复制到了Neat download manager中,速度非常快就下载好了。下载之后需要解压,我的文件夹如下:
然后还是和方式1一样的代码,只不过注意root参数应该是上图文件夹的位置(注意不是精确到图中cifar-10-batches-py
,而是它所在的这个文件夹的位置)
三、观察CIFAR-10数据集的size
之前代码有一句:
trainloader = t.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2)
这说明batch_size为4,并且允许打乱。我们现在观察一下trainloader是什么(testloader同理):
jupyter输入:
trainloader
输出:
<torch.utils.data.dataloader.DataLoader at 0x12b3c00b8>
这显然不足以让我了解trainloader,那么我要怎么观察trainloader呢?
答案很简单,可以用enumerate
,这样可以把trainloader中的数据赋值给data:
for i, data in enumerate(trainloader):
# print(np.array(data).shape) # 输出结果为(2, )
print(data[0].shape)
print(data[1].shape)
if i == 2: break
输出结果:
torch.Size([4, 3, 32, 32]) # batch, channel, height, width
torch.Size([4])
torch.Size([4, 3, 32, 32])
torch.Size([4])
torch.Size([4, 3, 32, 32])
torch.Size([4])
输出了三个batch的数据的shape。容易知道,data是[图片, label]的组织形式。上述代码中的data[0]是图片数据,data[1]是对应label。这样一来,我们就通过对trainloader的观察明白了data的形态。图片是3通道,高×宽是32×32;label就是一个数字。
四、LeNet处理CIFAR-10完整代码
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
from torch import optim
import ssl
show = ToPILImage()
########################
###### 加载数据##########
########################
# 如果选择下载的方式加载数据,
# 报错“URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1056)>”
# 加入下面这句解决了错误
ssl._create_default_https_context = ssl._create_unverified_context
# 第一次运行torchvision会自动下载CIFAR-10数据集,大约100MB,
# 如果已经有,可以通过root参数指定
# 定义对数据的预处理
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
# 训练集
# 如果选择下载的方式加载数据,download设置为True
trainset = tv.datasets.CIFAR10(
root='上文讲的cifar-10-batches-py文件夹所在的文件夹',
train=True,
download=False,
transform=transform)
trainloader = t.utils.data.DataLoader(
trainset,
batch_size=4,
shuffle=True,
num_workers=2)
# 测试集
# 如果选择下载的方式加载数据,download设置为True
testset = tv.datasets.CIFAR10(
root='上文讲的cifar-10-batches-py文件夹所在的文件夹',
train=False,
download=False,
transform=transform)
testloader = t.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
##################
##### 定义网络 ####
##################
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
#########################
# 定义损失函数和优化器######
#########################
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
#########################
## 训练 #################
#########################
# tqdm显示进度条,可以和enumerate结合使用
from tqdm import tqdm
for epoch in range(2):
running_loss = 0.0
for i, data in enumerate(tqdm(trainloader)):
# 前文已经讲过data[0]是图片,data[1]是label
inputs, labels = data
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = net(inputs)
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
optimizer.step()
running_loss += loss.data
# 每2000个batch打印一次训练状态
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / 2000))
running_loss = 0.0
print("finish training")
因为主要目的是演示代码运行,所以只训练两个epoch,运行结果如下:
16%|█▋ | 2044/12500 [00:04<00:23, 437.93it/s]
[1, 2000] loss: 2.185
33%|███▎ | 4072/12500 [00:10<00:23, 352.94it/s]
[1, 4000] loss: 1.808
48%|████▊ | 6062/12500 [00:15<00:17, 367.84it/s]
[1, 6000] loss: 1.637
65%|██████▍ | 8075/12500 [00:20<00:10, 428.74it/s]
[1, 8000] loss: 1.553
80%|████████ | 10053/12500 [00:25<00:06, 381.21it/s]
[1, 10000] loss: 1.483
96%|█████████▋| 12059/12500 [00:29<00:01, 412.74it/s]
[1, 12000] loss: 1.438
100%|██████████| 12500/12500 [00:31<00:00, 401.07it/s]
16%|█▋ | 2035/12500 [00:04<00:28, 365.83it/s]
[2, 2000] loss: 1.369
33%|███▎ | 4091/12500 [00:09<00:18, 451.14it/s]
[2, 4000] loss: 1.356
48%|████▊ | 6062/12500 [00:14<00:15, 409.82it/s]
[2, 6000] loss: 1.330
64%|██████▍ | 8057/12500 [00:19<00:10, 422.07it/s]
[2, 8000] loss: 1.279
81%|████████ | 10072/12500 [00:24<00:05, 412.58it/s]
[2, 10000] loss: 1.273
96%|█████████▌| 12023/12500 [00:29<00:01, 365.16it/s]
[2, 12000] loss: 1.243
100%|██████████| 12500/12500 [00:30<00:00, 406.19it/s]
finish training