UFO-ViT: High Performance Linear Vision Transformer without Softmax

UFO-ViT是一种新型的线性复杂度VisionTransformer,通过消除自注意力机制中的softmax非线性和采用XNorm约束,提高了模型的效率和性能。该模型在图像分类和目标检测等任务上表现出色,同时减少了对GPU内存的需求和计算资源的消耗。XNorm防止了模型对初始化的依赖,并允许线性计算,适合密集预测任务。
摘要由CSDN通过智能技术生成

paper链接: https://arxiv.org/pdf/2109.14382.pdf

(一)、引言

本文提出了单元强制操作Vision Transformer(UFO-ViT),这是一种具有线性复杂度的新型SA机制。这项工作的主要方法是从原来的SA。我们分解了SA机构的矩阵乘法消除非线性,没有复杂的线性逼近。仅修改原始SA中的几行代码,所提出的模型在大多数图像分类和密集预测任务上优于基于Transformer的模型。
原始的自注意(SA)机制尽管取得了巨大的成功,但由于 σ ( Q K T ) ∈ R N × N σ(QK^T)∈R^{N×N} σ(QKT)RN×N和V的矩阵乘法,其时间和计算复杂度为 O ( n 2 ) O(n^2) O(n2)。这是传统Transformer的缺点之一。对于视觉任务,N与输入分辨率成正比。这意味着如果输入图像的宽度和高度加倍,SA将消耗16倍的计算资源。
1、提出了一种新的约束方案XNorm,它生成一个单元来提取关系特征。该方案可以防止SA依赖于初始化。此外,通过替换softmax函数,消除了SA的非线性。
2、经验表明,UFO-ViT模型具有更快的推理速度和更少的GPU内存需求。对于不同的分辨率,所需的计算资源并没有显著增加。此外,模型中使用的权重与分辨率无关。这对于密集的预测任务(如目标检测和语义分割)是一个有用的特征。大多数密集预测任务需要比预训练阶段更高的分辨率,即基于mlp的结构需要额外的后处理以适应各种分辨率。

(二)、实现细节

本文模型结构如下所示。它混合了卷积层、UFO模块和简单的前馈MLP层。
在这里插入图片描述

(一)、基础结构

对于输入 x ∈ R N × C x∈R^{N×C} xRN×C,传统SA机制表述如下:
在这里插入图片描述
其中A表示注意算子。如果消除了softmax的非线性, σ ( Q K T ) V σ(QK^T)V σ(QKT)V可分解为 O ( N × h + h × N ) O(N × h + h × N) O(N×h+h×N)。本文使用XNorm代替softmax,它允许SA模块首先计算 K T V K^TV KTV
XNorm的定义如下:
在这里插入图片描述
其中 γ γ γ是一个可学习的参数,h是嵌入维数。它是一个常见的l2范数,但它是沿着两个维度应用的: K T V K^TV KTV的空间维度和q的通道维度。因此,它被称为交叉归一化。
使用结合律,键和值首先相乘,然后查询相乘。下图描述了这一点。这两个乘法运算的复杂度都是 O ( h N d ) O(hNd) O(hNd),所以这个过程对N是线性的。
在这里插入图片描述

(二)、XNorm

在XNorm中,自注意力的键和值直接相乘。通过线性核的方法生成h个聚类:
在这里插入图片描述
XNorm应用于查询和输出。
在这里插入图片描述
其中x表示输入。最后,投影权重使用加权和缩放和聚集点积项。
在这里插入图片描述
在这个公式中,关系特征是由嵌入块和簇之间的余弦相似度定义的。XNorm将查询和聚类中的每个像素的特征限制为单位向量。这可以防止它们的值通过将它们正则化为有限的长度来抑制关系属性。如果它们具有任意值,则注意区域依赖于初始化。
残差连接,任意一个模块的输出公式如下:
在这里插入图片描述
其中n和x分别表示当前层和输入图像的索引。假设x为某物体的位移,n为时间,则上式可以重新定义为:
在这里插入图片描述
大多数神经网络是离散的,因此∆t是常数。(为简便起见,设∆t = 1。)残差项表示速度,所以当粒子有单位质量且∆t = 1时,这一项表示权重项。
在物理学中,胡克定律被定义为弹性向量k和位移向量x的点积。弹性力产生谐波势U,是x2的函数。物理上,势能会干扰粒子运动的路径。(想象一个球在抛物线轨道上运动。)
在这里插入图片描述
以上公式一般用于近似分子在x≈0处的势能。对于多个分子,可以利用弹性线性度:
在这里插入图片描述
XNorm不是规范化,而是约束。这就是为什么被称为单位权重操作,或简称为UFO。

(三)、实验

(一)、消融实验

在这里插入图片描述
大多数其他归一化方法都无法训练。有趣的是,单一l2范数的应用也表现出较差的性能。所有结果见表3。
在这里插入图片描述
在这里插入图片描述

(二)、目标检测

在这里插入图片描述
代码:

import numpy as np
import torch
from torch import nn
from torch.functional import norm
from torch.nn import init


def XNorm(x,gamma):
    norm_tensor=torch.norm(x,2,-1,True)
    return x*gamma/norm_tensor


class UFOAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''

    def __init__(self, d_model, d_k, d_v, h,dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(UFOAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.dropout=nn.Dropout(dropout)
        self.gamma=nn.Parameter(torch.randn((1,h,1,1)))

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries, keys, values):
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        kv=torch.matmul(k, v) #bs,h,c,c
        kv_norm=XNorm(kv,self.gamma) #bs,h,c,c
        q_norm=XNorm(q,self.gamma) #bs,h,n,c
        out=torch.matmul(q_norm,kv_norm).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)

        
        return out


if __name__ == '__main__':
    input=torch.randn(50,49,512)
    ufo = UFOAttention(d_model=512, d_k=512, d_v=512, h=8)
    output=ufo(input,input,input)
    print(output.shape)
`
``

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值