模块出处
[TCSVT 24] [link] [code] DSNet: A Novel Way to Use Atrous Convolutions in Semantic Segmentation
模块名称
Multi-Scale Attention Fusion (MSAF)
模块作用
双级特征融合
模块结构
模块思想
MSAF的主要思想是让网络根据损失学习特征权重,允许模型选择性地融合来自不同尺度的信息。
模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class MSAF(nn.Module):
def __init__(self, channels=64, r=4):
super(MSAF, self).__init__()
inter_channels = int(channels // r)
self.local_att = nn.Sequential(
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)
self.context1 = nn.Sequential(
nn.AdaptiveAvgPool2d((4, 4)),
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels)
)
self.context2 = nn.Sequential(
nn.AdaptiveAvgPool2d((8, 8)),
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels)
)
self.global_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)
self.sigmoid = nn.Sigmoid()
def forward(self, x, residual):
h, w = x.shape[2], x.shape[3]
xa = x + residual
xl = self.local_att(xa)
c1 = self.context1(xa)
c2 = self.context2(xa)
xg = self.global_att(xa)
c1 = F.interpolate(c1, size=[h, w], mode='nearest')
c2 = F.interpolate(c2, size=[h, w], mode='nearest')
xlg = xl + xg + c1 + c2
wei = self.sigmoid(xlg)
xo = 2 * x * wei + 2 * residual * (1 - wei)
return xo
if __name__ == '__main__':
msaf = MSAF()
x1 = torch.randn([2, 64, 16, 16])
x2 = torch.randn([2, 64, 16, 16])
out = msaf(x1, x2)
print(out.shape) # 2, 64, 16, 16