MAE实验测试

MAE: Masked Autoencoders Are Scalable Vision Learners

数据准备:cats and dogs https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition
代码来源:https://github.com/lucidrains/vit-pytorch ,代码是基于这个库的样例文件改的
实验:使用cats_and_dogs数据训练一个以vit为encoder 的MAE无监督模型,然后再进行训练分类模型(20 epoch)。
对比实验:直接使用vit训练分类模型(20 epoch)。
结果:对比两种情况下,训练速度

代码1:使用MAE训练无监督模型

from __future__ import print_function

import glob
import os
import random
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import  transforms
from tqdm.notebook import tqdm



# from vit_pytorch.efficient import ViT 
from vit_pytorch import ViT, MAE

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(1000)
device = 'cuda'
epochs = 100
train_list = glob.glob(os.path.join('data/train/*.jpg'))
os.environ["CUDA_VISIBLE_DEVICES"] = '3'


train_transforms = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("/")[-1].split(".")[0]
        label = 1 if label == "dog" else 0

        return img_transformed, label
     
train_data = CatsDogsDataset(train_list, transform=train_transforms)
train_loader = DataLoader(dataset = train_data, batch_size=4, shuffle=True )

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

mae = MAE(
    encoder = v,
    masking_ratio = 0.75,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
)
mae.to(device)

mae.to(device)
for epoch in range(epochs):
    for data, label in tqdm(train_loader):
        data = data.to(device)
        loss = mae(data)
        loss.backward()
    print('epoch:' + str(epoch) + '--loss:' + str(loss))
torch.save(v.state_dict(), './trained-vit_epoch100.pt')

注意,上面这个训练,并未用到label信息哦,只用到了image。训练上面的结束之后,获得一个trained-vit_epoch100.pt的文件,该文件相当于一个vit的checkpoint/pretrain,直接在这个文件上finetune,代码如下:

from __future__ import print_function

import glob
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import  transforms
from tqdm import tqdm

from vit_pytorch import  ViT

# Training settings
batch_size = 64
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)
device = 'cuda'

os.makedirs('data', exist_ok=True)
train_dir = 'data/train'
test_dir = 'data/test'

train_list = glob.glob(os.path.join(train_dir,'*.jpg'))
test_list = glob.glob(os.path.join(test_dir, '*.jpg'))
labels = [path.split('/')[-1].split('.')[0] for path in train_list]

train_list, valid_list = train_test_split(train_list, 
                                          test_size=0.2,
                                          stratify=labels,
                                          random_state=seed)
print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")
print(f"Test Data: {len(test_list)}")

train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)

class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("/")[-1].split(".")[0]
        label = 1 if label == "dog" else 0

        return img_transformed, label

train_data = CatsDogsDataset(train_list, transform=train_transforms)
valid_data = CatsDogsDataset(valid_list, transform=test_transforms)
test_data = CatsDogsDataset(test_list, transform=test_transforms)
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=batch_size, shuffle=True)



model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
).to(device)

model.load_state_dict(torch.load('trained-vit_epoch100.pt'))

# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)


for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)
        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

对比实验就是https://github.com/lucidrains/vit-pytorch/blob/main/examples/cats_and_dogs.ipynb,不再贴过来了。
结果对比一下:
使用mae的vit :

Epoch : 1 - loss : 0.7540 - acc: 0.5233 - val_loss : 0.6809 - val_acc: 0.5572
Epoch : 2 - loss : 0.6860 - acc: 0.5577 - val_loss : 0.6658 - val_acc: 0.6038
Epoch : 3 - loss : 0.6796 - acc: 0.5676 - val_loss : 0.6628 - val_acc: 0.5971
Epoch : 4 - loss : 0.6799 - acc: 0.5692 - val_loss : 0.7245 - val_acc: 0.5265
Epoch : 5 - loss : 0.6738 - acc: 0.5791 - val_loss : 0.6697 - val_acc: 0.5829
Epoch : 6 - loss : 0.6714 - acc: 0.5847 - val_loss : 0.6504 - val_acc: 0.6185
Epoch : 7 - loss : 0.6613 - acc: 0.5991 - val_loss : 0.6250 - val_acc: 0.6495
Epoch : 8 - loss : 0.6471 - acc: 0.6173 - val_loss : 0.6043 - val_acc: 0.6679
Epoch : 9 - loss : 0.6371 - acc: 0.6354 - val_loss : 0.5948 - val_acc: 0.6824
Epoch : 10 - loss : 0.6185 - acc: 0.6542 - val_loss : 0.5795 - val_acc: 0.6932
Epoch : 11 - loss : 0.6089 - acc: 0.6660 - val_loss : 0.5585 - val_acc: 0.7067
Epoch : 12 - loss : 0.5935 - acc: 0.6765 - val_loss : 0.5573 - val_acc: 0.7057
Epoch : 13 - loss : 0.5893 - acc: 0.6823 - val_loss : 0.5524 - val_acc: 0.7207
Epoch : 14 - loss : 0.5783 - acc: 0.6891 - val_loss : 0.5372 - val_acc: 0.7282
Epoch : 15 - loss : 0.5762 - acc: 0.6890 - val_loss : 0.5453 - val_acc: 0.7193
Epoch : 16 - loss : 0.5637 - acc: 0.7035 - val_loss : 0.5251 - val_acc: 0.7348
Epoch : 17 - loss : 0.5625 - acc: 0.7077 - val_loss : 0.5188 - val_acc: 0.7427
Epoch : 18 - loss : 0.5590 - acc: 0.7051 - val_loss : 0.5312 - val_acc: 0.7296
Epoch : 19 - loss : 0.5538 - acc: 0.7113 - val_loss : 0.5123 - val_acc: 0.7482
Epoch : 20 - loss : 0.5506 - acc: 0.7143 - val_loss : 0.5251 - val_acc: 0.7318

原始vit:

Epoch : 1 - loss : 0.6959 - acc: 0.5061 - val_loss : 0.6900 - val_acc: 0.5530
Epoch : 2 - loss : 0.6917 - acc: 0.5264 - val_loss : 0.6822 - val_acc: 0.5845
Epoch : 3 - loss : 0.6836 - acc: 0.5485 - val_loss : 0.6748 - val_acc: 0.5791
Epoch : 4 - loss : 0.6799 - acc: 0.5667 - val_loss : 0.6659 - val_acc: 0.6082
Epoch : 5 - loss : 0.6756 - acc: 0.5721 - val_loss : 0.6661 - val_acc: 0.5817
Epoch : 6 - loss : 0.6700 - acc: 0.5845 - val_loss : 0.6452 - val_acc: 0.6250
Epoch : 7 - loss : 0.6651 - acc: 0.5868 - val_loss : 0.6463 - val_acc: 0.6173
Epoch : 8 - loss : 0.6526 - acc: 0.6033 - val_loss : 0.6297 - val_acc: 0.6454
Epoch : 9 - loss : 0.6458 - acc: 0.6110 - val_loss : 0.6112 - val_acc: 0.6551
Epoch : 10 - loss : 0.6427 - acc: 0.6197 - val_loss : 0.6343 - val_acc: 0.6288
Epoch : 11 - loss : 0.6333 - acc: 0.6327 - val_loss : 0.6021 - val_acc: 0.6711
Epoch : 12 - loss : 0.6266 - acc: 0.6412 - val_loss : 0.5872 - val_acc: 0.6843
Epoch : 13 - loss : 0.6177 - acc: 0.6497 - val_loss : 0.5842 - val_acc: 0.6893
Epoch : 14 - loss : 0.6121 - acc: 0.6560 - val_loss : 0.5775 - val_acc: 0.6877
Epoch : 15 - loss : 0.6116 - acc: 0.6583 - val_loss : 0.5734 - val_acc: 0.6901
Epoch : 16 - loss : 0.6061 - acc: 0.6615 - val_loss : 0.5720 - val_acc: 0.6903
Epoch : 17 - loss : 0.6006 - acc: 0.6683 - val_loss : 0.5756 - val_acc: 0.6903
Epoch : 18 - loss : 0.5978 - acc: 0.6752 - val_loss : 0.5725 - val_acc: 0.6964
Epoch : 19 - loss : 0.5936 - acc: 0.6762 - val_loss : 0.5629 - val_acc: 0.7004
Epoch : 20 - loss : 0.5935 - acc: 0.6791 - val_loss : 0.5610 - val_acc: 0.7087

可以看到,原始vit在第20个epoch的val_acc达到0.7087,但是有mae加持的vit只用了12个epoch。论文里面在ViT-L模型加速2.8和3.7倍,我这个没达到有两个原因:1、我这训练数据较小,一个epoch也才2分钟不到,模型还未拟合,2、我的Vit模型太浅,还没到Vit-B大小,所以加速没那么明显。即使如此,也已经很不错了。
祖传卤水要换了?

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值