作为深度学习奠基作之一,《ImageNet Classification with Deep Convolutional Neural Networks》由Alex Krizhevsky等人在2012年发表,AlexNet就此横空出世,并以巨大的优势赢得了2012年ImageNet图像识别挑战赛冠军,一举打破了当时计算机视觉研究的现状。
从文章中看
作者在文中对AlexNet的网络结构和各层的基本参数进行了介绍,对于网络的新颖特点展开了详细描述。十多年后,回看这篇经典文章,依然能感到收获颇丰。
网络结构
AlexNet由8层构成,其中前5层为卷积层,后3层为全连接层,作者在文章中给出了具体结构示意图。
作者将AlexNet部署在两块GPU上训练,因此整体模型被分为两部分,在观察结构图时会存在一定障碍。
在这里引用《Dive Into Deep Learning》中给出的ALexNet稍微精简版本的结构表示,去除了当年需要两块GPU同时运算的设计特点,也便于后续对模型代码复现时可以有所参考。
特点
作者在文章中着重介绍了AlexNet的特点,这些特点对模型训练效率和最终效果带来了很大的提升,具体包含以下几个方面。
(1)ReLU
在激活函数的选择上,AlexNet没有选择常规的tanh和sigmoid,而是使用了ReLU,这一选择使得模型在训练速度上快了好几倍。
(2)多GPU训练
单GPU受限于内存,无法满足大模型训练的需要,因此作者使用了两块GPU并行训练的方式,并设计了GPU间的通信策略,有效节约了训练时间开销。
(3)重叠池化
对于传统的池化方法进行了微小改动,使用重叠池化的方式降低过拟合的风险。
应对过拟合
在大模型的训练过程中,过拟合现象是需要想办法来避免的,作者为此使用了以下两种方法。
(1)数据增广
通过对数据集进行数据增广,进而降低过拟合的风险。作者在文中使用了对图像进行平移和水平映射以及对颜色通道进行主成分分析(PCA)的方式完成数据增广。
(2)Dropout
将全连接层中每个神经元的输出以一定概率设置为零,作者在文中设置概率为0.5。因此对于每次输入,神经网络都会对不同的架构进行采样。作者在文中提到,如果没有使用Dropout,网络就会表现出大量的过拟合。
PyTorch简易实现
在此尝试了基于PyTorch对AlexNet模型的实现。由于是简易实现,只保留了AlexNet的基本实现思路,模型的训练使用MNIST数据集,
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
# 定义模型
class AlexNet(nn.Module):
def __init__(self):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 96, kernel_size=11, stride=4,padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(6400, 4096),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(4096, 10)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# 定义数据预处理函数
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
# 加载训练集和验证集
train_dataset = MNIST(root='../data/', train=True, transform=transform, download=True)
validation_dataset = MNIST(root='../data/', train=False, transform=transform)
# 定义数据加载器
batch_size = 32
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle=False)
model = AlexNet()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 定义损失函数,随机梯度下降
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
train_corrects = 0
train_loss = 0
model.train()
for data in train_loader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
_, preds = torch.max(outputs, 1)
train_corrects += torch.sum(preds == labels.data)
train_loss += loss.item() * inputs.size(0)
train_acc = train_corrects.double() / len(train_dataset)
train_loss = train_loss / len(train_dataset)
print("Epoch:{}/{} Train Loss:{:.4f} Train Acc: {:.4f}".format(epoch+1, num_epochs, train_loss, train_acc))
# 计算验证集准确率和损失
val_corrects = 0
val_loss = 0
model.eval()
for data in validation_loader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
val_corrects += torch.sum(preds == labels.data)
val_loss += loss.item() * inputs.size(0)
val_acc = val_corrects.double() / len(validation_dataset)
val_loss = val_loss / len(validation_dataset)
print("Epoch:{}/{} Val Loss:{:.4f} Val Acc: {:.4f}".format(epoch+1, num_epochs, val_loss, val_acc))
参考文献
《ImageNet Classification with Deep Convolutional Neural Networks》
《Dive Into Deep Learning》