1 研究动机
选择这篇论文来读,有一点奇文共欣赏的意思。 区别于现在主流的框架比拼算力,本文重新思考是不是可以通过lstm 和 单头的attention就可以在现在的数据集上完成大型框架类似的指标。
作者在文章里花了很大的篇幅去讨论,如何会去思考来构建sha-rnn这个模型。他类比了计算机的发展史和摩尔定律,讨论了语言模型和tokern。 作者认为减少缓存,让语言模型的实现可以跑在较低的资源上,不失为一个值得去研究的技术方向,就像计算机发展当年的故事,如果所有的研究都投入在集群和大型机,怎么会有二十世纪末期微机的大行其道。作者认为,即使是transformer已经是主流,也可以继续尝试用lstm + attention,通过精心的设计,仔细的调差,一样可以用显存消耗较小的模型达到较好的效果。
2 研究内容和方法
sha-rnn的设计架构,如下图所示,仔细看其实并没有特别出彩的地方。撇除那些各条路线上的FusedLayerNorm (LN)层,其实架构和transformer是非常接近的。 沿用传统的lstm 而不是算力消耗或者说参数量更大的self attention层。 attention的k,q,v其实均来自lstm的输出,然后依然是类似transformer的旁路设计(残差)。具体可以看源码关于这一块的核心设计。
def forward(self, h, pe, attn_mask, mem=None, hidden=None):
new_mem = None
h = self.lnstart(h)
if self.rnn:
x, new_hidden = self.rnn(h, None if hidden is None else hidden)
#x = self.rnn_down(self.drop(x))
# Trim the end off if the size is different
ninp = h.shape[-1]
z = torch.narrow(x, -1, 0, x.shape[-1] // ninp * ninp)
# Divide the hidden size evenly into chunks
z = x.view(*x.shape[:-1], x.shape[-1] // ninp, ninp)
# Collapse the chunks through summation
#h = h + self.drop(x).sum(dim=-2)
x = self.drop(z).sum(dim=-2)
#x = x + z.sum(dim=-2)
h = h + x if self.residual else x.float()
focus, new_mem = None, []
if self.attn is not None:
mh = self.lnmem(h)
h = self.lnmid(h)
if mem is not None:
bigh = torch.cat([mem, mh], dim=0)
else:
bigh = mh
new_mem = bigh[-len(pe):]
q, k = h, bigh
x, focus = checkpoint(self.attn, q, k, bigh, attn_mask)
#x, focus = tcheckpoint(self.attn, q, k, bigh, attn_mask)
x = self.drop(x)
h = x + h
if self.ff:
h, x = self.lnff(h), self.lnxff(h)
x = checkpoint(self.ff, x)
#x = tcheckpoint(self.ff, h)
x = self.drop(x)
h = x + h
return h, new_mem, new_hidden, focus
sha-rnn关于attention的设计,最主要的着眼点还是减少矩阵乘法带来的消耗,从下图可以看出,整个过程其实只有一次的矩阵乘法
3 实验
对于论文的实验 ,我们主要关注 ENWIK8这个数据集,源码中还包含wikitext-2,wikitext-103和PTB等数据集。下图展示 sha-rnn和其他模型的参数对比:
对于sha-rnn训练的实验结果和截图如下:
其实训练的过程,也应用了很多基本的技巧,比如warmup,比如一开始训练(作者建议32个epoch,实际我因为意外大概训练了10个左右,其实bpc和loss基本已经变化很小),我decay一下lr,又先后训练了两个epoch和1个epoch,最后的结果如下:
4 创新点和个人点评
本文其实架构的创新不是特别大,但是思路其实有可取之处,特别是坚持保留主流之外其他架构设计的可能性,是非常值得我们研究者学习的一种精神。而且,作者的代码,有大量的工程和试验的部分,都是值得学习和借鉴的,比如boom层的设计中的切块。 最后,其实,文章还有很多的细节,我后续读参考文献及其代码,会补充或者单开文章来写,比如作者用的优化器LAMB,以及英伟达的混合精度和分布式训练的库APEX,当然作者提到的tokenization attack也待补充。