1. ViT
论文地址:https://arxiv.org/abs/2010.11929
ViT是2020年Google团队提出的将Transformer应用在图像分类的模型,虽然不是第一篇将transformer应用在视觉任务的论文,但是因为其模型“简单”且效果好,可扩展性强(scalable,模型越大效果越好),成为了transformer在CV领域应用的里程碑著作,也引爆了后续相关研究
ViT将输入图片分为多个patch(16x16),再将每个patch投影为固定长度的向量送入Transformer,后续encoder的操作和原始Transformer中完全相同。但是因为对图片分类,因此在输入序列中加入一个特殊的token,该token对应的输出即为最后的类别预测
需要注意的是,ViT只使用了Transformer的encoder
(1) patch embedding:例如输入图片大小为224x224,将图片分为固定大小的patch,patch大小为16x16,则每张图像会生成224x224/16x16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,线性投射层的维度为768xN (N=768),因此输入通过线性投射层之后的维度依然为196x768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197x768。到目前为止,已经通过patch embedding将一个视觉问题转化为了一个seq2seq问题
(2) positional encoding(standard learnable 1D position embeddings):ViT同样需要加入位置编码,位置编码可以理解为一张表,表一共有N行,N的大小和输入序列长度相同,每一行代表一个向量,向量的维度和输入序列embedding的维度相同(768)。注意位置编码的操作是sum,而不是concat。加入位置编码信息之后,维度依然是197x768
(3) LN/multi-head attention/LN:LN输出维度依然是197x768。多头自注意力时,先将输入映射到q,k,v,如果只有一个头,qkv的维度都是197x768,如果有12个头(768/12=64),则qkv的维度是197x64,一共有12组qkv,最后再将12组qkv的输出拼接起来,输出维度是197x768,然后在过一层LN,维度依然是197x768
(4) MLP:将维度放大再缩小回去,197x768放大为197x3072,再缩小变为197x768
2. 代码实现ViT
- 首先对图片进行切片,增加cls_token,添加位置编码
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout):
super(PatchEmbedding, self).__init__()
self.patcher = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size),
nn.Flatten(2)
)
self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embed_dim)), requires_grad=True)
self.pos_embed = nn.Parameter(torch.randn(size=(1, num_patches+1, embed_dim)), requires_grad=True)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
cls_token = self.cls_token.expand(x.size(0), -1, -1)
x = self.patcher(x).permute(0, 2, 1)
x = torch.cat([x, cls_token], dim=1)
x = x + self.pos_embed
x = self.dropout(x)
return x
- 对切片后的数据进行编码
class Vit(nn.Module):
def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout, num_heads,
activation, num_encoders, num_classes):
super(Vit, self).__init__()
self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim, num_patches, dropout)
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout,
activation=activation, batch_first=True, norm_first=True)
self.encoder_layer = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
self.mlp = nn.Sequential(
nn.LayerNorm(normalized_shape=embed_dim),
nn.Linear(in_features=embed_dim, out_features=num_classes)
)
def forward(self, x):
x = self.patch_embedding(x)
x = self.encoder_layer(x)
x = self.mlp(x[:, 0, :])
return x
- 设置数据集
import pandas as pd
import random
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
class MINISTTrainDataset(Dataset):
def __init__(self, images, labels, indicies):
self.images = images
self.labels = labels
self.indicies = indicies
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5], [0.5, 0.5, 0.5])
])
def __len__(self):
return len(self.images)
def __getitem__(self, inx):
image = self.images[inx].reshape((28, 28)).astype(np.uint8)
label = self.labels[inx]
index = self.indicies[inx]
image = self.transform(image)
return {"image": image, "label": label, "index": index}
class MINISTValDataset(Dataset):
def __init__(self, images, labels, indicies):
self.images = images
self.labels = labels
self.indicies = indicies
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5], [0.5, 0.5, 0.5])
])
def __len__(self):
return len(self.images)
def __getitem__(self, inx):
image = self.images[inx].reshape((28, 28)).astype(np.uint8)
label = self.labels[inx]
index = self.indicies[inx]
image = self.transform(image)
return {"image": image, "label": label, "index": index}
class MINISTSubmissionDataset(Dataset):
def __init__(self, images, indicies):
self.images = images
self.indicies = indicies
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5], [0.5, 0.5, 0.5])
])
def __len__(self):
return len(self.images)
def __getitem__(self, inx):
image = self.images[inx].reshape((28, 28)).astype(np.uint8)
index = self.indicies[inx]
image = self.transform(image)
return {"image": image, "index": index}
- 获取数据集
import pandas as pd
from sklearn.model_selection import train_test_split
from dataset import MINISTTrainDataset, MINISTValDataset, MINISTSubmissionDataset
import numpy as np
from torch.utils.data import DataLoader, Dataset
def get_loaders(train_dir, test_dir, submission_dir, batch_size):
train_df = pd.read_csv(train_dir)
test_df = pd.read_csv(test_dir)
submission_df = pd.read_csv(submission_dir)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=36)
train_dataset = MINISTTrainDataset(train_df.iloc[:, 1:].values.astype(np.uint8),
train_df.iloc[:, 0].values.astype(np.uint8),
train_df.index.values)
val_dataset = MINISTValDataset(val_df.iloc[:, 1:].values.astype(np.uint8),
val_df.iloc[:, 0].values.astype(np),
val_df.index.values)
test_dataset = MINISTSubmissionDataset(test_df.iloc[1:].values.astype(np.uint8),
test_df.iloc[:, 0].values.astype(np.uint8))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
return train_dataloader, val_dataloader,test_dataloader
- 编写train.py
数据集:https://pan.baidu.com/s/1jQ3_xGPo9SQg_jIAyIys6w?pwd=1314
import torch
import torch.nn as nn
from torch import optim
import timeit
from tqdm import tqdm
from utils import get_loaders
from model import Vit
# Hyper Parameters
device = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHES = 50
BATCH_SIZE = 16
TRAIN_DF_DIR = r"./dataset/train.csv"
TEST_DF_DIR = r"./dataset/test.csv"
SUBMISSION_DF_DIR = r"./dataset/sample_submission.csv"
# Model Parameters
IN_CHANNELS = 1
IMG_SIZE = 28
PATCH_SIZE = 4
EMBEDDING_DIM = (PATCH_SIZE**2)*IN_CHANNELS
NUM_PATCHERS = (IMG_SIZE // PATCH_SIZE) ** 2
DROPOUT_RATE = 0.001
NUM_HEADS = 8
ACTIVATION = "gelu"
NUM_ENCODERS = 66
NUM_CLASSES = 10
LEARNING_RATE = 1e-4
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
if __name__=="__main__":
train_dataloader, val_dataloader, test_dataloader = get_loaders(TRAIN_DF_DIR, TEST_DF_DIR, SUBMISSION_DF_DIR, BATCH_SIZE)
model = Vit(IN_CHANNELS, PATCH_SIZE, EMBEDDING_DIM, NUM_PATCHERS,
DROPOUT_RATE, NUM_HEADS, ACTIVATION, NUM_ENCODERS, NUM_CLASSES).to(device)
print("using device:", device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE,
weight_decay=ADAM_WEIGHT_DECAY)
start = timeit.default_timer()
for epoch in range(EPOCHES):
model.train()
train_labels = []
train_preds = []
train_running_loss = 0
for idx, image_label in enumerate(tqdm(train_dataloader, position=0, leave=True)):
img = image_label["image"].float().to(device)
label = image_label["label"].type(torch.uint8).to(device)
y_pred = model(img)
y_pred_label = torch.argmax(y_pred, dim=1)
train_labels.extend(label.cpu().detach())
train_preds.extend(y_pred_label.cpu().detach())
loss = criterion(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_running_loss += loss.item()
train_loss = train_running_loss / (idx + 1)
model.eval()
val_labels = []
val_preds = []
val_running_loss = 0
with torch.no_grad():
for idx, image_label in enumerate(tqdm(val_dataloader, position=0, leave=True)):
img = image_label["image"].float().to(device)
label = image_label["label"].type(torch.uint8).to(device)
y_pred = model(img)
y_pred_label = torch.argmax(y_pred, dim=1)
val_labels.extend(label.cpu().detach())
val_preds.extend(y_pred_label.cpu().detach())
loss = criterion(y_pred, label)
val_running_loss += loss.item()
val_loss = val_running_loss / (idx + 1)
print("-"*30)
print(f"Train Loss Epoch {epoch+1}: {train_loss:.4f}")
print(f"Val Loss Epoch {epoch+1}: {val_loss:.4f}")
print(f"train Accuracy Epoch {epoch+1}: {sum(1 for x, y in zip (train_preds, train_labels) if x==y) / len(train_labels):.4f}")
print(f"val Accuracy Epoch {epoch+1}: {sum(1 for x, y in zip (val_preds, val_labels) if x==y) / len(val_labels):.4f}")
print("-"*30)