Rotate to Attend 三分支结构捕获跨维度交互的注意力机制

Rotate to Attend: Convolutional Triplet Attention Module

  • 论文的名字很好,反映了本文的核心想法:triplet attention,这是一种通过使用三分支结构捕获跨维度交互来计算注意力权重的新方法。对于输入张量,triplet attention通过旋转操作,然后使用残差变换建立维度间的依存关系,并以可忽略的计算开销对通道间和空间信息进行编码。

paper:https://arxiv.org/pdf/2010.03045.pdf

github:https://github.com/landskape-ai/triplet-attention/

前言

  • 这次是复现WACV2021的一篇论文,本论文的注意力非常简单,而且是一个即插即用的小模块,接近于无参数可用于多种网络。

  • 本文的主要创新点是提出了一个新的注意力机制,是一个Channel & Spatial attention,在各CV任务测试性能如下

相关代码

具体的网络结构如上图所示:

  • 1.第一个分支:通道C和空间W维度交互捕获分支,输入特征先经过permute,变为H X C X W维度特征,接着在H维度上进行Z-Pool,后面操作类似。最后需要经过permuter变为C X H X W维度特征,方便进行element-wise相加

  • 2.第二个分支:通道C和空间H维度交互捕获分支,输入特征先经过permute,变为W X H X C维度特征,接着在W维度上进行Z-Pool,后面操作类似。最后需要经过permuter变为C X H X W维度特征,方便进行element-wise相加

  • 3.第三个分支:通道注意力计算分支,输入特征经过Z-Pool,再接着7 x 7卷积,最后Sigmoid激活函数生成空间注意力权重

最后对3个分支输出特征进行相加求Avg

import paddle
import paddle.nn as nn
import cv2

class BasicConv(nn.Layer):
    def __init__(
        self,
        in_planes,
        out_planes,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        relu=True,
        bn=True,
        bias_attr=False,
    ):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2D(
            in_planes,
            out_planes,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias_attr=bias_attr,
        )
        self.bn = (
            nn.BatchNorm2D(out_planes, epsilon=1e-5, momentum=0.01)
            if bn
            else None
        )
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class ZPool(nn.Layer):
    def forward(self, x):
        #print(x.shape)#[4, 16, 512, 16][512, 1, 16][4, 1, 512, 16]

        #print(paddle.max(x, 1).unsqueeze(1).shape)
        #print(paddle.mean(x, 1).unsqueeze(1).shape)
        return paddle.concat(
                            (paddle.max(x, 1).unsqueeze(1), 
                            paddle.mean(x, 1).unsqueeze(1))
                            ,axis=1)


class AttentionGate(nn.Layer):
    def __init__(self):
        super(AttentionGate, self).__init__()
        kernel_size = 7
        self.compress = ZPool()
        self.conv = BasicConv(
            2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
        )

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.conv(x_compress)
        scale = paddle.nn.functional.sigmoid(x_out)
        return x * scale


class TripletAttention(nn.Layer):
    def __init__(self, no_spatial=False):
        super(TripletAttention, self).__init__()
        self.cw = AttentionGate()
        self.hc = AttentionGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.hw = AttentionGate()

    def forward(self, x):
        x_perm1 = x.transpose([0, 2, 1, 3])
        x_out1 = self.cw(x_perm1)
        x_out11 = x_out1.transpose([0, 2, 1, 3])
        x_perm2 = x.transpose([0, 3, 2, 1])
        x_out2 = self.hc(x_perm2)
        x_out21 = x_out2.transpose([0, 3, 2, 1])
        if not self.no_spatial:
            x_out = self.hw(x)
            x_out = 1 / 3 * (x_out + x_out11 + x_out21)
        else:
            x_out = 1 / 2 * (x_out11 + x_out21)
        return x_out


[1, 512, 16, 16]

验证

input size = 64,512,16,16 --> TA --> output size = 64,512,16,16

if __name__=="__main__":
    a = paddle.rand([64,512,16,16])
    model = TripletAttention(512)
    a = model(a)
    print(a.shape)

对TA性能进行验证

论文中,作者在Resnet34测试,但是对于ResNet18深层网络作者没有做相关实验,我们这次搭建一个ResNet18网络来验证性能,TA模块插入位置如下。

TA_ResNet18 搭建

import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url

class BasicBlock(nn.Layer):
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2D

        if dilation > 1:
            raise NotImplementedError(
                "Dilation > 1 not supported in BasicBlock")

        self.conv1 = nn.Conv2D(
            inplanes, planes, 3, padding=1, stride=stride, bias_attr=False)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class BottleneckBlock(nn.Layer):

    expansion = 4

    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None):
        super(BottleneckBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2D
        width = int(planes * (base_width / 64.)) * groups

        self.conv1 = nn.Conv2D(inplanes, width, 1, bias_attr=False)
        self.bn1 = norm_layer(width)

        self.conv2 = nn.Conv2D(
            width,
            width,
            3,
            padding=dilation,
            stride=stride,
            groups=groups,
            dilation=dilation,
            bias_attr=False)
        self.bn2 = norm_layer(width)

        self.conv3 = nn.Conv2D(
            width, planes * self.expansion, 1, bias_attr=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU()
        self.downsample = downsample
        self.stride = stride
        self.attention = TripletAttention(planes * self.expansion)



    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.attention(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Layer):

    def __init__(self,
                 block,
                 depth=50,
                 width=64,
                 num_classes=1000,
                 with_pool=True):
        super(ResNet, self).__init__()
        layer_cfg = {
            18: [2, 2, 2, 2],
            34: [3, 4, 6, 3],
            50: [3, 4, 6, 3],
            101: [3, 4, 23, 3],
            152: [3, 8, 36, 3]
        }
        layers = layer_cfg[depth]
        self.groups = 1
        self.base_width = width
        self.num_classes = num_classes
        self.with_pool = with_pool
        self._norm_layer = nn.BatchNorm2D

        self.inplanes = 64
        self.dilation = 1

        self.conv1 = nn.Conv2D(
            3,
            self.inplanes,
            kernel_size=7,
            stride=2,
            padding=3,
            bias_attr=False)
        self.bn1 = self._norm_layer(self.inplanes)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        if with_pool:
            self.avgpool = nn.AdaptiveAvgPool2D((1, 1))

        if num_classes > 0:
            self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2D(
                    self.inplanes,
                    planes * block.expansion,
                    1,
                    stride=stride,
                    bias_attr=False),
                norm_layer(planes * block.expansion), )

        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.with_pool:
            x = self.avgpool(x)

        if self.num_classes > 0:
            x = paddle.flatten(x, 1)
            x = self.fc(x)

        return x


def _resnet(arch, Block, depth, pretrained, **kwargs):
    model = ResNet(Block, depth, **kwargs)
    if pretrained:
        assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
            arch)
        weight_path = get_weights_path_from_url(model_urls[arch][0],
                                                model_urls[arch][1])

        param = paddle.load(weight_path)
        model.set_dict(param)

    return model


def resnet18(pretrained=False, **kwargs):

    return _resnet('resnet18', BasicBlock, 18, pretrained, **kwargs)


def resnet34(pretrained=False, **kwargs):

    return _resnet('resnet34', BasicBlock, 34, pretrained, **kwargs)


def resnet50(pretrained=False, **kwargs):

    return _resnet('resnet50', BottleneckBlock, 50, pretrained, **kwargs)


def resnet101(pretrained=False, **kwargs):

    return _resnet('resnet101', BottleneckBlock, 101, pretrained, **kwargs)


def resnet152(pretrained=False, **kwargs):

    return _resnet('resnet152', BottleneckBlock, 152, pretrained, **kwargs)


def wide_resnet50_2(pretrained=False, **kwargs):
    kwargs['width'] = 64 * 2
    return _resnet('wide_resnet50_2', BottleneckBlock, 50, pretrained, **kwargs)


def wide_resnet101_2(pretrained=False, **kwargs):

    kwargs['width'] = 64 * 2
    return _resnet('wide_resnet101_2', BottleneckBlock, 101, pretrained,
                   **kwargs)
Ta_res50 = resnet50(num_classes=10)
paddle.Model(Ta_res50).summary((1,3,224,224))

Cifar10数据准备

import paddle.vision.transforms as T
from paddle.vision.datasets import Cifar10
paddle.set_device('gpu')

# 数据准备
transform = T.Compose([
    T.Resize(size=(224,224)),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],data_format='HWC'),
    T.ToTensor()
])

train_dataset = Cifar10(mode='train', transform=transform)
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)

ResNet18在Cifar10训练

# 模型准备
res50 = paddle.vision.models.resnet18(num_classes=10)
res50.train()


# 训练准备
epoch_num = 10
optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=res50.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()

res50_loss = []
res50_acc = []

for epoch in range(epoch_num):
    for batch_id, data in enumerate(train_loader):
        inputs = data[0]            
        labels = data[1].unsqueeze(1)            
        predicts = res50(inputs)    

        loss = loss_fn(predicts, labels)
        acc = paddle.metric.accuracy(predicts, labels)
        loss.backward()

        if batch_id % 100 == 0: 
            print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))

        if batch_id % 20 == 0:

            res50_loss.append(loss.numpy())
            res50_acc.append(acc.numpy())

        optim.step()
        optim.clear_grad()

TA_ResNet18在Cifar10数据集训练

# 模型准备
ta_res50 = resnet18(num_classes=10)
ta_res50.train()

# 训练准备
epoch_num = 10
optim = paddle.optimizer.Adam(learning_rate=0.001,parameters=ta_res50.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()

ta_res50_loss = []
ta_res50_acc = []

for epoch in range(epoch_num):
    for batch_id, data in enumerate(train_loader):
        inputs = data[0]            
        labels = data[1].unsqueeze(1)            
        predicts = ta_res50(inputs)    

        loss = loss_fn(predicts, labels)
        acc = paddle.metric.accuracy(predicts, labels)
        loss.backward()

        if batch_id % 100 == 0: 
            print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
            
        if batch_id % 20 == 0:
            ta_res50_loss.append(loss.numpy())
            ta_res50_acc.append(acc.numpy())

        optim.step()
        optim.clear_grad()

绘制 ResNet18 和 TA_ResNet18 训练曲线

import matplotlib.pyplot as plt

plt.figure(figsize=(18,12))
plt.subplot(211)

plt.xlabel('iter')
plt.ylabel('loss')
plt.title('train loss')

x=range(len(ta_res50_loss))
plt.plot(x,res50_loss,color='b',label='ResNet18')
plt.plot(x,ta_res50_loss,color='r',label='ResNet18 + TA')

plt.legend()
plt.grid()

plt.subplot(212)
plt.xlabel('iter')
plt.ylabel('acc')
plt.title('train acc')

x=range(len(ta_res50_acc))
plt.plot(x, res50_acc, color='b',label='ResNet18')
plt.plot(x, ta_res50_acc, color='r',label='ResNet18 + CA')

plt.legend()
plt.grid()

plt.show()
  • 模型训练总结:通过曲线可以看到加入了TA注意力机制之后模型收敛的速度会提高,识别的精准度也会提高,说明TA注意力的有效性。


总结

  • 作者观察到CBAM中的通道注意力方法虽然提供了显着的性能改进,却不是因为跨通道交互。

  • 因此作者提出了可以有效解决跨维度交互的triplet attention。相较于以往的注意力方法,主要有两个优点:

  • 1.可以忽略的计算开销

  • 2.强调了多维交互而不降低维度的重要性,因此消除了通道和权重之间的间接对应。

特别感谢:仰世而来丶(本文参考了https://aistudio.baidu.com/aistudio/projectdetail/1884947?channelType=0&channel=0)

请点击此处查看本环境基本用法.

Please click here for more detailed instructions.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值