GCT代码理解——CVPR 2020
论文链接: https://openaccess.thecvf.com/content_CVPR_2020/html/Yang_Gated_Channel_Transformation_for_Visual_Recognition_CVPR_2020_paper.html
代码链接: https://github.com/z-x-yang/GCT/blob/master/PyTorch/GCT.py
核心代码:
class GCT(nn.Module):
def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):
super(GCT, self).__init__()
self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))
self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
self.epsilon = epsilon
self.mode = mode
self.after_relu = after_relu
def forward(self, x):
if self.mode == 'l2':
embedding = (x.pow(2).sum((2,3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
elif self.mode == 'l1':
if not self.after_relu:
_x = torch.abs(x)
else:
_x = x
embedding = _x.sum((2,3), keepdim=True) * self.alpha
norm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)
else:
print('Unknown mode!')
sys.exit()
gate = 1. + torch.tanh(embedding * norm + self.beta)
return x * gate
1 Normalization
标准化是目前深度学习中必不可少的操作。标准化的公式:
X
i
−
μ
σ
\frac{X_i-\mu}{\sigma}
σXi−μ
数据给定的情况下,
μ
\mu
μ(均值)和
σ
\sigma
σ(方差)都是常数,公式可变形为:
X
i
−
μ
σ
=
X
i
σ
−
μ
σ
=
X
i
σ
−
c
\frac{X_i-\mu}{\sigma}=\frac{X_i}{\sigma}-\frac{\mu}{\sigma}=\frac{X_i}{\sigma}-c
σXi−μ=σXi−σμ=σXi−c
通过上式可以看出标准化就是对数据
X
X
X按照比例压缩
σ
\sigma
σ再平移
c
c
c。所以标准化的本质是一种线性变换。[1]
1.1 常见的Normalization
提到Normalization,一下子可以想到有4种:BN、LN、IN、GN。
这四种可以用同一个公式来描述:
y
=
x
−
E
[
x
]
V
a
r
[
x
]
+
ϵ
∗
γ
+
β
y=\frac{x-E[x]}{\sqrt{Var[x]+\epsilon}}*\gamma+\beta
y=Var[x]+ϵx−E[x]∗γ+β
E
[
x
]
E[x]
E[x]和
V
a
r
[
x
]
Var[x]
Var[x]是数据
x
x
x的均值和方差,
γ
\gamma
γ和
β
\beta
β是两个可学习参数。
四种标准化操作的区别就是
x
x
x取的范围不一样,如下图所示 [2],蓝色部分为
x
x
x取的数据范围:
具体来说,假设输入网络的数据为(N, C, H, W)
- BN中: x x x取(N, H, W),计算C个均值和方差;
- LN中: x x x取(C, H, W),计算N个均值和方差;
- IN中: x x x取(H, W),计算N*C个均值和方差;
- GN中: x x x取( C G \frac{C}{G} GC, H, W),计算N*G个均值和方差;
1.2 Local Response Normalization
前面都是铺垫,这篇论文使用
l
2
l_2
l2normalization 建立channel normalization参考了LRN (Local Response Normalization) [3]。
LRN与上面提到了四种标准化的区别,一是使用位置,LRN一般放在ReLU层之后;二是计算上不减去均值,也不用两个可学参数。LRN公式如下:
其中,
i
i
i表示第
i
i
i张特征图,
a
x
,
y
i
a^{i}_{x,y}
ax,yi表示第
i
i
i张特征图上
(
x
,
y
)
(x,y)
(x,y)处的像素值,
n
n
n表示取得的特征图数目,N表示特征图总数。
由上式可以看出,每次计算取的是通道上第 i i i张特征图前后共 n + 1 n+1 n+1张上 ( x , y ) (x,y) (x,y)处的像素值(共 n + 1 n+1 n+1个像素值),这也是LRN中Local的体现。
由上式可以看出,如果数据的均值为0, k k k为极小值, α = 1 \alpha=1 α=1, β = 0.5 \beta=0.5 β=0.5,上式即为 y = x − E [ x ] V a r [ x ] + ϵ y=\frac{x-E[x]}{\sqrt{Var[x]+\epsilon}} y=Var[x]+ϵx−E[x]的一种特殊情况。
在 [3] 中作者通过实验验证后将超参数确定为: k = 2 k=2 k=2, n = 5 n=5 n=5, α = 1 0 − 4 \alpha=10^{-4} α=10−4, β = 0.75 \beta=0.75 β=0.75
1.3 L2范数
假设X是n维的特征
X
=
(
x
1
,
x
2
,
x
3
,
.
.
.
.
.
.
,
x
n
)
X=(x_1, x_2,x_3,......,x_n)
X=(x1,x2,x3,......,xn)
L2范数:
∣
∣
X
∣
∣
2
=
∑
i
=
1
n
x
i
2
||X||_2=\sqrt{\sum_{i=1}^{n}x_i^2}
∣∣X∣∣2=∑i=1nxi2
2 GCT
GCT公式:
其中,
α
,
γ
,
β
\alpha, \gamma, \beta
α,γ,β是可学习参数。
2.1 Global Context Embedding
Global Context Embedding公式如下:
通过上面公式可以看出Global Context Embedding就是计算每个通道C上的L2范数再乘以系数
α
c
\alpha_c
αc。
对应代码中的embedding = (x.pow(2).sum((2,3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
2.2 Channel Normalization
Channel Normalization公式如下:
通过上面公式可以看出Channel Normalization就是将每个通道上的数值
s
c
s_c
sc除以所有通道上数值的L2范数,再乘以常数
C
\sqrt{C}
C。
对应代码中的norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
3.3 Gating Adaptation
Gating Adaptation公式如下:
这里设计了权重
γ
\gamma
γ和偏置
β
\beta
β来控制通道特征是否激活。当一个通道的特征权重
γ
c
\gamma_c
γc被正激活,GCT将促进这个通道的特征和其它通道的特征“竞争”。当一个通道的特征
γ
c
\gamma_c
γc被负激活,GCT将促进这个通道的特征和其它通道的特征“合作”。
对应代码中的
gate = 1. + torch.tanh(embedding * norm + self.beta)
return x * gate
参考:
[1] https://www.codetd.com/article/3208734
[2] Wu, Yuxin, and Kaiming He. “Group normalization.” Proceedings of the European conference on computer vision (ECCV). 2018.
[3] Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. “Imagenet classification with deep convolutional neural networks.” Advances in neural information processing systems 25 (2012).