SpectFormer:频率和注意力是你在ViT中所需要的

SpectFormer是一种新的Transformer架构,结合了频域层和多头自注意力层,以提升图像识别任务的性能。研究表明,这种混合设计能更好地捕获图像的全局和局部特征,提高模型的准确性。在ImageNet上,SpectFormer达到了85.7%的Top-1精度,显示了其潜在优势。
摘要由CSDN通过智能技术生成

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

摘要

        视觉Transformer已成功地应用于图像识别任务中。 在文本模型中有类似于最初工作的基于多头自注意力的模型(ViT,DeIT),或者最近基于光谱层的模型(Fnet,GFNet,AFNO)。 我们假设光谱和多头注意力都起了主要作用。 我们通过这项工作研究了这一假设,并观察到实际上结合光谱和多头注意层提供了一个更好的Transformer结构。 因此,我们提出了一种新的Transformer的SPectformer结构,它结合了光谱和多头注意力层。 我们相信,所得到的表示允许转换器适当地捕获特征表示,并且它比其他转换器表示产生了更好的性能。 例如,与GFNET-H和LIT相比,它在ImageNet上将Top-1的准确度提高了2%。 Spectformer-S在ImageNet-1K(小型版本的最新技术)上达到84.25%的Top-1精度。 此外,Spectformer-L实现了85.7%,这是变形金刚可比基础版本的最新技术。 我们进一步确保我们在其他场景中获得合理的结果,如在标准数据集上的迁移学习,如CIFAR-10、CIFAR-100、Oxford-IIIT-Flower和Standford CAR数据集。 然后,我们研究了它在MS-COCO数据集的Ofobject检测和实例分割等下游任务中的应用,并观察到Spectformer显示出与最好的骨干相当的一致性性能,并可以进一步优化和改进。 因此,我们认为,结合光谱和注意力层是视觉Transformer所需要的。

1. SpectFormer

1.1 频域层和自注意力层混合建模的合理性验证

        为了验证SpectFormer架构设计的合理性,作者首先对这种混合建模形式进行实验验证。首先对频域层和多头自注意层的不同组合进行性能对比,这些组合包括:(1)全注意力层,(2)全频域层,(3)频域层在前、注意力层在后,即本文提出的SpectFormer架构,(4)注意力层在前,频域层在后,作者将这种设置称为反向SpectFormer。这四种组合设置的性能对比如下图所示。可以看出,先使用频域层对图像提取浅层特征,然后再使用多头自注意力层进行深层次的特征建模效果更好。因此可以证明本文提出的SpectFormer的架构合理性。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8GGknUJk-1687177709373)(https://ai-studio-static-online.cdn.bcebos.com/acd18acc542d4279a1041bd98cd14091ed85a367452a47d29ceb5a9b3aa59efc)]

1.2 SpectFormer架构

        SpectFormer架构的整体框架如下图所示,在下图左侧作者首先放置了DeiT的分层结构图。可以看出,SpectFormer整体也呈现出分层Transformer的设计,相比于DeiT,SpectFormer划分了四个特征提取stage,每个stage中由多个SpectFormer块堆叠而成,每个SpectFormer块又由若干频域块和注意力块构成。除此之外,SpectFormer与其他Transformer结构类似,包括一个线性图像块嵌入层,后跟一个位置编码嵌入层。并且在注意力建模之后设置一个分类头(分类数量为1000的MLP层)。如果我们仔细观察的话可以发现,在四个特征提取stage中,频域层只出现在了较为浅层的stage1和stage2中,而在stage3和stage4中则完全通过自注意力层完成操作。

1.2.1 频域层设计

        频域层的设计目标是从复数角度审视输入图像,来捕获图像的不同频率分量来提取图像中的局部频域特征。这一操作可以通过一个频域门控网络来实现,该网络由一个快速傅里叶变换层(FFT)、一个加权门控层和一个逆傅里叶层(IFFT)构成。首先通过FFT层来将图像的物理空间转到频域空间中,然后使用具有可学习权重参数的门控层来确定每个频率分量的权重,以便适当地捕获图像的线条和边缘,门控层可以使用网络的反向传播进行参数更新。随后IFFT层再将频域空间转回物理空间中。此外作者提到,在频域层中除了使用FFT和IFFT操作,还可以使用小波变换和逆小波变换来实现。

1.2.2 自注意力层设计

        SpectFormer的自注意力层是一个标准的注意力层实现,由层归一化层、多头自注意力层(multiheaded self-attention,MHSA)和MLP层堆叠构成。SpectFormer中的MHSA使用与DeiT相同的结构,即在自注意力建模阶段先使用MHSA进行token融合,然后再通过MLP层进行通道融合。

1.2.3 整合后的SpectFormer层

        为了实现频域层和自注意力层之间的性能平衡,作者引入了一个 α 参数来控制SpectFormer层中频域层和自注意力层之间的比例。如果α=0,代表SpectFormer层完全使用自注意力层实现,此时的SpectFormer等价于DeiT。当 α=12 时,SpectFormer等价于GFNet,即完全使用频域层构成。需要注意的是,所有的注意力层都存在局部特征捕捉不精确的缺点,而所有的频域层都存在无法准确处理全局图像属性或语义特征的缺点。因此,SpectFormer的这种混合设计具有灵活性,可以动态的改变频域层和注意力层的数量,从而有助于准确捕捉全局属性和局部特征。

2. 代码复现

2.1 下载并导入所需的库

!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
import itertools
from functools import partial
import math

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
2.3.3 SpectFormer模型的创建
class Attention(nn.Layer):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape((B, N, 3, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
        attn = F.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
class Mlp(nn.Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
class SpectralGatingNetwork(nn.Layer):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = self.create_parameter(shape=(h, w, dim, 2), default_initializer=nn.initializer.TruncatedNormal(std=.02))
        self.w = w
        self.h = h

    def forward(self, x, spatial_size=None):
        B, N, C = x.shape
        if spatial_size is None:
            a = b = int(math.sqrt(N))
        else:
            a, b = spatial_size

        x = x.reshape((B, a, b, C))

        x = paddle.fft.rfft2(x, axes=(1, 2), norm='ortho')
        weight = paddle.as_complex(self.complex_weight)
        x = x * weight
        x = paddle.fft.irfft2(x, s=(a, b), axes=(1, 2), norm='ortho')

        x = x.reshape((B, N, C))

        return x
class Block(nn.Layer):

    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.filter = SpectralGatingNetwork(dim, h=h, w=w)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
        return x
class Block_attention(nn.Layer):

    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8, num_heads=4):
        super().__init__()
        num_heads= num_heads # 4 for tiny, 6 for small and 12 for base
        self.norm1 = norm_layer(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.attn = Attention(dim,  num_heads=num_heads, qkv_bias=True, qk_scale=False, attn_drop=drop, proj_drop=drop)

    def forward(self, x):
        # x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
def to_2tuple(x):
    return (x, x)


class PatchEmbed(nn.Layer):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose([0, 2, 1])
        return x
class DownLayer(nn.Layer):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=56, dim_in=64, dim_out=128):
        super().__init__()
        self.img_size = img_size
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.proj = nn.Conv2D(dim_in, dim_out, kernel_size=2, stride=2)
        self.num_patches = img_size * img_size // 4

    def forward(self, x):
        B, N, C = x.size()
        x = x.reshape((B, self.img_size, self.img_size, C)).transpose([0, 3, 1, 2])
        x = self.proj(x).transpose([0, 2, 3, 1])
        x = x.reshape((B, -1, self.dim_out))
        return x
class SpectFormer(nn.Layer):

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, head=4, embed_dim=768, depth=12,
                 mlp_ratio=4., representation_size=None, uniform_drop=False,
                 drop_rate=0., drop_path_rate=0., norm_layer=None,
                 dropcls=0):

        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6)

        self.patch_embed = PatchEmbed(
                img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.pos_embed = self.create_parameter(shape=(1, num_patches, embed_dim), default_initializer=nn.initializer.TruncatedNormal(std=.02))
        self.pos_drop = nn.Dropout(p=drop_rate)

        h = img_size // patch_size
        w = h // 2 + 1

        if uniform_drop:
            # print('using uniform droppath with expect rate', drop_path_rate)
            dpr = [drop_path_rate for _ in range(depth)]  # stochastic depth decay rule
        else:
            # print('using linear droppath with expect rate', drop_path_rate * 0.5)
            dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        # dpr = [drop_path_rate for _ in range(depth)]  # stochastic depth decay rule

        alpha=4
        self.blocks = nn.LayerList()
        for i in range(depth):
            if i<alpha:
                layer = Block(dim=embed_dim, mlp_ratio=mlp_ratio,drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, h=h, w=w)
                self.blocks.append(layer)
            else:
                layer = Block_attention(dim=embed_dim, mlp_ratio=mlp_ratio,drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, h=h, w=w, num_heads=head)
                self.blocks.append(layer)

        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(embed_dim, representation_size)),
                ('act', nn.Tanh())
            ]))
        else:
            self.pre_logits = nn.Identity()

        # Classifier head
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        if dropcls > 0:
            print('dropout %.2f before classifier' % dropcls)
            self.final_dropout = nn.Dropout(p=dropcls)
        else:
            self.final_dropout = nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        tn = nn.initializer.TruncatedNormal(std=.02)
        zero = nn.initializer.Constant(0.0)
        one = nn.initializer.Constant(1.0)
        if isinstance(m, nn.Linear):
            tn(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zero(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zero(m.bias)
            one(m.weight)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x).mean(1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.final_dropout(x)
        x = self.head(x)
        return x
num_classes = 10

def spectformer_ti():

    model = SpectFormer(embed_dim=256, head=4, depth=12, num_classes=num_classes)

    return model

def spectformer_xs():

    model = SpectFormer(embed_dim=384, head=6, depth=12, num_classes=num_classes)

    return model

def spectformer_s():

    model = SpectFormer(embed_dim=384, head=6, depth=19, num_classes=num_classes)

    return model

def spectformer_b():

    model = SpectFormer(embed_dim=512, head=8, depth=19, num_classes=num_classes)

    return model
2.3.4 模型的参数
model = spectformer_ti()
paddle.summary(model, (1, 3, 224, 224))

model = spectformer_xs()
paddle.summary(model, (1, 3, 224, 224))

model = spectformer_s()
paddle.summary(model, (1, 3, 224, 224))

model = spectformer_b()
paddle.summary(model, (1, 3, 224, 224))

2.4 训练

learning_rate = 0.001
n_epochs = 100
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'

# SpectFormer-Ti
model = spectformer_ti()

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)

    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))

2.5 结果分析

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')

在这里插入图片描述

import time
work_path = 'work/model'
model = spectformer_ti()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:792
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = spectformer_ti()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

在这里插入图片描述

!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = spectformer_ti()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:55<00:00, 178.97it/s]

55<00:00, 178.97it/s]

在这里插入图片描述

总结

        本文对传统Transformer的核心架构进行了分析,并且分别探索了频域和多头自注意力层的作用效果。之前的Transformer网络要么只使用全注意力层,要么只使用频域层,在图像特征提取方面存在各自的局限性。本文提出了一种新型的混合Transformer架构,即将这两个方面结合起来,提出了Spectformer模型。Spectformer显示出比先前模型更加稳定的性能。除了在传统的视觉任务上可以获得SOTA性能之外(在ImageNet-1K数据集上实现了85.7%的Top-1识别准确率),作者还认为,将Spectformer应用到一些频域信息更加丰富的领域上(例如遥感和医学图像数据),可能会激发出混合频域层和注意力层更大的潜力。

参考文献

  1. SpectFormer: Frequency and Attention is what you need in a Vision Transformer
  2. badripatro/SpectFormers
  3. Transformer仅有自注意力还不够?微软联合巴斯大学提出频域混合注意力SpectFormer

此文章为搬运
原项目链接

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值