UACANet分割模型搭建

本文介绍了UACANet,一个基于Res2Net的编码解码器网络,其特色在于使用轴注意力(AxisAttention)模块,包括PAA_kernel、PAA_e和PAA_d,以增强特征学习和上下文信息处理。网络通过多尺度特征融合和注意力机制提高图像分割性能。
摘要由CSDN通过智能技术生成

原论文:https://arxiv.org/abs/2107.02368
源码:https://github.com/plemeri/UACANet

直接步入正题~~~

一、轴注意力

class self_attn(nn.Module):
    def __init__(self, in_channels, mode='hw'):
        super(self_attn, self).__init__()

        self.mode = mode

        self.query_conv = conv(in_channels, in_channels // 8, kernel_size=(1, 1))
        self.key_conv = conv(in_channels, in_channels // 8, kernel_size=(1, 1))
        self.value_conv = conv(in_channels, in_channels, kernel_size=(1, 1))

        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x): # 例如输入x: bs 256 44 44
        batch_size, channel, height, width = x.size()

        axis = 1
        if 'h' in self.mode:
            axis *= height
        if 'w' in self.mode:
            axis *= width

        view = (batch_size, -1, axis)

        #bs 256 44 44 -- bs 32 44 44 -- bs 1408 44(bs 32*44 44) -- bs 44 1408
        projected_query = self.query_conv(x).view(*view).permute(0, 2, 1)
        #bs 256 44 44 -- bs 32 44 44 -- bs 1408 44
        projected_key = self.key_conv(x).view(*view)

        attention_map = torch.bmm(projected_query, projected_key) # bs 44 44
        attention = self.softmax(attention_map) # bs 44 44
        projected_value = self.value_conv(x).view(*view) #bs 256 44 44 -- bs 11264 44

        out = torch.bmm(projected_value, attention.permute(0, 2, 1)) # bs 11264 44
        out = out.view(batch_size, channel, height, width) #view为原始的形状:bs 256 44 44

        out = self.gamma * out + x #bs 256 44 44
        return out

参考:http://t.csdn.cn/Ag463 

二、PAA_kernel模块

class PAA_kernel(nn.Module):
    def __init__(self, in_channel, out_channel, receptive_size=3):
        super(PAA_kernel, self).__init__()
        self.conv0 = conv(in_channel, out_channel, 1)
        self.conv1 = conv(out_channel, out_channel, kernel_size=(1, receptive_size))
        self.conv2 = conv(out_channel, out_channel, kernel_size=(receptive_size, 1))
        self.conv3 = conv(out_channel, out_channel, 3, dilation=receptive_size)
        self.Hattn = self_attn(out_channel, mode='h')
        self.Wattn = self_attn(out_channel, mode='w')

    def forward(self, x): # 例如输入为:bs 32 44 44
        x = self.conv0(x) # bs 32 44 44
        x = self.conv1(x) # bs 32 44 44
        x = self.conv2(x) # bs 32 44 44
        # print(f'PAA_Kernel_x:{x.shape}')

        Hx = self.Hattn(x) # bs 32 44 44
        # print(f'PAA_Kernel_Hx:{Hx.shape}')
        Wx = self.Wattn(x) # bs 32 44 44
        # print(f'PAA_Kernel_Wx:{Wx.shape}')

        x = self.conv3(Hx + Wx) # bs 32 44 44
        # print(f'PAA_Kernel:{x.shape}')
        return x

三、PAA_e模块

class PAA_e(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(PAA_e, self).__init__()
        self.relu = nn.ReLU(True)

        self.branch0 = conv(in_channel, out_channel, 1)
        self.branch1 = PAA_kernel(in_channel, out_channel, 3)
        self.branch2 = PAA_kernel(in_channel, out_channel, 5)
        self.branch3 = PAA_kernel(in_channel, out_channel, 7)

        self.conv_cat = conv(4 * out_channel, out_channel, 3)
        self.conv_res = conv(in_channel, out_channel, 1)

    def forward(self, x):  # 输入为 bs 512 44 44
        x0 = self.branch0(x) # bs 32 44 44
        # print(f'PAA_e_x0:{x0.shape}')
        x1 = self.branch1(x) # bs 32 44 44
        x2 = self.branch2(x) # bs 32 44 44
        x3 = self.branch3(x) # bs 32 44 44

        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) # bs 32*4 44 44 -- bs 32 44 44
        x = self.relu(x_cat + self.conv_res(x)) # bs 32 44 44

        return x

tips: 这个结构类似于PraNet中的RFB模块!http://t.csdn.cn/cJfm2

四、PAA_d模块

class PAA_d(nn.Module):
    # dense decoder, it can be replaced by other decoder previous, such as DSS, amulet, and so on.
    # used after MSF
    def __init__(self, channel):
        super(PAA_d, self).__init__()
        self.conv1 = conv(channel * 3 ,channel, 3)
        self.conv2 = conv(channel, channel, 3)
        self.conv3 = conv(channel, channel, 3)
        self.conv4 = conv(channel, channel, 3)
        self.conv5 = conv(channel, 1, 3, bn=False)

        self.Hattn = self_attn(channel, mode='h')
        self.Wattn = self_attn(channel, mode='w')

        self.upsample = lambda img, size: F.interpolate(img, size=size, mode='bilinear', align_corners=True)
        
    def forward(self, f1, f2, f3):
        #f1:bs, 32, 11, 11   f2:bs, 32, 22, 22   f3:bs, 32, 44, 44
        f1 = self.upsample(f1, f3.shape[-2:]) #f1:bs, 32, 44, 44
        f2 = self.upsample(f2, f3.shape[-2:]) #f2:bs, 32, 44, 44
        f3 = torch.cat([f1, f2, f3], dim=1) #f3:bs, 32*3, 44, 44
        f3 = self.conv1(f3) #f3:bs, 32, 44, 44

        Hf3 = self.Hattn(f3) #f3:bs, 32, 44, 44
        Wf3 = self.Wattn(f3) #f3:bs, 32, 44, 44

        f3 = self.conv2(Hf3 + Wf3) #f3:bs, 32, 44, 44
        f3 = self.conv3(f3) #f3:bs, 32, 44, 44
        f3 = self.conv4(f3) #f3:bs, 32, 44, 44
        out = self.conv5(f3) #out:bs, 1, 44, 44

        return f3, out

五、UACA模块

class UACA(nn.Module):
    def __init__(self, in_channel, channel):
        super(UACA, self).__init__()
        self.channel = channel

        self.conv_query = nn.Sequential(conv(in_channel, channel, 3, relu=True),
                                        conv(channel, channel, 3, relu=True))
        self.conv_key = nn.Sequential(conv(in_channel, channel, 1, relu=True),
                                      conv(channel, channel, 1, relu=True))
        self.conv_value = nn.Sequential(conv(in_channel, channel, 1, relu=True),
                                        conv(channel, channel, 1, relu=True))

        self.conv_out1 = conv(channel, channel, 3, relu=True)
        self.conv_out2 = conv(in_channel + channel, channel, 3, relu=True)
        self.conv_out3 = conv(channel, channel, 3, relu=True)
        self.conv_out4 = conv(channel, 1, 1)

    def forward(self, x, map):
        b, c, h, w = x.shape
        
        # compute class probability  x:bs, 64, 11, 11  map:bs, 1, 44, 44
        map = F.interpolate(map, size=x.shape[-2:], mode='bilinear', align_corners=False) #map:bs, 1, 11, 11
        fg = torch.sigmoid(map) #:bs, 1, 11, 11
        
        p = fg - .5

        # clip函数将p限制到0到1之间,如果小于0则为0,大于0则为它本身
        fg = torch.clip(p, 0, 1) # foreground 前景
        bg = torch.clip(-p, 0, 1) # background 背景
        cg = .5 - torch.abs(p) # confusion area 不确定区域

        prob = torch.cat([fg, bg, cg], dim=1) #bs, 3, 11, 11

        # reshape feature & prob
        f = x.view(b, h * w, -1)  #bs, 121, 64
        prob = prob.view(b, 3, h * w) #bs, 3, 121
        
        # compute context vector
        context = torch.bmm(prob, f).permute(0, 2, 1).unsqueeze(3) #bs, 3, 64 -- bs, 64, 3 -- bs, 64, 3, 1

        # k q v compute
        query = self.conv_query(x).view(b, self.channel, -1).permute(0, 2, 1) #bs, 32, 11, 11 -- bs, 32, 121 -- bs, 121, 32
        key = self.conv_key(context).view(b, self.channel, -1) #bs, 64, 3, 1 -- bs, 32, 3, 1 -- bs, 32, 3
        value = self.conv_value(context).view(b, self.channel, -1).permute(0, 2, 1) #bs, 64, 3, 1 -- bs, 32, 3, 1 -- bs, 32, 3 -- bs, 3, 32

        # compute similarity map
        sim = torch.bmm(query, key) # bs, 121, 3
        sim = (self.channel ** -.5) * sim
        sim = F.softmax(sim, dim=-1) # bs, 121, 3

        # compute refined feature
        context = torch.bmm(sim, value).permute(0, 2, 1).contiguous().view(b, -1, h, w)  # bs, 121, 32 -- bs, 32, 121 -- bs, 32, 11, 11
        context = self.conv_out1(context) #bs, 32, 11, 11

        x = torch.cat([x, context], dim=1) #bs, 96, 11, 11
        x = self.conv_out2(x) #bs, 32, 11, 11
        x = self.conv_out3(x) #bs, 32, 11, 11
        # print(f'x:{x.shape}')
        out = self.conv_out4(x) #bs, 1, 11, 11
        out = out + map #bs, 1, 11, 11
        # print(f'out:{out.shape}')
        return x, out

六、整体网络结构

class UACANet(nn.Module):
    # res2net based encoder decoder
    def __init__(self, channels=256, output_stride=16, pretrained=True):
        super(UACANet, self).__init__()
        self.resnet = res2net50_v1b_26w_4s(pretrained=pretrained, output_stride=output_stride)

        self.context2 = PAA_e(512, channels)
        self.context3 = PAA_e(1024, channels)
        self.context4 = PAA_e(2048, channels)

        self.decoder = PAA_d(channels)

        self.attention2 = UACA(channels * 2, channels)
        self.attention3 = UACA(channels * 2, channels)
        self.attention4 = UACA(channels * 2, channels)

        self.loss_fn = bce_iou_loss

        self.ret = lambda x, target: F.interpolate(x, size=target.shape[-2:], mode='bilinear', align_corners=False)
        self.res = lambda x, size: F.interpolate(x, size=size, mode='bilinear', align_corners=False)

    def forward(self, sample):
        x = sample['image']
        if 'gt' in sample.keys():
            y = sample['gt']
        else:
            y = None
            
        base_size = x.shape[-2:] # 352, 352
                                 # 输入为 bs, 3, 352, 352
        x = self.resnet.conv1(x) # bs, 64, 176, 176
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x) # bs, 64, 88, 88

        x1 = self.resnet.layer1(x) # bs, 256, 88, 88
        x2 = self.resnet.layer2(x1) # bs, 512, 44, 44
        x3 = self.resnet.layer3(x2) # bs, 1024, 22, 22
        x4 = self.resnet.layer4(x3) # bs, 2048, 11, 11

        x2 = self.context2(x2) # bs, 512, 44, 44 -- bs, 32, 44, 44
        x3 = self.context3(x3) # bs, 1024, 22, 22 -- bs, 32, 22, 22
        x4 = self.context4(x4) # bs, 2048, 11, 11 -- bs, 32, 11, 11

        f5, a5 = self.decoder(x4, x3, x2) # f5:bs, 32, 44, 44   a5:bs, 1, 44, 44
        out5 = self.res(a5, base_size) # bs, 1, 352, 352

        f4, a4 = self.attention4(torch.cat([x4, self.ret(f5, x4)], dim=1), a5)  # f4:bs, 32, 11, 11   a4:bs, 1, 11, 11
        out4 = self.res(a4, base_size)  # bs, 1, 352, 352

        f3, a3 = self.attention3(torch.cat([x3, self.ret(f4, x3)], dim=1), a4)   # f4:bs, 32, 22, 22   a4:bs, 1, 22, 22
        out3 = self.res(a3, base_size)  # bs, 1, 352, 352

        _, a2 = self.attention2(torch.cat([x2, self.ret(f3, x2)], dim=1), a3)   # a4:bs, 1, 44, 44
        out2 = self.res(a2, base_size)  # bs, 1, 352, 352


        if y is not None:
            loss5 = self.loss_fn(out5, y)
            loss4 = self.loss_fn(out4, y)
            loss3 = self.loss_fn(out3, y)
            loss2 = self.loss_fn(out2, y)

            loss = loss2 + loss3 + loss4 + loss5
            debug = [out5, out4, out3]
        else:
            loss = 0
            debug = []

        return {'pred': out2, 'loss': loss, 'debug': debug}

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值