可逆神经网络(Invertible Neural Networks)详细解析:让神经网络更加轻量化

4bf876880695bba3a816fd249f0b0007.png

来源:PaperWeekly
本文约3600字,建议阅读7分钟
本文以可逆残差网络(The Reversible Residual Network: Backpropagation Without Storing Activations)作为基础进行分析。

为什么要用可逆网络呢?

  1. 因为编码和解码使用相同的参数,所以 model 是轻量级的。可逆的降噪网络 InvDN 只有 DANet 网络参数量的 4.2%,但是 InvDN 的降噪性能更好。

  2. 由于可逆网络是信息无损的,所以它能保留输入数据的细节信息。

  3. 无论网络的深度如何,可逆网络都使用恒定的内存来计算梯度。

其中最主要目的就是为了减少内存的消耗,当前所有的神经网络都采用反向传播的方式来训练,反向传播算法需要存储网络的中间结果来计算梯度,而且其对内存的消耗与网络单元数成正比。这也就意味着,网络越深越广,对内存的消耗越大,这将成为很多应用的瓶颈。

下面是 Pytorch summary 的结果,Forward/backward pass size(MB): 218.59 就是需要保存的中间变量大小,可以看出这部分占据了很大部分显存(随着网络深度的增加,中间变量占据显存量会一直增加,resnet152(size=224)的中间变量更是占据总共内存的 606.6÷836.79≈0.725 )。如果不存储中间层结果,那么就可以大幅减少 GPU 的显存占用,有助于训练更深更广的网络。

import torch
from torchvision import models
from torchsummary import summary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg = models.vgg16().to(device)

summary(vgg, (3, 224, 224))

结果:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
        MaxPool2d-17          [-1, 256, 28, 28]               0
           Conv2d-18          [-1, 512, 28, 28]       1,180,160
             ReLU-19          [-1, 512, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       2,359,808
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
        MaxPool2d-24          [-1, 512, 14, 14]               0
           Conv2d-25          [-1, 512, 14, 14]       2,359,808
             ReLU-26          [-1, 512, 14, 14]               0
           Conv2d-27          [-1, 512, 14, 14]       2,359,808
             ReLU-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
        MaxPool2d-31            [-1, 512, 7, 7]               0
           Linear-32                 [-1, 4096]     102,764,544
             ReLU-33                 [-1, 4096]               0
          Dropout-34                 [-1, 4096]               0
           Linear-35                 [-1, 4096]      16,781,312
             ReLU-36                 [-1, 4096]               0
          Dropout-37                 [-1, 4096]               0
           Linear-38                 [-1, 1000]       4,097,000
================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.59
Params size (MB): 527.79
Estimated Total Size (MB): 746.96
----------------------------------------------------------------

接下来我将先从可逆神经网络讲起,然后是神经网络的反向传播,最后是标准残差网络。对反向传播算法和标准残差网络比较熟悉的小伙伴,可以只看第一节:可逆神经网络。如果各位小伙伴不熟悉反向传播算法和标准残差网络,建议先看第二节:反向传播(BP)算法和第三节:残差网络(Residual Network)。本文1.2和1.3.4摘录自 @阿亮。

可逆神经网络

可逆网络具有的性质:

  1. 网络的输入、输出的大小必须一致。

  2. 网络的雅可比行列式不为 0。


1.1 什么是雅可比行列式?

雅可比行列式通常称为雅可比式(Jacobian),它是以 n 个 n 元函数的偏导数为元素的行列式 。事实上,在函数都连续可微(即偏导数都连续)的前提之下,它就是函数组的微分形式下的系数矩阵(即雅可比矩阵)的行列式。若因变量对自变量连续可微,而自变量对新变量连续可微,则因变量也对新变量连续可微。这可用行列式的乘法法则和偏导数的连锁法则直接验证。也类似于导数的连锁法则。偏导数的连锁法则也有类似的公式;这常用于重积分的计算中。

a9957c868e7eb796fc0eaf932384b03f.png

f348b9212853ba87742022812430b555.png

e27af405c878a22c2db7e0b5507472af.png

1.2 雅可比行列式与神经网络的关系

为什么神经网络会与雅可比行列式有关系?这里我借用李宏毅老师的 ppt(12-14页)。想看视频的可以到 b 站上看。

2fa47717dede8b8b0ea04eef12a0f3ac.png

f0c8e524c1ca5051ac7e0310d95bc597.png

c1e689b4eca4203262824ddefcf6e7a7.png

简单的来讲就是 ,他们的分布之间的关系就变为 ,又因为有 ,所以  这个网络的雅可比行列式不为 0 才行。

顺便提一下,flow-based Model 优化的损失函数如下:

7070563b5152550138df68a6c70087a2.png

其实这里跟矩阵运算很像,矩阵可逆的条件也是矩阵的雅可比行列式不为 0,雅可比矩阵可以理解为矩阵的一阶导数。

假设可逆网络的表达式为:

3897d18815f7217bd8135e2394eae177.png

c91a2ba013d07c7b94fb04d76dcde200.png

它的雅可比矩阵为:

0708e1cc0f1b7eb4c327d7fe2b276b61.png

其行列式为 1。

1.3 可逆残差网络(Reversible Residual Network)

f4b4cd702a375e5c0dc9c8e8cb08ca62.png

论文标题:

The Reversible Residual Network: Backpropagation Without Storing Activations

论文链接:

https://arxiv.org/abs/1707.04585

多伦多大学的 Aidan N.Gomez 和 Mengye Ren 提出了可逆残差神经网络,当前层的激活结果可由下一层的结果计算得出,也就是如果我们知道网络层最后的结果,就可以反推前面每一层的中间结果。这样我们只需要存储网络的参数和最后一层的结果即可,激活结果的存储与网络的深度无关了,将大幅减少显存占用。令人惊讶的是,实验结果显示,可逆残差网络的表现并没有显著下降,与之前的标准残差网络实验结果基本旗鼓相当。

1.3.1 可逆块结构

可逆神经网络将每一层分割成两部分,分别为  和 ,每一个可逆块的输入是 ,输出是 。其结构如下:

正向计算图示:

02c355752b6a94e5f39a33348cc3fda1.png

公式表示:

a61c1ffbc2381cb6fe8cc49bf492173a.png

逆向计算图示:

96a406bfb32e6ba77b4b92c00092ade4.png

公式表示:

4c6ac3a99202c2ae9a1bf1ee8b340cfa.png

其中 F 和 G 都是相似的残差函数,参考上图残差网络。可逆块的跨距只能为 1,也就是说可逆块必须一个接一个连接,中间不能采用其它网络形式衔接,否则的话就会丢失信息,并且无法可逆计算了,这点与残差块不一样。如果一定要采取跟残差块相似的结构,也就是中间一部分采用普通网络形式衔接,那中间这部分的激活结果就必须显式的存起来。

1.3.2 不用存储激活结果的反向传播

为了更好地计算反向传播的步骤,我们修改一下上述正向计算和逆向计算的公式:

0f3130af6dc650b94b1c1ef2a1a2c856.png

尽管  和  的值是相同的,但是两个变量在图中却代表不同的节点,所以在反向传播中它们的总体导数是不一样的。 的导数包含通过  产生的间接影响,而  的导数却不受  的任何影响。

在反向传播计算流程中,先给出最后一层的激活值  和误差传播的总体导数 ,然后要计算出其输入值  和对应的导数 ,以及残差函数 F 和 G 中权重参数的总体导数,求解步骤如下:

46c0eba7386311a55d34a37600f8b790.png

1.3.3 计算开销

一个 N 个连接的神经网络,正向计算的理论加乘开销为 N,反向传播求导的理论加乘开销为 2N(反向求导包含复合函数求导连乘),而可逆网络多一步需要反向计算输入值的操作,所以理论计算开销为 4N,比普通网络开销约多出 33% 左右。但是在实际操作中,正向和反向的计算开销在 GPU 上差不多,可以都理解为 N。那么这样的话,普通网络的整体计算开销为 2N,可逆网络的整体开销为 3N,也就是多出了约 50%。

1.3.4 雅可比行列式的计算

7c0be42166a22a64592a6096771fe846.png

其编码公式如下:

09ed302a3d7140a52ee7fd4461d9fdf1.png

ed8b1d45aec493eab925d5a1490f56b5.png

其解码公式如下:

82f49ceced37c1d428a22336f087270d.png

837e0d12983d7fd552418ebe1cf12bd4.png

为了计算雅可比矩阵,我们更直观的写成下面的编码公式:

3d08e1cf91cf82b528c982bf11287f41.png

3c9b4a1a4f77792b8748f74d9512a869.png

它的雅可比矩阵为:

b1c1829adea9b845429c4a87a1623173.png

其实上面这个雅可比行列式也是1,因为这里 ,它们的系数是一样的。

有另外一种解释方式就是把这种对偶的形式切成两半:

f2b7422e16f56ca96dd36b809cb74621.png

f1791575122e0aaa8b21a0217b25f1a7.png

3095bfc666d2c2c51812f8996473dcb1.png

其行列式为 1.

3857e2fd277d266a66b2aff161e956ee.png

aa777451ded8fe8f939230d9d7cc1775.png

因为是对偶的形式,所以这里的行列式也为 1.

因为 ,所以其行列式也为 1。

反向传播(BP)算法

dcf79ab3642cfa245e29d9781c8b3345.png

上图中符号的含义:

  • x1,x2,x3:表示 3 个输入层节点。

  • :表示从 t-1 层到 t 层的权重参数,j 表示 t 层的第 j 个节点,i 表示 t-1 层的第 i 个节点。

  • :表示 t 层的第 i 个激活后输出结果。

  • g(x):表示激活函数。

正向传播计算过程:

  • 隐藏层(网络的第二层)

a93f70af2c4796b327e8c6064ca45756.png

  • 输出层(网络的最后一层)

7fd419023ab973f261e5c864e9a7974a.png

反向传播计算过程:

以单个样本为例,假设输入向量是 [x1,x2,x3],目标输出值是 [y1,y2],代价函数用 L 表示。反向传播的总体原理就是根据总体输出误差,反向传播回网络,通过计算每一层节点的梯度,利用梯度下降法原理,更新每一层的网络权重 w 和偏置 b,这也是网络学习的过程。误差反向传播的优点就是可以把繁杂的导数计算以数列递推的形式来表示, 简化了计算过程。

以平方误差来计算反向传播的过程,代价函数表示如下:

703e1906422024f6558eeb057951d635.png

根据导数的链式法则反向求解隐藏 -> 输出层、输入层 -> 隐藏层的权重表示:

d1f4a146f2be5ceebde7658f38fb4791.png

引入新的误差求导表示形式,称为神经单元误差:

c6f5086ccc0ac7deb40ce3f8054a3d8b.png

l=2,3 表示第几层,j 表示某一层的第几个节点。替换表示后如下:

816f24dc7cf7eb352ee64740d2846370.png

所以我们可以归纳出一般的计算公式:

a2be640fb6a0a9b11e87e25b2210b0ce.png

从上述公式可以看出,如果神经单元误差 δ 可以求出来,那么总误差对每一层的权重 w 和偏置 b 的偏导数就可以求出来,接下来就可以利用梯度下降法来优化参数了。

求解每一层的 δ:

  • 输出层

0a19c4e4083a412db54820d52ca51e12.png

  • 隐藏层

8cf12e3a3aa275b6fd12f48868e86df8.png

也就是说,我们根据输出层的神经误差单元 δ 就可以直接求出隐藏层的神经误差单元,进而省去了隐藏层的繁杂的求导过程,我们可以得出更一般的计算过程:

84273a9fbec54c4e0be666da749a376e.png

从而得出 l 层神经单元误差和 l+1 层神经单元误差的关系。这就是误差反向传播算法,只要求出输出层的神经单元误差,其它层的神经单元误差就不需要计算偏导数了,而可以直接通过上述公式得出。

残差网络(Residual Network)

残差网络主要可以解决两个问题(其结构如下图):

1)梯度消失问题;

2)网络退化问题。

a8642ba989593c05a070560d1b1ccb90.png

上述结构就是一个两层网络组成的残差块,残差块可以由 2、3 层甚至更多层组成,但是如果是一层的,就变成线性变换了,没什么意义了。上述图可以写成公式如下:

4fea4486cc612d425cab566aca12e30d.png

所以在第二层进入激活函数ReLU之 前 F(x)+x 组成新的输入,也叫恒等映射。

恒等映射就是在这个残差块输入是 x 的情况下输出依然是 x,这样其目标就是学习让 F(X)=0。

这里有一个问题哈,为什么要额外加一个 x 呢,而不是让模型直接学习 F(x)=x?

因为让 F(x)=0 比较容易,初始化参数 W 非常小接近 0,就可以让输出接近 0,同时输出如果是负数,经过第一层 Relu 后输出依然 0,都能使得最后的 F(x)=0,也就是有多种情况都可以使得 F(x)=0;但是让 F(x)=x 确实非常难的,因为参数都必须刚刚好才能使得最后输出为 x。

恒等映射有什么作用?

恒等映射就可以解决网络退化的问题,当网络层数越来越深的时候,网络的精度却在下降,也就是说网络自身存在一个最优的层度结构,太深太浅都能使得模型精度下降。有了恒等映射存在,网络就能够自己学习到哪些层是冗余的,就可以无损通过这些层,理论上讲再深的网络都不影响其精度,解决了网络退化问题。

为什么可以解决梯度消失问题呢?

以两个残差块的结构实例图来分析,其中每个残差块有 2 层神经网络组成,如下图:

6bb24242855fefccba7d014565a5fe34.png

假设激活函数 ReLU 用 g(x) 函数来表示,样本实例是 [x1,y1],即输入是 x1,目标值是 y1,损失函数还是采用平方损失函数,则每一层的计算如下:

6cc168744f3a80d208780a15ad16779b.png

下面我们对第一个残差块的权重参数求导,根据链式求导法则,公式如下:

db2819549ebd4b8834d00a6fac7c40dc.png

我们可以看到求导公式中多了一个+1项,这就将原来的链式求导中的连乘变成了连加状态,可以有效避免梯度消失了。

参考文献

[1]PPT

 https://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/FLOW%20(v7).pdf

[2] 神经网络的可逆形式 

https://zhuanlan.zhihu.com/p/268242678

[3] 大幅减少GPU显存占用:

可逆残差网络(The Reversible Residual Network) 

https://www.cnblogs.com/gczr/p/12181354.html

[4] 雅可比行列式 

https://baike.baidu.com/item/雅可比行列式/4709261?fr=aladdin

[5] The Reversible Residual Network: 

Backpropagation Without Storing Activations

[6] pytorch-summary 

https://github.com/sksq96/pytorch-summary

编辑:王菁

校对:杨学俊

4d1dc153d1821243d26d8ba68c24a6c8.png

f625098a06a4c69f61fd727508930b3f.png

基于PyTorch的可逆神经网络Invertible Neural Network, INN)是一种设计成可以精确地逆向其操作的网络。这意味着对于每个输出,网络都能够推断出一个唯一的输入。在点云特征提取和坐标回归的任务中,使用INN可以有助于模型学习更加精准的特征表达和坐标变换。 以下是一个简化的例子,用于演示如何构建一个基本的可逆神经网络结构,该结构包含可逆层,用于点云数据的特征提取和坐标回归。请注意,这里仅提供一个基本的框架,具体实现可能需要根据你的任务和数据进行调整。 ```python import torch import torch.nn as nn import torch.nn.functional as F class InvertibleBlock(nn.Module): """ 可逆网络层的基本块,假设输入输出维度相等。 """ def __init__(self, dim): super(InvertibleBlock, self).__init__() self.dim = dim # 初始化可逆层中的权重,例如可以使用1x1卷积、可逆矩阵等 self.weight_matrix = nn.Parameter(torch.randn(dim, dim)) def forward(self, x): """ 正向传播,计算输出。 """ # 这里使用了简单的线性变换作为示例,实际应用中应替换为可逆操作 y = torch.matmul(x, self.weight_matrix) return y def inverse(self, y): """ 反向传播,计算输入。 """ # 使用相同的权重矩阵进行反向操作,实际应用中应确保可逆操作的反向一致性 x = torch.matmul(y, torch.inverse(self.weight_matrix)) return x class InvertiblePointNet(nn.Module): """ 点云特征提取和坐标回归的可逆神经网络。 """ def __init__(self): super(InvertiblePointNet, self).__init__() self.invertible_block = InvertibleBlock(dim=3) # 假设点云数据是3维的 # 其他的可逆层或网络结构可以在这里添加 def forward(self, points): """ 对点云数据进行特征提取和坐标回归的正向操作。 """ # 通过一个或多个可逆层进行特征提取和变换 transformed_points = self.invertible_block(points) # 可以添加更多的处理步骤,例如特征提取、聚合等 # 这里简化为直接返回变换后的点云 return transformed_points def inverse(self, transformed_points): """ 根据变换后的点云数据进行反向操作,恢复原始坐标。 """ # 通过相同的可逆层进行坐标恢复 original_points = self.invertible_block.inverse(transformed_points) # 其他的反向操作可以在这里添加 # 这里简化为直接返回恢复后的点云 return original_points # 假设我们有一些点云数据 points = torch.randn(10, 3) # 10个点,每个点3个坐标值 # 创建并使用网络 model = InvertiblePointNet() transformed_points = model(points) original_points = model.inverse(transformed_points) # 确认逆过程是否可以恢复原始数据 assert torch.allclose(points, original_points) ``` 这段代码展示了一个非常基础的可逆神经网络结构,其中包括了一个可逆层的基本块(`InvertibleBlock`)和一个点云处理的网络(`InvertiblePointNet`)。在实际应用中,你可能需要实现更复杂的可逆层,比如基于耦合层(Coupling Layers)的结构,并添加更多的网络层来处理特征提取和点云数据。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值