【源码解读】Transformer的MultiHeadAttention部分代码解读

1 说明

首先,先给出Transformer的MultiHeadAttention部分的pytorch版本的代码,然后再对于此部分的细节进行解析

2 源码

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0#剖析点1
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        # 纬度
        # shape:query=key=value--->:[batch_size,max_legnth,embedding_dim=512]
        
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)#剖析点2
        nbatches = query.size(0)
        
        #第一步:将q,k,v分别与Wq,Wk,Wv矩阵进行相乘
        #shape:Wq=Wk=Wv----->[512,512]
        #第二步:将获得的Q、K、V在第三个纬度上进行切分
        #shape:[batch_size,max_length,8,64]
        #第三部:填充到第一个纬度
        #shape:[batch_size,8,max_length,64]
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]#剖析点3
        
        #进入到attention之后纬度不变,shape:[batch_size,8,max_length,64]
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)
        
        # 将纬度进行还原
        # 交换纬度:[batch_size,max_length,8,64]
        # 纬度还原:[batch_size,max_length,512]
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)#剖析点4
        
        # 最后与WO大矩阵相乘 shape:[512,512]
        return self.linears[-1](x)

3 源码剖析

3.1 剖析点1:assert d_model % h == 0

assert断言机制
Python assert(断言)用于判断一个表达式,在表达式条件为 false 的时候触发异常。
语法:

assert expression

等价于(这种方式比较好理解)

if not expression:
    raise AssertionError(arguments)

assert 后面也可以紧跟参数:

assert expression [, arguments]
等价于
if not expression:
    raise AssertionError(arguments)

eg:

assert True#没有任何输出 程序继续向下执行
assert False
#输出
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-13-a871fdc9ebee> in <module>
----> 1 assert False

AssertionError: 
import sys
assert ('linux' in sys.platform), "该代码只能在 Linux 下执行"
#用于验证代码所在平台的系统是否是linux

3.2 剖析点2:mask = mask.unsqueeze(1)

主要是将mask进行一个升维的操作,1表示在第1个维度上升维(从0开始)

b=torch.rand(2,5)
b
#输出
tensor([[0.6956, 0.4611, 0.2149, 0.2581, 0.6836],
        [0.6159, 0.4464, 0.2467, 0.2504, 0.8744]])
b.shape,b.unsqueeze(0),b.unsqueeze(0).shape#在第0个维度进行升维
#输出
(torch.Size([2, 5]), tensor([[[0.6956, 0.4611, 0.2149, 0.2581, 0.6836],
          [0.6159, 0.4464, 0.2467, 0.2504, 0.8744]]]), torch.Size([1, 2, 5]))
b.shape,b.unsqueeze(1),b.unsqueeze(1).shape#在第1个维度进行升维
#输出
(torch.Size([2, 5]), tensor([[[0.6956, 0.4611, 0.2149, 0.2581, 0.6836]], 
         [[0.6159, 0.4464, 0.2467, 0.2504, 0.8744]]]), torch.Size([2, 1, 5]))

3.3 剖析点3:for l, x in zip(self.linears, (query, key, value))

作用:依次取出self.linears[0]和query,self.linears[1]和key,self.linears[2]和value 取名l和x,分别对这三对执行l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)操作
等价于

l,x=self.linears[0],query
l,x=self.linears[1],key
l,x=self.linears[2],value
对每对l,x执行:l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)

举个例子:

a=[1,2,3,4]
b=torch.zeros(5,2,2)
c=torch.ones(5,2,2)
d=torch.rand(5,2,2)
print(b,b.shape)
print(c,c.shape)
print(d,d.shape)
print("============================================")
for x,y in zip(a,(b,c,d)):
    print(x,y)
    print("shape:",y.shape)
    print("=======")

输出

tensor([[[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]]]) torch.Size([5, 2, 2])
tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]]) torch.Size([5, 2, 2])
tensor([[[0.0764, 0.8718],
         [0.3432, 0.0081]],

        [[0.8416, 0.9806],
         [0.0932, 0.2501]],

        [[0.7480, 0.3873],
         [0.8147, 0.6484]],

        [[0.6723, 0.1186],
         [0.4056, 0.6158]],

        [[0.6319, 0.5724],
         [0.7458, 0.6811]]]) torch.Size([5, 2, 2])
============================================
1 tensor([[[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]]])
shape: torch.Size([5, 2, 2])
=======
2 tensor([[[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.]]])
shape: torch.Size([5, 2, 2])
=======
3 tensor([[[0.0764, 0.8718],
         [0.3432, 0.0081]],

        [[0.8416, 0.9806],
         [0.0932, 0.2501]],

        [[0.7480, 0.3873],
         [0.8147, 0.6484]],

        [[0.6723, 0.1186],
         [0.4056, 0.6158]],

        [[0.6319, 0.5724],
         [0.7458, 0.6811]]])
shape: torch.Size([5, 2, 2])
=======

3.4 剖析点4:x.transpose(1, 2).contiguous()

参考:

  • https://zhuanlan.zhihu.com/p/64551412
  • https://blog.csdn.net/gdymind/article/details/82662502
    在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的。换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据。
    这些操作是:
narrow(),view(),expand()和transpose()

举个栗子,在使用transpose()进行转置操作时,pytorch并不会创建新的、转置后的tensor,而是修改了tensor中的一些属性(也就是元数据),使得此时的offset和stride是与转置tensor相对应的。转置的tensor和原tensor的内存是共享的!

x = torch.randn(3, 2)
print(x,x.shape)
y = x.transpose(0, 1)
y,y.shape
#输出
tensor([[-1.9441,  1.5522],
        [ 0.5396, -1.1500],
        [ 1.3438, -2.3227]]) torch.Size([3, 2])

(tensor([[-1.9441,  0.5396,  1.3438],
         [ 1.5522, -1.1500, -2.3227]]), torch.Size([2, 3]))

为了验证x,y是否是共享内存空间,在此我们尝试修改x矩阵的第一个元素,我们发现x修改之后,y中的数据也跟着发生改变

x[0, 0] = 111
print(y)
print(x)
#输出
tensor([[111.0000,   0.5396,   1.3438],
        [  1.5522,  -1.1500,  -2.3227]])
tensor([[111.0000,   1.5522],
        [  0.5396,  -1.1500],
        [  1.3438,  -2.3227]])

当调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一模一样

x = torch.randn(3, 2)
print(x,x.shape)
y = x.transpose(0, 1)
y=y.contiguous()
print(y,y.shape)
x[0, 0] = 111
print(y)
print(x)
#输出
tensor([[ 2.2892, -0.0997],
        [-0.0294,  0.1934],
        [ 0.7963, -0.3681]]) torch.Size([3, 2])
tensor([[ 2.2892, -0.0294,  0.7963],
        [-0.0997,  0.1934, -0.3681]]) torch.Size([2, 3])
tensor([[ 2.2892, -0.0294,  0.7963],
        [-0.0997,  0.1934, -0.3681]])
tensor([[ 1.1100e+02, -9.9713e-02],
        [-2.9446e-02,  1.9339e-01],
        [ 7.9626e-01, -3.6809e-01]])
  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
嗨!对于Transformer源码解读,我可以给你一些基本的指导。请注意,我不能提供完整的源代码解读,但我可以帮助你理解一些关键概念和模块。 Transformer是一个用于自然语言处理任务的模型,其中最著名的应用是在机器翻译中。如果你想要深入了解Transformer的实现细节,我建议你参考谷歌的Transformer源码,它是用TensorFlow实现的。 在Transformer中,有几个关键的模块需要理解。首先是"self-attention"机制,它允许模型在处理序列中的每个位置时,同时关注其他位置的上下文信息。这个机制在Transformer中被广泛使用,并且被认为是其性能优越的主要原因之一。 另一个重要的模块是"Transformer Encoder"和"Transformer Decoder"。Encoder负责将输入序列转换为隐藏表示,而Decoder则使用这些隐藏表示生成输出序列。Encoder和Decoder都由多个堆叠的层组成,每个层都包含多头自注意力机制和前馈神经网络。 除了这些核心模块外,Transformer还使用了一些辅助模块,如位置编码和残差连接。位置编码用于为输入序列中的每个位置提供位置信息,以便模型能够感知到序列的顺序。残差连接使得模型能够更好地传递梯度,并且有助于避免梯度消失或爆炸的问题。 了解Transformer源码需要一定的数学和深度学习背景知识。如果你对此不太了解,我建议你先学习相关的基础知识,如自注意力机制、多头注意力机制和残差连接等。这样你就能更好地理解Transformer源码中的具体实现细节。 希望这些信息对你有所帮助!如果你有任何进一步的问题,我会尽力回答。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值