关于transformer-xl中rel-shift实现的解读

8 篇文章 5 订阅


 

方法

抽象地看,我们要做的事情就是,给定一个矩阵,每行都进行左移,而移动的个数随行数递增而递减。

我目前想到的一种方法是使用gather,将想要的index提前定好,然后使用Pytorch的gather就能够实现。

而transformer-xl实现了另一种更好的方法:_rel_shift

def _rel_shift(self, x, zero_triu=False):
    # x: q,k,bs,n_head
    zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                           device=x.device, dtype=x.dtype)
    x_padded = torch.cat([zero_pad, x], dim=1)

    x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

    x = x_padded[1:].view_as(x)

    return x

第一步是,将x的第一列填上padding,此时x.size()=q,k+1,bs,n_head,接下来将其重新reshape,则变成了x.size()=k+1,q,bs,n_head,最后将第一行去掉,变成x.size()=k,q,bs,n_head,再将其reshape回x原来的样子。

为什么这么做实现了我们想要的左移的功能?我们应该从一维的角度去理解。因为实际上在内存中所有元素都是按照一维去排列的。

原来的矩阵:

实际上就是有q个key按照一行去排列。

在做完padding之后,则:

实际上就是在每个key前面插入了0。

接下来view,实际上数据的先后顺序还是没有变(因为不是transpose):

实际上只是强行将该行切成一个一个q而已。

那么最后一个操作,将第一行丢掉,实际上就是要把原来的x的第一行强行左移q-1个(因为有padding)。那么为什么后面的行能够左移的个数依次减少?别忘了padding,第一行左移了q-1个,但第二个key前面也有一个padding,所以相当于将其向右推了一格;第三个又有一个padding,就在原来的基础上又推了一格,也即推了两格。因此最后达到了我们想要的目的。

实际上要理解该方法,需要牢牢把握数据存储的本质是一整行。

该方法没有数据的拷贝,全部都是view操作,因此更高效。

不得不佩服想到该方法的人的工程能力,同时也感谢戴宁带我理解该方法的本质,一开始我是死活不理解的。以后或许可以将该思想灵活应用到其他方面。

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值