PyTorch笔记——观察DataLoader&用torch构建LeNet处理CIFAR-10完整代码

参考资料:《深度学习框架PyTorch:入门与实践》

一、简单numpy例子观察DataLoader

创建数据,显示它的shape:

import numpy as np

data = np.array([[1,1,1,1],
                [2,2,2,2],
                [3,3,3,3],
                [4,4,4,4],
               [5,5,5,5],
                [6,6,6,6]])
print(data)
print(data.shape)

输出:

[[1 1 1 1]
 [2 2 2 2]
 [3 3 3 3]
 [4 4 4 4]
 [5 5 5 5]
 [6 6 6 6]]
(6, 4)

然后使用DataLoader处理这份数据,注意这份数据中有“6个”数据:

import torch as t
from tqdm import tqdm
# 每个batch有2个数据,shuffle=False是禁止打乱
dataloader = t.utils.data.DataLoader(data,
                                     batch_size=2,
                                     shuffle=False,
                                     num_workers=1)
for i, data_ in enumerate(tqdm(dataloader)):
    print(i)
    print(data_)

逐个batch输出数据,就会显示如下:

100%|██████████| 3/3 [00:00<00:00, 12.95it/s]
0
tensor([[1, 1, 1, 1],
        [2, 2, 2, 2]])
1
tensor([[3, 3, 3, 3],
        [4, 4, 4, 4]])
2
tensor([[5, 5, 5, 5],
        [6, 6, 6, 6]])

二、两种方式加载CIFAR-10数据

方式1,用torchvision自动下载CIFAR-10

import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()

# 第一次运行torchvision会自动下载CIFAR-10数据集,大约100MB,
# 如果已经有,可以通过root参数指定

# 定义对数据的预处理
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5))])

# 训练集
# 注意download=True即可
trainset = tv.datasets.CIFAR10(
                    root='填写想存放数据的文件夹',
                    train=True,
                    download=True,
                    transform=transform)

trainloader = t.utils.data.DataLoader(
                trainset,
                batch_size=4,
                shuffle=True,
                num_workers=2)

# 测试集
testset = tv.datasets.CIFAR10(
                    root='填写想存放数据的文件夹',
                    train=False,
                    download=True,
                    transform=transform)

testloader = t.utils.data.DataLoader(
                testset,
                batch_size=4,
                shuffle=False,
                num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

如果上述代码报错:
URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1056)>

那么只需要在开头加入:

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

但是这种方式下载速度非常慢,所以我采用方式2

方式2,自行下载CIFAR-10

上述代码是在https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz中下载数据的,我将链接复制到了Neat download manager中,速度非常快就下载好了。下载之后需要解压,我的文件夹如下:
在这里插入图片描述
然后还是和方式1一样的代码,只不过注意root参数应该是上图文件夹的位置(注意不是精确到图中cifar-10-batches-py,而是它所在的这个文件夹的位置)

三、观察CIFAR-10数据集的size

之前代码有一句:

trainloader = t.utils.data.DataLoader(
                trainset,
                batch_size=4,
                shuffle=True,
                num_workers=2)

这说明batch_size为4,并且允许打乱。我们现在观察一下trainloader是什么(testloader同理):

jupyter输入:

trainloader

输出:

<torch.utils.data.dataloader.DataLoader at 0x12b3c00b8>

这显然不足以让我了解trainloader,那么我要怎么观察trainloader呢?

答案很简单,可以用enumerate,这样可以把trainloader中的数据赋值给data:

for i, data in enumerate(trainloader):
#     print(np.array(data).shape)	# 输出结果为(2, )
    print(data[0].shape)
    print(data[1].shape)
    if i == 2: break

输出结果:

torch.Size([4, 3, 32, 32])	# batch, channel, height, width
torch.Size([4])
torch.Size([4, 3, 32, 32])
torch.Size([4])
torch.Size([4, 3, 32, 32])
torch.Size([4])

输出了三个batch的数据的shape。容易知道,data是[图片, label]的组织形式。上述代码中的data[0]是图片数据,data[1]是对应label。这样一来,我们就通过对trainloader的观察明白了data的形态。图片是3通道,高×宽是32×32;label就是一个数字。

四、LeNet处理CIFAR-10完整代码


import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
from torch import optim
import ssl
show = ToPILImage()

########################
###### 加载数据##########
########################

# 如果选择下载的方式加载数据,
# 报错“URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1056)>”
# 加入下面这句解决了错误
ssl._create_default_https_context = ssl._create_unverified_context

# 第一次运行torchvision会自动下载CIFAR-10数据集,大约100MB,
# 如果已经有,可以通过root参数指定

# 定义对数据的预处理
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5),
                                                     (0.5, 0.5, 0.5))])

# 训练集
# 如果选择下载的方式加载数据,download设置为True
trainset = tv.datasets.CIFAR10(
                    root='上文讲的cifar-10-batches-py文件夹所在的文件夹',
                    train=True,
                    download=False,
                    transform=transform)

trainloader = t.utils.data.DataLoader(
                trainset,
                batch_size=4,
                shuffle=True,
                num_workers=2)

# 测试集
# 如果选择下载的方式加载数据,download设置为True
testset = tv.datasets.CIFAR10(
                    root='上文讲的cifar-10-batches-py文件夹所在的文件夹',
                    train=False,
                    download=False,
                    transform=transform)

testloader = t.utils.data.DataLoader(
                testset,
                batch_size=4,
                shuffle=False,
                num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

##################
##### 定义网络 ####
##################
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
net = Net()

#########################
# 定义损失函数和优化器######
#########################
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

#########################
## 训练 #################
#########################

# tqdm显示进度条,可以和enumerate结合使用
from tqdm import tqdm
for epoch in range(2):
    
    running_loss = 0.0
    for i, data in enumerate(tqdm(trainloader)):
    	# 前文已经讲过data[0]是图片,data[1]是label
        inputs, labels = data
        # 梯度清零
        optimizer.zero_grad()
        
        # 前向传播
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        # 反向传播
        loss.backward()
        optimizer.step()
        
        running_loss += loss.data
        #  每2000个batch打印一次训练状态
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, running_loss / 2000))
            running_loss = 0.0

print("finish training")

因为主要目的是演示代码运行,所以只训练两个epoch,运行结果如下:

16%|█▋        | 2044/12500 [00:04<00:23, 437.93it/s]
[1,  2000] loss: 2.185
 33%|███▎      | 4072/12500 [00:10<00:23, 352.94it/s]
[1,  4000] loss: 1.808
 48%|████▊     | 6062/12500 [00:15<00:17, 367.84it/s]
[1,  6000] loss: 1.637
 65%|██████▍   | 8075/12500 [00:20<00:10, 428.74it/s]
[1,  8000] loss: 1.553
 80%|████████  | 10053/12500 [00:25<00:06, 381.21it/s]
[1, 10000] loss: 1.483
 96%|█████████▋| 12059/12500 [00:29<00:01, 412.74it/s]
[1, 12000] loss: 1.438
100%|██████████| 12500/12500 [00:31<00:00, 401.07it/s]
 16%|█▋        | 2035/12500 [00:04<00:28, 365.83it/s]
[2,  2000] loss: 1.369
 33%|███▎      | 4091/12500 [00:09<00:18, 451.14it/s]
[2,  4000] loss: 1.356
 48%|████▊     | 6062/12500 [00:14<00:15, 409.82it/s]
[2,  6000] loss: 1.330
 64%|██████▍   | 8057/12500 [00:19<00:10, 422.07it/s]
[2,  8000] loss: 1.279
 81%|████████  | 10072/12500 [00:24<00:05, 412.58it/s]
[2, 10000] loss: 1.273
 96%|█████████▌| 12023/12500 [00:29<00:01, 365.16it/s]
[2, 12000] loss: 1.243
100%|██████████| 12500/12500 [00:30<00:00, 406.19it/s]
finish training

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值