Transformer代码详解

本教程适用于对Transformer理论有一定理解的朋友。理论部分请看其他教程,本文详解代码。

Embedding

Embedding很好理解,vocab表示词表大小,d_model表示embedding大小。至于返回值为什么乘上sqrt(self.d_model) 目前还不是很理解。

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

位置编码

因为子注意力机制在算注意力权重的时候,并没有考虑到词语前后关系,而是考虑了整体的上下文,因此需要加入位置编码。主要的数学公式如下所示:
在这里插入图片描述
pos可以理解成每个字符在一句话的位置,i可以理解为在embedding向量的位置。这里假设betch_size设置为1,那么一句话的矩阵表达就是:[seq_len, d_model]。在我们一开始得到了embedding后的矩阵,需要再加上PositionalEncoding矩阵,它的维度也是[seq_len, d_model],下面让我们看看如何得到这个矩阵。

# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.pe = torch.zeros(max_len, d_model)
        # [man_len, 1]
        position = torch.arange(0, max_len).unsqueeze(1)
        # (d_model/2, )
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        # pe: [max_len, d_model]  矩阵的扩充运算
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)
        # pe : [1, max_len, d_model] 第一个维度是batch_size
        self.pe = self.pe.unsqueeze(0)

    def forward(self, x):
        # 输入的x维度: [batch_size, seq_len, d_model]
        # 因为输入的句子的长度会比设定的max_len小,因为需要切片操作
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

画出图形看一下形状:

import matplotlib.pyplot as plt
plt.figure(figsize=(15,5))
model = PositionalEncoding(20, 0)
x = torch.zeros(1, 100, 20)
plt.plot(range(100), model(x)[0,:, 4:8])
plt.show()

可以看到周期性的变化,这让每个位置的值都有了自己的position
在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值