mage classification with EANet (External Attention Transformer
使用 EANet(外部注意力转换器)进行图像分类
介绍
此示例实现了用于图像分类的EANet 模型,并在 CIFAR-100 数据集上进行了演示。EANet 引入了一种新的注意力机制,称为外部注意力,基于两个外部的、小型的、可学习的和共享的内存,只需使用两个级联的线性层和两个归一化层即可轻松实现。它方便地取代了现有架构中使用的自我注意。外部注意力具有线性复杂性,因为它只隐含地考虑所有样本之间的相关性。此示例需要 TensorFlow 2.5 或更高版本,以及 TensorFlow Addons包,可以使用以下命令安装:
pip install -U tensorflow-addons
设置
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
准备数据
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 100) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 100)
配置超参数
weight_decay = 0.0001
learning_rate = 0.001
label_smoothing = 0.1
validation_split = 0.2
batch_size = 128
num_epochs = 50
patch_size = 2 # Size of the patches to be extracted from the input images.
num_patches = (input_shape[0] // patch_size) ** 2 # Number of patch
embedding_dim = 64 # Number of hidden units.
mlp_dim = 64
dim_coefficient = 4
num_heads = 4
attention_dropout = 0.2
projection_dropout = 0.2
num_transformer_blocks = 8 # Number of repetitions of the transformer layer
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
Patch size: 2 X 2 = 4 Patches per image: 256
使用数据增强
data_augmentation = keras.Sequential(
[
layers.Normalization(),
layers.RandomFlip("horizontal"),
layers.RandomRotation(factor=0.1),
layers.RandomContrast(factor=0.1),
layers.RandomZoom(height_factor=0.2, width_factor=0.2),
],
name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)
实现补丁提取和编码层
class PatchExtract(layers.Layer):
def __init__(self, patch_size, **kwargs):
super(PatchExtract, self).__init__(**kwargs)
self.