手动实现Vit vision transformer

1. ViT

论文地址:https://arxiv.org/abs/2010.11929
ViT是2020年Google团队提出的将Transformer应用在图像分类的模型,虽然不是第一篇将transformer应用在视觉任务的论文,但是因为其模型“简单”且效果好,可扩展性强(scalable,模型越大效果越好),成为了transformer在CV领域应用的里程碑著作,也引爆了后续相关研究

ViT将输入图片分为多个patch(16x16),再将每个patch投影为固定长度的向量送入Transformer,后续encoder的操作和原始Transformer中完全相同。但是因为对图片分类,因此在输入序列中加入一个特殊的token,该token对应的输出即为最后的类别预测
在这里插入图片描述

需要注意的是,ViT只使用了Transformer的encoder
(1) patch embedding:例如输入图片大小为224x224,将图片分为固定大小的patch,patch大小为16x16,则每张图像会生成224x224/16x16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196x768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197x768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题

(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197x768

(3) LN/multi-head attention/LN:LN输出维度依然是197x768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197x768,如果有12个头(768/12=64),则qkv的维度是197x64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197x768,然后在过一层LN,维度依然是197x768

(4) MLP:将维度放大再缩小回去,197x768放大为197x3072,再缩小变为197x768

2. 代码实现ViT

  1. 首先对图片进行切片,增加cls_token,添加位置编码
import torch
import torch.nn as nn


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout):
        super(PatchEmbedding, self).__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(2)
        )
        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embed_dim)), requires_grad=True)
        self.pos_embed = nn.Parameter(torch.randn(size=(1, num_patches+1, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        cls_token = self.cls_token.expand(x.size(0), -1, -1)
        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([x, cls_token], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        return x
  1. 对切片后的数据进行编码
class Vit(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout, num_heads,
                 activation, num_encoders, num_classes):
        super(Vit, self).__init__()
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim, num_patches, dropout)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout,
                                                   activation=activation, batch_first=True, norm_first=True)
        self.encoder_layer = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
        self.mlp = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.encoder_layer(x)
        x = self.mlp(x[:, 0, :])
        return x
  1. 设置数据集

import pandas as pd
import random
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np


class MINISTTrainDataset(Dataset):

    def __init__(self, images, labels, indicies):
        self.images = images
        self.labels = labels
        self.indicies = indicies
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize([0.5],[0.5], [0.5, 0.5, 0.5])
        ])


    def __len__(self):
        return len(self.images)

    def __getitem__(self, inx):
        image = self.images[inx].reshape((28, 28)).astype(np.uint8)
        label = self.labels[inx]
        index = self.indicies[inx]
        image = self.transform(image)
        return {"image": image, "label": label, "index": index}


class MINISTValDataset(Dataset):

    def __init__(self, images, labels, indicies):
        self.images = images
        self.labels = labels
        self.indicies = indicies
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5],[0.5], [0.5, 0.5, 0.5])
        ])


    def __len__(self):
        return len(self.images)

    def __getitem__(self, inx):
        image = self.images[inx].reshape((28, 28)).astype(np.uint8)
        label = self.labels[inx]
        index = self.indicies[inx]
        image = self.transform(image)
        return {"image": image, "label": label, "index": index}


class MINISTSubmissionDataset(Dataset):

    def __init__(self, images, indicies):
        self.images = images
        self.indicies = indicies
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5],[0.5], [0.5, 0.5, 0.5])
        ])


    def __len__(self):
        return len(self.images)

    def __getitem__(self, inx):
        image = self.images[inx].reshape((28, 28)).astype(np.uint8)
        index = self.indicies[inx]
        image = self.transform(image)
        return {"image": image, "index": index}
  1. 获取数据集
import pandas as pd
from sklearn.model_selection import train_test_split
from dataset import MINISTTrainDataset, MINISTValDataset, MINISTSubmissionDataset

import numpy as np
from torch.utils.data import DataLoader, Dataset


def get_loaders(train_dir, test_dir, submission_dir, batch_size):
    train_df = pd.read_csv(train_dir)
    test_df = pd.read_csv(test_dir)
    submission_df = pd.read_csv(submission_dir)
    train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=36)
    train_dataset = MINISTTrainDataset(train_df.iloc[:, 1:].values.astype(np.uint8),
                                       train_df.iloc[:, 0].values.astype(np.uint8),
                                       train_df.index.values)
    val_dataset = MINISTValDataset(val_df.iloc[:, 1:].values.astype(np.uint8),
                                   val_df.iloc[:, 0].values.astype(np),
                                   val_df.index.values)
    test_dataset = MINISTSubmissionDataset(test_df.iloc[1:].values.astype(np.uint8),
                                           test_df.iloc[:, 0].values.astype(np.uint8))

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_dataloader, val_dataloader,test_dataloader


  1. 编写train.py
    数据集:https://pan.baidu.com/s/1jQ3_xGPo9SQg_jIAyIys6w?pwd=1314
import torch
import torch.nn as nn
from torch import optim
import timeit
from tqdm import tqdm
from utils import get_loaders
from model import Vit



# Hyper Parameters
device = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHES = 50
BATCH_SIZE = 16
TRAIN_DF_DIR = r"./dataset/train.csv"
TEST_DF_DIR = r"./dataset/test.csv"
SUBMISSION_DF_DIR = r"./dataset/sample_submission.csv"

# Model Parameters
IN_CHANNELS = 1
IMG_SIZE = 28
PATCH_SIZE = 4
EMBEDDING_DIM = (PATCH_SIZE**2)*IN_CHANNELS
NUM_PATCHERS = (IMG_SIZE // PATCH_SIZE) ** 2
DROPOUT_RATE = 0.001

NUM_HEADS = 8
ACTIVATION = "gelu"
NUM_ENCODERS = 66
NUM_CLASSES = 10

LEARNING_RATE = 1e-4
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)

if __name__=="__main__":
    train_dataloader, val_dataloader, test_dataloader = get_loaders(TRAIN_DF_DIR, TEST_DF_DIR, SUBMISSION_DF_DIR, BATCH_SIZE)

    model = Vit(IN_CHANNELS, PATCH_SIZE, EMBEDDING_DIM, NUM_PATCHERS,
                DROPOUT_RATE, NUM_HEADS, ACTIVATION, NUM_ENCODERS, NUM_CLASSES).to(device)

    print("using device:", device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE,
                           weight_decay=ADAM_WEIGHT_DECAY)

    start = timeit.default_timer()

    for epoch in range(EPOCHES):
        model.train()
        train_labels = []
        train_preds = []
        train_running_loss = 0

        for idx, image_label in enumerate(tqdm(train_dataloader, position=0, leave=True)):
            img = image_label["image"].float().to(device)
            label = image_label["label"].type(torch.uint8).to(device)
            y_pred = model(img)
            y_pred_label = torch.argmax(y_pred, dim=1)

            train_labels.extend(label.cpu().detach())
            train_preds.extend(y_pred_label.cpu().detach())

            loss = criterion(y_pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_running_loss += loss.item()

        train_loss = train_running_loss / (idx + 1)

        model.eval()
        val_labels = []
        val_preds = []
        val_running_loss = 0
        with torch.no_grad():
            for idx, image_label in enumerate(tqdm(val_dataloader, position=0, leave=True)):
                img = image_label["image"].float().to(device)
                label = image_label["label"].type(torch.uint8).to(device)
                y_pred = model(img)
                y_pred_label = torch.argmax(y_pred, dim=1)

                val_labels.extend(label.cpu().detach())
                val_preds.extend(y_pred_label.cpu().detach())

                loss = criterion(y_pred, label)
                val_running_loss += loss.item()

        val_loss = val_running_loss / (idx + 1)

        print("-"*30)
        print(f"Train Loss Epoch {epoch+1}: {train_loss:.4f}")
        print(f"Val Loss Epoch {epoch+1}: {val_loss:.4f}")

        print(f"train Accuracy Epoch {epoch+1}: {sum(1 for x, y in zip (train_preds, train_labels) if x==y) / len(train_labels):.4f}")
        print(f"val Accuracy Epoch {epoch+1}: {sum(1 for x, y in zip (val_preds, val_labels) if x==y) / len(val_labels):.4f}")
        print("-"*30)


  • 17
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

heromps

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值