每日Attention学习12——Exterior Contextual-Relation Module

模块出处

[ISBI 22] [link] [code] Duplex Contextual Relation Network for Polyp Segmentation


模块名称

Exterior Contextual-Relation Module (ECRM)


模块作用

内存型特征增强模块


模块结构

在这里插入图片描述


模块思想

原文表述:在临床环境中,不同样本之间存在息肉的同步视觉模式。基于这一关键观察,属于所有训练数据的同一语义类的区域特征应该具有上下文关系。因此,我们提出了一种新颖的跨不同样本的上下文关系探索模块。
具体做法则是,对于编码器最后一层得到的全局特征(图中红色方块),进行两次增强:
第一次是直接将全局特征送入一个 1 × 1 1 \times 1 1×1卷积(图中浅紫色部分)以获取一个粗糙分割mask,该mask与全局特征相乘后便能得到过滤掉背景特征的增强特征(图中enqueue左边的部分)。
第二次增强则是基于网络存储的源自其他训练样本的历史上下文信息(图中的Cross-Batch Memory)。即,当前特征与Memory内特征进行Cross Attention操作,从而利用历史经验对当前状态进行补全。


模块代码

代码实现有几个额外要注意的地方:

  • 模块返回的aux_out要进行side supervision监督,以保证准确性;
  • Memory负责维护网络的历史信息,为防止被破坏,这部分信息并不参与梯度更新过程;
  • 在测试阶段,Memory不再更新,直接使用训练所存储的历史信息,这一思想与BatchNorm类似。
import torch
from torch import nn

def conv2d(in_channel, out_channel, kernel_size):
    layers = [
        nn.Conv2d(
            in_channel, out_channel, kernel_size, padding=kernel_size // 2, bias=False
        ),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
    ]
    return nn.Sequential(*layers)


def conv1d(in_channel, out_channel):
    layers = [
        nn.Conv1d(in_channel, out_channel, 1, bias=False),
        nn.BatchNorm1d(out_channel),
        nn.ReLU(),
    ]
    return nn.Sequential(*layers)


class ECRM(nn.Module):
    def __init__(self, bank_size=20, feat_channels=512, num_classes=1):
        super(ECRM, self).__init__()  
        # BANK CONFIG
        self.bank_size = bank_size
        self.register_buffer("bank_ptr", torch.zeros(1, dtype=torch.long))  # memory bank pointer
        self.register_buffer("bank", torch.zeros(self.bank_size, feat_channels, num_classes))  # memory bank
        self.bank_full = False

        # ATTENTION CONFIG
        self.feat_channels = feat_channels
        self.L = nn.Conv2d(feat_channels, num_classes, 1)
        self.X = conv2d(feat_channels, 512, 3)
        self.phi = conv1d(512, 256)
        self.psi = conv1d(512, 256)
        self.delta = conv1d(512, 256)
        self.rho = conv1d(256, 512)
        self.g = conv2d(512 + 512, 512, 1)

    def init(self):
        self.bank_ptr[0] = 0
        self.bank_full = False

    @torch.no_grad()
    def update_bank(self, x):
        ptr = int(self.bank_ptr)
        batch_size = x.shape[0]
        vacancy = self.bank_size - ptr
        if batch_size >= vacancy:
            self.bank_full = True
        pos = min(batch_size, vacancy)
        self.bank[ptr:ptr+pos] = x[0:pos].clone()
        # update pointer
        ptr = (ptr + pos) % self.bank_size
        self.bank_ptr[0] = ptr

    def enhance_by_memory(self, bank, X_flat, X):
        batch, n_class, height, width = X.shape
        # query = S * C
        query = self.phi(bank).squeeze(dim=2)
        # key: = B * C * HW
        key = self.psi(X_flat)
        # logit = HW * S * B (cross image relation)
        logit = torch.matmul(query, key).transpose(0,2)
        # attn = HW * S * B
        attn = torch.softmax(logit, 2)
        # delta = S * C
        delta = self.delta(bank).squeeze(dim=2)
        # attn_sum = B * C * HW
        attn_sum = torch.matmul(attn.transpose(1,2), delta).transpose(1,2)
        # x_obj = B * C * H * W
        X_obj = self.rho(attn_sum).view(batch, -1, height, width)
        concat = torch.cat([X, X_obj], 1)
        out = self.g(concat)
        return out
    
    def get_prototype(self, input):
        L = self.L(input)
        aux_out = L
        batch, n_class, _, _ = L.shape
        l_flat = L.view(batch, n_class, -1)
        M = torch.softmax(l_flat, -1)
        X = self.X(input)
        channel = X.shape[1]
        X_flat = X.view(batch, channel, -1)
        f_k = (M @ X_flat.transpose(1, 2)).transpose(1, 2)
        return aux_out, f_k, X_flat, X

    def forward(self, x, flag='train'):
        # x [3, 512, 11, 11]
        # patch [3, 512, 1]
        aux_out, patch, feats_flat, feats = self.get_prototype(x)
        if flag == 'train':
            self.update_bank(patch)
            ptr = int(self.bank_ptr)
            if self.bank_full == True:
                out = self.enhance_by_memory(self.bank, feats_flat, feats)
            else:
                out = self.enhance_by_memory(self.bank[0:ptr], feats_flat, feats)
        elif flag == 'test':
            out = self.enhance_by_memory(patch, feats_flat, feats)
        return out, aux_out
    
if __name__ == '__main__':
    x = torch.randn([3, 512, 11, 11])
    ecrm = ECRM()
    out = ecrm(x)
    print(out[0].shape)  # 3, 512, 11, 11
    print(out[1].shape)  # 3, 1, 11, 11

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值