【科研基础】前沿AI方法(VIT、ConvMixer等)

关于前沿AI方法transformer、VIT和Swin T的总结

[1].Attention Is All You Need
[2].https://zhuanlan.zhihu.com/p/366592542
[3].代码实现:https://zhuanlan.zhihu.com/p/653170203
[4].An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
[5] Swin Transformer 时间复杂度的分析 https://blog.csdn.net/weixin_45943887/article/details/127881179
[6]. ConvMixer, Trockman, Asher, and J. Zico Kolter. “Patches are all you need?.” arXiv preprint arXiv:2201.09792 (2022).
CSDN: https://blog.csdn.net/baidu_36913330/article/details/120655407
[7]. Yadav, Saurabh, and Koteswar Rao Jerripothula. “FCCNs: Fully Complex-valued Convolutional Networks using Complex-valued Color Model and Loss Function.” Proceedings of the IEEE/CVF International Conference on Computer Vision. 2023.

1.transformer

1.1.注意力机制

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.[1]
输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility)每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出

Attention Is All You Need 中用到的attention叫做“Scaled Dot-Product Attention”,具体过程如下图所示:
在这里插入图片描述
代码实现:

import torch
import torch.nn as nn


class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs  to  be div by heads"
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]  # the number of training examples
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
            # Fills elements of self tensor with value where mask is True

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out

1.为什么有mask?
NLP处理不定长文本需要padding,但是padding的内容无意义,所以处理时需要mask.
2.关于qkv
qkv是相同的,需要查询的q,与每一个key相乘得到权重信息,权重与v相乘,这样结果受权重大的v影响
3.为什么除以根号dk

We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 1 √dk
点积过大,经过softmax,进入饱和区,梯度很小

4.为什么需要多头
在这里插入图片描述
不同头部的output就是从不同层面(representation subspace)考虑关联性而得到的输出。
4.复杂度分析
在这里插入图片描述

1.2.TransformerBlock

解码端的后面两部分和编码段一样,所以打包成一个类
在这里插入图片描述

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

1.3.Encoder

关键的就是位置编码

class Encoder(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length
                 ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
                for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_lengh = x.shape
        positions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

2.VIT

在这里插入图片描述
序列化
原始图像尺寸:3×224×224
patch size:16×16(×3)
token数:(224/16)^2 = 14 ^ 2 =196
所以最终序列: (196,768)

在Swin T中,
patch size:4×4(×3)
token数:(224/4)^2 = 56 ^ 2 =3136
所以最终序列: (3136,48)

但是Swin T引入了window的概念,是在window中作self-attention,window size = 7,所以一共有56/7 = 8个窗口,每个窗口有7×7个patch, 每个patch是4×4个像素。
之后Swin T在窗口里做self-attention,序列长度从3136降为49

3.Swin T

在这里插入图片描述
Patch partition和Linear Embedding
Patch partition和Linear Embedding是以此卷积就完成的,Patch partition将原始图片分块,Linear Embedding将分块后的48投射到指定维度

x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
self.proj = nn.Conv2d(in_chans=3, embed_dim=256, kernel_size=patch_size(2,2), stride=patch_size)(2,2)
x = self.proj(x) #(32,3,32,32) --> (32,256,16,16) if patch_size = 4, (32,256,8,8)
x = x.flatten(2)#(32,256,256)
x = x.transpose(1,2)#(32,256,256) (32,64,256) (B,N,C)
 
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 
# (3,128,8,64,32)
q, k, v = qkv[0], qkv[1], qkv[2]  
# (128, 8,         64,32)
# (B_,  num_heads, N, C // self.num_heads)

patch merge的含义和作用?
上采样,获得多尺度的信息,本质是用空间的维度的减少,换取更多的channel维度。
首先合并patch(并不是相邻的patch合并,而是隔一个),这样HW都减半,C变成4C。为了和ResNet中的池化对应(HW都减半,C变成2C),patch合并后又加了一个1×1的卷积将4C变为2C

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

移动窗口
提出窗口做self-attention,但是窗口间没有通信无法全局建模(没有上下文信息),所以移动窗口。
所以每个基本的Swin T block都包括两次多头self-attention, 一次是W-MSA, 一次是SW-MSA

4.ConvMixer

  • 24
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

|7_7|

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值