Torch.nn.Hardswish激活函数实现
class Hardswish(nn.Module):
def __init__(self):
super(Hardswish, self).__init__()
def forward(self, input):
input[input<=-3] = 0
idx = list(set(input>=-3) & set(input<=3))
input[idx] = (torch.pow(input[idx],2)+3*input[idx])/6.0
return input
Reference:Howard, Andrew, et al. “Searching for mobilenetv3.” Proceedings of the IEEE/CVF international conference on computer vision. 2019.