ResNet学习笔记及仿真

ResNet学习笔记及仿真

前言

深度残差网络(Deep Residual Network)是cnn机器视觉史上的一件里程碑的事件,在2015年,ResNet在ImageNet和COCO数据集上获得了非常优秀的成绩。如下所示:

  • ImageNet Classification: “Ultra-deep”152-layer nets
  • ImageNet Detection: 16% better than 2nd
  • ImageNet Localization: 27% better than 2nd
  • COCO Detection: 11% better than 2nd
  • COCO Segmentation: 12% better than 2nd

ResNet获得了五项第一,再一次的刷新了CNN模型在ImageNet上的历史[1],而论文的作者何凯明也获得了CVPR2016最佳论文奖[2]。以下是记录了解ResNet的一些笔记,主要是来自于15年和16年何凯明所在的微软研究院发表的论文。

深度网络退化问题(degradation problem of deep network)

AlexNetGoogLeNet,看起来好像是网络越深越好,但是直接秉持着We need go deeper的思想构建神经网络,到最后会发现,随着网络越来越深,会出现梯度消失(vanishing gradients)梯度爆炸(exploding gradients)以及网络退化(network degradation)

  • 梯度消失和爆炸(vanishing/exploding gradients):网络层数太深引发的梯度方向传播中的连乘效应引起
  • 网络退化(network degradation):较深的模型可以看作是较浅模型的超空间,理论上较深的模型可以看作是较浅模型的恒等映射,但是实际上较深模型后面添加的不是恒等映射,而是一些非线性层[3]

对于梯度消失和爆炸的应对方法如下:

  • 改换激活函数: 使用reluLeakyReluELU等激活函数可以改善梯度消散或爆炸问题。relu导数的正数部分恒等于1,所以不会产生梯度消失和梯度爆炸
  • BatchNormalization: 对每一层的输入做scale和shift方法,将每层神经元的输入分布强行拉回均值为0、方差为1的标准正态分布,这就使得激活层输入值落入在非线性函数对输入值比较敏感的区域,使得输入的小变化会导致损失函数较大的变化,使得梯度变大,训练速度加快,且避免梯度消失问题
  • 梯度剪切: 该方法主要是针对梯度爆炸提出。其思想是设置一个梯度剪切阈值,更新梯度时,如果梯度超过这个阈值,那么限制其在这个范围之内

但是随着网络的加深,可以看到的是训练误差和测试误差都开始增加,这自然不是过拟合引起的,而是网络出现退化[4],如figure1所示:

fig1

网络退化表明了并非所有系统都同样容易优化。考虑一个较深的网络和一个较浅的网络,更深的网络对应更浅的网络相当于是增加了更多的层,添加的层可以是恒等映射(identity mapping),而其他的层是相当于是更浅的模型中直接复制的,这种构造容易得到,较深的模型不会产生比较浅的模型更高的训练误差,但是实验表明,简单地加深网络模型会出现网络退化的问题。

残差块(Residual block)

fig2

网络退化出现的原因就是现有的网络无法实现恒等映射,将想要得到的恒等映射函数表示为 H ( x ) = x H(x)=x H(x)=x,残差块网络设计为 H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x,即直接将恒等映射作为网络的一部分,就可以将问题转化为学习一个残差函数 F ( x ) = H ( x ) − x F(x)=H(x)-x F(x)=H(x)x,那么只要 F ( x ) = 0 F(x)=0 F(x)=0,就可以构成一个恒等映射 H ( x ) = x H(x)=x H(x)=x,拟合残差比拟合恒等映射容易的多[5]。从数学上更加详细地描述残差结构,可以是:

y = F ( x , W i ) + x y=F(x, W_i)+x y=F(x,Wi)+x

其中 x x x y y y 表示的分别是所考虑的层的输入和输出向量,函数 F ( x , W i ) F(x, W_i) F(x,Wi) 表示要学习的残差映射,操作 F ( x , W i ) + x F(x, W_i)+x F(x,Wi)+x 是通过跳接实现,在方程1 x x x F F F 的维度必须相同,否则,需要对跳接进行线性投影 W s W_s Ws 来匹配维度:

y = F ( x , W i ) + W s x y=F(x, W_i)+W_s x y=F(x,Wi)+Wsx

  • F ( x ) + x F(x)+x F(x)+x 的公式可以通过具有跳接(shortcut connections)的前馈神经网络来实现,跳接可以是跳过一层或者多层的连接,通过跳接在激活函数前,将上一层或几层的输出与本层输出相加,将求和结果输入到激活函数作为本层的输出,残差块示例如figure2所示

  • 跳接只是执行恒等映射,他们的输出被添加到堆叠层的输出中,这不会增加额外的参数,也不会增加计算复杂性

  • 添加了残差网络,可以给神经网络一个多的选择,例如学习到的一层的参数是冗余的,那么就可以直接走跳接路线,跳过这个冗余层,而不用去拟合参数使得输出 H ( x ) = x H(x)=x H(x)=x

  • 学习残差的计算量比学习输出等于输入小:例如普通网络为A,残差网络为B,输入为2,输出为2,那么普通网络就是 A ( 2 ) = 2 A(2)=2 A(2)=2,而残差网络就是 B ( 2 ) = F ( 2 ) + 2 = 2 B(2)=F(2)+2=2 B(2)=F(2)+2=2,即残差网络中 F ( 2 ) = 0 F(2)=0 F(2)=0。一般网络会将权重初始化为0附近的数,所以让 F ( 2 ) F(2) F(2)拟合0会比 A ( 2 ) = 2 A(2)=2 A(2)=2容易

  • ReLU能够将负数激活为0,正数等于输出,这相当于过滤了负数的线性变化,让 F ( x ) = 0 F(x)=0 F(x)=0变的更加容易

请添加图片描述

  • 对残差网络 H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x求梯度,即反向传播时,得到 H ′ ( x ) = F ′ ( x ) + 1 H'(x)=F'(x)+1 H(x)=F(x)+1残差结构的这个常数1能够保证求梯度的时候梯度不会消失

  • 这种结构不仅适用于全连接层,还适用于卷积层,逐通道地对两个特征图执行元素相加

网络架构

ResNet结构

网络结构如figure3所示,从左到右分别是VGG-19(196亿次浮点运算)、34层不带残差的plain net(36亿次浮点运算)、34层的残差网络(36亿次浮点运算)

plain network
  • 主要受到VGGNet的启发,遵循两个简单设计规则:
    • 对于相同的输出特征图大小,层具有相同数量的滤波器
    • 如果特征图大小减半,则过滤器的数量加倍
  • 步长为2的卷积层直接执行下采样
  • 网络结尾是一个整体平均池化层和一个1000路的全连接层和softmax函数
  • 总体的带权重的层是34层
  • 该模型比VGGNet相比具有更少的滤波器和更低的复杂度,plain net 有36亿次浮点运算,而VGG-19有196亿次浮点运算,前者是后者的18%
Residual Network
  • 在plain net网络中引入了跳接(shortcut conncetions),将其转换为了对应的残差版本
  • 跳接引入有两种形式:
    • 实线:跳接引入的输入和输出维度相同,可以直接相加,如公式1
    • 虚线:引入的维度增加时,可以有两种方式
      • 跳接仍然执行恒等映射,填充零元素保证维度相同
      • 利用公式2对跳接进行投影来匹配维度

实现细节

  • 遵循AlexNet数据预处理的方法,对图像进行裁剪和水平翻转得到224 x 224的图像,然后减去每个像素的平均值[6]
  • 每次卷积之后和激活函数之前采用批处理标准化(batch normalization, BN)
  • 批大小(mini-batch )为256
  • 学习率(learning rate) 从 0.1 开始,当误差平稳时,学习率除以10,模型训练了 60 × 1 0 4 60 \times 10^4 60×104 次迭代
  • 权重衰减(weight decay)0.0001,动量(momentum)为 0.9
  • 网络中没有使用到dropout

实验结果

ImageNet分类结果

对比了18层和34层的plain net以及对应的ResNet,可以看到如figure4所示,残差结果确实解决了退化问题

fig4

对比了ImageNet数据集的测试结果汇总如表3所示

基于cifar10数据集的ResNet实现[7]

导入基本模块

from __future__ import print_function
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

torchvision下载cifar10数据集

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_set = torchvision.datasets.CIFAR10('../data', train=True, 
                                         download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10('../data', train=False, 
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, 
                                           shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=100, 
                                          shuffle=False, num_workers=4)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
Using downloaded and verified file: ../data\cifar-10-python.tar.gz
Extracting ../data\cifar-10-python.tar.gz to ../data
Files already downloaded and verified
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
    img = img / 2 + 0.5
    plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))
    plt.show()
image_iter = iter(train_loader)
images, _ = image_iter.next()
imshow(torchvision.utils.make_grid(images[:4]))

在这里插入图片描述

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'

BasicBlock

在这里插入图片描述

class BasicBlock(nn.Module):
    """
    对于浅层网络,我们使用基本的Block
    基础块没有维度压缩,所以expansion=1
    """
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

        # 如果输入输出维度不等,则使用1x1卷积层来改变维度
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels),
            )

            

    def forward(self, x):
        out = self.features(x)
        out += self.shortcut(x)
        out = torch.relu(out)
        return out
# 测试
basic_block = BasicBlock(64, 128)
print(basic_block)
x = torch.randn(2, 64, 32, 32)
y = basic_block(x)
print(y.shape)
BasicBlock(
  (features): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (shortcut): Sequential(
    (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
torch.Size([2, 128, 32, 32])

Bottleneck Block

在这里插入图片描述

class Bottleneck(nn.Module):
    """
    对于深层网络,我们使用BottleNeck,论文中提出其拥有近似的计算复杂度,但能节省很多资源
    zip_channels: 压缩后的维数,最后输出的维数是 expansion * zip_channels
    """
    expansion = 4
    def __init__(self, in_channels, zip_channels, stride=1):
        super(Bottleneck, self).__init__()
        out_channels = self.expansion * zip_channels
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, zip_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(zip_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(zip_channels, zip_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(zip_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(zip_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
    def forward(self, x):
        out = self.features(x)
        out += self.shortcut(x)
        out = torch.relu(out)
        return out
# 测试
bottleneck = Bottleneck(256, 128)
print(bottleneck)
x = torch.randn(2, 256, 32, 32)
y = bottleneck(x)
print(y.shape)
Bottleneck(
  (features): Sequential(
    (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (shortcut): Sequential(
    (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
torch.Size([2, 512, 32, 32])

ResNet

在这里插入图片描述

class ResNet(nn.Module):
    """
    不同的ResNet架构都是统一的一层特征提取、四层残差,不同点在于每层残差的深度。
    对于cifar10,feature map size的变化如下:
    (32, 32, 3) -> [Conv2d] -> (32, 32, 64) -> [Res1] -> (32, 32, 64) -> [Res2] 
 -> (16, 16, 128) -> [Res3] -> (8, 8, 256) ->[Res4] -> (4, 4, 512) -> [AvgPool] 
 -> (1, 1, 512) -> [Reshape] -> (512) -> [Linear] -> (10)
    """
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # cifar10经过上述结构后,到这里的feature map size是 4 x 4 x 512 x expansion
        # 所以这里用了 4 x 4 的平均池化
        self.avg_pool = nn.AvgPool2d(kernel_size=4)
        self.classifer = nn.Linear(512 * block.expansion, num_classes)

        

    def _make_layer(self, block, out_channels, num_blocks, stride):
        # 第一个block要进行降采样
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            # 如果是Bottleneck Block的话需要对每层输入的维度进行压缩,压缩后再增加维数
            # 所以每层的输入维数也要跟着变
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    

    def forward(self, x):
        out = self.features(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.classifer(out)
        return out
def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])
net = ResNet34().to(device)
print(net)
if device == 'cuda':
    net = nn.DataParallel(net)
    # 当计算图不会改变的时候(每次输入形状相同,模型不改变)的情况下可以提高性能,反之则降低性能
    torch.backends.cudnn.benchmark = True
ResNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (features): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (features): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (2): BasicBlock(
      (features): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (features): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (features): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (2): BasicBlock(
      (features): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (3): BasicBlock(
      (features): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (features): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (features): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (2): BasicBlock(
      (features): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (3): BasicBlock(
      (features): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (4): BasicBlock(
      (features): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (5): BasicBlock(
      (features): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (features): Sequential(
        (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (features): Sequential(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
    (2): BasicBlock(
      (features): Sequential(
        (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (shortcut): Sequential()
    )
  )
  (avg_pool): AvgPool2d(kernel_size=4, stride=4, padding=0)
  (classifer): Linear(in_features=512, out_features=10, bias=True)
)
# 测试
x = torch.randn(2, 3, 32, 32).to(device)
y = net(x)
print(y.shape)
torch.Size([2, 10])

train model

lr = 1e-1
momentum = 0.9
weight_decay = 5e-4
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, threshold=0.1, patience=3, verbose=True)
# Training
def train(epoch):
    print('\nEpoch: %d' % (epoch))
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx % 100 == 99:    # print every 100 mini-batches
            print('[%d, %5d] loss: %.6f |  Acc: %.3f%% (%d/%d)' %
                  (epoch + 1, batch_idx + 1, train_loss, 100.*correct/total, correct, total))
    return train_loss
load_model = False
if load_model:
    checkpoint = torch.load('./checkpoint/res18.ckpt')
    net.load_state_dict(checkpoint['net'])
    start_epoch = checkpoint['epoch']
else:
    start_epoch = 0
print('start_epoch: %s' % start_epoch)
start_epoch: 0
for epoch in range(start_epoch, 50):
    loss = train(epoch)
    print('Total loss: %.6f' % loss)
    start_epoch = epoch
    scheduler.step(loss)
Epoch: 0
[1,   100] loss: 118.672118 |  Acc: 56.953% (7290/12800)
[1,   200] loss: 235.403772 |  Acc: 57.355% (14683/25600)
[1,   300] loss: 342.972803 |  Acc: 58.721% (22549/38400)
Total loss: 436.702453

Epoch: 1
[2,   100] loss: 95.490529 |  Acc: 65.930% (8439/12800)
[2,   200] loss: 189.245391 |  Acc: 66.246% (16959/25600)
[2,   300] loss: 279.333860 |  Acc: 66.758% (25635/38400)
Total loss: 357.071455

Epoch: 2
[3,   100] loss: 76.589579 |  Acc: 73.203% (9370/12800)
[3,   200] loss: 151.513288 |  Acc: 73.492% (18814/25600)
[3,   300] loss: 224.068864 |  Acc: 73.836% (28353/38400)
Total loss: 286.139592

Epoch: 3
[4,   100] loss: 62.455524 |  Acc: 78.664% (10069/12800)
[4,   200] loss: 125.946750 |  Acc: 78.246% (20031/25600)
[4,   300] loss: 186.423765 |  Acc: 78.414% (30111/38400)
Total loss: 240.281207

Epoch: 4
[5,   100] loss: 54.547970 |  Acc: 81.414% (10421/12800)
[5,   200] loss: 110.654127 |  Acc: 80.898% (20710/25600)
[5,   300] loss: 166.249208 |  Acc: 80.919% (31073/38400)
Total loss: 215.950420

Epoch: 5
[6,   100] loss: 48.104260 |  Acc: 83.109% (10638/12800)
[6,   200] loss: 99.401246 |  Acc: 82.734% (21180/25600)
[6,   300] loss: 148.791911 |  Acc: 82.807% (31798/38400)
Total loss: 194.453984

Epoch: 6
[7,   100] loss: 44.256123 |  Acc: 85.352% (10925/12800)
[7,   200] loss: 90.796863 |  Acc: 84.582% (21653/25600)
[7,   300] loss: 138.052944 |  Acc: 84.396% (32408/38400)
Total loss: 181.198413

Epoch: 7
[8,   100] loss: 40.750996 |  Acc: 85.906% (10996/12800)
[8,   200] loss: 84.440442 |  Acc: 85.539% (21898/25600)
[8,   300] loss: 127.751372 |  Acc: 85.430% (32805/38400)
Total loss: 168.756287

Epoch: 8
[9,   100] loss: 40.119882 |  Acc: 86.266% (11042/12800)
[9,   200] loss: 79.863018 |  Acc: 86.211% (22070/25600)
[9,   300] loss: 120.620995 |  Acc: 86.188% (33096/38400)
Total loss: 158.876436

Epoch: 9
[10,   100] loss: 35.623312 |  Acc: 87.477% (11197/12800)
[10,   200] loss: 75.740778 |  Acc: 86.777% (22215/25600)
[10,   300] loss: 115.162053 |  Acc: 86.703% (33294/38400)
Total loss: 151.007361

Epoch: 10
[11,   100] loss: 34.881428 |  Acc: 88.039% (11269/12800)
[11,   200] loss: 71.926582 |  Acc: 87.699% (22451/25600)
[11,   300] loss: 109.965547 |  Acc: 87.375% (33552/38400)
Total loss: 145.488317

Epoch: 11
[12,   100] loss: 32.708189 |  Acc: 88.977% (11389/12800)
[12,   200] loss: 66.790455 |  Acc: 88.594% (22680/25600)
[12,   300] loss: 103.832237 |  Acc: 88.125% (33840/38400)
Total loss: 137.913376

Epoch: 12
[13,   100] loss: 31.950675 |  Acc: 89.242% (11423/12800)
[13,   200] loss: 65.730325 |  Acc: 88.820% (22738/25600)
[13,   300] loss: 101.885522 |  Acc: 88.430% (33957/38400)
Total loss: 135.173613

Epoch: 13
[14,   100] loss: 29.611200 |  Acc: 89.805% (11495/12800)
[14,   200] loss: 62.823584 |  Acc: 89.129% (22817/25600)
[14,   300] loss: 97.116191 |  Acc: 88.849% (34118/38400)
Total loss: 129.670478

Epoch: 14
[15,   100] loss: 29.925015 |  Acc: 89.875% (11504/12800)
[15,   200] loss: 62.226747 |  Acc: 89.516% (22916/25600)
[15,   300] loss: 95.177161 |  Acc: 89.206% (34255/38400)
Total loss: 126.196716

Epoch: 15
[16,   100] loss: 28.724815 |  Acc: 90.039% (11525/12800)
[16,   200] loss: 60.983424 |  Acc: 89.551% (22925/25600)
[16,   300] loss: 93.304751 |  Acc: 89.354% (34312/38400)
Total loss: 123.906554

Epoch: 16
[17,   100] loss: 27.764434 |  Acc: 90.305% (11559/12800)
[17,   200] loss: 57.115116 |  Acc: 90.148% (23078/25600)
[17,   300] loss: 89.535789 |  Acc: 89.685% (34439/38400)
Total loss: 118.733271

Epoch: 17
[18,   100] loss: 26.901688 |  Acc: 90.781% (11620/12800)
[18,   200] loss: 56.244663 |  Acc: 90.316% (23121/25600)
[18,   300] loss: 87.834935 |  Acc: 89.872% (34511/38400)
Total loss: 116.597480

Epoch: 18
[19,   100] loss: 27.289408 |  Acc: 90.633% (11601/12800)
[19,   200] loss: 57.436502 |  Acc: 90.137% (23075/25600)
[19,   300] loss: 88.500381 |  Acc: 89.846% (34501/38400)
Total loss: 115.314192

Epoch: 19
[20,   100] loss: 24.068543 |  Acc: 91.852% (11757/12800)
[20,   200] loss: 53.208921 |  Acc: 90.828% (23252/25600)
[20,   300] loss: 84.727040 |  Acc: 90.203% (34638/38400)
Total loss: 112.072869
Epoch    21: reducing learning rate of group 0 to 1.0000e-02.

Epoch: 20
[21,   100] loss: 17.140250 |  Acc: 94.570% (12105/12800)
[21,   200] loss: 30.638147 |  Acc: 95.156% (24360/25600)
[21,   300] loss: 41.904663 |  Acc: 95.573% (36700/38400)
Total loss: 51.213734

Epoch: 21
[22,   100] loss: 7.424102 |  Acc: 97.984% (12542/12800)
[22,   200] loss: 14.609958 |  Acc: 97.918% (25067/25600)
[22,   300] loss: 21.400117 |  Acc: 97.964% (37618/38400)
Total loss: 27.305064

Epoch: 22
[23,   100] loss: 5.123270 |  Acc: 98.586% (12619/12800)
[23,   200] loss: 9.734514 |  Acc: 98.684% (25263/25600)
[23,   300] loss: 14.615595 |  Acc: 98.648% (37881/38400)
Total loss: 19.189702

Epoch: 23
[24,   100] loss: 3.460799 |  Acc: 99.172% (12694/12800)
[24,   200] loss: 6.807557 |  Acc: 99.203% (25396/25600)
[24,   300] loss: 10.049018 |  Acc: 99.211% (38097/38400)
Total loss: 13.315432

Epoch: 24
[25,   100] loss: 2.478284 |  Acc: 99.469% (12732/12800)
[25,   200] loss: 4.640014 |  Acc: 99.492% (25470/25600)
[25,   300] loss: 6.763096 |  Acc: 99.505% (38210/38400)
Total loss: 9.023635

Epoch: 25
[26,   100] loss: 1.528404 |  Acc: 99.680% (12759/12800)
[26,   200] loss: 2.968595 |  Acc: 99.711% (25526/25600)
[26,   300] loss: 4.535004 |  Acc: 99.706% (38287/38400)
Total loss: 5.932488

Epoch: 26
[27,   100] loss: 1.165903 |  Acc: 99.852% (12781/12800)
[27,   200] loss: 2.107836 |  Acc: 99.867% (25566/25600)
[27,   300] loss: 3.091869 |  Acc: 99.875% (38352/38400)
Total loss: 4.153865

Epoch: 27
[28,   100] loss: 0.732892 |  Acc: 99.945% (12793/12800)
[28,   200] loss: 1.651403 |  Acc: 99.883% (25570/25600)
[28,   300] loss: 2.452116 |  Acc: 99.888% (38357/38400)
Total loss: 3.153302

Epoch: 28
[29,   100] loss: 0.618112 |  Acc: 99.945% (12793/12800)
[29,   200] loss: 1.297444 |  Acc: 99.941% (25585/25600)
[29,   300] loss: 1.891233 |  Acc: 99.943% (38378/38400)
Total loss: 2.471666

Epoch: 29
[30,   100] loss: 0.548403 |  Acc: 99.945% (12793/12800)
[30,   200] loss: 1.019150 |  Acc: 99.953% (25588/25600)
[30,   300] loss: 1.501457 |  Acc: 99.958% (38384/38400)
Total loss: 1.972135

Epoch: 30
[31,   100] loss: 0.435647 |  Acc: 99.969% (12796/12800)
[31,   200] loss: 0.805838 |  Acc: 99.977% (25594/25600)
[31,   300] loss: 1.260185 |  Acc: 99.971% (38389/38400)
Total loss: 1.586358

Epoch: 31
[32,   100] loss: 0.347193 |  Acc: 99.992% (12799/12800)
[32,   200] loss: 0.633916 |  Acc: 99.992% (25598/25600)
[32,   300] loss: 0.942637 |  Acc: 99.995% (38398/38400)
Total loss: 1.342172

Epoch: 32
[33,   100] loss: 0.301421 |  Acc: 99.984% (12798/12800)
[33,   200] loss: 0.628001 |  Acc: 99.980% (25595/25600)
[33,   300] loss: 0.972346 |  Acc: 99.982% (38393/38400)
Total loss: 1.273002

Epoch: 33
[34,   100] loss: 0.271426 |  Acc: 100.000% (12800/12800)
[34,   200] loss: 0.522611 |  Acc: 100.000% (25600/25600)
[34,   300] loss: 0.791412 |  Acc: 99.997% (38399/38400)
Total loss: 1.069014

Epoch: 34
[35,   100] loss: 0.272007 |  Acc: 100.000% (12800/12800)
[35,   200] loss: 0.500229 |  Acc: 100.000% (25600/25600)
[35,   300] loss: 0.741914 |  Acc: 99.997% (38399/38400)
Total loss: 0.962418

Epoch: 35
[36,   100] loss: 0.200861 |  Acc: 100.000% (12800/12800)
[36,   200] loss: 0.421179 |  Acc: 100.000% (25600/25600)
[36,   300] loss: 0.637626 |  Acc: 100.000% (38400/38400)
Total loss: 0.835531

Epoch: 36
[37,   100] loss: 0.200766 |  Acc: 100.000% (12800/12800)
[37,   200] loss: 0.397603 |  Acc: 100.000% (25600/25600)
[37,   300] loss: 0.606028 |  Acc: 99.995% (38398/38400)
Total loss: 0.800073

Epoch: 37
[38,   100] loss: 0.178643 |  Acc: 100.000% (12800/12800)
[38,   200] loss: 0.374064 |  Acc: 100.000% (25600/25600)
[38,   300] loss: 0.577130 |  Acc: 100.000% (38400/38400)
Total loss: 0.768444

Epoch: 38
[39,   100] loss: 0.192881 |  Acc: 100.000% (12800/12800)
[39,   200] loss: 0.412415 |  Acc: 99.996% (25599/25600)
[39,   300] loss: 0.607835 |  Acc: 99.997% (38399/38400)
Total loss: 0.769075

Epoch: 39
[40,   100] loss: 0.174156 |  Acc: 100.000% (12800/12800)
[40,   200] loss: 0.356172 |  Acc: 100.000% (25600/25600)
[40,   300] loss: 0.544260 |  Acc: 100.000% (38400/38400)
Total loss: 0.711841

Epoch: 40
[41,   100] loss: 0.197980 |  Acc: 99.992% (12799/12800)
[41,   200] loss: 0.405721 |  Acc: 99.996% (25599/25600)
[41,   300] loss: 0.596260 |  Acc: 99.997% (38399/38400)
Total loss: 0.783890

Epoch: 41
[42,   100] loss: 0.195553 |  Acc: 99.992% (12799/12800)
[42,   200] loss: 0.377932 |  Acc: 99.996% (25599/25600)
[42,   300] loss: 0.565132 |  Acc: 99.997% (38399/38400)
Total loss: 0.740863

Epoch: 42
[43,   100] loss: 0.184922 |  Acc: 100.000% (12800/12800)
[43,   200] loss: 0.370228 |  Acc: 99.996% (25599/25600)
[43,   300] loss: 0.563876 |  Acc: 99.997% (38399/38400)
Total loss: 0.738950

Epoch: 43
[44,   100] loss: 0.188086 |  Acc: 100.000% (12800/12800)
[44,   200] loss: 0.363090 |  Acc: 100.000% (25600/25600)
[44,   300] loss: 0.529876 |  Acc: 100.000% (38400/38400)
Total loss: 0.684271
Epoch    45: reducing learning rate of group 0 to 1.0000e-03.

Epoch: 44
[45,   100] loss: 0.180676 |  Acc: 99.992% (12799/12800)
[45,   200] loss: 0.349191 |  Acc: 99.996% (25599/25600)
[45,   300] loss: 0.512983 |  Acc: 99.997% (38399/38400)
Total loss: 0.664923

Epoch: 45
[46,   100] loss: 0.166781 |  Acc: 100.000% (12800/12800)
[46,   200] loss: 0.320780 |  Acc: 100.000% (25600/25600)
[46,   300] loss: 0.477024 |  Acc: 100.000% (38400/38400)
Total loss: 0.632136

Epoch: 46
[47,   100] loss: 0.153178 |  Acc: 100.000% (12800/12800)
[47,   200] loss: 0.315739 |  Acc: 100.000% (25600/25600)
[47,   300] loss: 0.473674 |  Acc: 100.000% (38400/38400)
Total loss: 0.619973

Epoch: 47
[48,   100] loss: 0.163760 |  Acc: 100.000% (12800/12800)
[48,   200] loss: 0.322436 |  Acc: 100.000% (25600/25600)
[48,   300] loss: 0.487261 |  Acc: 100.000% (38400/38400)
Total loss: 0.619886

Epoch: 48
[49,   100] loss: 0.160341 |  Acc: 100.000% (12800/12800)
[49,   200] loss: 0.328375 |  Acc: 100.000% (25600/25600)
[49,   300] loss: 0.497179 |  Acc: 100.000% (38400/38400)
Total loss: 0.643087

Epoch: 49
[50,   100] loss: 0.156791 |  Acc: 100.000% (12800/12800)
[50,   200] loss: 0.309782 |  Acc: 100.000% (25600/25600)
[50,   300] loss: 0.466129 |  Acc: 100.000% (38400/38400)
Total loss: 0.606150
Epoch    51: reducing learning rate of group 0 to 1.0000e-04.

save model

save_model = True
if save_model:
    state = {
        'net': net.state_dict(),
        'epoch': epoch,
    }
    os.makedirs('checkpoint', exist_ok=True)
    torch.save(state, './checkpoint/res18.ckpt')

evaluate

dataiter = iter(test_loader)
images, labels = dataiter.next()
images = images[:4]
labels = labels[:4]
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

outputs = net(images.to(device))
_, predicted = torch.max(outputs.cpu(), 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()



print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))



class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

在这里插入图片描述

GroundTruth:    cat  ship  ship plane
Predicted:    cat  ship   car plane
Accuracy of the network on the 10000 test images: 90 %
Accuracy of plane : 90 %
Accuracy of   car : 100 %
Accuracy of  bird : 81 %
Accuracy of   cat : 72 %
Accuracy of  deer : 89 %
Accuracy of   dog : 87 %
Accuracy of  frog : 94 %
Accuracy of horse : 93 %
Accuracy of  ship : 95 %
Accuracy of truck : 96 %

参考文献

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值