homework7——shortcut connections

《Identity Mappings in Deep Residual Networks》中的多种shortcut connections的复现和使用

导入包库

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

MNIST数据集加载

# data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

tra_data = datasets.MNIST(root='./datasets/mnist', transform=transform, train=True, download=False)
test_data = datasets.MNIST(root='./datasets/mnist', transform=transform, train=False, download=False)

tra_loader = DataLoader(dataset=tra_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False)

Various types of shortcut connections

original

在这里插入图片描述

# original
class Original(nn.Module):
    def __init__(self, channels):
        super(Original, self).__init__()

        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)

        self.relu = nn.ReLU()
        self.BN = nn.BatchNorm2d(channels)

    def forward(self, x):
        y = self.relu(self.BN(self.conv1(x)))
        y = self.BN(self.conv2(y))
        return self.relu(x + y)

constant scaling

在这里插入图片描述

# constant scaling
class ConstantScaling(nn.Module):
    def __init__(self, channels):
        super(ConstantScaling, self).__init__()

        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)

        self.BN = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        y = self.relu(self.BN(self.conv1(x)))
        y = self.BN(self.conv2(y))

        x = x * 0.5
        y = y * 0.5
        return self.relu(x + y)

exclusive gating

在这里插入图片描述

# exclusive gating
class ExclusiveGating(nn.Module):
    def __init__(self, channels):
        super(ExclusiveGating, self).__init__()

        self.conv3x3_1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
        self.conv3x3_2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)

        self.conv1x1 = nn.Conv2d(channels, channels, kernel_size=1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.BN = nn.BatchNorm2d(channels)

    def forward(self, x):
        y1 = self.relu(self.BN(self.conv3x3_1(x)))
        y1 = self.BN(self.conv3x3_2(y1))

        y2 = self.BN(self.conv1x1(x))

        y_mul1 = y1 * y2
        y_mul2 = (1 - self.sigmoid(y2)) * x

        return self.relu(y_mul1 + y_mul2)

shortcut-only gating

在这里插入图片描述

# shortcut-only gating
class ShortcutGating(nn.Module):
    def __init__(self, channels):
        super(ShortcutGating, self).__init__()

        self.conv3x3_1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
        self.conv3x3_2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)

        self.conv1x1 = nn.Conv2d(channels, channels, kernel_size=1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.BN = nn.BatchNorm2d(channels)

    def forward(self, x):

        y1 = self.relu(self.BN(self.conv3x3_1(x)))
        y1 = self.BN(self.conv3x3_2(y1))

        y2 = self.sigmoid(self.BN(self.conv1x1(x)))
        y2 = x * (1 - y2)

        return self.relu(y1 + y2)

conv shortcut

在这里插入图片描述

# conv shortcut
class ConvShortcut(nn.Module):
    def __init__(self, channels):
        super(ConvShortcut, self).__init__()

        self.conv3x3_1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
        self.conv3x3_2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)

        self.conv1x1 = nn.Conv2d(channels, channels, kernel_size=1)

        self.relu = nn.ReLU()
        self.BN = nn.BatchNorm2d(channels)

    def forward(self, x):
        y1 = self.relu(self.BN(self.conv3x3_1(x)))
        y1 = self.BN(self.conv3x3_2(y1))

        y2 = self.BN(self.conv1x1(x))

        return self.relu(y1 + y2)

dropout shortcut

在这里插入图片描述

# dropout shortcut
class DropoutShortcut(nn.Module):
    def __init__(self, channels, dropout_rate=0.5):
        super(DropoutShortcut, self).__init__()

        self.conv3x3_1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)
        self.conv3x3_2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, stride=1)

        self.dropout = nn.Dropout(p=dropout_rate)   # Dropout:在训练过程中随机地将一部分神经元的输出设为0。p表示在训练过程中每个神经元被随机丢弃的概率

        self.relu = nn.ReLU()
        self.BN = nn.BatchNorm2d(channels)

    def forward(self, x):
        y1 = self.relu(self.BN(self.conv3x3_1(x)))
        y1 = self.BN(self.conv3x3_2(y1))

        y2 = self.dropout(x)

        return self.relu(y1 + y2)

定义网络结构

只有self.rblock1和self.rblock2(也就是shortcut connection)改变,其他的网络结构保持不变

# Net
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.mp = nn.MaxPool2d(2)
        self.relu = nn.ReLU()

        # self.rblock1 = Original(16)
        # self.rblock2 = Original(32)

        # self.rblock1 = ConstantScaling(16)
        # self.rblock2 = ConstantScaling(32)

        # self.rblock1 = ExclusiveGating(16)
        # self.rblock2 = ExclusiveGating(32)

        # self.rblock1 = ShortcutGating(16)
        # self.rblock2 = ShortcutGating(32)

        # self.rblock1 = ConvShortcut(16)
        # self.rblock2 = ConvShortcut(32)

        self.rblock1 = DropoutShortcut(16)
        self.rblock2 = DropoutShortcut(32)

        self.linear = nn.Linear(512, 10)

    def forward(self, x):
        batch_size = x.size(0)

        x = self.relu(self.mp(self.conv1(x)))
        x = self.rblock1(x)
        x = self.relu(self.mp(self.conv2(x)))
        x = self.rblock2(x)

        x = x.view(batch_size, -1)
        x = self.linear(x)
        return x

# # 查看view后输出的通道数,方便linear层的参数设置
# x = torch.randn(1, 1, 28, 28)
# model = Net_original()
# print(model(x).size(1))

model = Net()
model = model.to(device)

loss and optimizer

criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.5)

train和test函数

def train(epoch):
    running_loss = 0.0
    for i, data in enumerate(tra_loader):
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        y_pred = model(inputs)
        l = criterion(y_pred, targets)

        l.backward()

        optimizer.step()

        running_loss += l.item()
        if i % 300 == 299:
            print('[%d %5d]\tloss: %3f' % (epoch+1, i+1, running_loss / 300))
            running_loss = 0.0


def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for data in test_loader:
            x, labels = data
            x, labels = x.to(device), labels.to(device)

            outputs = model(x)

            total += labels.size(0)
            _, predicted = torch.max(outputs.data, dim=1)
            correct += (predicted == labels).sum().item()

    print("Accuracy on Test is %2f %% [%d %d]" % (100 * correct / total, correct, total))
    return 100 * correct / total

训练和测试

if __name__ == '__main__':
    acc_list = []
    for epoch in range(10):
        train(epoch)
        acc = test()
        acc_list.append(acc)


    # 将训练得到的准确率列表保存到txt中,方便后续画总图,更能直观地对比,文件名改成每次对应的shortcut connection
    acc_list = np.array(acc_list)
    np.savetxt("./acc_list/DropoutShortcut.txt", acc_list)

画图

将每个块的accuracy保存下来以后画图

# 画图
import os

fig, ax = plt.subplots()   # 创建图实例
x = np.linspace(0, 1, 10)  # 创建x的取值范围

path = "./acc_list/"
for name in os.listdir(path):
    txt_name = path + name
    txt = np.loadtxt(fname=txt_name)

    ax.plot(x, txt, label=name[:-4])

ax.set_xlabel('epoch')
ax.set_ylabel('Accuracy')
ax.set_title('Accuracy on Different Networks')
ax.legend()    # 自动检测要在图例中显示的元素,并且显示

plt.show()
plt.close()
  • 11
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值