Single Headed Attention RNN: Stop ThinkingWith Your Head 论文笔记

1 研究动机

选择这篇论文来读,有一点奇文共欣赏的意思。 区别于现在主流的框架比拼算力,本文重新思考是不是可以通过lstm 和 单头的attention就可以在现在的数据集上完成大型框架类似的指标。

作者在文章里花了很大的篇幅去讨论,如何会去思考来构建sha-rnn这个模型。他类比了计算机的发展史和摩尔定律,讨论了语言模型和tokern。 作者认为减少缓存,让语言模型的实现可以跑在较低的资源上,不失为一个值得去研究的技术方向,就像计算机发展当年的故事,如果所有的研究都投入在集群和大型机,怎么会有二十世纪末期微机的大行其道。作者认为,即使是transformer已经是主流,也可以继续尝试用lstm + attention,通过精心的设计,仔细的调差,一样可以用显存消耗较小的模型达到较好的效果。

2 研究内容和方法

sha-rnn的设计架构,如下图所示,仔细看其实并没有特别出彩的地方。撇除那些各条路线上的FusedLayerNorm (LN)层,其实架构和transformer是非常接近的。 沿用传统的lstm 而不是算力消耗或者说参数量更大的self attention层。 attention的k,q,v其实均来自lstm的输出,然后依然是类似transformer的旁路设计(残差)。具体可以看源码关于这一块的核心设计。
sha-rnn

def forward(self, h, pe, attn_mask, mem=None, hidden=None):
        new_mem = None

        h = self.lnstart(h)

        if self.rnn:
            x, new_hidden = self.rnn(h, None if hidden is None else hidden)
            #x = self.rnn_down(self.drop(x))

            # Trim the end off if the size is different
            ninp = h.shape[-1]
            z = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)
            # Divide the hidden size evenly into chunks
            z = x.view(*x.shape[:-1], x.shape[-1] // ninp, ninp)
            # Collapse the chunks through summation
            #h = h + self.drop(x).sum(dim=-2)
            x = self.drop(z).sum(dim=-2)
            #x = x + z.sum(dim=-2)

            h = h + x if self.residual else x.float()

        focus, new_mem = None, []
        if self.attn is not None:
            mh = self.lnmem(h)
            h = self.lnmid(h)

            if mem is not None:
                bigh = torch.cat([mem, mh], dim=0)
            else:
                bigh = mh
            new_mem = bigh[-len(pe):]

            q, k = h, bigh

            x, focus = checkpoint(self.attn, q, k, bigh, attn_mask)
            #x, focus = tcheckpoint(self.attn, q, k, bigh, attn_mask)
            x = self.drop(x)
            h = x + h

        if self.ff:
            h, x = self.lnff(h), self.lnxff(h)
            x = checkpoint(self.ff, x)
            #x = tcheckpoint(self.ff, h)
            x = self.drop(x)
            h = x + h

        return h, new_mem, new_hidden, focus

sha-rnn关于attention的设计,最主要的着眼点还是减少矩阵乘法带来的消耗,从下图可以看出,整个过程其实只有一次的矩阵乘法
attention

3 实验

对于论文的实验 ,我们主要关注 ENWIK8这个数据集,源码中还包含wikitext-2,wikitext-103和PTB等数据集。下图展示 sha-rnn和其他模型的参数对比:
model

对于sha-rnn训练的实验结果和截图如下:
paramter
开始训练
其实训练的过程,也应用了很多基本的技巧,比如warmup,比如一开始训练(作者建议32个epoch,实际我因为意外大概训练了10个左右,其实bpc和loss基本已经变化很小),我decay一下lr,又先后训练了两个epoch和1个epoch,最后的结果如下:
results

4 创新点和个人点评

本文其实架构的创新不是特别大,但是思路其实有可取之处,特别是坚持保留主流之外其他架构设计的可能性,是非常值得我们研究者学习的一种精神。而且,作者的代码,有大量的工程和试验的部分,都是值得学习和借鉴的,比如boom层的设计中的切块。 最后,其实,文章还有很多的细节,我后续读参考文献及其代码,会补充或者单开文章来写,比如作者用的优化器LAMB,以及英伟达的混合精度和分布式训练的库APEX,当然作者提到的tokenization attack也待补充。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值