Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context(2019-1-9)

模型介绍

Transformer最大的问题在于没有办法建模超过最大长度的序列,Transformer-XL主要提出了两个优化点:段级递归和相对位置编码。

段级递归

为了解决固定长度的限制,Transformer-XL提出了一种递归机制,如下图,第一个segment计算完成后,把计算的结果保存下来,在计算第二个片段的时候,把第一个片段的hidden state和第二个片段的hidden state拼接在一起,再进行后续的计算。

在这里插入图片描述
我们看下具体的计算公式,其中h表示的是hidden state, τ \tau τ 表示第 τ \tau τ 个segment,SG函数表示的是不更新梯度,[]表示的是向量的拼接。

在这里插入图片描述
第一个公式的意思即:第 τ + 1 \tau+1 τ+1个segment第n-1层的hidden state 等于第 τ \tau τ 个segment第n - 1层的hidden state拼接上第 τ + 1 \tau +1 τ+1 个segment第n - 1层的hidden state,后续两个公式和vanilla版本类似,但要注意,q是未拼接的hidden state,k、v是拼接过后的,因为q表示的是当前的segment,所以不需要拼接。

可以看到,对于第一个segment来说,hidden state是没有额外需要拼接的值的,从第二个segment开始才需要拼接,在论文中,每次都是和上一个segment进行拼接,理论上来说每次可以拼接多个segment,第n个segment可以和前n-1个segment进行拼接,不过这个就取决于你自己的显存了,并且一个segment通常来说不会像上图中的这么短(一个segment可能长度就512了),文本自身的上下文依赖一般也不会超过一个segment的长度。

实现代码

    def init_mems(self, bsz):
        if self.mem_len > 0:
            mems = []
            for i in range(self.n_layer):
                empty = tf.zeros([self.mem_len, bsz, self.d_model])
                mems.append(empty)

            return mems
        else:
            return None

    def _update_mems(self, hids, mems, mlen, qlen):
        # does not deal with None
        if mems is None:
            return None

        # mems is not None
        assert len(hids) == len(mems), "len(hids) != len(mems)"

        # There are `mlen + qlen` steps that can be cached into mems
        new_mems = []
        end_idx = mlen + tf.math.maximum(0, qlen)
        beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
        for i in range(len(hids)):
            mems[i] = tf.cast(mems[i], dtype=hids[i].dtype)
            cat = tf.concat([mems[i], hids[i]], axis=0)
            tf.stop_gradient(cat)
            new_mems.append(cat[beg_idx:end_idx])

        return new_mems

相对位置编码

Vanilla的位置编码是和embedding相加后输入到下一层的,Transformer-XL的位置编码没有在输入上做处理,而是对attention score进行了修改。

在这里插入图片描述
考虑一下,当query与key进行计算时,实际上并不需要知道key的绝对位置编码,模型实际上需要的是一个“时间线索”即字词的一个先后顺序,因此,知道query与key的相对位置即可。根据以上的思路,Transformer-XL做了三个方面的改进,分别如下:

在这里插入图片描述
在新的参数下,每一项都有了一个具体的含义,a表示的是query与key的内容相关性,b表示的是query的内容和key的位置的相关性,c表示的是query的位置与key的内容的相关性,d表示的是quey与key的位置的相关性。

总结一下,对于一个N层1个head的Transformer-XL,其完整步骤如下:

在这里插入图片描述

实现代码

class RelativeMultiHeadAttention(layers.Layer):
    def __init__(self, num_heads, embed_size):
        super(RelativeMultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.hidden_size = embed_size // num_heads

        self.qvk_net = layers.Dense(3 * embed_size)
        self.r_net = layers.Dense(embed_size)
        self.o_net = layers.Dense(embed_size)

        self.layer_norm = layers.LayerNormalization()

    def _rel_shift(self, x):
        x_size = tf.shape(x)
        # shape:(seq_len_q, seq_len_k, batch_size, num_heads)=>(seq_len_q, seq_len_k + 1, batch_size, num_heads)
        x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
        # shape:(seq_len_q, seq_len_k + 1, batch_size, num_heads)=>(seq_len_q + 1, seq_len_k, batch_size, num_heads)
        x = tf.reshape(x, (x_size[0] + 1, x_size[1], x_size[2], x_size[3]))
        # shape:(seq_len_q + 1, seq_len_k, batch_size, num_heads)=>(seq_len_q, seq_len_k, batch_size, num_heads)
        x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1])
        return x

    # w表示token embedding,r表示relative position embedding
    # r_w_bias表示uT,r_r_bias表示vT,形状和w的形状一致
    def __call__(self, w, r, r_w_bias, r_r_bias, mask=None, mems=None, *args, **kwargs):
        # w
        # shape:(seq_len, batch_size, embed_size)
        # r
        # shape:(seq_len, 1, embed_size)
        seq_len_q, batch_size, seq_len_r = tf.shape(w)[0], tf.shape(w)[1], tf.shape(r)[0]
        if mems is not None:
            cat = tf.concat([mems, w], axis=0)
            w_heads = self.qvk_net(cat)
            # 有mems时:
            # w_head_q
            # shape:(seq_len_q, batch_size, embed_size)
            # w_head_k, w_head_v
            # shape:(seq_len_k, batch_size, embed_size),其中seq_len_k = seq_len_q + seq_len_mems
            w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
            w_head_q = w_head_q[-seq_len_q:]
            r_head_k = self.r_net(r)
        else:
            w_heads = self.qvk_net(w)
            # 没有mems时:(seq_len_q = seq_len)
            # w_head_q, w_head_k, w_head_v
            # shape:(seq_len_q, batch_size, embed_size)
            w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
            r_head_k = self.r_net(r)
        seq_len_k = tf.shape(w_head_k)[0]
        # w_head_q
        # shape:(seq_len_q, batch_size, embed_size)=>(seq_len_q, batch_size, num_heads, hidden_size)
        # w_head_k, w_head_v
        # shape:(seq_len_k, batch_size, embed_size)=>(seq_len_k, batch_size, num_heads, hidden_size)
        # r_head_k
        # shape:(seq_len_r, 1, embed_size)=>(seq_len_r, num_heads, hidden_size)
        w_head_q = tf.reshape(w_head_q, (seq_len_q, batch_size, self.num_heads, self.hidden_size))
        w_head_k = tf.reshape(w_head_k, (seq_len_k, batch_size, self.num_heads, self.hidden_size))
        w_head_v = tf.reshape(w_head_v, (seq_len_k, batch_size, self.num_heads, self.hidden_size))
        r_head_k = tf.reshape(r_head_k, (seq_len_r, self.num_heads, self.hidden_size))
        # 计算A+C两项,(w_head_q + r_w_bias) * w_head_k = (qT + uT) * k
        # w_head_q
        # shape:(seq_len_q, batch_size, num_heads, hidden_size)
        # r_w_bias
        # shape:(seq_len_q, batch_size, num_heads, hidden_size)
        # w_head_k
        # shape:(seq_len_k, batch_size, num_heads, hidden_size)
        wr_head_q = w_head_q + r_w_bias
        # shape:(seq_len_q, seq_len_k, batch_size, num_heads)
        AC = tf.einsum("ibnh,jbnh->ijbn", wr_head_q, w_head_k)
        # 计算B+D两项,(w_head_q + r_r_bias) * r_head_k = (qT + vT) * r
        wr_head_r = w_head_q + r_r_bias
        # shape:(seq_len_q, seq_len_k, batch_size, num_heads)
        BD = tf.einsum("ibnh,jnh->ijbn", wr_head_r, r_head_k)
        BD = self.rel_shift(BD)
        # 计算attention_score,attention_score = softmax((A+B+C+D)/dk[+mask])
        attention_score = (AC + BD) / tf.sqrt(self.hidden_size)
        # 如果有mask
        if mask is not None:
            attention_score += (mask * 1e-9)
        # shape:(seq_len_q, seq_len_k, batch_size, num_heads)
        attention_score = tf.nn.softmax(attention_score, axis=1)
        # 计算attention,attention = attention_score * v
        # shape:(seq_len_q, batch_size, num_heads, hidden_size)
        attention = tf.einsum("ijbn,jbnh->ibnh", attention_score, w_head_v)
        # shape:(seq_len_q, batch_size, num_heads, hidden_size)=>(seq_len_q, batch_size, embed_size)
        attention = tf.reshape(attention, (seq_len_q, batch_size, self.embed_size))
        attention = self.o_net(attention)
        # residual connection
        output = attention + w
        # layer normalization
        output = self.layer_norm(output)
        return output

模型参考

论文地址:https://arxiv.org/abs/1901.02860

代码地址:https://github.com/kimiyoung/transformer-xl

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

不负韶华ღ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值