记录新的激活函数和一些tricks

最近看efficientnet代码,发现一些操作不太一样,包括之前看mobv3的代码也是,记录一下这些操作。

 

首先是swish激活函数,efficientnet里面使用了swish激活函数。

函数的过程如下,看代码里面设置beta为1。

对应的代码如下:

# An ordinary implementation of Swish function
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


# A memory-efficient implementation of Swish function
class SwishImplementation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i * torch.sigmoid(i)
        ctx.save_for_backward(i)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        i = ctx.saved_tensors[0]
        sigmoid_i = torch.sigmoid(i)
        return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))

class MemoryEfficientSwish(nn.Module):
    def forward(self, x):
        return SwishImplementation.apply(x)

最简单的实现方式,是直接用torch.sigmoid,让pytorch自己完成反向传播的过程。

比较高级的实现方式是,使用ctx把在前向传播中,把反向传播需要使用的变量save下来,然后在反向传播的时候直接load。

 

不同于efficientnet,mobv3里面使用了hard-swish(hswish),hard-swish使用一个线性变化的relu6来替代了sigmoid function。

mob3 论文中提到,在嵌入式等平台,量化时,sigmoid function开销较大,作者在网络的后半部分阶段替代了relu6,在网络的前半部分使用hsigmoid代替了sigmoid function。

代码如下:

class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out

后面发现pytorch 最新的版本1.8已经对于 hard-swish有了支持,api如下: torch.nn.Hardswish 

 

hsigmoid 函数如下

代码如下:

class hsigmoid(nn.Module):
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out

pytorch下文档显示如下:

 

 

 

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值