一、动机
复现vision Transformer在CIFAR-10数据集上进行训练,在此过程中发现 vision Transformer (ViT) 似乎容易过拟合,训练集上的acc达到80%,而在验证集上只能达到60%。比原论文报告的99%准确率差距很大。其中原因可能是由于CIFAR-10数据集规模比较小,对于较大参数量的模型容易过拟合。vision Transformer应该更适合用于大规模的数据集,而原论文报告的99%准确率就是将Transformer 在 ImageNet 大规模数据集上进行预训练,而后在CIFAR-10数据集上进行微调。因此想要在 ImageNet 预训练的 vision Transformer 上进行微调。
原论文地址:https://arxiv.org/pdf/1812.01187.pdf
二、数据集
CIFAR-10 简介(以下内容由KIMI生成)
CIFAR-10数据集是一个广泛使用的图像识别数据集,主要用于计算机视觉领域的研究和教育。以下是CIFAR-10数据集的一些关键特点:
-
图像数量:CIFAR-10包含60,000张32x32像素的RGB彩色图像,分为10个类别,每个类别有6,000张图像。
-
类别:数据集中的图像分为10个不同的类别:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)和卡车(truck)。
-
训练与测试集:数据集被分为50,000张训练图像和10,000张测试图像。每个类别在训练集和测试集中分别有5,000张和1,000张图像。
-
数据集分布:CIFAR-10的图像是随机选择的,以确保数据集的多样性和代表性。
-
应用:CIFAR-10数据集常用于评估机器学习模型在图像识别任务上的性能,尤其是在深度学习领域。
-
挑战:由于图像尺寸较小(32x32像素),模型需要能够从有限的像素中提取有用的特征,这为图像识别任务带来了一定的挑战。
-
获取方式:CIFAR-10数据集是公开的,可以在多个在线资源上免费下载,例如加拿大高级研究院(Canadian Institute For Advanced Research)的官方网站。
-
扩展:CIFAR-10还有一个扩展版本,称为CIFAR-100,它包含100个类别,每个类别有600张图像,其中100个类别被分为20个超类别。
CIFAR-10数据集是计算机视觉研究中的一个重要基准,它为研究人员提供了一个标准化的平台来测试和比较不同的图像识别算法。
三、模型介绍
Vision Transformer 简介
Vision Transformer 借鉴了原版 Transformer 的思路,将图片分成多个图片块,将它们看作句子中不同的每个token。将他们 flatten 后进行embedding就获得了每个图片的向量表示。但是每个图片块并没有它们在原始图片上的位置信息,因此需要加入位置编码(position embedding),在VIT中位置编码是可学习的,它们代表着不同位置的的图片块的位置特征而不是图片块本身的特征,例如在边缘位置的图片块可能存在较多的背景等等(个人推测)。额外的类别编码并不是说把标签进行 embedding (这显然不合理),它也是随机初始化,它负责汇总所有图像块地信息进行分类判别。
四、实验步骤
1. 下载预训练模型
从Hugging Face – The AI community building the future.中下载 vit-base-patch16-224,必须下载的是 config.json 和 model 文件。(要梯子)
2. 定义CIFAR-10数据集
import torchvision
import torch
def get_data_loader(data_dir, batch_size, num_workers, transform, istrain, aug=False, download=False):
dataset = torchvision.datasets.CIFAR10(root=data_dir, train=istrain,
download=download, transform=transform)
loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, shuffle=aug,
num_workers=num_workers)
return loader, dataset
3. 训练模型
import torch
from torch import optim
from tqdm import tqdm
import torch.optim.lr_scheduler as lr_scheduler
from transformers import ViTImageProcessor, ViTForImageClassification
from dataloader import get_data_loader
from torchvision import transforms
from torchvision.transforms import Resize
if __name__ == "__main__":
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose(
[Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
train_dir = '..\\data\\train_dir'
val_dir = '..\\data\\val_dir'
train_loader, train_dataset = get_data_loader(train_dir, 32, 4, transform=transform, istrain=True, aug=True)
val_loader, val_dataset = get_data_loader(val_dir, 32, num_workers=4, istrain=False, transform=transform)
# model_name = "google/vit-base-patch16-224"
stoi = {
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'
}
itos = {
'airplane': 0,
'automobile': 1,
'bird': 2,
'cat' : 3,
'deer': 4,
'dog' : 5,
'frog': 6,
'horse': 7,
'ship': 8,
'truck': 9
}
model_name = r"D:\PretrainedModel\ViT\vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True, id2label=itos, label2id=stoi)
model = model.to(device)
loss_function = torch.nn.CrossEntropyLoss()
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=1e-4, momentum=0.9, weight_decay=5e-5)
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - 0.01) + 0.01 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
for epoch in range(3):
model.train()
train_acc = 0
train_loss = []
train_bar = tqdm(train_loader)
for data in train_bar:
train_bar.set_description("epoch {}".format(epoch))
images, labels = data
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
logits = model(images)
logits = logits.logits
prediction = torch.max(logits, dim=1)[1]
loss = loss_function(logits, labels)
loss.backward()
optimizer.step()
scheduler.step()
train_loss.append(loss.item())
train_bar.set_postfix(loss="{:.4f}".format(loss.item()))
del images, labels
4. 目前在验证集上 acc=0.9870, 训练集上 0.9990.(测试集没测试)