TOKEN MERGING: YOUR VIT BUT FASTER

文章提出TokenMerging(ToMe)技术,通过合并相似token以提高Transformer模型如ViT的推理和训练速度,同时保持高精度。ToMe避免了传统剪枝方法的缺点,适用于图片、视频和音频数据,且在多种模型和数据集上表现出色,显著提升了吞吐量并缩短了训练时间。
摘要由CSDN通过智能技术生成

论文地址

ABSTRACT

提出了Token Merging(ToMe),能够提升ViT模型的吞吐量。ToMe通过轻量级算法合并相似的token。原始的ViT-L 512(图片) ,ViT-H 518(图片)和ViT-L(视频),经过token merging操作之后推理的吞吐量成倍增加,并且正确率只有0.2-0.3%的降低。在训练时加入ToMe,可以缩短MAE近两倍微调时间;对音频数据,ToMe可以提高ViT-B两倍的吞吐量,仅降低0.4% mAP。

INTRODUCTION

swin transformer使用vision-specific attention,MViT使用vision-specific pooling,LeViT使用vision-specific conv
modules都想提高模型的效率,使模型在较少计算量的情况下获得最好的结果。

由于token的数量决定了transformer模型的复杂度,所以如何在大量减少token数量的同时,保证模型的性能成为了一项新的挑战。前人们想到了对token进行剪枝进而减少token的数量,但是token剪枝有很多缺点:1)剪枝造成的信息损失限制了减少token的数量。2)部分剪枝的方法由于引进了参数,所以要对模型进行重新训练。3)大部分方法不能进行加速训练。4)输入内容不同可能导致减少的token数量不同,这就意味着无法进行批量的推理。

我们提出了token merging(ToMe),可以加速模型推理和训练速度。并且无需训练就可以直接用于图片、视频、音频的处理。

TOKEN MERGING

我们的目标是将token merging模块插入到Vit模型中。通过合并冗余的token,我们希望能够提高吞吐量,同时不需要行训练。
在这里插入图片描述
融合策略

Token Merging 的基本思路是在一个 ViT 模型中间插入一些 token merging 的模块。基本作法是在每一个层之后减少r个 token, 那么一个有 L层的 Transformer 模型从头到尾减少的 token 数量就是 rL 。这个r值越高, 减少的 token 数量就越多, 但是精度也会越差。并且无论一张输入图片有多少个 tokens, 都会减少rL个 token。这样设计的好处就是模型减少的token数量与输入图像的大小无关。

token相似问题
在这里插入图片描述
把不同 head 的 Key 进行取平均操作,而不是拼接在一起,可以提高效率
在这里插入图片描述

作者通过消融实验来确定衡量相似度最好的办法。作者发现使用 Key矩阵 来衡量相似度对性能最优,因为 Attention 模块中的 Key 已经总结了每个 token 中包含的信息,以便用于 Attention 中的 dot-product 相似度。如右图所示为使用什么距离衡量相似度,作者发现使用余弦距离来衡量 token 之间的相似度可以获得最好的精度-速度权衡。

使用二分软匹配算法进行token融合

1)将输入ToMe模块的所有tokens均分到两个集合A,B(交替取样);
2)对于集合A中的每个token,在集合B中找到与之最相似的一个token,连起来;
3)保留r条最相似的边;
4)将仍然相连的tokens融合(取均值)
5)输出两个集合的并集。
在这里插入图片描述

该算法的pytorch代码

def bipartite_soft_matching(
    metric: torch.Tensor,
    r: int,
    class_token: bool = False,
    distill_token: bool = False,
) -> Tuple[Callable, Callable]:
    """
    Applies ToMe with a balanced matching set (50%, 50%).
    Input size is [batch, tokens, channels].
    r indicates the number of tokens to remove (max 50% of tokens).
    Extra args:
     - class_token: Whether or not there's a class token.
     - distill_token: Whether or not there's also a distillation token.
    When enabled, the class token and distillation tokens won't get merged.
    """
    protected = 0
    if class_token:
        protected += 1
    if distill_token:
        protected += 1

    # We can only reduce by a maximum of 50% tokens
    t = metric.shape[1]
    r = min(r, (t - protected) // 2)

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., ::2, :], metric[..., 1::2, :]
        scores = a @ b.transpose(-1, -2)

        if class_token:
            scores[..., 0, :] = -math.inf
        if distill_token:
            scores[..., :, 0] = -math.inf

        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        if class_token:
            # Sort to ensure the class token is at the start
            unm_idx = unm_idx.sort(dim=1)[0]

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        if distill_token:
            return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
        else:
            return torch.cat([unm, dst], dim=1)

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        n, _, c = unm.shape

        src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))

        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., 1::2, :] = dst
        out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)

        return out

    return merge, unmerge
#其中, 关键变量的含义, 维度和相关的注释如下:
#src: 集合A, shape: (B,N,c)
#dst: 集合B, shape: (B,N,c), 其中第1个 token 是 [distillation] token
#unm: 集合A 中不 merge 的 tokens, shape: (B,N−r,c),其中第1个 token 是 [class] token
#src:集合 A 中要 merge 的 tokens, shape: (B,r,c)
#借助 dst.scatter_reduce() 函数在集合 B 中完成 token merging 操作。
#借助 torch.cat() 函数完成 merge 之后的集合 B 与不 merge 的集合 A 中的 tokens 的拼接工作。

作者通过实验确定了,交替采样效果最好

在这里插入图片描述

作者同时对比了token剪枝和token融合的性能,结果token融合并且采用二分软匹配算法的性能更好。

在这里插入图片描述

一旦token被合并,它们就不再代表一个padding。我们需要改变softmax:一个代表了很多padding的token的权重当然和代表一个padding的token的权重不一样。具体的公式为:

在这里插入图片描述
其中s是一个包含每个token的行向量,其中包含每个token的大小(token表示的padding数)

当token进行融合时,我们应该采用什么样的融合方式?

在这里插入图片描述

图像上的训练

在这里插入图片描述
作者对融合策略进行大量的实验,对于每个融合策略(每一层固定融合r个token还是每一层随机融合不定数量的token),作者使用现成的AugReg ViT-B/16模型在ImageNet-1k val数据集上测试其准确性和fp16吞吐量。我们可以看到固定token数量的融合策略比平均值要好。

监督模型

作者在11个 SOTA 的预训练 ViT 模型 (直接下载开源模型,不进行任何额外的训练) 上使用了本文提出的 ToMe 方法。AugReg 实验结果如所示,为在大规模数据集预训练的模型,再在 ImageNet-1K 上 fine-tune 得到的结果。
在这里插入图片描述
SWAG 实验结果如下图所示,为在大规模数据集弱监督预训练的模型,再在 ImageNet-1K 上 fine-tune 得到的结果。

在这里插入图片描述
结果表明,无论模型的尺寸和类型,ToMe 都能够带来约2倍的吞吐量加速。即使减少 96-98% 的 tokens,最大的模型几乎没有任何精度下降:在2倍吞吐量的设置下,AugReg 得到的 ViT-B,ViT-S 和 ViT-Ti 都有大约 4-5% 的精度下降。ViT-L 在 224px 图像上仅下降 2%,在 384px 图像上下降 0.7%,可能是因为更大的输入图片有更多的 tokens。

自监督模型

MAE 实验结果如下图所示,为在大规模数据集弱监督预训练的模型,再在 ImageNet-1K 上 fine-tune 得到的结果。结果显示,在2倍吞吐量的设置下,MAE 得到的 ViT-H,ViT-L 和 ViT-TB 分别有 0.4%,0.6% 和 1.7% 的精度下降。

在这里插入图片描述

作者对MAE微调模型与在没有额外数据的ImageNet-1k上训练的最先进的模型进行了比较。我们可以看到ToMe合并提高了ViT模型的吞吐量,这样ViT-L和ViT-H在速度上可以与较低层的模型相媲美的同时正确率也有所保障。

在这里插入图片描述

下图所示是 ToMe 方法与 Token Pruning 方法 (DynamicViT,A-ViT,SPViT) 在 DeiT-S 模型上的对比结果。ToMe 方法可以再不使用梯度技巧,如 gumbel softmax 等,不添加额外的参数,以及不使用额外的训练技巧的情况下匹配性能,并超过现有的 Token Pruning 方法。而且,Token Pruning 方法通过由于自身的限制往往使用 token padding 或者 attention 掩码的方法,使得剪枝带来的好处没法发挥出来。但是,ToMe 方法不受这个问题的影响。
在这里插入图片描述
下图是经过ToMe方法后图像的可视化结果

在这里插入图片描述

视频上的训练

作者用两种方法将ToMe方法应用到模型中。一种是是直接把 ToMe 方法应用在现成的训练好的模型中,另一种是在 MAE 进行微调的环节用上 ToMe 方法。将这两种方法和各种视频分类模型做对比,结果如图所示。将 ToMe 方法应用在 ViT-L 上之后,吞吐量与 Swin-B 接近,同时性能更好。而且,将 ToMe 方法应用在 ViT-L 上之后,使用 Spatiotemporal MAE的方式,性能明显优于 MAE 方式训练的 ViT-B 模型,说明 token 融合的方法比 model scaling 更好。

在这里插入图片描述

我们显示了我们的方法应用于ViT-L的吞吐量和训练时间。在恒定的策略下,我们可以将吞吐量增加2.2×,准确率下降0.2%。此外,这种设置将训练时间缩短了一半。

在这里插入图片描述

下图的实验说明了,当利用较多的clips对视频进行训练时,基准模型和ToMe方法的模型正确率相差不大。
在这里插入图片描述
可视化结果

在这里插入图片描述

结论

我们引入了Token Merging(ToMe),通过逐步合并token来提高ViT模型的吞吐量。ToMe自然地利用融合中的冗余。ToMe在跨领域的大型模型上工作得很好,并减少了训练时间和内存的使用,这意味着它可能成为训练大型模型的核心组成部分。

ToMe可以在不经过训练的情况下,大量融合冗余的token,使得模型训练和推理速度大幅提升的同时保证准确率不会大幅度下降。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值