GoogLeNet 浅析

1. GoogLeNet 浅析

  GoogLeNet 于论文《Going deeper with convolutions》中提出,并一举斩获 2014 年 ImageNet 挑战赛的冠军。一般而言,增加网络深度和宽度是提升网络性能最直接的方法,但这样也会带来诸多问题:

  • 参数太多,如果训练数据集有限,很容易产生过拟合;
  • 网络越大,参数越多,计算复杂度越大,应用难度越高;
  • 网络越深,越容易出现梯度弥散问题。

  GoogLeNet 共有 22 层网络,但其参数量却比 AlexNet 和 VGG 小很多。GoogLeNet 主要通过以下方法来提升网络的性能:

  • 通过 Inception 结构融合多尺度特征信息;
  • 使用 1x1 卷积进行降维,减少计算量;
  • 添加辅助分类器,缓解梯度弥散问题,帮助训练;
  • 丢弃全连接层,使用平均池化层,大幅减少模型参数量。
GoogLeNet 网络结构

  GoogLeNet 的网络结构如下,共有 22 层(蓝色部分)。关于 GoogLeNet 网络的分析,主要关注在两个方面,分别是 Inception 结构(红框部分)和辅助分类器(绿框部分)。

Inception 结构

  Inception 结构图如下,左图是 Inception 的原始结构,右图是加上降维功能的结构。Inception 结构能够提取不同尺度的特征,同时利用稀疏矩阵计算的原理来加速收敛。此外,论文作者认为池化也具有提取特征的功能,因此在第四个分支也使用了最大池化。

Inception 结构共有 4 个分支,输入特征经由各分支得到 4 个输出,之后在通道维度进行拼接得到最终输出。对于各分支而言,需要通过 stride 和 padding 来保证得到同样大小的输出。对比左图,右图在分支 2,3,4 上加入了 1x1 卷积以降维,在增加非线性表达能力的同时,减少参数量和计算量。下图给出 1x1 卷积降维的图示。

辅助分类器

   辅助分类器主要用以缓解梯度弥散问题,GoogLeNet 网络中使用了两个辅助分类器,二者的结构是一模一样的,结构参数如下:

  • 第一层:平均池化下采样层,池化核大小为 5x5,stride=3;
  • 第二层:卷积层,卷积核大小为 1x1,stride=1,卷积核个数为 128;
  • 第三层:全连接层,共1024 个节点;
  • 第四层:全连接层,节点数为 1000,对应类别数。

2. 代码实现(PyTorch)

2.1 Inception 结构

   具有降维功能的 Inception 结构如上图所示,其具体结构参数如下:

  • 分支 1:是卷积核大小为 1x1 的卷积层,stride=1;
  • 分支 2:是卷积核大小为 3x3 的卷积层,stride=1,padding=1;
  • 分支 3:是卷积核大小为 5x5 的卷积层,stride=1,padding=2;
  • 分支 4:是池化核大小为 3x3 的最大池化下采样,stride=1,padding=1。
import torch
import torch.nn as nn


class Inception(nn.Module):
    def __init__(self, in_c, o_c):
        super(Inception, self).__init__()
        self.conv1 = nn.Conv2d(in_c, o_c, kernel_size=(1, 1))

        self.branch2 = nn.Sequential(
            nn.Conv2d(in_c, 4, kernel_size=(1, 1)),
            nn.Conv2d(4, o_c, kernel_size=(3, 3), padding=1)
        )

        self.branch3 = nn.Sequential(
            nn.Conv2d(in_c, 4, kernel_size=(1, 1)),
            nn.Conv2d(4, o_c, kernel_size=(5, 5), padding=2)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_c, 4, kernel_size=(1, 1))
        )

    def forward(self, x):
        bran1 = self.conv1(x)
        bran2 = self.branch2(x)
        bran3 = self.branch3(x)
        bran4 = self.branch4(x)
        outputs = [bran1, bran2, bran3, bran4]
        out = torch.cat(outputs, 1)

        return out


if __name__ == "__main__":
    x = torch.rand((8, 3, 32, 32))  # (B, C, H, W)
    net = Inception(3, 8)
    out = net(x)

2.2 GoogLeNet 网络

  本人并未实现 GoogLeNet 网络,但在此处给出 github 地址以供有心人参考。

【参考】

  1. 实现pytorch实现GoogLeNet(CNN经典网络模型详解)
  • 3
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
ThreadLocal 是 Java 中的一个类,它提供了一种线程局部变量的机制。线程局部变量是指每个线程都有自己的变量副本,每个线程对该变量的访问都是独立的,互不影响。 ThreadLocal 主要用于解决多线程并发访问共享变量时的线程安全问题。在多线程环境下,如果多个线程共同访问同一个变量,可能会出现竞争条件,导致数据不一致或者出现线程安全问题。通过使用 ThreadLocal,可以为每个线程提供独立的副本,从而避免了线程安全问题。 ThreadLocal 的工作原理是,每个 Thread 对象内部都维护了一个 ThreadLocalMap 对象,ThreadLocalMap 是一个 key-value 结构,其中 key 是 ThreadLocal 对象,value 是该线程对应的变量副本。当访问 ThreadLocal 的 get() 方法时,会根据当前线程获取到对应的 ThreadLocalMap 对象,并从中查找到与 ThreadLocal 对象对应的值。如果当前线程尚未设置该 ThreadLocal 对象的值,则会通过 initialValue() 方法初始化一个值,并将其存入 ThreadLocalMap 中。当访问 ThreadLocal 的 set() 方法时,会将指定的值存入当前线程对应的 ThreadLocalMap 中。 需要注意的是,ThreadLocal 并不能解决共享资源的并发访问问题,它只是提供了一种线程内部的隔离机制。在使用 ThreadLocal 时,需要注意合理地使用,避免出现内存泄漏或者数据不一致的情况。另外,由于 ThreadLocal 使用了线程的 ThreadLocalMap,因此在使用完 ThreadLocal 后,需要手动调用 remove() 方法清理对应的变量副本,以防止内存泄漏。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值