为了处理二维图像,我们将尺寸为 H×W×C的图像reshape为拉平的2维图块,尺寸为 (N×(P^2×C))。其中, (P,P)为图块的大小, N=HW/P^2 。 N 是图块的数量,会影响输入序列的长度。Transformer在所有图层上使用恒定的隐矢量D,因此我们将图块拉平,并使用可训练的线性投影映射到D的大小,将此投影的输出称为patch embedding。对应代码如下:直接暴力拉伸
# Transformer.
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
类似BERT的[class] token,我们在可嵌入的补丁序列(z_0^0=x_class )之前准备了可学习的embedding向量,该序列在Transformer编码器的输出(z_L^0 )的状态用作图像表示y。 在预训练和微调期间,都将分类head连接到 z_L^0。
# If we want to add a class token, add it here.
if self.classifier == 'token':
cls = self.param('cls', nn.initializers.zeros, (1, 1, c))
cls = jnp.tile(cls, [n, 1, 1])
x = jnp.concatenate([cls, x], axis=1)
分类head是通过在预训练时具有一个隐藏层的MLP以及在微调时通过一个线性层的MLP来实现的。
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int
dtype: Dtype = jnp.float32
out_dim: Optional[int] = None
dropout_rate: float = 0.1
kernel_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.xavier_uniform()
bias_init: Callable[[PRNGKey, Shape, Dtype],
Array] = nn.initializers.normal(stddev=1e-6)
@nn.compact
def __call__(self, inputs, *, deterministic):
"""Applies Transformer MlpBlock module."""
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
features=self.mlp_dim,
dtype=self.dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init)( # pytype: disable=wrong-arg-types
inputs)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
output = nn.Dense(
features=actual_out_dim,
dtype=self.dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init)( # pytype: disable=wrong-arg-types
x)
output = nn.Dropout(
rate=self.dropout_rate)(
output, deterministic=deterministic)
return
位置embedding会添加到patch embedding中,以保留位置信息。我们使用标准的可学习1D位置embedding,因为我们没有观察到使用更高级的2D感知位置embedding可显着提高性能。embedding向量的结果序列用作编码器的输入。
class AddPositionEmbs(nn.Module):
"""Adds (optionally learned) positional embeddings to the inputs.
Attributes:
posemb_init: positional embedding initializer.
"""
posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
@nn.compact
def __call__(self, inputs):
"""Applies AddPositionEmbs module.
By default this layer uses a fixed sinusoidal embedding table. If a
learned position embedding is desired, pass an initializer to
posemb_init.
Args:
inputs: Inputs to the layer.
Returns:
Output tensor with shape `(bs, timesteps, in_dim)`.
"""
# inputs.shape is (batch_size, seq_len, emb_dim).
assert inputs.ndim == 3, ('Number of dimensions should be 3,'
' but it is: %d' % inputs.ndim)
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape)
return inputs + pe