代码 参考
class se_block(nn.Module):
def __init__(self, chan_num, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.se_seq = nn.Sequential(
nn.Linear(chan_num, chan_num//reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(chan_num//reduction, chan_num, bias=False),
nn.Sigmoid()
)
def forward(self, x):
'''
方法1:
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.se_seq(y).view(b, c, 1, 1)
#print(y.shape)
return x * y.expand_as(x)
'''
方法2:
a = self.avg_pool(x).flatten(start_dim=1)
a = self.se_seq(a).unsqueeze(-1).unsqueeze(-1)
y = a * x
return y
对应的 ONNX 图是:
方法1:
方法2: