使用Vision Transformer来对CIFA-10数据集进行分类

多的不说,直接放码过来:
vit的主要思想就是将图片切割为多个patch块,大小为patch_size,数量为(size/patch_size)^2
对每个patch展平为一维向量,传入transformer的编码器中得到提取特征后的向量,这样就和nlp里面的任务一样了!

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import optim
import timeit
from tqdm import tqdm


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout):
        super(PatchEmbedding, self).__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(2)
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, embed_dim)), requires_grad=True)
        self.postion_embedding = nn.Parameter(torch.randn(size=(1, num_patches + 1, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        # cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        cls_token = self.cls_token.expand(x.shape[0], 1, -1)  # [batch_size,1,768(embed_dim)]
        x = self.patcher(x).permute(0, 2, 1)  # [batch_size,patches,embed_dim]
        x = torch.cat([cls_token, x], dim=1)  # 拼接分类编码
        x = x + self.postion_embedding  # size=(1, num_patches + 1, embed_dim)
        x = self.dropout(x)
        return x


class Vit(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim, num_patches, dropout,
                 num_head, activation, num_encoders, num_class):
        super(Vit, self).__init__()
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim, num_patches, dropout)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_head, dropout=dropout,
                                                   activation=activation, batch_first=True, norm_first=True)

        # 使用多个TransformerEncoderLayer实例化TransformerEncoder
        self.encoder_layers = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)
        self.MLP = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_class)
        )

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.encoder_layers(x)
        x = self.MLP(x[:, 0, :])  # 取cls_token
        return x


in_channels = 3
img_size = 32
patch_size = 8
embed_dim = patch_size ** 2 * in_channels
num_patches = (img_size // patch_size) ** 2
dropout = 0.01
batch_size = 64
device = "cuda" if torch.cuda.is_available() else "cpu"
epochs = 50
num_head = 8
activation = "gelu"
num_encoders = 10
num_classes = 10
learning_rate = 1e-4
weight_dacay = 1e-4
betas = (0.9, 0.999)
train_transform = torchvision.transforms.Compose([
    # transforms.ToPILImage(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=0.5, std=0.5
    )
])
test_transform = torchvision.transforms.Compose([
    # transforms.ToPILImage(),
    # transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=0.5, std=0.5
    )
])
train_dataset = torchvision.datasets.CIFAR10(root="../../datas", train=True, transform=train_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = torchvision.datasets.CIFAR10(root="../../datas", train=False, transform=test_transform)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

model = Vit(in_channels, patch_size, embed_dim, num_patches, dropout,
            num_head, activation, num_encoders, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimzer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas, weight_decay=weight_dacay)
start = timeit.default_timer()
best_acc = 0
print(f"training on : {device}")
for epoch in range(epochs):
    model.train()
    train_labels = []
    train_preds = []
    train_running_loss = 0
    n = 0
    train_n_sum = 0
    train_n_correct = 0
    for idx, (X, y) in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        X = X.to(device)
        y = y.to(device)
        y_pred = model(X)
        y_pred_label = torch.argmax(y_pred, dim=1)
        # print(y.shape, y_pred.shape)
        train_labels.extend(y.cpu().detach())
        train_preds.extend(y_pred.cpu().detach())

        loss = criterion(y_pred, y)
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()

        train_running_loss += loss.item()
        n += 1
        train_n_sum += X.size(0)
        train_n_correct += (y == y_pred_label).sum().item()
    train_loss = train_running_loss / (n + 1)
    train_acc = train_n_correct / train_n_sum
    model.eval()
    val_labels = []
    val_preds = []
    val_running_loss = 0
    test_n_sum = 0
    test_n_correct = 0
    with torch.no_grad():
        n = 0
        for idx, (X, y) in enumerate(tqdm(test_dataloader, position=0, leave=True)):
            X = X.to(device)
            y = y.to(device)
            y_pred = model(X)
            y_pred_label = torch.argmax(y_pred, dim=1)
            # print(y_pred.shape, y.shape)
            val_labels.extend(y.cpu().detach())
            val_preds.extend(y_pred.cpu().detach())

            loss = criterion(y_pred, y)
            val_running_loss += loss.item()
            n += 1
            test_n_sum += X.size(0)
            test_n_correct += (y == y_pred_label).sum().item()
        test_loss = val_running_loss / (n + 1)
        test_acc = test_n_correct / test_n_sum
        if test_acc > best_acc:
            best_acc = test_acc
            map = {
                'state': model.state_dict(),
                "acc": test_acc,
                "loss": test_loss
            }
            print("save model : ", map['acc'])
            torch.save(map, "./checkpoints/vit_model.pth")
    print("-" * 30)
    print(f"train loss epoch : {epoch + 1} : {train_loss:.4f}")
    print(f"test loss epoch : {epoch + 1} : {test_loss:.4f}")
    print(
        f"train acc epoch : {epoch + 1} : {train_acc:.4f}"
    )
    print(
        f"test acc epoch : {epoch + 1} : {test_acc:.4f}"
    )
    print("-" * 30)

stop = timeit.default_timer()
print(f"training time : {stop - start:.2f}")

# patcher = PatchEmbedding(in_channels=in_channels, patch_size=patch_size,
#                          embed_dim=embed_dim, num_patches=num_patches, dropout=dropout)
# for idx, (x, y) in enumerate(dataloader):
#     print("before : ", idx, x.shape, y.shape)
#     x = patcher(x)
#     print("after : ", x.shape)
#     break

附带几个训练效果图:



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值