pytorch unsqueeze_Transformer后生仔:Star-Transformer 剖析和pytorch实现

29221595599cba4e204e9d6983c66912.png

Transformer后生仔:Star-Transformer 剖析和pytorch实现

Transformer作为目前"地表最强"特征抽取器,风头正盛,不乏有人对其进行改进,最出名的便是XLNET使用到的Transformer-XL。 vanilla Transformer的时间复杂度为

,N为序列长度,可见处理长序列耗时会急速增加。Transformer-XL将长序列分割成K个segment,降低每一次encode的长度。同时,模型记忆并使用上一次的隐态,循环encode,复杂度为
。对于长序列,合理的split可以获得较快的加速比。

无独有偶,Star-Transformer同样优化了性能,时间复杂度降低到线性O(n),在合成数据集中,比vanilla Transformer平均快了4.5倍,同时在部分任务中效果更好。

获的这样的效果,Star-Transformer 其精髓便是Star!

一、全连接拓扑优化成星型拓扑

在vanilla Transformer中,每一个Token都可以和任意Token直接进行交互。Token是以全连接的方式组成拓扑结构,这有点像早期的计算机网络。这样的交互方式,作者认为不太合理,并认为将vanilla Transformer的全连接拓扑改成星型结构可能会更好,作者由此出发,著称此文。

2c50ac099fde278b1f5e038e0838196f.png
星型拓扑

如图所示,Star-Transformer 的星型结构中,有两种节点:

  1. 卫星节点(Satellite nodes):类似计算机网络中终端,卫星节点每一个节点都和中继节点链接,在序列表示中,一个Token代表一个卫星节点。
  2. 中继节点(Relay Node):类似计算网络中的中继器,充当Token交互的桥梁。

同时,结构中有两种链接:

  1. 辐射链接(Radical Connections):终端和中继的链接。有了中继节点,每一个Token可以通过两步的方式链接到任意非邻居节点
  2. 环形链接(Ring Connections):左右相邻的Token互相链接,而且需要将第一个Token和最后一个Token链起来,组成一个环。环形链接的角色类似CNN或者双向LSTM。

以上就是Star-Transformer 的结构,个人认为和vanilla Transformer相比主要有两个差异 1. 星型代替全连接,非邻居Token需要通过中继节点,两步更新才能互相交互到。降低了交互的次数,对于一些重要特征,会比较集中在中继节点的表示。 2. 邻居节点是直接交互,非邻居交互成本较大,因此Star-Transformer 比vanilla Transformer更容易学习到局部信息。

二、更新机制和输出

算法流程:

09bfdee8dbafd48757b892316215a484.png

Star-Transformer 提出了基于time step的循环更新方式: 每一个Token(H)由embedding 初始化,中继节点初始为每一个Token的平均值。每一个Token依次通过attention机制更新,值得注意的是attention中的K,V是通过几个向量concat而来

4a61ed56339960f5d5ab5712206e8cd3.png

即当前的h值是根据:上一轮的上一个节点隐态,上一轮的该节点隐态,上一轮下一个节点态隐,本节点的embedding,上一轮的中继节点共同决定。在每次更新都直连了embedding ,相当于和embedding开了high way。 当每一个卫星节点更新完比之后,就更新中继节点。

d8f792911bc9ec54fe866c8507d5e7f3.png

即当前的s值根据:上一轮的s,和本次更新的H(所有已更新的Token)

在多轮更新之后,start-Transformer使用 sT + max(HT)作为序列的特征表示。同时每个h可以单独来表示一个token,直接使用在序列标注任务。

三、效果

9d592b1b9f4bfc289ee6bac4792e55aa.png

本文只列出了一个分类实验,更多实验效果可以参考论文。通过POST出来的分类实验数据来看,star-transformer 准确率更高,而且推理速度更快。

四、pytorch实现

star-transformer的整体并不复杂,其中一些组件已经有公开源码,比如:multi-head-attiontion。只需要额外处理star-transformer特有的更新方式。

我把最核心的代码post在下方 ,其他代码可以参考我的github:star_transformer_pytorch

有几个点需要注意的是 1. h层初始化的时候,需要clone emb的输出,这里需要切断梯度的传递,代码里面有clone同理。 2. 在单一逐次更新token的时候,并没有使用循环,是将依赖的变量批量concat成一个大矩阵,通过矩阵一次计算到一个完整 time step。其实一开始我也是按照论文中的两层循环做,不过无法获得GPU加速,该方式比套循环更加高效。这也说明了,复现模型的时候,尽量不要使用循环和分支,GPU处理很慢。

class StarTransformerLayer(nn.Module):

    def __init__(self, cycle_num, hidden_size, num_attention_heads, attention_dropout_prob):
        super().__init__()
        self.cycle_num = cycle_num
        self.multi_att_satellite = MultiAttention(hidden_size, num_attention_heads, attention_dropout_prob)
        self.multi_att_relay = copy.deepcopy(self.multi_att_satellite)
        self.ln_satellite = LayerNorm(hidden_size)
        self.ln_relay = copy.deepcopy(self.ln_satellite)

    def cycle_shift(self, e: torch.Tensor, forward=True):
        b, l, d = e.size()

        if forward:
            temp = e[:, -1, :]
            for i in range(l - 1):
                e[:, i + 1, :] = e[:, i, :]
            e[:, 0, :] = temp
        else:
            temp = e[:, 0, :]
            for i in range(1, l):
                e[:, i - 1, :] = e[:, i, :]
            e[:, -1, :] = temp

        return e

    def forward(self, e: torch.Tensor):
        # Initialization
        h = e.clone()
        b, l, d = h.size()
        s = F.avg_pool2d(h, (h.shape[1], 1)).squeeze(1)
        for _ in range(self.cycle_num):
            # update the satellite nodes
            h_last, h_next = self.cycle_shift(h.clone(), False), self.cycle_shift(h.clone(), True)
            s_m = s.unsqueeze(1).expand_as(h)
            c = torch.cat(
                [h_last.unsqueeze(-2), h.unsqueeze(-2), h_next.unsqueeze(-2), e.unsqueeze(-2), s_m.unsqueeze(-2)],
                dim=-2)
            c = c.view(b * l, -1, d)
            h = h.unsqueeze(-2).view(b * l, -1, d)
            h = self.ln_satellite(F.relu(self.multi_att_satellite(h, c, c))).squeeze(-2).view(b, l, -1)
            # update the relay node
            s = s.unsqueeze(1)
            m_c = torch.cat([s, h], dim=1)
            s = self.ln_relay(F.relu(self.multi_att_relay(s, m_c, m_c))).squeeze(1)

        return h, s
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值