Transformer变体层出不穷,它们都长什么样?

©PaperWeekly 原创 · 作者|上杉翔二

单位|悠闲会

研究方向|信息检索


不知不觉 Transformer 已经逐步渗透到了各个领域,就其本身也产生了相当多的变体,如上图。本篇文章想大致按照这个图,选一些比较精彩的变体整理,话不多说直接开始。

Transformer-XL

论文标题:

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

收录会议:

ACL 2019

论文链接:

https://arxiv.org/abs/1901.02860

代码链接:

https://github.com/kimiyoung/transformer-xl

上图上标的是“Recurrence”,首先看看这篇文章聚焦的 2 个问题:

  • 虽然 Transformer 可以学习到输入文本的长距离依赖关系和全局特性,但是!需要事先设定输入长度,这导致了其对于长程关系的捕捉有了一定限制。

  • 出于效率的考虑,需要对输入的整个文档进行分割(固定的),那么每个序列的计算相互独立,所以只能够学习到同个序列内的语义联系,整体上看,这将会导致文档语意上下文的碎片化(context fragmentation)。

那么如何学习更长语义联系?

segment-level Recurrence

segment-level 循环机制。如上图左边为原始 Transformer,右边为 Transformer-XL,Transformer-XL 模型的计算当中加入绿色连线,使得当层的输入取决于本序列和上一个序列前一层的输出。这样每个序列计算后的隐状态会参与到下一个序列的计算当中,使得模型能够学习到跨序列的语义联系(看动图可能更好理解)。

是第 个 segment 的第 n 层隐向量,那么第 r+1 个的第 n 层的隐向量的计算,就是上面这套公式。

  • 其中 SG 是是 stop-gradient,不再对 的隐向量做反向传播(这样虽然在计算中运用了前一个序列的计算结果,但是在反向传播中并不对其进行梯度的更新,毕竟前一个梯度肯定不受影响)。

  • 是对两个隐向量序列沿长度 L 方向的拼接 。3 个 W 分别对应 query,key 和 value 的转化矩阵,需要注意的是!k 和 v 的 W 用的是 ,而 q 是用的 ,即 kv 是用的拼接之后的 h,而 q 用的是原始序列的信息。感觉可以理解为以原始序列查拼接序列,这样可以得到一些前一个序列的部分信息以实现跨语义。

  • 最后的公式是标准的 Transformer。

还有一点设计是,在评估预测模型的时候它是会连续计算前 L 个长度的隐向量的(训练的时候只有前一个,缓存在内存中)。

即每一个位置的隐向量,除了自己的位置,都跟下一层中前(L-1)个位置的 token 存在依赖关系,而且每往下走一层,依赖关系长度会增加(L-1),这样能使跨语义更加的深入。


只看看 XL 多头注意力的 forward 的不同地方吧。

def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
             #w是上一层的输出,r是相对位置嵌入(在下一节),r_w_bias是u,r_r_bias是v向量
            qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

            if mems is not None: #mems就是前一些序列的向量,不为空
                cat = torch.cat([mems, w], 0) #就拼起来
                if self.pre_lnorm: #如果有正则化
                    w_heads = self.qkv_net(self.layer_norm(cat)) #这个net是nn.Linear,即qkv的变换矩阵W参数
                else:
                    w_heads = self.qkv_net(cat)#没有正则就直接投影一下
                r_head_k = self.r_net(r)#也是nn.Linear

                w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) #复制3份
                w_head_q = w_head_q[-qlen:] #q的W不要拼接的mems
            else:#没有mems,就正常的计算
                if self.pre_lnorm:
                    w_heads = self.qkv_net(self.layer_norm(w))
                else:
                    w_heads = self.qkv_net(w)
                r_head_k = self.r_net(r)

                w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

            klen = w_head_k.size(0)
            #qlen是序列长度,bsz是batch size,n_head是注意力头数,d_head是每个头的隐层维度
            w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
            w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
            w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
回答: Transformer有多种变体,其中包括Universal transformers(UT)和压缩(Compressive) Transformer。Universal transformers是对传统transformer结构的改进,使其更加丰富多彩。\[1\]压缩TransformerTransformer-XL模型的延伸,其关键思想是保持对过去段激活的细粒度内存,与Transformer-XL不同,后者在跨段移动时会丢弃过去的激活。\[2\]此外,由于注意力机制忽视了位置信息,所以在Transformer中必须加入位置编码。原始Transformer采用了正弦/余弦函数来编码绝对位置信息,而Transformer-XL采用了相对位置编码来解决不同序列间同一个位置得到相同编码的问题。\[3\] #### 引用[.reference_title] - *1* [transformer变体](https://blog.csdn.net/u013596454/article/details/120530025)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [【CS224n】(lecture9)Transformer变体](https://blog.csdn.net/qq_35812205/article/details/122152418)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值