在cirfa10数据集上实现一个文艺复兴期间的CNN网络VGG
🙋♂️ 张同学,zhangruiyuan@zju.edu.cn 有问题请联系我~
这里是目录呀~
〇、背景介绍
今夜注定不能好好睡觉了。
我想自己实现一个的CNN网络,加深一下自己的知识,前段时间学习深度学习的时候,也听过其他Up讲VGG,这个网络比较简单,所以就想用Pytorch实现一下。
本文使用的数据集是cirfar-10
,包含10种物品,每种物品有6000张彩色、32 x 32大小图片,其中50000张用于训练,10000张用于测试。
本文结构为:对数据集cirfar-10
的读取、可视化、切割,对模型验证的实现将在如本章下代码,章节一、二、三章中实现的网络模型需要配合本章基本代码一起运行。基本代码非本文重点,在此不过多赘述。
import torch
import torchvision
import torchvision.transforms as transforms
# 一、下载数据到本地
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./datasets/5f6b1577787e9d5bb70800a4-momodel', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./datasets/5f6b1577787e9d5bb70800a4-momodel', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 二、图像展示
import matplotlib.pyplot as plt
import numpy as np
# 展示图像的函数
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 获取随机数据
dataiter = iter(trainloader)
images, labels = dataiter.next()
# 展示图像
imshow(torchvision.utils.make_grid(images))
# 显示图像标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
# 三、演示用网络结构(请使用博客中一、二、三章节代码替换此处)
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 输入的形状为 3*32*32
self.c1 = nn.Conv2d(3, 6, 5, padding=1) # 6*32*32
self.b1 = nn.BatchNorm2d(6) # 6*32*32
self.a1 = nn.ReLU() # 6*32*32
self.p1 = nn.MaxPool2d(2,2,padding=1) # 6*16*16
self.d1 = nn.Dropout(p=0.2) # 6*16*16
self.flatten = nn.Flatten() # 1536
self.f1= nn.Linear(1536,128) # 128
self.a2= nn.ReLU() # 128
self.f2= nn.Linear(128,10) # 10
self.a3= nn.Softmax() # 10
def forward(self, x):
x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.p1(x)
x = self.d1(x)
x = self.flatten(x)
x = self.f1(x)
x = self.a2(x)
x = self.f2(x)
x = self.a3(x)
return x
net = Net()
print(net)
# 四、模型训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)
net.to(device)
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01,momentum=0.8)
print('begin')
for epoch in range(5): # 多批次循环
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# 获取输入
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# 梯度置0
optimizer.zero_grad()
# 正向传播,反向传播,优化
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 打印状态信息
running_loss += loss.item()
if i % 500 == 499: # 每3批次打印一次, 在 gpu 上训练请调大此参数
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2))
running_loss = 0.0
print('Finished Training')
# 五、模型准确率计算
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
一、只使用torch.nn.XXX
来构建网络结构
import torch.nn as nn
import torch.nn.functional as F
# VGG 实现
# 打印模型时,会更多地展示模型的细节
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 输入的形状为 3*32*32
self.c1 = nn.Conv2d(3, 64, 3, padding=1)
self.b1 = nn.BatchNorm2d(64)
self.a1 = nn.ReLU()
self.c2 = nn.Conv2d(64, 64, 3, padding=1)
self.b2 = nn.BatchNorm2d(64)
self.a2 = nn.ReLU()
self.p1 = nn.MaxPool2d(2)
self.c3 = nn.Conv2d(64, 128, 3, padding=1)
self.b3 = nn.BatchNorm2d(128)
self.a3 = nn.ReLU()
self.c4 = nn.Conv2d(128, 128, 3, padding=1)
self.b4 = nn.BatchNorm2d(128)
self.a4 = nn.ReLU()
self.p2 = nn.MaxPool2d(2)
self.c5 = nn.Conv2d(128, 256, 3, padding=1)
self.b5 = nn.BatchNorm2d(256)
self.a5 = nn.ReLU()
self.c6 = nn.Conv2d(256, 256, 3, padding=1)
self.b6 = nn.BatchNorm2d(256)
self.a6 = nn.ReLU()
self.c7 = nn.Conv2d(256, 256, 3, padding=1)
self.b7 = nn.BatchNorm2d(256)
self.a7 = nn.ReLU()
self.p3 = nn.MaxPool2d(2)
self.c8 = nn.Conv2d(256, 512, 3, padding=1)
self.b8 = nn.BatchNorm2d(512)
self.a8 = nn.ReLU()
self.c9 = nn.Conv2d(512, 512, 3, padding=1)
self.b9 = nn.BatchNorm2d(512)
self.a9 = nn.ReLU()
self.c10 = nn.Conv2d(512, 512, 3, padding=1)
self.b10 = nn.BatchNorm2d(512)
self.a10 = nn.ReLU()
self.p4 = nn.MaxPool2d(2)
self.c11 = nn.Conv2d(512, 512, 3, padding=1)
self.b11 = nn.BatchNorm2d(512)
self.a11 = nn.ReLU()
self.c12 = nn.Conv2d(512, 512, 3, padding=1)
self.b12 = nn.BatchNorm2d(512)
self.a12 = nn.ReLU()
self.c13 = nn.Conv2d(512, 512, 3, padding=1)
self.b13 = nn.BatchNorm2d(512)
self.a13 = nn.ReLU()
self.p5 = nn.MaxPool2d(2)
self.flatten = nn.Flatten()
self.f1 = nn.Linear(512,512)
self.a14 = nn.ReLU()
self.f2 = nn.Linear(512,512)
self.a15 = nn.ReLU()
self.f3 = nn.Linear(512,10)
self.a16= nn.Softmax()
def forward(self, x):
x = self.c1(x)
x = self.b1(x)
x = self.a1(x)
x = self.c2(x)
x = self.b2(x)
x = self.a2(x)
x = self.p1(x)
x = self.c3(x)
x = self.b3(x)
x = self.a3(x)
x = self.c4(x)
x = self.b4(x)
x = self.a4(x)
x = self.p2(x)
x = self.c5(x)
x = self.b5(x)
x = self.a5(x)
x = self.c6(x)
x = self.b6(x)
x = self.a6(x)
x = self.c7(x)
x = self.b7(x)
x = self.a7(x)
x = self.p3(x)
x = self.c8(x)
x = self.b8(x)
x = self.a8(x)
x = self.c9(x)
x = self.b9(x)
x = self.a9(x)
x = self.c10(x)
x = self.b10(x)
x = self.a10(x)
x = self.p4(x)
x = self.c11(x)
x = self.b11(x)
x = self.a11(x)
x = self.c12(x)
x = self.b12(x)
x = self.a12(x)
x = self.c13(x)
x = self.b13(x)
x = self.a13(x)
x = self.p5(x)
x = self.flatten(x)
x = self.f1(x)
x = self.a14(x)
x = self.f2(x)
x = self.a15(x)
x = self.f3(x)
x = self.a16(x)
return x
net = Net()
print(net)
二、使用Pytorch官方推荐的方式
官方推荐如下:
- 在init中推荐初始化可训练的模型组件,如Conv2d、BatchNorm2d、Linear等,即使用nn.XXX表示的类。
- 对于不可训练的组件,如激活函数、池化操作等,即使用nn.functional.xxx表示的函数,建议不在init中初始化,只在forward中使用。
import torch.nn as nn
import torch.nn.functional as F
# VGG 实现
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 输入的形状为 3*32*32
self.c1 = nn.Conv2d(3, 64, 3, padding=1) # 卷积层1
self.b1 = nn.BatchNorm2d(64) # BN层1
self.c2 = nn.Conv2d(64, 64, 3, padding=1)
self.b2 = nn.BatchNorm2d(64) # BN层2
self.c3 = nn.Conv2d(64, 128, 3, padding=1)
self.b3 = nn.BatchNorm2d(128) # BN层3
self.c4 = nn.Conv2d(128, 128, 3, padding=1)
self.b4 = nn.BatchNorm2d(128) # BN层1
self.c5 = nn.Conv2d(128, 256, 3, padding=1)
self.b5 = nn.BatchNorm2d(256) # BN层5
self.c6 = nn.Conv2d(256, 256, 3, padding=1)
self.b6 = nn.BatchNorm2d(256) # BN层6
self.c7 = nn.Conv2d(256, 256, 3, padding=1)
self.b7 = nn.BatchNorm2d(256)
self.c8 = nn.Conv2d(256, 512, 3, padding=1)
self.b8 = nn.BatchNorm2d(512) # BN层8
self.c9 = nn.Conv2d(512, 512, 3, padding=1)
self.b9 = nn.BatchNorm2d(512) # BN层9
self.c10 = nn.Conv2d(512, 512, 3, padding=1)
self.b10 = nn.BatchNorm2d(512)
self.c11 = nn.Conv2d(512, 512, 3, padding=1)
self.b11 = nn.BatchNorm2d(512) # BN层11
self.c12 = nn.Conv2d(512, 512, 3, padding=1)
self.b12 = nn.BatchNorm2d(512) # BN层12
self.c13 = nn.Conv2d(512, 512, 3, padding=1)
self.b13 = nn.BatchNorm2d(512)
self.flatten = nn.Flatten()
self.f1 = nn.Linear(512,512)
self.f2 = nn.Linear(512,512)
self.f3 = nn.Linear(512,10)
def forward(self, x):
x = self.c1(x)
x = self.b1(x)
x = F.relu(x)
x = self.c2(x)
x = self.b2(x)
x = F.relu(x)
x = F.max_pool2d(x,2)
x = self.c3(x)
x = self.b3(x)
x = F.relu(x)
x = self.c4(x)
x = self.b4(x)
x = F.relu(x)
x = F.max_pool2d(x,2)
x = self.c5(x)
x = self.b5(x)
x = F.relu(x)
x = self.c6(x)
x = self.b6(x)
x = F.relu(x)
x = self.c7(x)
x = self.b7(x)
x = F.relu(x)
x = F.max_pool2d(x,2)
x = self.c8(x)
x = self.b8(x)
x = F.relu(x)
x = self.c9(x)
x = self.b9(x)
x = F.relu(x)
x = self.c10(x)
x = self.b10(x)
x = F.relu(x)
x = F.max_pool2d(x,2)
x = self.c11(x)
x = self.b11(x)
x = F.relu(x)
x = self.c12(x)
x = self.b12(x)
x = F.relu(x)
x = self.c13(x)
x = self.b13(x)
x = F.relu(x)
x = F.max_pool2d(x,2)
x = self.flatten(x)
x = self.f1(x)
x = F.relu(x)
x = self.f2(x)
x = F.relu(x)
x = self.f3(x)
x = F.softmax(x)
return x
net = Net()
print(net)
三、使用nn.Sequential
来简化代码(推荐)
说实话,我觉得对于新手而言,前面的两种方式才是总容易理解的方式。接下来,在理解了网络结构之后,我们可以使用本节的代码来帮助我们优化代码结构。我主观地觉得,如果上课只讲当前的结构的话,可能会比较简介,但不利于学生的理解。
import torch.nn as nn
import torch.nn.functional as F
class VGG(nn.Module):
"""
VGG builder
"""
def __init__(self, arch: object, num_classes=1000) -> object:
super(VGG, self).__init__()
self.in_channels = 3
self.conv3_64 = self.__make_layer(64, arch[0])
self.conv3_128 = self.__make_layer(128, arch[1])
self.conv3_256 = self.__make_layer(256, arch[2])
self.conv3_512a = self.__make_layer(512, arch[3])
self.conv3_512b = self.__make_layer(512, arch[4])
self.fc1 = nn.Linear(512, 512)
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(512)
self.fc2 = nn.Linear(512, 512)
self.fc3 = nn.Linear(512, num_classes)
def __make_layer(self, channels, num):
layers = []
for i in range(num):
layers.append(nn.Conv2d(self.in_channels, channels, 3, stride=1, padding=1, bias=False)) # same padding
layers.append(nn.BatchNorm2d(channels))
layers.append(nn.ReLU())
self.in_channels = channels
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv3_64(x)
out = F.max_pool2d(out, 2)
out = self.conv3_128(out)
out = F.max_pool2d(out, 2)
out = self.conv3_256(out)
out = F.max_pool2d(out, 2)
out = self.conv3_512a(out)
out = F.max_pool2d(out, 2)
out = self.conv3_512b(out)
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = self.fc1(out)
out = self.bn1(out)
out = F.relu(out)
out = self.fc2(out)
out = self.bn2(out)
out = F.relu(out)
return F.softmax(self.fc3(out))
def VGG_11():
return VGG([1, 1, 2, 2, 2], num_classes=10)
def VGG_13():
return VGG([1, 1, 2, 2, 2], num_classes=10)
def VGG_16():
return VGG([2, 2, 3, 3, 3], num_classes=10)
def VGG_19():
return VGG([2, 2, 4, 4, 4], num_classes=10)
net = VGG_16()
print(net)
结、最后一个故事(可以划走了)
我在实现的时候遇见了一个最简单的bug,我在nn.module
继承的函数 forward
中使用x
保存我的中间结果,结果到了最后一步,也不知道我当时哪根筋错了,想着返回值要用y
。本来这也没什么问题,结果改着改着网络结构,就把这些东西搞混了。最后造成的结果是,y中只进行了一次全连接,一次Relu
激活,在训练的过程中,损失率一直不下降,测试集准确率差到离谱,我一度怀疑是pytorch
文档我没看懂的问题,给爷整怀疑了。