模块出处
[AAAI 23] [link] [code] High-Resolution Iterative Feedback Network for Camoufaged Object Detection
模块名称
Adaptive Feature Fusion
模块作用
多尺度特征融合
模块结构
模块代码
import torch
import torch.nn as nn
class AFF(nn.Module):
def __init__(self, ch_in=32, reduction=16):
super(AFF, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(ch_in, ch_in // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(ch_in // reduction, ch_in, bias=False),
nn.Sigmoid()
)
self.fc_wight = nn.Sequential(
nn.Linear(ch_in, ch_in // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(ch_in // reduction, 1, bias=False),
nn.Sigmoid()
)
def forward(self, x_h, x_l):
# 增强high level特征, 在Squeeze and Excite上引入了一个额外的h_weight($\alpha_1$)
b, c, _, _ = x_h.size()
y_h = self.avg_pool(x_h).view(b, c)
h_weight = self.fc_wight(y_h)
y_h = self.fc(y_h).view(b, c, 1, 1)
x_fusion_h = x_h * y_h.expand_as(x_h)
x_fusion_h = torch.mul(x_fusion_h, h_weight.view(b, 1, 1, 1))
# 增强low level特征, 在Squeeze and Excite上引入了一个额外的l_weight($\alpha_2$)
b, c, _, _ = x_l.size()
y_l = self.avg_pool(x_l).view(b, c)
l_weight = self.fc_wight(y_l)
y_l = self.fc(y_l).view(b, c, 1, 1)
x_fusion_l = x_l * y_l.expand_as(x_l)
x_fusion_l = torch.mul(x_fusion_l, l_weight.view(b, 1, 1, 1))
# 多级特征融合
x_fusion = x_fusion_h + x_fusion_l
return x_fusion
if __name__ == '__main__':
aff = AFF()
x_h = torch.randn([1, 32, 16, 16])
x_l = torch.randn([1, 32, 16, 16])
out = aff(x_h, x_l)
print(out.shape) # 1, 32, 16, 16