最近看efficientnet代码,发现一些操作不太一样,包括之前看mobv3的代码也是,记录一下这些操作。
首先是swish激活函数,efficientnet里面使用了swish激活函数。
函数的过程如下,看代码里面设置beta为1。
对应的代码如下:
# An ordinary implementation of Swish function
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
# A memory-efficient implementation of Swish function
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_tensors[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)
最简单的实现方式,是直接用torch.sigmoid,让pytorch自己完成反向传播的过程。
比较高级的实现方式是,使用ctx把在前向传播中,把反向传播需要使用的变量save下来,然后在反向传播的时候直接load。
不同于efficientnet,mobv3里面使用了hard-swish(hswish),hard-swish使用一个线性变化的relu6来替代了sigmoid function。
mob3 论文中提到,在嵌入式等平台,量化时,sigmoid function开销较大,作者在网络的后半部分阶段替代了relu6,在网络的前半部分使用hsigmoid代替了sigmoid function。
代码如下:
class hswish(nn.Module):
def forward(self, x):
out = x * F.relu6(x + 3, inplace=True) / 6
return out
后面发现pytorch 最新的版本1.8已经对于 hard-swish有了支持,api如下: torch.nn.Hardswish
hsigmoid 函数如下
代码如下:
class hsigmoid(nn.Module):
def forward(self, x):
out = F.relu6(x + 3, inplace=True) / 6
return out
pytorch下文档显示如下: