detach_channel_max试验记录

实验原理

设输入 X i X_{i} Xi为输入 X ∈ X\in X{B,C,W,H}的第i个channel。
Y i = X i / m a x ( a b s ( X i ) ) ∗ d e t a c h ( m a x ( a b s ( X i ) ) Y_i=X_i/max(abs(X_i))*detach(max(abs(X_i)) Yi=Xi/max(abs(Xi))detach(max(abs(Xi))
代码如下:

class Detach_max(nn.Module):
    def __init__(self):
        super(Detach_max,self).__init__()

    def forward(self, input):
        max_value = torch.max(torch.max(torch.max(torch.abs(input),\
                            0,True)[0], 2, True)[0], -1, True)[0]
        out = input / max_value * torch.detach(max_value)

        return out

单元测试代码:

import torch 
import torch.nn as nn 
class Detach_max(nn.Module):
    def __init__(self):
        super(Detach_max,self).__init__()

    def forward(self, input):
        max_value = torch.max(torch.max(torch.max(torch.abs(input),\
                            0,True)[0], 2, True)[0], -1, True)[0]

        out = input / max_value * torch.detach(max_value)
        return out
x = torch.tensor(torch.range(1.,16.).reshape(1,2,8,1),requires_grad=True)

detach_max = Detach_max()

y = torch.sum(detach_max(x))

y.backward()
print(x.grad)

在这里插入图片描述
经过校验,代码没有问题。

实验结果

若直接把Detach_max 接到BN后面,会造成网络性能的下降。cifar10数据集下降1.5个点左右。
若将Detach_max接到BN前面,也会造成网络性能的下降,但是下降的相对较少。cifar10数据集下降0.5个点左右。

项目test_acc
origin_BN92.83%
origin_BN + detach_max91.57%
detach_max + origin_BN92.30%

若去掉BN,只要Detach_max,则网络非常不稳定,会出现NAN。
出现NAN的原因,经过测试,发现是因为LOSS过大导致。LOSS过大会使得导数中出现nan或者前向传播的数据中出现nan。
在这里插入图片描述
但是这个现象并不是因为detach_max导致的。事实上在resnet20中去掉了bn,就会出现这个情况。而且往往一两个batch就会出现nan。


下面的实验结果是在没有bn没有detachmax的情况下做的,可以看到loss在4个batch后就超出了系统表达能力。
在这里插入图片描述
这时,在最后一个block中,取两个卷积层,看他们的权重及输出。
在这里插入图片描述
在这里插入图片描述
可以看到,两者的方差都特别大,达到了10^6的两级。
第二个卷积层的输出:
在这里插入图片描述
输出方差也极大。
这也解释了为什么会出现nan值。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值