As pointed out in the origin paper of Transformer, Attention is all your need.
We just need to understand attention, Transformer its implication is just some stacks of code.
class Attention(nn.Cell):
def __init__(self,
dim:int,
num_heads:int = 8,
keep_prob:float = 1.0,
attention_keep_prob: float = 1.0):
super(Attention, self).__init__*(
self.num_heads = num_heads
head_dim = dim//num_heads
self.scale = ms.Tensor(head_dim** -0.5)
self.qkv = nn.Dense(dim, dim*3)
self.attn_drop = nn.Dropout(p = 1.0 - attention_keep_prob)
self.out = nn.Dense(dim, dim)
self.out_drop = nn.Dropout(p = 1.0-keep_prob)
self.attn_matmul_v = ops.BatchMatMul()
self.q_matmul_k = ops.BatchMatMul(transpose_b = True)
self.softmax = nn.Softmax(axis = -1)
def construct(self, x):
b, n, c = x.shape
qkv = self.qkv(x)
qkv = ops.reshape(qkv,(b,n,3,self.num_heads,c//self.num_heads))
qkv = ops.transpose(qkv, (2,0,3,1,4))
q, k, v = ops.unstack(qkv, axis = 0)
attn = self.q_matmul_k(q, k)
attn = ops.mul(attn, self.scale)
out = self.attn_matmil_v(attn, v)
out = ops.transpose(out, (0,2,1,3))
out = ops.reshape(out, (b,n,c))
out = self.out(out)
out = self.out_drop(out)
return out
Attention, if you wanna have better understanding of Transformer, never see this blog.
This, as an author blank to deep learning, is what a candidate understands the concepts.
You defined some heads.
you rescale them, you dropout, you get out.
Bahdanau attention is something to treat sequence.
q, k, v are those query, key, values.
each input are multiplied by Wq Wk Wv respectively. Then ,since multiple head, We alse have Wo as a param to solve the problem of how those heads are treated.
This is what they learn. This is why it is enough to train big model,
Since each attention head can be calculated in parallel.
Sadly not graph attached here, but it is useless to grab the things they do.
However, I attach a graph of the architecture of Transformer here.
What does it excel?
Residual block connecting, layer norm, positional encoding to get its position as a vector.
Nx means repetition, so transformer can be deep. Encoder and Decoder interacts with some place. Interesting, but hard to know why there.
All in all,, you goes anywhere around the Multi-Head attention block!
class TransformerEncoder(nn.Cell):
def __init__(self,
dim:int,
num_layers:int,
num_heads:int,
mlp_dim:int,
keep_prob:float= 1.
attention_keep_prob:float = 1.0,
drop_path_keep_prob:float = 1.0,
activation: nn.Cell = nn.GELU,
norm: nn.Cell = nn.LayerNorm):
super(TransformerEncoder, self).__init__()
layers = []
for _ in range(num_layers):
normalization1 = norm((dim,))
normalization2 = norm((dim,))
attention = Attention(dim = dim,
num_heads = num_heads,
keep_prob = keep_prob,
attention_keep_prob = attention_keep_prob)
feedforwaed = FeedForward(in_features=dim,
hidden_features=mlp_dim,
activation = activation,
keep_prob = keep_prob)
layers.append(
nn.SequentialCell([
ResidualCell(nn.SequentialCell([normalization1, attention])),
ResidualCell(nn.SequentialCell([normalization2, feedforward]))
])
)
self.layers = nn.SequentialCell(layers)
def construct(self,x):
return self.layers(x)
Remenber that it is connected residually, RESNET is efficient.
For next, it is input as patching 14*14 size, 16*16 in total of a single picture,
Then pos embedding, then cls token, then pos_dropout(trick), then into encoder and normalized, finally into MLP(Dense).
For model training, we define super parameter,
we define model, we define learning rate,
we define optimizer and crossentropyloss ,
we set checkpoints, train.
MORE ABOUT DECODER
multihead -- addnorm1--multihead---addnorm2 ---ffn(positionwiseffn)--addnorm3
for whole building:
n x modules + dense
I will copy mli's TransformerDecoderBlock as final:
def forward(self, X,state):
enc_outputs, enc_valid_lens = state[0], state[1]
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i],X),dim = 1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
dec_valid_lens = torch.arange(
1,num_steps = 1, device = X.device).repeat(batch_size,1)
else :
dec_valid_lens = None
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X,X2)
Y2 = self.attention2(Y,enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y,Y2)
return self.addnorm3(Z,self.ffn(Z)), state