(pytorch进阶之路四)Vision Transformer

大致思想

Vision Transformer 用的是Encoder only类型,主要用的就是Transformer Encoder模块

VIT的想法就是将Transformer应用到图像识别上去

但是直接应用有个问题,NLP是单词为单位,一句话的词数还是比较少的,但是图片的基本单位是一个个像素点,数量多得多,导致计算量会非常的大。

图片的一个像素点不包含多少信息量,对于图片它的信息量主要还是聚焦在一小块区域上。

直接的想法就是将很多的像素点组成块,图像分成很多个块,将图像块当作一个token送入到transformer中

块有两种角度理解,一是用DNN角度,将图片切割成很多个块,再将这很多个块经过仿射变换变成向量。

另一个角度是CNN角度,块用卷积操作,stride = kernel size,再将输出卷积图拉直变成一个个token向量。

为了做分类任务,VIT借鉴了bert中class token embedding这样一个占位符

在VIT中同样用了position embedding,经对比还是使用可训练的一维embedding好一点

最后用class token这个状态量做分类任务

将图片分成固定大小的块,用线性网络得到嵌入式表征,再加入位置编码,将一系列的向量送入到标准transformer中。
在这里插入图片描述
图片大小可能会变换,但是块的大小是不会变的,反应的只是序列长度大一些。

在encoder的最后一层,取出新加的输出状态,经过MLP得到类别的概率分布,使用交叉熵算出分类loss

论文地址

https://openreview.net/pdf?id=YicbFdNTTy

Patch embedding

第一步将图片变成embedding,两种方式变embedding,一是DNN,二是CNN

首先图片分块,其实就是卷积中的滑动操作,我们可以直接使用functional中的Unfold函数,直接拿出所有卷积的区域,刚好就是image to patch的过程

一个patch的大小 = ph * pw * ic

Unfold操作结果shape = [bs, patch大小, patch数量]

转置一下Unfold操作结果,再与DNN权重相乘,就得到了embedding

DNN embedding:

import torch.nn.functional as F
import torch as t

def image2emb_naive(image, patch_size, weight):
    """
    image shape: [bs,c,h,w]
    weight: DNN weight
    """
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
    patch_embedding = patch @ weight

    print(patch.shape, "# patch.shape")
    print(patch_embedding.shape, "# patch_embedding.shape")

    return patch_embedding


def test_image2emb_naive():
    bs, ic, ih, iw = 1, 3, 8, 8
    patch_size = 4
    model_dim = 8
    patch_depth = patch_size * patch_size * ic
    image = t.randn(bs, ic, ih, iw)
    weight = t.randn(patch_depth, model_dim)
    image2emb_naive(image, patch_size, weight)


test_image2emb_naive()

CNN卷积构造embedding最关键的是构造好kernel,model_dim其实就是输出通道数oc,

kernel shape = [oc, ic, kh, kw] = [model_dim, ic, patch_size, patch_size]

stride = patch_size

def image2emb_conv(image, kernel, stride):
    # conv_output: [bs, oc, oh, ow]
    conv_output = F.conv2d(image, kernel, stride=stride)
    bs, oc, oh, ow = conv_output.shape
    embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(-1, -2)
    return embedding


def test_image2emb_conv():
    bs, ic, ih, iw = 1, 3, 8, 8
    patch_size = 4
    model_dim = 8
    image = t.randn(bs, ic, ih, iw)
    kernel = t.randn(model_dim, ic, patch_size, patch_size)
    patch_embedding = image2emb_conv(image, kernel, stride=patch_size)
    print(patch_embedding.shape)


test_image2emb_conv()

Class token

在序列前面添加CLS token,用来做分类任务

def append_cls_token(patch_embedding):
    bs, _, model_dim = patch_embedding.shape
    cls_token_embedding = t.randn(bs, 1, model_dim, requires_grad=True)
    # 把cls放到第一个位置上
    token_embedding = t.cat([cls_token_embedding, patch_embedding], dim=1)
    return token_embedding

Position embedding

这里的position embedding的思想类似word embedding,用一个table做embbeding

tile函数的作用就是复制,拷贝table和bs一样多份

def append_position_embedding(max_num_token, token_embedding):
    bs, seq_len, model_dim = token_embedding.shape
    # shape = [vocab_size, model_dim]
    position_embedding_table = t.randn(max_num_token, model_dim, requires_grad=True)
    position_embedding = t.tile(position_embedding_table[:seq_len], [bs, 1, 1])
    token_embedding += position_embedding
    return token_embedding

Encoder

直接调用nn API

def pass_embedding_to_encoder(token_embedding, ):
    bs, seq_len, model_dim = token_embedding.shape
    encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
    encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
    encoder_output = encoder(token_embedding)
    return encoder_output

Classification mlp

最后是取出CLS token做分类任务

def do_classification(encoder_output, num_class, model_dim, label):
    # label = t.randint(10,(bs,))
    cls_token_output = encoder_output[:, 0, :]
    linear_layer = nn.Linear(model_dim, num_class)
    logits = linear_layer(cls_token_output)
    loss_fn = nn.CrossEntropyLoss()
    loss = loss_fn(logits, label)
    return loss

完整代码

import torch.nn.functional as F
import torch as t
import torch.nn as nn

def image2emb_conv(image, kernel, stride):
    # conv_output: [bs, oc, oh, ow]
    conv_output = F.conv2d(image, kernel, stride=stride)
    bs, oc, oh, ow = conv_output.shape
    embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(-1, -2)
    return embedding

def append_cls_token(patch_embedding):
    bs, _, model_dim = patch_embedding.shape
    cls_token_embedding = t.randn(bs, 1, model_dim, requires_grad=True)
    # 把cls放到第一个位置上
    token_embedding = t.cat([cls_token_embedding, patch_embedding], dim=1)
    return token_embedding


def append_position_embedding(max_num_token, token_embedding):
    """
    max_num_token:序列最大长度
    """
    bs, seq_len, model_dim = token_embedding.shape
    # shape = [vocab_size, model_dim]
    position_embedding_table = t.randn(max_num_token, model_dim, requires_grad=True)
    position_embedding = t.tile(position_embedding_table[:seq_len], [bs, 1, 1])
    token_embedding += position_embedding
    return token_embedding


def pass_embedding_to_encoder(token_embedding):
    bs, seq_len, model_dim = token_embedding.shape
    encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
    encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
    encoder_output = encoder(token_embedding)
    return encoder_output


def do_classification(encoder_output, num_class, model_dim, label):
    # label = t.randint(10,(bs,))
    cls_token_output = encoder_output[:, 0, :]
    linear_layer = nn.Linear(model_dim, num_class)
    logits = linear_layer(cls_token_output)
    loss_fn = nn.CrossEntropyLoss()
    loss = loss_fn(logits, label)
    return loss

def test_full():
    bs, ic, ih, iw = 1, 3, 8, 8
    patch_size = 4
    model_dim = 8
    max_num_token = 16
    num_class = 10
    label = t.randint(10,(bs,))
    image = t.randn(bs, ic, ih, iw)
    kernel = t.randn(model_dim, ic, patch_size, patch_size)
    patch_embedding = image2emb_conv(image, kernel, stride=patch_size)
    token_embedding = append_position_embedding(max_num_token ,append_cls_token(patch_embedding))
    encoder_output = pass_embedding_to_encoder(token_embedding)
    loss = do_classification(encoder_output, num_class, model_dim, label)
    print(loss)


test_full()
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值