Set Transformer原理以及源码解读

《Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks》,本文的作者来自牛津大学。
Transformer借助Self-attention在顺序数据的处理上取得了巨大的成功,本文旨在处理集合数据(其出发点和Pointer Net是一样一样的),并且强调了对于集合的置换不变性。换句话说,模型的输出不依赖于输入的顺序,因此Set Transformer被用于处理集合数据。但是,对于大量数据的集合,自注意力的时间复杂度是 O ( n 2 ) O(n^2) O(n2),因此本文又采取了新的方法把时间复杂度降低到了 O ( m n ) O(mn) O(mn)

Background

Pooling Architecture for Sets

联想到图池化的操作,置换不变性的模型的一个概述性的式子为:
在这里插入图片描述
[1]已经证明过当 p o o l pool pool是sum池化 β β β ϕ ϕ ϕ是法人一连续函数的时候,所有的置换不变函数都可以表示成(1)的形式。(1)继续细分可以分为encoder ( ϕ ) (ϕ) (ϕ) 和decoder β ( p o o l ( ⋅ ) ) β(pool(·)) β(pool())两个部分,这就和Transformer的的架构一致。此外[1]还观察到尽管编码器是permutation-equivariant层的堆叠,该模型仍然保持置换不变性。反正甭管怎么说吧,此处为接下来的Set Transformer提供了理论支持。

Attention

在此也复习一下自注意力。
在这里插入图片描述
其中 K Q KQ KQ的维度一样,都是 n v × d q n_v×d_q nv×dq的,因此 Q K T QK^T QKT的结果为 n v × n v n_v×n_v nv×nv的。然后 w w w是激活函数, V V V的维度是 n v × d v n_v×d_v nv×dv,因此(3)输出的shape=[ n v × d v n_v×d_v nv×dv]。多头注意力则是多次注意力的拼接之后的线性变换:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
其中 O j O_j Oj就是一次注意力之后的结果,为了保证多层堆叠的时候不调节每一层之间的参数,Transformer中强制设置 d q M = d q M = d / h d_q^M=d_q^M=d/h dqM=dqM=d/h,也就是输出的维度和输入的维度一致。

Set Transformer

先用图给出下面几种不同的结构的示意:
在这里插入图片描述

首先,定义多头注意力Block(MAB):
在这里插入图片描述
H H H是输入的特征与其多头自注意力变换之后的特征的和再过一个Norm层。然后,再对 H H H做了rFF(row-wise feedforward layer)再和 H H H相加。借助MAB,Set Attention Block (SAB)被定义为:
在这里插入图片描述
SAB接受一个集合,并在集合中的元素之间执行自我注意,从而产生一个大小相同的集合。但是与Transformer不同的是,它缺少了position embedding以及dropout。由于SAB的输出包含关于输入集合X中元素之间成对交互的信息,因此可以和transformer堆叠多个SAB来编码更高阶的交互,这就是简化版本的transformer。之后,由于时间复杂度的关系,提出了一个简化的版本ISAB:
在这里插入图片描述
相比于SAB,注意力不在对两两输入节点(数据)计算,ISAB首先根据输入集将I转换成H,这类似于低秩投影或自编码器模型,其中输入(X)首先投影到低维对象(H),然后重构产生输出。 I I I的大小为 m m m,这是需要调节的超参数。

Pooling by Multihead Attention

上述的SAB的多层堆叠可以看做是encoder的过程。之后,对encoder的输出 Z ∈ R n ∗ d Z∈R^{n*d} ZRnd进行解码:
在这里插入图片描述
这里通过对一组可学习的k个种子向量S应用多头注意力来聚合特征,因此 P M A k PMA_k PMAk的输出是 k k k个元素的集合。在大多数情况下,我们使用一个种子向量(k = 1),但对于需要k个相关输出的平摊聚类这样的问题,自然的做法是使用k个种子向量。而为了模拟这 k k k个种子之间的交互,又用了一个SAB(真不愧是Attention is all you need):
在这里插入图片描述
所以整体的模型的框架为:
在这里插入图片描述
在这里插入图片描述
至此模型的定义结束。实验我还就来一手不写。

Code!

来看代码。github地址为:https://github.com/juho-lee/set_transformer。

class SetTransformer(nn.Module):
    def __init__(self, dim_input, num_outputs, dim_output,
            num_inds=32, dim_hidden=128, num_heads=4, ln=False):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
                ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
                ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
        self.dec = nn.Sequential(
                PMA(dim_hidden, num_heads, num_outputs, ln=ln),
                SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                nn.Linear(dim_hidden, dim_output))

    def forward(self, X):
        return self.dec(self.enc(X))

可以看到,encoder由两层ISAB组成,decoder由PMA以及两层SAB组成。接下来拆解看不同的Block:
MAB:

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads   # 多头注意力头数
        self.fc_q = nn.Linear(dim_Q, dim_V)   # 自注意力KQV参数
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)   # 对应公式(4)中的输出参数O

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)
		# 注意此处是采用了Transformer原文中的自注意力的定义:softmax(QK^T/√d)V的方式,与本论文中的公式稍有不同
        A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))   # 对应公式(6)
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

SAB:。和论文中描述的一样,SAB只是输入都是X的MAB

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)

ISAB:,相比于SAB只是多了一个根据预定义参数 I I I求H的操作:

class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)  # 公式(10)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)  # 公式(9)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
        return self.mab1(X, H)

PMA:,根据一组seeds计算多个MAB,对应文中的公式(11)

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
        nn.init.xavier_uniform_(self.S)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(self.S.repeat(X.size(0), 1, 1), X)

参考文献

Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., and Smola, A. J. Deep sets. In Advances in Neural Information Processing Systems (NeurIPS), 2017.

嗨!对于Transformer源码解读,我可以给你一些基本的指导。请注意,我不能提供完整的源代码解读,但我可以帮助你理解一些关键概念和模块。 Transformer是一个用于自然语言处理任务的模型,其中最著名的应用是在机器翻译中。如果你想要深入了解Transformer的实现细节,我建议你参考谷歌的Transformer源码,它是用TensorFlow实现的。 在Transformer中,有几个关键的模块需要理解。首先是"self-attention"机制,它允许模型在处理序列中的每个位置时,同时关注其他位置的上下文信息。这个机制在Transformer中被广泛使用,并且被认为是其性能优越的主要原因之一。 另一个重要的模块是"Transformer Encoder"和"Transformer Decoder"。Encoder负责将输入序列转换为隐藏表示,而Decoder则使用这些隐藏表示生成输出序列。Encoder和Decoder都由多个堆叠的层组成,每个层都包含多头自注意力机制和前馈神经网络。 除了这些核心模块外,Transformer还使用了一些辅助模块,如位置编码和残差连接。位置编码用于为输入序列中的每个位置提供位置信息,以便模型能够感知到序列的顺序。残差连接使得模型能够更好地传递梯度,并且有助于避免梯度消失或爆炸的问题。 了解Transformer源码需要一定的数学和深度学习背景知识。如果你对此不太了解,我建议你先学习相关的基础知识,如自注意力机制、多头注意力机制和残差连接等。这样你就能更好地理解Transformer源码中的具体实现细节。 希望这些信息对你有所帮助!如果你有任何进一步的问题,我会尽力回答。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

五月的echo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值