The shape of tensors in Transformer for translation

My input of both encoder and decoder are the input embeddings. I skipped the embedding layers. 

Encoder:
Input: (batch_size, seq_len_enc, embed_dim)
Positional Encoding: (batch_size, max_len, embed_dim)
Input+Positional Encoding: (batch_size, seq_len_enc, embed_dim)

Q input: (batch_size, seq_len_enc, embed_dim)
K input: (batch_size, seq_len_enc, embed_dim)
V input: (batch_size, seq_len_enc, embed_dim)

Q output: (batch_size, seq_len_enc, embed_dim)
K output: (batch_size, seq_len_enc, embed_dim)
V output: (batch_size, seq_len_enc, embed_dim)

Q output after reshape: (batch_size, seq_len_enc, head_num, head_dim)
K output after reshape: (batch_size, seq_len_enc, head_num, head_dim)
V output after reshape: (batch_size, seq_len_enc, head_num, head_dim)

Q output after reshape and transpose: (batch_size, head_num, seq_len_enc, head_dim)
K output after reshape and transpose: (batch_size, head_num, seq_len_enc, head_dim)
V output after reshape and transpose: (batch_size, head_num, seq_len_enc, head_dim)

Q*K transpose=(batch_size, head_num, seq_len_enc, head_dim)*(batch_size, head_num, head_dim, seq_len_enc)=
(batch_size, head_num, seq_len_ec, seq_len_enc)
score=Q*K transpose*V=(batch_size, head_num, seq_len_enc, seq_len_enc)*(batch_size, head_num, seq_len_enc, head_dim)=(batch_size, head_num, seq_len_enc, head_dim)

score transpose=(batch_size, seq_len_enc, head_num, head_dim)
score concat=(batch_size, seq_len_enc, embed_dim)
Output of multi-head attention=score after linear=(batch_size, seq_len_enc, embed_dim)

Add1=(Input+Positional Encoding)+Output of multi-head attention=(batch_size, seq_len_enc, embed_dim)+(batch_size, seq_len_enc, embed_dim)=(batch_size, seq_len_enc, embed_dim)
Norm(Add1)=(batch_size, seq_len_enc, embed_dim)

Feed forward input=Norm(Add1)=(batch_size, seq_len_enc, embed_dim)
Feed forward output1=(batch_size, seq_len_enc, feed forward_dim)
Feed forward output2=(batch_size, seq_len_enc, embed_dim)

Add2=Norm(Add1)+Feed forward output2=(batch_size, seq_len_enc, embed_dim)+(batch_size, seq_len_enc, embed_dim)=(batch_size, seq_len_enc, embed_dim)
Output of encoder=Norm(Add2)=(batch_size, seq_len_enc, embed_dim)

Decoder: 
Input: (batch_size, seq_len_dec, embed_dim)
Positional Encoding: (batch_size, max_len, embed_dim)
Input+Positional Encoding: (batch_size, seq_len_dec, embed_dim)

Q input: (batch_size, seq_len_dec, embed_dim)
K input: (batch_size, seq_len_dec, embed_dim)
V input: (batch_size, seq_len_dec, embed_dim)

Q output: (batch_size, seq_len_dec, embed_dim)
K output: (batch_size, seq_len_dec, embed_dim)
V output: (batch_size, seq_len_dec, embed_dim)

Q output after reshape: (batch_size, seq_len_dec, head_num, head_dim)
K output after reshape: (batch_size, seq_len_dec, head_num, head_dim)
V output after reshape: (batch_size, seq_len_dec, head_num, head_dim)

Q output after reshape and transpose: (batch_size, head_num, seq_len_dec, head_dim)
K output after reshape and transpose: (batch_size, head_num, seq_len_dec, head_dim)
V output after reshape and transpose: (batch_size, head_num, seq_len_dec, head_dim)

Q*K transpose=(batch_size, head_num, seq_len_dec, head_dim)*(batch_size, head_num, head_dim, seq_len_dec)=
(batch_size, head_num, seq_len_dec, seq_len_dec)
score=Q*K transpose*V=(batch_size, head_num, seq_len_dec, seq_len_dec)*(batch_size, head_num, seq_len_dec, head_dim)=(batch_size, head_num, seq_len_dec, head_dim)

score transpose=(batch_size, seq_len_dec, head_num, head_dim)
score concat=(batch_size, seq_len_dec, embed_dim)
Output of multi-head attention=score after linear=(batch_size, seq_len_dec, embed_dim)

Add1=(Input+Positional Encoding)+Output of multi-head attention=(batch_size, seq_len_dec, embed_dim)+(batch_size, seq_len_dec, embed_dim)=(batch_size, seq_len_dec, embed_dim)
Norm(Add1)=(batch_size, seq_len_dec, embed_dim)

Q input: (batch_size, seq_len_dec, embed_dim)
K input: (batch_size, seq_len_enc, embed_dim)
V input: (batch_size, seq_len_enc, embed_dim)

Q output: (batch_size, seq_len_dec, embed_dim)
K output: (batch_size, seq_len_enc, embed_dim)
V output: (batch_size, seq_len_enc, embed_dim)

Q output after reshape: (batch_size, seq_len_dec, head_num, head_dim)
K output after reshape: (batch_size, seq_len_enc, head_num, head_dim)
V output after reshape: (batch_size, seq_len_enc, head_num, head_dim)

Q output after reshape and transpose: (batch_size, head_num, seq_len_dec, head_dim)
K output after reshape and transpose: (batch_size, head_num, seq_len_enc, head_dim)
V output after reshape and transpose: (batch_size, head_num, seq_len_enc, head_dim)

Q*K transpose=(batch_size, head_num, seq_len_dec, head_dim)*(batch_size, head_num, head_dim, seq_len_enc)=(batch_size, head_num, seq_len_dec, seq_len_enc)
Q*K transpose*V=(batch_size, head_num, seq_len_dec, seq_len_enc)*(batch_size, head_num, seq_len_enc, head_dim)=(batch_size, head_num, seq_len_dec, head_dim)

score transpose=(batch_size, seq_len_dec, head_num, head_dim)
score concat=(batch_size, seq_len_dec, embed_dim)
Output of cross attention=score after linear=(batch_size, seq_len_dec, embed_dim)

Add2=Norm(Add1)+Output of cross attention=(batch_size, seq_len_dec, embed_dim)+(batch_size, seq_len_dec, embed_dim)=(batch_size, seq_len_dec, embed_dim)
Norm(Add2)=(batch_size, seq_len_dec, embed_dim)

Feed forward input=(batch_size, seq_len_dec, embed_dim)
Feed forward output1=(batch_size, seq_len_dec, feed forward_dim)
Feed forward output2=(batch_size, seq_len_dec, embed_dim)

Add3=Feed forward input+feed forward output2=(batch_size, seq_len_dec, embed_dim)
Output of decoder=Norm(Add3)=(batch_size, seq_len_dec, embed_dim)

Others: 
Input of feed forward=(batch_size, seq_len_dec, embed_dim)
Output of feed forward=(batch_size, seq_len_dec, trg_vocab_size)
Output of Softmax=(batch_size, seq_len_dec, trg_vocab_size)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值