★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>
摘要
视觉Transformer已成功地应用于图像识别任务中。 在文本模型中有类似于最初工作的基于多头自注意力的模型(ViT,DeIT),或者最近基于光谱层的模型(Fnet,GFNet,AFNO)。 我们假设光谱和多头注意力都起了主要作用。 我们通过这项工作研究了这一假设,并观察到实际上结合光谱和多头注意层提供了一个更好的Transformer结构。 因此,我们提出了一种新的Transformer的SPectformer结构,它结合了光谱和多头注意力层。 我们相信,所得到的表示允许转换器适当地捕获特征表示,并且它比其他转换器表示产生了更好的性能。 例如,与GFNET-H和LIT相比,它在ImageNet上将Top-1的准确度提高了2%。 Spectformer-S在ImageNet-1K(小型版本的最新技术)上达到84.25%的Top-1精度。 此外,Spectformer-L实现了85.7%,这是变形金刚可比基础版本的最新技术。 我们进一步确保我们在其他场景中获得合理的结果,如在标准数据集上的迁移学习,如CIFAR-10、CIFAR-100、Oxford-IIIT-Flower和Standford CAR数据集。 然后,我们研究了它在MS-COCO数据集的Ofobject检测和实例分割等下游任务中的应用,并观察到Spectformer显示出与最好的骨干相当的一致性性能,并可以进一步优化和改进。 因此,我们认为,结合光谱和注意力层是视觉Transformer所需要的。
1. SpectFormer
1.1 频域层和自注意力层混合建模的合理性验证
为了验证SpectFormer架构设计的合理性,作者首先对这种混合建模形式进行实验验证。首先对频域层和多头自注意层的不同组合进行性能对比,这些组合包括:(1)全注意力层,(2)全频域层,(3)频域层在前、注意力层在后,即本文提出的SpectFormer架构,(4)注意力层在前,频域层在后,作者将这种设置称为反向SpectFormer。这四种组合设置的性能对比如下图所示。可以看出,先使用频域层对图像提取浅层特征,然后再使用多头自注意力层进行深层次的特征建模效果更好。因此可以证明本文提出的SpectFormer的架构合理性。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8GGknUJk-1687177709373)(https://ai-studio-static-online.cdn.bcebos.com/acd18acc542d4279a1041bd98cd14091ed85a367452a47d29ceb5a9b3aa59efc)]
1.2 SpectFormer架构
SpectFormer架构的整体框架如下图所示,在下图左侧作者首先放置了DeiT的分层结构图。可以看出,SpectFormer整体也呈现出分层Transformer的设计,相比于DeiT,SpectFormer划分了四个特征提取stage,每个stage中由多个SpectFormer块堆叠而成,每个SpectFormer块又由若干频域块和注意力块构成。除此之外,SpectFormer与其他Transformer结构类似,包括一个线性图像块嵌入层,后跟一个位置编码嵌入层。并且在注意力建模之后设置一个分类头(分类数量为1000的MLP层)。如果我们仔细观察的话可以发现,在四个特征提取stage中,频域层只出现在了较为浅层的stage1和stage2中,而在stage3和stage4中则完全通过自注意力层完成操作。
1.2.1 频域层设计
频域层的设计目标是从复数角度审视输入图像,来捕获图像的不同频率分量来提取图像中的局部频域特征。这一操作可以通过一个频域门控网络来实现,该网络由一个快速傅里叶变换层(FFT)、一个加权门控层和一个逆傅里叶层(IFFT)构成。首先通过FFT层来将图像的物理空间转到频域空间中,然后使用具有可学习权重参数的门控层来确定每个频率分量的权重,以便适当地捕获图像的线条和边缘,门控层可以使用网络的反向传播进行参数更新。随后IFFT层再将频域空间转回物理空间中。此外作者提到,在频域层中除了使用FFT和IFFT操作,还可以使用小波变换和逆小波变换来实现。
1.2.2 自注意力层设计
SpectFormer的自注意力层是一个标准的注意力层实现,由层归一化层、多头自注意力层(multiheaded self-attention,MHSA)和MLP层堆叠构成。SpectFormer中的MHSA使用与DeiT相同的结构,即在自注意力建模阶段先使用MHSA进行token融合,然后再通过MLP层进行通道融合。
1.2.3 整合后的SpectFormer层
为了实现频域层和自注意力层之间的性能平衡,作者引入了一个 α 参数来控制SpectFormer层中频域层和自注意力层之间的比例。如果α=0,代表SpectFormer层完全使用自注意力层实现,此时的SpectFormer等价于DeiT。当 α=12 时,SpectFormer等价于GFNet,即完全使用频域层构成。需要注意的是,所有的注意力层都存在局部特征捕捉不精确的缺点,而所有的频域层都存在无法准确处理全局图像属性或语义特征的缺点。因此,SpectFormer的这种混合设计具有灵活性,可以动态的改变频域层和注意力层的数量,从而有助于准确捕捉全局属性和局部特征。
2. 代码复现
2.1 下载并导入所需的库
!pip install paddlex
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
import itertools
from functools import partial
import math
2.2 创建数据集
train_tfm = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomRotation(20),
paddlex.transforms.MixupImage(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
test_tfm = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
train_dataset: 50000
val_dataset: 10000
batch_size=256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
2.3 模型的创建
2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
def __init__(self, smoothing=0.1):
super().__init__()
self.smoothing = smoothing
def forward(self, pred, target):
confidence = 1. - self.smoothing
log_probs = F.log_softmax(pred, axis=-1)
idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
nll_loss = paddle.gather_nd(-log_probs, index=idx)
smooth_loss = paddle.mean(-log_probs, axis=-1)
loss = confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
return output
class DropPath(nn.Layer):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
2.3.3 SpectFormer模型的创建
class Attention(nn.Layer):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.dim = dim
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape((B, N, 3, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
attn = F.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Layer):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SpectralGatingNetwork(nn.Layer):
def __init__(self, dim, h=14, w=8):
super().__init__()
self.complex_weight = self.create_parameter(shape=(h, w, dim, 2), default_initializer=nn.initializer.TruncatedNormal(std=.02))
self.w = w
self.h = h
def forward(self, x, spatial_size=None):
B, N, C = x.shape
if spatial_size is None:
a = b = int(math.sqrt(N))
else:
a, b = spatial_size
x = x.reshape((B, a, b, C))
x = paddle.fft.rfft2(x, axes=(1, 2), norm='ortho')
weight = paddle.as_complex(self.complex_weight)
x = x * weight
x = paddle.fft.irfft2(x, s=(a, b), axes=(1, 2), norm='ortho')
x = x.reshape((B, N, C))
return x
class Block(nn.Layer):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = SpectralGatingNetwork(dim, h=h, w=w)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
return x
class Block_attention(nn.Layer):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8, num_heads=4):
super().__init__()
num_heads= num_heads # 4 for tiny, 6 for small and 12 for base
self.norm1 = norm_layer(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=True, qk_scale=False, attn_drop=drop, proj_drop=drop)
def forward(self, x):
# x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def to_2tuple(x):
return (x, x)
class PatchEmbed(nn.Layer):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose([0, 2, 1])
return x
class DownLayer(nn.Layer):
""" Image to Patch Embedding
"""
def __init__(self, img_size=56, dim_in=64, dim_out=128):
super().__init__()
self.img_size = img_size
self.dim_in = dim_in
self.dim_out = dim_out
self.proj = nn.Conv2D(dim_in, dim_out, kernel_size=2, stride=2)
self.num_patches = img_size * img_size // 4
def forward(self, x):
B, N, C = x.size()
x = x.reshape((B, self.img_size, self.img_size, C)).transpose([0, 3, 1, 2])
x = self.proj(x).transpose([0, 2, 3, 1])
x = x.reshape((B, -1, self.dim_out))
return x
class SpectFormer(nn.Layer):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, head=4, embed_dim=768, depth=12,
mlp_ratio=4., representation_size=None, uniform_drop=False,
drop_rate=0., drop_path_rate=0., norm_layer=None,
dropcls=0):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6)
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.pos_embed = self.create_parameter(shape=(1, num_patches, embed_dim), default_initializer=nn.initializer.TruncatedNormal(std=.02))
self.pos_drop = nn.Dropout(p=drop_rate)
h = img_size // patch_size
w = h // 2 + 1
if uniform_drop:
# print('using uniform droppath with expect rate', drop_path_rate)
dpr = [drop_path_rate for _ in range(depth)] # stochastic depth decay rule
else:
# print('using linear droppath with expect rate', drop_path_rate * 0.5)
dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
# dpr = [drop_path_rate for _ in range(depth)] # stochastic depth decay rule
alpha=4
self.blocks = nn.LayerList()
for i in range(depth):
if i<alpha:
layer = Block(dim=embed_dim, mlp_ratio=mlp_ratio,drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, h=h, w=w)
self.blocks.append(layer)
else:
layer = Block_attention(dim=embed_dim, mlp_ratio=mlp_ratio,drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer, h=h, w=w, num_heads=head)
self.blocks.append(layer)
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
# Classifier head
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
if dropcls > 0:
print('dropout %.2f before classifier' % dropcls)
self.final_dropout = nn.Dropout(p=dropcls)
else:
self.final_dropout = nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
tn = nn.initializer.TruncatedNormal(std=.02)
zero = nn.initializer.Constant(0.0)
one = nn.initializer.Constant(1.0)
if isinstance(m, nn.Linear):
tn(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zero(m.bias)
elif isinstance(m, nn.LayerNorm):
zero(m.bias)
one(m.weight)
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x).mean(1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.final_dropout(x)
x = self.head(x)
return x
num_classes = 10
def spectformer_ti():
model = SpectFormer(embed_dim=256, head=4, depth=12, num_classes=num_classes)
return model
def spectformer_xs():
model = SpectFormer(embed_dim=384, head=6, depth=12, num_classes=num_classes)
return model
def spectformer_s():
model = SpectFormer(embed_dim=384, head=6, depth=19, num_classes=num_classes)
return model
def spectformer_b():
model = SpectFormer(embed_dim=512, head=8, depth=19, num_classes=num_classes)
return model
2.3.4 模型的参数
model = spectformer_ti()
paddle.summary(model, (1, 3, 224, 224))
model = spectformer_xs()
paddle.summary(model, (1, 3, 224, 224))
model = spectformer_s()
paddle.summary(model, (1, 3, 224, 224))
model = spectformer_b()
paddle.summary(model, (1, 3, 224, 224))
2.4 训练
learning_rate = 0.001
n_epochs = 100
paddle.seed(42)
np.random.seed(42)
work_path = 'work/model'
# SpectFormer-Ti
model = spectformer_ti()
criterion = LabelSmoothingCrossEntropy()
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)
gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}} # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}} # for recording accuracy
loss_iter = 0
acc_iter = 0
for epoch in range(n_epochs):
# ---------- Training ----------
model.train()
train_num = 0.0
train_loss = 0.0
val_num = 0.0
val_loss = 0.0
accuracy_manager = paddle.metric.Accuracy()
val_accuracy_manager = paddle.metric.Accuracy()
print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
for batch_id, data in enumerate(train_loader):
x_data, y_data = data
labels = paddle.unsqueeze(y_data, axis=1)
logits = model(x_data)
loss = criterion(logits, y_data)
acc = paddle.metric.accuracy(logits, labels)
accuracy_manager.update(acc)
if batch_id % 10 == 0:
loss_record['train']['loss'].append(loss.numpy())
loss_record['train']['iter'].append(loss_iter)
loss_iter += 1
loss.backward()
optimizer.step()
scheduler.step()
optimizer.clear_grad()
train_loss += loss
train_num += len(y_data)
total_train_loss = (train_loss / train_num) * batch_size
train_acc = accuracy_manager.accumulate()
acc_record['train']['acc'].append(train_acc)
acc_record['train']['iter'].append(acc_iter)
acc_iter += 1
# Print the information.
print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))
# ---------- Validation ----------
model.eval()
for batch_id, data in enumerate(val_loader):
x_data, y_data = data
labels = paddle.unsqueeze(y_data, axis=1)
with paddle.no_grad():
logits = model(x_data)
loss = criterion(logits, y_data)
acc = paddle.metric.accuracy(logits, labels)
val_accuracy_manager.update(acc)
val_loss += loss
val_num += len(y_data)
total_val_loss = (val_loss / val_num) * batch_size
loss_record['val']['loss'].append(total_val_loss.numpy())
loss_record['val']['iter'].append(loss_iter)
val_acc = val_accuracy_manager.accumulate()
acc_record['val']['acc'].append(val_acc)
acc_record['val']['iter'].append(acc_iter)
print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))
# ===================save====================
if val_acc > best_acc:
best_acc = val_acc
paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))
print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
2.5 结果分析
def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
''' Plot learning curve of your CNN '''
maxtrain = max(map(float, record['train'][title]))
maxval = max(map(float, record['val'][title]))
ymax = max(maxtrain, maxval) * 1.1
mintrain = min(map(float, record['train'][title]))
minval = min(map(float, record['val'][title]))
ymin = min(mintrain, minval) * 0.9
total_steps = len(record['train'][title])
x_1 = list(map(int, record['train']['iter']))
x_2 = list(map(int, record['val']['iter']))
figure(figsize=(10, 6))
plt.plot(x_1, record['train'][title], c='tab:red', label='train')
plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
plt.ylim(ymin, ymax)
plt.xlabel('Training steps')
plt.ylabel(ylabel)
plt.title('Learning curve of {}'.format(title))
plt.legend()
plt.show()
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')
plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')
import time
work_path = 'work/model'
model = spectformer_ti()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):
x_data, y_data = data
labels = paddle.unsqueeze(y_data, axis=1)
with paddle.no_grad():
logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
Throughout:792
def get_cifar10_labels(labels):
"""返回CIFAR10数据集的文本标签。"""
text_labels = [
'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
'horse', 'ship', 'truck']
return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
"""Plot a list of images."""
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if paddle.is_tensor(img):
ax.imshow(img.numpy())
else:
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if pred or gt:
ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
return axes
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = spectformer_ti()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
!pip install interpretdl
import interpretdl as it
work_path = 'work/model'
model = spectformer_ti()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model)
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
100%|██████████| 10000/10000 [00:55<00:00, 178.97it/s]
55<00:00, 178.97it/s]
总结
本文对传统Transformer的核心架构进行了分析,并且分别探索了频域和多头自注意力层的作用效果。之前的Transformer网络要么只使用全注意力层,要么只使用频域层,在图像特征提取方面存在各自的局限性。本文提出了一种新型的混合Transformer架构,即将这两个方面结合起来,提出了Spectformer模型。Spectformer显示出比先前模型更加稳定的性能。除了在传统的视觉任务上可以获得SOTA性能之外(在ImageNet-1K数据集上实现了85.7%的Top-1识别准确率),作者还认为,将Spectformer应用到一些频域信息更加丰富的领域上(例如遥感和医学图像数据),可能会激发出混合频域层和注意力层更大的潜力。
参考文献
- SpectFormer: Frequency and Attention is what you need in a Vision Transformer
- badripatro/SpectFormers
- Transformer仅有自注意力还不够?微软联合巴斯大学提出频域混合注意力SpectFormer
此文章为搬运
原项目链接