《Identity Mappings in Deep Residual Networks》中的多种shortcut connections的复现和使用
文章目录
导入包库
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MNIST数据集加载
# data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
tra_data = datasets.MNIST(root='./datasets/mnist', transform=transform, train=True, download=False)
test_data = datasets.MNIST(root='./datasets/mnist', transform=transform, train=False, download=False)
tra_loader = DataLoader(dataset=tra_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False)
Various types of shortcut connections
original
# original
class Original(nn.Module):
def __init__(self, channels):
super(Original, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.relu = nn.ReLU()
self.BN = nn.BatchNorm2d(channels)
def forward(self, x):
y = self.relu(self.BN(self.conv1(x)))
y = self.BN(self.conv2(y))
return self.relu(x + y)
constant scaling
# constant scaling
class ConstantScaling(nn.Module):
def __init__(self, channels):
super(ConstantScaling, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.BN = nn.BatchNorm2d(channels)
self.relu = nn.ReLU()
def forward(self, x):
y = self.relu(self.BN(self.conv1(x)))
y = self.BN(self.conv2(y))
x = x * 0.5
y = y * 0.5
return self.relu(x + y)
exclusive gating
# exclusive gating
class ExclusiveGating(nn.Module):
def __init__(self, channels):
super(ExclusiveGating, self).__init__()
self.conv3x3_1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.conv3x3_2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.conv1x1 = nn.Conv2d(channels, channels, kernel_size=1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.BN = nn.BatchNorm2d(channels)
def forward(self, x):
y1 = self.relu(self.BN(self.conv3x3_1(x)))
y1 = self.BN(self.conv3x3_2(y1))
y2 = self.BN(self.conv1x1(x))
y_mul1 = y1 * y2
y_mul2 = (1 - self.sigmoid(y2)) * x
return self.relu(y_mul1 + y_mul2)
shortcut-only gating
# shortcut-only gating
class ShortcutGating(nn.Module):
def __init__(self, channels):
super(ShortcutGating, self).__init__()
self.conv3x3_1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.conv3x3_2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.conv1x1 = nn.Conv2d(channels, channels, kernel_size=1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.BN = nn.BatchNorm2d(channels)
def forward(self, x):
y1 = self.relu(self.BN(self.conv3x3_1(x)))
y1 = self.BN(self.conv3x3_2(y1))
y2 = self.sigmoid(self.BN(self.conv1x1(x)))
y2 = x * (1 - y2)
return self.relu(y1 + y2)
conv shortcut
# conv shortcut
class ConvShortcut(nn.Module):
def __init__(self, channels):
super(ConvShortcut, self).__init__()
self.conv3x3_1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.conv3x3_2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.conv1x1 = nn.Conv2d(channels, channels, kernel_size=1)
self.relu = nn.ReLU()
self.BN = nn.BatchNorm2d(channels)
def forward(self, x):
y1 = self.relu(self.BN(self.conv3x3_1(x)))
y1 = self.BN(self.conv3x3_2(y1))
y2 = self.BN(self.conv1x1(x))
return self.relu(y1 + y2)
dropout shortcut
# dropout shortcut
class DropoutShortcut(nn.Module):
def __init__(self, channels, dropout_rate=0.5):
super(DropoutShortcut, self).__init__()
self.conv3x3_1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.conv3x3_2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
self.dropout = nn.Dropout(p=dropout_rate) # Dropout:在训练过程中随机地将一部分神经元的输出设为0。p表示在训练过程中每个神经元被随机丢弃的概率
self.relu = nn.ReLU()
self.BN = nn.BatchNorm2d(channels)
def forward(self, x):
y1 = self.relu(self.BN(self.conv3x3_1(x)))
y1 = self.BN(self.conv3x3_2(y1))
y2 = self.dropout(x)
return self.relu(y1 + y2)
定义网络结构
只有self.rblock1和self.rblock2(也就是shortcut connection)改变,其他的网络结构保持不变
# Net
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
self.mp = nn.MaxPool2d(2)
self.relu = nn.ReLU()
# self.rblock1 = Original(16)
# self.rblock2 = Original(32)
# self.rblock1 = ConstantScaling(16)
# self.rblock2 = ConstantScaling(32)
# self.rblock1 = ExclusiveGating(16)
# self.rblock2 = ExclusiveGating(32)
# self.rblock1 = ShortcutGating(16)
# self.rblock2 = ShortcutGating(32)
# self.rblock1 = ConvShortcut(16)
# self.rblock2 = ConvShortcut(32)
self.rblock1 = DropoutShortcut(16)
self.rblock2 = DropoutShortcut(32)
self.linear = nn.Linear(512, 10)
def forward(self, x):
batch_size = x.size(0)
x = self.relu(self.mp(self.conv1(x)))
x = self.rblock1(x)
x = self.relu(self.mp(self.conv2(x)))
x = self.rblock2(x)
x = x.view(batch_size, -1)
x = self.linear(x)
return x
# # 查看view后输出的通道数,方便linear层的参数设置
# x = torch.randn(1, 1, 28, 28)
# model = Net_original()
# print(model(x).size(1))
model = Net()
model = model.to(device)
loss and optimizer
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.5)
train和test函数
def train(epoch):
running_loss = 0.0
for i, data in enumerate(tra_loader):
inputs, targets = data
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
y_pred = model(inputs)
l = criterion(y_pred, targets)
l.backward()
optimizer.step()
running_loss += l.item()
if i % 300 == 299:
print('[%d %5d]\tloss: %3f' % (epoch+1, i+1, running_loss / 300))
running_loss = 0.0
def test():
total = 0
correct = 0
with torch.no_grad():
for data in test_loader:
x, labels = data
x, labels = x.to(device), labels.to(device)
outputs = model(x)
total += labels.size(0)
_, predicted = torch.max(outputs.data, dim=1)
correct += (predicted == labels).sum().item()
print("Accuracy on Test is %2f %% [%d %d]" % (100 * correct / total, correct, total))
return 100 * correct / total
训练和测试
if __name__ == '__main__':
acc_list = []
for epoch in range(10):
train(epoch)
acc = test()
acc_list.append(acc)
# 将训练得到的准确率列表保存到txt中,方便后续画总图,更能直观地对比,文件名改成每次对应的shortcut connection
acc_list = np.array(acc_list)
np.savetxt("./acc_list/DropoutShortcut.txt", acc_list)
画图
将每个块的accuracy保存下来以后画图
# 画图
import os
fig, ax = plt.subplots() # 创建图实例
x = np.linspace(0, 1, 10) # 创建x的取值范围
path = "./acc_list/"
for name in os.listdir(path):
txt_name = path + name
txt = np.loadtxt(fname=txt_name)
ax.plot(x, txt, label=name[:-4])
ax.set_xlabel('epoch')
ax.set_ylabel('Accuracy')
ax.set_title('Accuracy on Different Networks')
ax.legend() # 自动检测要在图例中显示的元素,并且显示
plt.show()
plt.close()