论文链接:https://arxiv.org/abs/1711.07971
github地址:https://github.com/lwdoubles/Non-local_pytorch
该模型的思想中借鉴了非局部均值滤波的思想,这里可以先看下非局部均值滤波的原理:
https://www.cnblogs.com/helloforworld/p/5303422.html
一·、方程
non-local的方程可以表示如下:
(1)
Here i is the index of an output position (in space, time, or spacetime) whose response is to be computed and j is the index that enumerates all possible positions. 即j取输入数据的所有位置下标。
x is the input signal (image, sequence, video; often their features) and y is the output signal of the same size as x. x和y的维度是一样的。g则是一元输入函数,目的是进行信息变换,C(x)是归一化函数,保证变换前后整体信息不变。上述公式使得我们对某一局部区域进行变换的时候考虑了整体的区域与该区域的关系。下面是知乎上的一张比较形象的图:
二、实例
原文中给出了上述公式的一个具体的例子:
(2)
(3)
(4)
上述公式(2) 为Gaussian function,公式(3)则为Embedded Gaussian function,这两种f的形式都是可以的,论文中还提出了其他的f的形式,并且论文中提到f和g的形式其实影响并不是很大。公式(4)则是f采用Embedded Gaussian function后y的形式。如果设g为线性函数,那么Non-local模型的示意图如下:
注意θ、g后的1X1X1应该指的是TXHXW,T=1说明我们不考虑时间这个维度,说明输入数据应该是图像之类的。因为这些函数都是线性函数,故直接用1X1的卷积来实现。
三、Non-local 模块
Non-local block可以由如下公式定义:
(4)
上面的公式借鉴了残差网络。该做法的好处是可以随意嵌入到任何一个预训练好的网络中,因为只要设置W_z初始化为0,那么就没有任何影响,然后在迁移学习中学习新的权重。这样就不会因为引入了新的模块而导致预训练权重无法使用。
更多的细节还是需要自己去看论文。