SWITCH TRANSFORMER:Transformer类的万亿级别模型
2021年1月,谷歌大脑团队发布了一篇文章“SWITCH TRANSFORMERS: SCALING TO TRILLION PARAMETER MODELS WITH SIMPLE AND EFFICIENT SPARSITY”,文章提出了号称拥有万亿级别的Transformer类模型,命名为Switch Transformer。以下为本文的主要逻辑结构:
- 简述预训练语言模型的发展脉络
- 介绍文章的研究成果介绍Switch Transformer的模型结构
- 补充说明文章所用Tricks
- 其他
预训练语言模型
所谓预训练,即首先在大量的数据上进行训练,对于得到的模型进行微调,应用于下游任务上。第一代预训练语言模型包括经典的Word2Vec、Glove、Fasttext,第二代预训练语言模型囊括了许多如今常见且新颖的模型,包括
- ELMO(2018, BiLSTM)
- Transformer(2017)
- BERT(Transformer Encoder, base-110M, 2018)
- GPT(Transformer Decoder, 2018)
- GPT-2( 1.5B的参数量)、GPT-3(175B的参数量)
- T5(Text-to-Text Transfer Transformer, 11B的参数量, 使用了 C4 大规模数据集)
尤其在2018年Bert被提出后,NLP的各领域从原先的单打独斗,逐渐产生了Bert通吃的局面,Bert类模型的研究热度也愈发高涨。下图截取了GLUE榜单上,预训练语言模型的排名情况,时间截止于2021年4月7日。
研究成果
Baseline中,谷歌团队选取了自家的T5模型作为对比模型,其中FLOPS表示每秒浮点运算次数,是机器算力的量化表现。
下游任务上,团队选取了包括GLUE、SuperGLUE、翻译、文本摘要在内的多个任务的数据集,Fine-tuning表现如下:
此外,实验中Switch模型的参数被大幅增加,达到了1.6万亿:
其中,Switch-C就是万亿级别的模型,衡量指标为负log困惑度,困惑度是衡量语言模型好坏的性能指标,困惑度数值越小表示模型对语言文本越不困惑。由结果可知,Switch-C的表现不如Switch-XXL,证明并不是参数量越大,模型的性能就一定越好,更多的是一种权衡。
模型结构
首先分析Switch Transformer的模型结构,与标准的Transformer Encoder类似,输入token依次经过self-Attention、Layer Normalization、FNN、Layer Normalization。
但是,不同于标准Transformer的一个FNN,Switch使用了一组FNN(最多使用了2048个,即Switch-C),下方输入的token首先经过一个名叫Router路由的可学习的权重矩阵,Router得到每个token的概率值,概率最大的那一个(对应于Router中的直方图),被映射到第几个FNN。图中,Router直方图的第二列概率值最大,因此下层的输入被路由到第二个FNN中。
此过程的数学形式表达如下图所示:
相应伪代码如下:
Tricks
想要训练如此大的模型,绝非易事,此处截取文章提到的并行训练策略:
此过程比较繁琐,在此不做赘述,也欢迎大家私信。
代码
对于最重要的代码部分,谷歌发布的官方代码使用了Mesh Tensorfolw框架,作为一款面向分布式训练模型的框架,并不适合个人使用,在此推荐大家Github上Pytorch版本的实现,可在单GPU下运行。
结语
Switch Transformer作为当前最大的预训练语言模型,选取Transformer 的Encoder部分进行修改,引入了多个FNN。正因如此,大大扩展了参数量,但计算量并未因此增加,因为最终只会路由到一个FNN上,这种思想值得学习借鉴。