SPANet:用于深度卷积神经网络的空间金字塔注意

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

SPANet:用于深度卷积神经网络的空间金字塔注意

摘要

        注意机制在计算机视觉中取得了巨大的成功。 然而,在某些实现中普遍使用的全局平均池将三维特征图聚合为一维注意力图,导致注意力学习中结构信息的显著丢失。 在本文中,我们提出了一个新的空间金字塔注意网络(SPANET),它利用结构信息和通道关系来更好地表示特征。 SPANet通过横向增加空间金字塔注意力(SPA)块来增强基础网络。 通过对自注意机制设计的重新思考,我们进一步提出了三种适用于我们Spanet的注意路径连接的拓扑结构。 它们可以灵活地应用于各种CNN体系结构。 SPANet在概念上很简单,但实际上很强大。 它利用结构正则化和结构信息来获得更好的学习能力。 我们在四个基准数据集上综合评估了SPANet在不同视觉任务下的性能。 实验结果表明,SPANet算法在不增加计算开销的情况下显著提高了识别精度。 在基于ResNet50的ImageNet2012基准测试中,使用SPANet实现了1.6%的Top-1分类准确率的提高,并且SPANet优于SENet和其他注意力方法。 SPANet还显著地提高了目标检测性能,增加的计算开销可以忽略不计。 将SPANet应用到基于ResNet50骨干网的RetinaNet中时,基线模型的性能提高了2.3mAP,增强模型的性能分别比SENet和GCNet提高了1.1mAP和1.7mAP。

1. SPANet

        与会议版本(SPANet:空间金字塔注意力网络)的不同之处在于金字塔特征的结合方式,会议版本将其展成1维并将其进行合并,然后送入到类似SE的权重生成网络,这样做会破坏金字塔特征的空间结构,因此期刊版本提出首先对特征图进行插值,然后将插值后的金字塔特征按权重(可学习)进行相加(可以更好地保留空间结构信息)并送入到类似SE的权重生成网络中以对通道进行调制,用公式表示为:

S = w ⊤ [ ρ fine  ( x l ) , ρ coarse  ( x l ) , ρ global  ( x l ) ] T = U ( σ ( τ ( τ ( S ) ) ) ) , \begin{array}{c} \mathbf{S}=w^{\boldsymbol{\top}}\left[\rho_{\text {fine }}\left(x_{l}\right), \rho_{\text {coarse }}\left(x_{l}\right), \rho_{\text {global }}\left(x_{l}\right)\right]\\ \mathbf{T}=\mathbf{U}(\sigma(\tau(\tau(\mathbf{S})))), \end{array} S=w[ρfine (xl),ρcoarse (xl),ρglobal (xl)]T=U(σ(τ(τ(S)))),

2. 代码复现

2.1 下载并导入所需要的包

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
from paddle import ParamAttr

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.Resize((130, 130)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomResizedCrop(128, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm)
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)

2.3 标签平滑

class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()

2.4 AlexNet-SPA

2.4.1 SPA
class SPA_Pro(nn.Layer):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool1 = nn.AdaptiveAvgPool2D(1)
        self.avg_pool2 = nn.AdaptiveAvgPool2D(2)
        self.avg_pool4 = nn.AdaptiveAvgPool2D(4)
        self.weight = self.create_parameter((1,3,1,1,1), default_initializer=nn.initializer.Constant(1.0))
        self.transform = nn.Sequential(
            nn.Conv2D(channel, channel//reduction, 1, bias_attr=False),
            nn.BatchNorm2D(channel//reduction),
            nn.ReLU(),
            nn.Conv2D(channel//reduction, channel, 1, bias_attr=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c,_, _ = x.shape
        y1 = self.avg_pool1(x)
        y2 = self.avg_pool2(x)
        y4 = self.avg_pool4(x)
        y = paddle.concat(
            [y4.unsqueeze(1),
             F.interpolate(y2, scale_factor=2).unsqueeze(1),
             F.interpolate(y1, scale_factor=4).unsqueeze(1)],
            axis=1
        )
        y = (y * self.weight).sum(axis=1,keepdim=False)
        y = self.transform(y)
        y = F.interpolate(y, size = x.shape[2:])

        return x * y
model = SPA_Pro(64)
paddle.summary(model, (1, 64, 224, 224))
W0808 19:16:57.674264   449 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0808 19:16:57.678887   449 gpu_resources.cc:91] device: 0, cuDNN Version: 7.6.


-------------------------------------------------------------------------------
   Layer (type)         Input Shape          Output Shape         Param #    
===============================================================================
AdaptiveAvgPool2D-1 [[1, 64, 224, 224]]     [1, 64, 1, 1]            0       
AdaptiveAvgPool2D-2 [[1, 64, 224, 224]]     [1, 64, 2, 2]            0       
AdaptiveAvgPool2D-3 [[1, 64, 224, 224]]     [1, 64, 4, 4]            0       
     Conv2D-1         [[1, 64, 4, 4]]        [1, 4, 4, 4]           256      
   BatchNorm2D-1       [[1, 4, 4, 4]]        [1, 4, 4, 4]           16       
      ReLU-5           [[1, 4, 4, 4]]        [1, 4, 4, 4]            0       
     Conv2D-2          [[1, 4, 4, 4]]       [1, 64, 4, 4]           256      
     Sigmoid-2        [[1, 64, 4, 4]]       [1, 64, 4, 4]            0       
===============================================================================
Total params: 528
Trainable params: 512
Non-trainable params: 16
-------------------------------------------------------------------------------
Input size (MB): 12.25
Forward/backward pass size (MB): 0.03
Params size (MB): 0.00
Estimated Total Size (MB): 12.28
-------------------------------------------------------------------------------






{'total_params': 528, 'trainable_params': 512}
2.4.2 AlexNet-SPA
class AlexNet_SPA_Pro(nn.Layer):
    def __init__(self,num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2D(3,48, kernel_size=11, stride=4, padding=11//2),
            SPA_Pro(48),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(48,128, kernel_size=5, padding=2),
            SPA_Pro(128),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
            nn.Conv2D(128, 192,kernel_size=3,stride=1,padding=1),
            SPA_Pro(192),
            nn.ReLU(),
            nn.Conv2D(192,192,kernel_size=3,stride=1,padding=1),
            SPA_Pro(192),
            nn.ReLU(),
            nn.Conv2D(192,128,kernel_size=3,stride=1,padding=1),
            SPA_Pro(128),
            nn.ReLU(),
            nn.MaxPool2D(kernel_size=3,stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(3 * 3 * 128,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,2048),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2048,num_classes),
        )
 
 
    def forward(self,x):
        x = self.features(x)
        x = paddle.flatten(x, 1)
        x=self.classifier(x)
 
        return x
model = AlexNet_SPA_Pro(num_classes=10)
paddle.summary(model, (1, 3, 128, 128))

2.5 训练

learning_rate = 0.001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

model = AlexNet_SPA_Pro(num_classes=10)

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)
    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

2.6 实验结果

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

请添加图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

请添加图片描述

import time
work_path = 'work/model'
model = AlexNet_SPA_Pro(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:1880
def get_cifar10_labels(labels):  
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):  
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = AlexNet_SPA_Pro(num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 128, 128, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

[0…1] for floats or [0…255] for integers).

![请添加图片描述](https://img-blog.csdnimg.cn/69f0f96ff1324dacb7a62671ff85cba6.png)

3. 对比实验结果

modelTrain AccVal Accparameter
AlexNet w/o SPA0.77850.804897524042
AlexNet w SPA0.85240.849687673642
AlexNet w SPA-Pro0.84730.849097537829

总结

        期刊版本对金字塔特征结合方式做了一定的调整,在参数上更高效,同时也能达到相近的性能,并减少过拟合现象

参考文献

Spatial Pyramid Attention for Deep Convolutional Neural Networks
ma-xu/SPANet_TMM
SPANet:空间金字塔注意力网络

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值