Vision transformer(ViT)输入后维度变化情况 + 代码实现,如何理解ViT?

11.1 Vision Transformer(vit)网络详解_哔哩哔哩_bilibili

文章地址:Vision Transformer详解_霹雳吧啦Wz-CSDN博客

一句话来解释vit:

将图像的这些东西(C,H,W)变为NLP处理的方式(T, D),所以就消除了CV中的通道C,宽高H\W的概念。

从而使C就是D,H和W变为了T。那么原来对CV的某些实现,例如IN, LN..就可以通过这种方式进行替换套用

整体运行:

其中关键的图:

vit-b/16

图像的维度变化:

1、输入 -->(N,C,H,W)

(N,C,H,W)维度的图像(4,3,32,32), N代表批次

2、经过timm.PatchEmbed() --> (N, T, D)

PatchEmbed()的patch size 为 (8, 8),input_dim设置为1024:

使用 patch size 为 (8, 8) 将输入图像划分为图像块,得到(32//8)*(32//8)=16个图像块,对应image_sequence_length,后面的维度就是定义的input_dim,所以此时的维度是(4,16,1024)

添加 Class Token,将维度变为(4,17,1024)

此时这里你也可以加入位置编码,维度不会发生变化(4,17,1024)

self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, input_dim)

3、多头自注意力机制(Multi-Head Self-Attention)--> (N, T, D)

如果默认的头数是 h(通常为 12),并且头维度(head dim)是 d_h(通常为 64),则每个头产生的输出维度为(4,17,64)。embed_dim对应d_h*h = 768

对patch后的token进行线性变换得到 Q, K, V --> (N,T,d_h*h​)

对输入(N, T, D)进行三个线性变换,得到Q,K,V,每个变换的权重矩阵维度为 (d_model​, d_h*h​),即(1024->12x64=768)。h为头的个数

Q 的维度为 (4,17,d_h*h​),K 的维度为 (4,17,d_h*h​),V 的维度为 (4,17,d_h*h​)。

self.W_q = nn.Linear(input_dim, embed_dim)
self.W_k = nn.Linear(input_dim, embed_dim)
self.W_v = nn.Linear(input_dim, embed_dim)

将Q, K, V拆分为多头:--> (N,T,h,d_h​)--> (N,h,T,d_h​)

将得到的 Q,K,V 分别拆分为 h 个头,每个头的维度为 (4,17,d_h)。

这里的拆分是在最后一维上进行的。

# 应用线性变换生成Q, K, V
Q = self.W_q(x).view(batch_size, image_sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
# 实际为(batch_size, self.num_heads, image_sequence_length, self.head_dim)
K = self.W_k(x).view(batch_size, image_sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(batch_size, image_sequence_length, self.num_heads, self.head_dim).transpose(1, 2)

对每个头进行注意力计算:--> (N,h,T,h_d​)

对 Q 和 K 进行点积,然后进行缩放操作,最后应用 softmax 函数。

将注意力分数乘以 V,得到每个头的注意力输出(4,17,d_h)。

# 得到(b,num_heads,image_sequence_length,image_sequence_length)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

# 应用Softmax获取注意力权重
# 计算块与块之间的注意力
attn_weights = torch.softmax(attn_scores, dim=-1)

# 使用注意力权重加权V
attn_output = torch.matmul(attn_weights, V)

合并多头输出:--> (N,T,D​)

将每个头的注意力输出按最后一维度进行连接(concatenate),得到多头注意力的输出。多头注意力的输出的维度为 (4,17,768)。这里768=h*d_h

映射为(4,17,1024),加残差,维度不变(4,17,1024)

attn_output = attn_output.view(batch_size, image_sequence_length, self.embed_dim)

4、LayerNormalization--> (N, T, D)

对于每个位置,进行 LayerNormalization 维度不变(4,17,1024)。

class PreNorm(nn.Module):
    def __init__(self,dim,fn):#fn指norm之后的函数 如上图,可以接mha或者mlp
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    
    def forward(self,x):
        return self.fn(self.norm(x))

5、Feedforward 网络 --> (N, T, D)

如果默认的 Feedforward 网络的中间维度(为 2048):
输入维度:(4,17,1024)
变化为:(4,17,2048)
输出维度:(4,17,1024)
加残差,维度不变(4,17,1024)

#ffn也就是mlp
class FeedForward(nn.Module):
    def __init__(self,dim,hidden_dim,drop_out=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim,hidden_dim),
            nn.GELU(),
            nn.Dropout(drop_out),
            nn.Linear(hidden_dim,dim),
            nn.Dropout())
    
    def forward(self,x):
        return self.net(x)

6、MLP head --> (N,  D)

(4,17,1024)全局池化或者只取cls token(4,1024)mlp输出1000类->(4,1000)

保姆级VIT详解输入维度变化以及对应代码实现_vit的输入-CSDN博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Pengsen Ma

太谢谢了

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值