Fully Attentional Network for Semantic Segmentation代码

官方链接
1:repeat是如何实现的?
2:cut是如何实现的?
3:整个block是如何实现的?
在这里插入图片描述

class FullyAttentionalBlock(nn.Module):
    def __init__(self, plane, norm_layer=SyncBatchNorm):
        super(FullyAttentionalBlock, self).__init__()
        self.conv1 = nn.Linear(plane, plane)  # 改变最后一个维度
        self.conv2 = nn.Linear(plane, plane)
        self.conv = nn.Sequential(nn.Conv2d(plane, plane, 3, stride=1, padding=1, bias=False),
                                  norm_layer(plane),
                                  nn.ReLU())

        self.softmax = nn.Softmax(dim=-1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, _, height, width = x.size()

        feat_h = x.permute(0, 3, 1, 2).contiguous().view(batch_size * width, -1, height)  # [b*w, c, h]
        feat_w = x.permute(0, 2, 1, 3).contiguous().view(batch_size * height, -1, width)  # [b*h, c, w]
        # (b,c,w)--->(b,h,c)
        encode_h = self.conv1(F.avg_pool2d(x, [1, width]).view(batch_size, -1, height).permute(0, 2, 1).contiguous())
        # (b,c,w)--->(b,w,c)
        encode_w = self.conv2(F.avg_pool2d(x, [height, 1]).view(batch_size, -1, width).permute(0, 2, 1).contiguous())
        # (b*w,c,h) * (b*w,h,c) = (b*w,c,c)
        energy_h = torch.matmul(feat_h, encode_h.repeat(width, 1, 1))
        # (b*h,c,w) * (b*h,w,c) = (b*w,c,c)
        energy_w = torch.matmul(feat_w, encode_w.repeat(height, 1, 1))
        full_relation_h = self.softmax(energy_h)  # [b*w, c, c]
        full_relation_w = self.softmax(energy_w)  # [b*h, c, c]
        # [b*w, c, c]*[b*w, c, h] = [b*w, c, h]---->[b,w,c,h]--->[b,c,h,w]
        full_aug_h = torch.bmm(full_relation_h, feat_h).view(batch_size, width, -1, height).permute(0, 2, 3, 1)
        full_aug_w = torch.bmm(full_relation_w, feat_w).view(batch_size, height, -1, width).permute(0, 2, 1, 3)
        out = self.gamma * (full_aug_h + full_aug_w) + x
        out = self.conv(out)
        return out

首先看一下construction里维度是如何变化的:
Fin(b,c,h,w)经过大小为【h x 1】和【1 x w】大小的卷积核后,维度变为【b,c,h】和【b,c,w】经过线性层唯独不变,再沿着h和w维度进行复制,维度变为【wb,h,c】,【hb,w,c】这里用一个demo演示repeat函数:

import torch.nn.functional as F
x = torch.rand(2,3,5,5)
batch_size, channel, height, width = x.size()
# (1,3,5)--->(1,15)
encode_h = F.avg_pool2d(x, [1, width]).view(batch_size, -1, height).permute(0, 2, 1).contiguous()
v = encode_h.repeat(width, 1, 1)
print(encode_h)
print(encode_h.shape)  # torch.Size([1, 5, 3])
print(v)
print(v.shape) # torch.Size([10, 5, 3])

然后进行cut操作,然后和feat_h,feat_w进行矩阵相乘,然后维度变为【bw,c,c】和【bh,c,c】经过softmax维度不变,生成的相似度map与v相乘,维度变为【b*w, c, h】再reshape维度变为【b,w,c,h】—>【b,c,h,w】

对于问题一,二:
在代码中使用了两次repeat函数,encode_h.repeat(width, 1, 1)encode_w.repeat(height, 1, 1)
分别在w和h方向上进行复制,
在这里插入图片描述
在原文中:
维度c x 1x w和 c x h x1变成c x h x w就是在h和w维度进行复制,然后又在h和w维度进行cut,cut是为了形成slices。在这里插入图片描述
在代码中并没有展现cut是如何操作的。
因为经过repeat之后就直接和K相乘了,而且也并没有表现merge部分。在图中A的维度为【(h+w),c,c】,在代码中直接repeat之后就和K相乘,生成了两个【bw,c,c】和【bh,c,c】,其中【bw,c,c】和【bh,c,c】相加,就生成了文中【(h+w),c,c】代码中是分离开来的。

        # (b*w,c,h) * (b*w,h,c) = (b*w,c,c)
        energy_h = torch.matmul(feat_h, encode_h.repeat(width, 1, 1))
        # (b*h,c,w) * (b*h,w,c) = (b*h,c,c)
        energy_w = torch.matmul(feat_w, encode_w.repeat(height, 1, 1))

最后生成的A,分别和V相乘(在这里K应该是等于V的),再相加。对应于文中的:
在这里插入图片描述

full_aug_h = torch.bmm(full_relation_h, feat_h).view(batch_size, width, -1, height).permute(0, 2, 3, 1)
full_aug_w = torch.bmm(full_relation_w, feat_w).view(batch_size, height, -1, width).permute(0, 2, 1, 3)
out = self.gamma * (full_aug_h + full_aug_w) + x

其余的细节参考官方代码。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值