Vision Transformer 简单实现

1 Transformer的基本结构

Transformer可以分为:编码过程和解码过程。NLP中的Transformer对词进行Embedding,Vision Transformer中将图像进行分块(Patch),将每块进行展平(Flatten),对Patch进行Embedding。

Embedding:从低维到高维或者从高维到低维空间的映射。

2 Patch Embedding 简单实现

import torch
import torch.nn as nn
import cv2

class MLP(nn.Module):
    def __init__(self, embed_dim, mlp_ratio=4.0, drop_out=0):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, int(embed_dim*mlp_ratio))
        self.fc2 = nn.Linear(int(embed_dim*mlp_ratio), embed_dim)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(drop_out)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.act(x)
        return x
    
class patch_embedding(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim, drop_out=0):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
        self.dropout = nn.Dropout(drop_out)
    def forward(self, x):
        # x:[1,1,28,28] -> x:[1, embed_dim, 28/patch_size, 28/patch_size]
        x = self.patch_embed(x)
        x = x.flatten(2) # x: [1, embed_dim, (28/patch_size) * (28/patch_size)]
        x = x.transpose(2,1)
        x = self.dropout(x)
        return x
    
# 读取一张图片
input = torch.randn((28,28,1))
print(input.shape)
input = torch.as_tensor(input)
input = input.reshape((1,1,28,28))

# patch embedding
patch_embed = patch_embedding(patch_size=7, in_channels=1, embed_dim=8)
out = patch_embed(input)
print(out.size())

# MLP
mlp = MLP(8)
out = mlp(out)
print(out.size())
torch.Size([28, 28, 1])
torch.Size([1, 16, 8])
torch.Size([1, 16, 8])

3 注意力机制

在单个序列中使用不同位置的注意力用于实现该序列的表征方法 ------ 《Attention is All You need》

更详细的讲解推荐详解Transformer,这里只是对有一些有疑问的地方解释。

Attention的计算方法,整个过程可以分成7步:
1.将输入单词转化成嵌入向量;
2.根据嵌入向量得到q,k,v三个向量;
3.为每个向量计算一个score: s c o r e = q ⋅ k {score=q\cdot{k}} score=qk
4.为了梯度的稳定,Transformer使用了score归一化,即除以 d k {\sqrt{d_k}} dk
5.对score施以softmax激活函数;
6.softmax点乘Value值v,得到加权的每个输入向量的评分v;
7.相加之后得到最终的输出结果z: z = ∑ v {\quad{z}=\sum{v}} z=v

q分别和k相乘,可以得到s,s非常接近Attention,然后进行缩放和softmax就可以得到Attention weight

A t t e n t i o n   w e i g h t ( Q , K ) = s o f t m a x ( Q K T d k ) \mathrm{Attention\ weight}(Q,K)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}}) Attention weight(Q,K)=softmax(dk QKT)

为什么缩放 d k ( q 或 k 的长度 ) {\sqrt{d_k}}(q或k的长度) dk (qk的长度)
Variance(var)表示什么?序列的波动。
序列var越大,那么经过softmax越容易偏向大值假设序列(feature)Q和K每一位都是iid,并且是random variable(std=1,mean=0)。
那么 Q ∗ K T {Q*K^T} QKT的var就是 d k d_k dk
所以需要把var稳定回1.0。

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V \mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

Multi-Head self Attention:多个Attention一起进行决策,将多个Attention输出经过W(每个输出可信度\哪个更重要)维度转换进行输出。

4 Attention 简单实现

class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads, qkv_bias=False, qk_scale=None):
        super().__init__()
        self.num_head = num_heads
        self.head_dim = embed_dim
        self.all_head_dim = int(self.head_dim*num_heads)
        self.qkv = nn.Linear(embed_dim, self.all_head_dim*3, bias=qkv_bias)
        self.scale = self.head_dim ** -0.5 if qk_scale == None else qk_scale
        self.softmax = nn.Softmax(-1)
        self.proj = nn.Linear(self.all_head_dim, embed_dim)

    def tanspose_multi_head(self, x):
        new_shape =[x.shape[:-1][0], x.shape[:-1][1], self.num_head, self.head_dim]
        x = x.view(new_shape)
        x = x.transpose(1,2)
        return x
        
    def forward(self, x):
        B, N, _ = x.shape
        # 生成Q,K,V
        qkv = self.qkv(x).chunk(3, -1) # [B, N, all_head_dim] * 3
        q,k,v = map(self.tanspose_multi_head, qkv) # [B, N, num_head, head_dim]

        # 计算Attention_weight
        attn = q @ k.transpose(2, 3)
        attn_weight = self.softmax(attn*self.scale)

        # 计算Attention
        out = attn_weight @ v
        out = out.transpose(1,2).contiguous().view(B, N, -1)
        # 多个Attention融合
        out = self.proj(out)

        return out

5 ViT整体的简单实现

class Encoder(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.att = Attention(embed_dim, num_heads=2)
        self.att_norm = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim)
        self.mlp_norm = nn.LayerNorm(embed_dim)
    def forward(self, x):
        h = x
        x = self.att(x)
        x = self.att_norm(x)
        x = h + x

        h = x
        x = self.mlp(x)
        x = self.mlp_norm(x)
        x = h + x
        return x

class ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = patch_embedding(7, 1, 16)
        layer_list = [Encoder(16) for i in range(5)]
        self.encoder = nn.ModuleList(layer_list)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(16, 10)
       
    def forward(self, x):
        x = self.patch_embed(x)
        for encoder in self.encoder:
            x = encoder(x)
        x = x.transpose(2,1)
        x = self.avgpool(x)
        x = x.flatten(1)
        x = self.head(x)

        return x


# 读取一张图片
input = torch.randn((224,224,4))
print(input.shape)
input = torch.as_tensor(input)
input = input.reshape((4,1,224,224))
out = ViT().cuda()(input.cuda())
print(out.size())
torch.Size([224, 224, 4])
torch.Size([4, 10])

参考
Vision Transformer打卡营

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值