实验原理
设输入
X
i
X_{i}
Xi为输入
X
∈
X\in
X∈{B,C,W,H}的第i个channel。
Y
i
=
X
i
/
m
a
x
(
a
b
s
(
X
i
)
)
∗
d
e
t
a
c
h
(
m
a
x
(
a
b
s
(
X
i
)
)
Y_i=X_i/max(abs(X_i))*detach(max(abs(X_i))
Yi=Xi/max(abs(Xi))∗detach(max(abs(Xi))
代码如下:
class Detach_max(nn.Module):
def __init__(self):
super(Detach_max,self).__init__()
def forward(self, input):
max_value = torch.max(torch.max(torch.max(torch.abs(input),\
0,True)[0], 2, True)[0], -1, True)[0]
out = input / max_value * torch.detach(max_value)
return out
单元测试代码:
import torch
import torch.nn as nn
class Detach_max(nn.Module):
def __init__(self):
super(Detach_max,self).__init__()
def forward(self, input):
max_value = torch.max(torch.max(torch.max(torch.abs(input),\
0,True)[0], 2, True)[0], -1, True)[0]
out = input / max_value * torch.detach(max_value)
return out
x = torch.tensor(torch.range(1.,16.).reshape(1,2,8,1),requires_grad=True)
detach_max = Detach_max()
y = torch.sum(detach_max(x))
y.backward()
print(x.grad)
经过校验,代码没有问题。
实验结果
若直接把Detach_max 接到BN后面,会造成网络性能的下降。cifar10数据集下降1.5个点左右。
若将Detach_max接到BN前面,也会造成网络性能的下降,但是下降的相对较少。cifar10数据集下降0.5个点左右。
项目 | test_acc |
---|---|
origin_BN | 92.83% |
origin_BN + detach_max | 91.57% |
detach_max + origin_BN | 92.30% |
若去掉BN,只要Detach_max,则网络非常不稳定,会出现NAN。
出现NAN的原因,经过测试,发现是因为LOSS过大导致。LOSS过大会使得导数中出现nan或者前向传播的数据中出现nan。
但是这个现象并不是因为detach_max导致的。事实上在resnet20中去掉了bn,就会出现这个情况。而且往往一两个batch就会出现nan。
下面的实验结果是在没有bn没有detachmax的情况下做的,可以看到loss在4个batch后就超出了系统表达能力。
这时,在最后一个block中,取两个卷积层,看他们的权重及输出。
可以看到,两者的方差都特别大,达到了10^6的两级。
第二个卷积层的输出:
输出方差也极大。
这也解释了为什么会出现nan值。