model.py
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm
class Embeddings(nn.Module):
'''
对图像进行编码,把图片当做一个句子,把图片分割成块,每一块表示一个单词
'''
def __init__(self, config, img_size, in_channels=3):
super(Embeddings, self).__init__()
img_size = img_size
patch_size = config.patch_size
n_patches = (img_size // patch_size) * (img_size // patch_size)
self.patch_embeddings = Conv2d(in_channels=in_channels,
out_channels=config.hidden_size,
kernel_size=patch_size,
stride=patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
self.dropout = Dropout(config.transformer.dropout_rate)
def forward(self, x):
x = self.patch_embeddings(x).flatten(2).transpose(-1, -2)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
# 2.构建self-Attention模块
class Attention(nn.Module):
def __init__(self, config, vis):
super(Attention, self).__init__()
self.vis = vis
self.num_attention_heads = config.transformer.num_heads # 12
self.attention_head_size = int(config.hidden_size / self.num_attention_heads) # 768/12=64
self.all_head_size = self.num_attention_heads * self.attention_head_size # 12*64=768
self.query = Linear(config.hidden_size, self.all_head_size) # wm,768->768,Wq矩阵为(768,768)
self.key = Linear(config.hidden_size, self.all_head_size) # wm,768->768,Wk矩阵为(768,768)
self.value = Linear(config.hidden_size, self.all_head_size) # wm,768->768,Wv矩阵为(768,768)
self.attn_dropout = Dropout(config.transformer.attention_dropout_rate)
self.softmax = Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.shape[:-1] + (
self.num_attention_heads,
self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) # wm,(bs,12,197,64)
def forward(self, hidden_states):
# hidden_states为:(bs,197,768)
mixed_query_layer = self.query(hidden_states) # wm,768->768
mixed_key_layer = self.key(hidden_states) # wm,768->768
mixed_value_layer = self.value(hidden_states) # wm,768->768
query_layer = self.transpose_for_scores(mixed_query_layer) # wm,(bs,12,197,64)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# 将q向量和k向量进行相乘(bs,12,197,197)
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
# 将结果除以向量维数的开方
attention_scores = attention_scores / self.attention_head_size**0.5
attention_probs = self.softmax(attention_scores) # 将得到的分数进行softmax,得到概率
weights = attention_probs if self.vis else None # wm,实际上就是权重
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer) # 将概率与内容向量相乘
context_layer = context_layer.permute(0, 2, 1, 3).flatten(2)
return context_layer, weights # wm,(bs,197,768),(bs,197,197)
# 3.构建前向传播神经网络
# 两个线性层,中间加了激活函数
class Mlp(nn.Module):
def __init__(self, config):
super(Mlp, self).__init__()
self.fc1 = Linear(config.hidden_size, config.transformer.mlp_dim) # wm,786->3072
self.fc2 = Linear(config.transformer.mlp_dim, config.hidden_size) # wm,3072->786
self.act_fn = nn.GELU() # wm,激活函数
self.dropout = Dropout(config.transformer["dropout_rate"])
def forward(self, x):
x = self.fc1(x) # wm,786->3072
x = self.act_fn(x) # 激活函数
x = self.dropout(x) # wm,丢弃
x = self.fc2(x) # wm3072->786
return x
# 4.构建编码器的可重复利用的Block()模块:每一个block包含了self-attention模块和MLP模块
class Block(nn.Module):
def __init__(self, config, vis):
super(Block, self).__init__()
self.hidden_size = config.hidden_size # wm,768
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) # wm,层归一化
self.attn = Attention(config, vis)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
def forward(self, x):
h = x
x = self.attention_norm(x)
x, weights = self.attn(x)
x = x + h
h = x
x = self.ffn(self.ffn_norm(x))
x = x + h
return x, weights
# 5.构建Encoder模块,该模块实际上就是堆叠N个Block模块
class Encoder(nn.Module):
def __init__(self, config, vis):
super(Encoder, self).__init__()
self.layer = nn.ModuleList(
[Block(config, vis) for _ in range(config.transformer.num_layers)]
)
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
def forward(self, embedding_input):
attn_weights = []
for layer_block in self.layer:
embedding_input, weights = layer_block(embedding_input)
attn_weights.append(weights)
encoded = self.encoder_norm(embedding_input)
return encoded, attn_weights
# 6构建transformers完整结构,首先图片被embedding模块编码成序列数据,然后送入Encoder中进行编码
class Transformer(nn.Module):
def __init__(self, config, img_size, vis):
super(Transformer, self).__init__()
# wm,对一幅图片进行切块编码,得到的是(bs,n_patch+1(196),每一块的维度(768))
self.embeddings = Embeddings(config, img_size=img_size)
self.encoder = Encoder(config, vis)
def forward(self, input_image):
embedding_output = self.embeddings(input_image) # wm,输出的是(bs,196,768)
encoded, attn_weights = self.encoder(embedding_output) # wm,输入的是(bs,196,768)
return encoded, attn_weights # 输出的是(bs,197,768)
# 7构建VisionTransformer,用于图像分类
class VisionTransformer(nn.Module):
def __init__(self, config):
super(VisionTransformer, self).__init__()
self.num_classes = config.num_classes
self.transformer = Transformer(config, config.img_size, config.vis)
self.gap = nn.AdaptiveAvgPool1d(1)
self.classer = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
x, attn_weights = self.transformer(x)
x = x.transpose(1,2)
x = self.gap(x)
x = self.classer(x.flatten(1))
return x, attn_weights
data.py
import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os.path as osp
class CustomDataset(Dataset):
"""
root
- images
x1.png
x2.jpg
- no
x3.png
train.txt
x1.png 1
x2.jpg 0
val.txt
no/x3.png 1
"""
def __init__(self, root, transform=None, mode='train'):
self.root = root
self.transform = transform
with open(root + f'/{mode}.txt', 'r') as f:
self.labels = [i.strip().split() for i in f.readlines() if i]
print(f'{mode} : {len(self.labels)}')
def __len__(self):
return len(self.labels)
def __getitem__(self, item):
image = Image.open(osp.join(self.root, self.labels[item][0]))
image = image.convert('RGB')
if self.transform:
image = self.transform(image)
label = int(self.labels[item][1])
return image, torch.LongTensor([label])
def get_loader(cfg):
# aug
transform_train = transforms.Compose([
transforms.RandomResizedCrop((cfg.img_size, cfg.img_size), scale=(0.05, 1.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
transforms.Resize((cfg.img_size, cfg.img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# dataset
if cfg.dataset_name == "cifar10":
trainset = datasets.CIFAR10(root=cfg.dataset_train, train=True, download=True,
transform=transform_train)
testset = datasets.CIFAR10(root=cfg.dataset_test, train=False, download=True,
transform=transform_test)
elif cfg.dataset_name == 'cifar100':
trainset = datasets.CIFAR100(root=cfg.dataset_train, train=True, download=False,
transform=transform_train)
testset = datasets.CIFAR100(root=cfg.dataset_test, train=False, download=False,
transform=transform_test)
else:
trainset = CustomDataset(root=cfg.dataset_train, transform=transform_train, mode='train')
testset = CustomDataset(root=cfg.dataset_train, transform=transform_test, mode='test')
# loader
train_loader = DataLoader(trainset, batch_size=cfg.train_batch_size, shuffle=True,
num_workers=cfg.num_workers,
pin_memory=True, drop_last=True)
test_loader = DataLoader(testset, batch_size=cfg.eval_batch_size, shuffle=False,
pin_memory=True)
print("train images:", len(train_loader), 'bs')
print("test images:", len(test_loader), 'bs')
return train_loader, test_loader
config.py
import ml_collections
def get_config():
config = ml_collections.ConfigDict()
# dataset
config.dataset_name = '',
config.num_classes = 4
config.dataset_train = './small/'
config.dataset_test = './small/'
config.train_batch_size = 16
config.eval_batch_size = 16
config.img_size = 224
# train/val
config.learning_rate = 0.03
config.output_dir = './output/'
config.total_epoch = 2
config.momentum = 0.9
config.weight_decay = 0
config.vis = False
config.device = 'cuda'
config.num_workers = 0
# model
config.patch_size = 16
config.hidden_size = 768
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 1024
config.transformer.num_heads = 12
config.transformer.num_layers = 3
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.transformer.aux_scale = 1e-5
config.classifier = 'token'
config.representation_size = None
return config
utils.py
import os
import torch
import numpy as np
from tqdm import tqdm
# 用测试集评估模型的训练好坏
@torch.no_grad()
def eval(cfg, model, test_loader, loss_function, device):
eval_loss = 0.0
total_acc = 0.0
model.eval()
for i, (x, y) in enumerate(test_loader):
x, y = x.to(device), y.to(device)
y = y.flatten()
logits, _ = model(x) # model返回的是(bs,num_classes)和 weight
batch_loss = loss_function(logits, y)
# 记录误差
eval_loss += batch_loss.item()
# 记录准确率
_, preds = logits.max(1)
num_correct = (preds == y).sum()
total_acc += num_correct
loss = eval_loss / len(test_loader)
acc = total_acc / (len(test_loader) * cfg.eval_batch_size)
return loss, acc
def train(cfg, model, optimizer, train_loader, test_loader,
loss_func, device):
# 设置测试损失list,和测试acc 列表,设置训练损失list
val_loss_list = []
val_acc_list = []
train_loss_list = []
best_acc = 0.
eval_loss, eval_acc = 0., 0.
model.train()
for epoch in range(cfg.total_epoch):
train_loss = 0
pbar = tqdm(enumerate(train_loader, 1))
for idx, (x, y) in pbar:
x, y = x.to(device), y.to(device)
y = y.flatten()
optimizer.zero_grad()
logits,_ = model(x)
loss = loss_func(logits, y)
train_loss += loss.item()
loss.backward()
optimizer.step()
# bar
pbar.set_description(f'Epoch={epoch}')
pbar.set_postfix(loss=f'{train_loss/idx:.3f}',
eval_loss=f'{eval_loss:.3f}',
eval_acc=f'{eval_acc:.3f}')
idx+=1
# train
train_loss = train_loss / len(train_loader)
train_loss_list.append(train_loss)
# val
eval_loss, eval_acc = eval(cfg, model, test_loader, loss_func, device)
val_loss_list.append(eval_loss)
val_acc_list.append(eval_acc)
# save model
if eval_acc > best_acc:
best_acc = eval_acc
model_checkpoint = os.path.join(cfg.output_dir, "best.pth")
torch.save(model.state_dict(), model_checkpoint)
model_checkpoint = os.path.join(cfg.output_dir, "last.pth")
torch.save(model.state_dict(), model_checkpoint)
# save results
np.savetxt(f"{cfg.output_dir}/train_loss.txt", train_loss_list)
np.savetxt(f"{cfg.output_dir}/val_loss.txt", val_loss_list)
np.savetxt(f"{cfg.output_dir}/val_acc.txt", val_acc_list)
main.py
import ml_collections
import os, numpy as np, time
from tqdm.notebook import tqdm
import numpy as np
import torch
from torch import nn
from model import VisionTransformer
from data import get_loader
from config import get_config
from utils import train
config = get_config()
device = torch.device("cuda" if torch.cuda.is_available() and \
config.device in ['cuda', 'gpu'] else "cpu")
# outdir
tt = time.strftime('%y%m%d_%H%M%S', time.localtime())
config.output_dir = os.path.join(config.output_dir, tt)
os.makedirs(config.output_dir, exist_ok=True)
loss_func = torch.nn.CrossEntropyLoss()
train_loader, test_loader = get_loader(config)
model = VisionTransformer(config)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(),
lr=config.learning_rate,
momentum=config.momentum,
weight_decay=config.weight_decay)
train(config,
model,
optimizer,
train_loader,
test_loader,
loss_func,
device)