[pytorch] 自定义激活函数swish(三)

         在神经网络模型中,激活函数多种多样。大体都是,小于0的部分,进行抑制(即,激活函数输出为非常小的数),大于0的部分,进行放大(即,激活函数输出为较大的数)。

         主流的激活函数一般都满足,

         1. 非线性。信号处理里,信号通过非线性系统后,能产生新频率的信号。不妨假定,非线性有相似作用。

         2. 可微性。可求导的,在反向传播中,可以方便使用链式求导的。

         3. 单调性。swish 激活函数在小于0的部分是非单调的。

         为了测试不同激活函数,对神经网络的影响。我们把之前写的CNN模型中,激活函数抽取出来,独立写成一个接口。

         由于pytorch集成了常用的激活函数,对于已经集成好的ReLU等函数。可以使用简单的

def Act_op():
    return nn.ReLU()

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.con_layer1 = nn.Sequential(
           nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
           Act_op()
        )
    def forward(self, x):
        x = self.con_layer1(x)
        return x

        对于Swish = x*sigmod(x) 这种pytorch还没集成的函数,就需要自定义Act_op()。

方法一:使用nn.Function

## 由于 Function 可能需要暂存 input tensor。
## 因此,建议不复用 Function 对象,以避免遇到内存提前释放的问题。
class Swish_act(torch.autograd.Function):
    ## save_for_backward can only!!!! save input or output tensors
    @staticmethod
    def forward(self, input_):
        print('swish act op forward')
        output = input_ * F.sigmoid(input_)
        self.save_for_backward(input_)
        return output

    @staticmethod
    def backward(self, grad_output):
	## according to the chain rule(Backpropagation),
	## d(loss)/d(x) = d(loss)/d(output) * d(output)/d(x)
	## grad_output is the d(loss)/d(output)
	## we calculate and save the d(output)/d(x) in forward
        input_, = self.saved_tensors
        output = input_ * F.sigmoid(input_)
        grad_swish = output + F.sigmoid(input_) * (1 - output)
        print('swish act op backward')
return grad_output * grad_swish

def Act_op():
    return Swish_act()

在使用这种方法写的时候,遇到了几个坑。

首先,save_to_backward()只允许保存输入、输出的张量。比如,输入为a(即,forward(self, a)),那么save_to_backward(a)没问题,save_to_backward(a+1)报错。

其次,根据pytorch的逻辑,nn.Function是在nn.Module的forward()过程中使用,不能在__init__中使用。

其三,如果在模型中,需要重复调用这个Swish_act()接口。会出现前一次使用的内存被提前释放掉,使得在反向传播计算中,需要使用的变量失效。


为了解决不能重复调用问题,可以使用nn.Module,创建CNN网络模型的Module子类。

方法二:

class Act_op(nn.Module):
    def __init__(self):
        super(Act_op, self).__init__()

    def forward(self, x):
        x = x * F.sigmoid(x)
        return x

     简单、快捷、方便。还不用自己写backward。而且,不需要对之前CNN模型里,class Net(nn.Module)做任何改动。


  • 11
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值