首先声明:代码是参考以下作者大佬的,如有侵权马上删。
本人只是在原有基础上加了点自己的笔记,改了点结构
transform的本质其实就是把图像切块然后靠线性层映射找关系
transform_seg
import logging
import math
import os
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from einops import rearrange
from transformer_model import TransModel2d, TransConfig
import math
class Encoder2D(nn.Module):
def __init__(self, config: TransConfig, is_segmentation=True):
super().__init__()
self.config = config
self.out_channels = config.out_channels
self.bert_model = TransModel2d(config) ## 经过transform计算了
sample_rate = config.sample_rate
sample_v = int(math.pow(2, sample_rate))
assert config.patch_size[0] * config.patch_size[1] * config.hidden_size % (sample_v**2) == 0, "不能除尽"
self.final_dense = nn.Linear(config.hidden_size, config.patch_size[0] * config.patch_size[1] * config.hidden_size // (sample_v**2))
self.patch_size = config.patch_size
self.hh = self.patch_size[0] // sample_v
self.ww = self.patch_size[1] // sample_v
self.is_segmentation = is_segmentation
def forward(self, x):
## x:(b, c, w, h)
b, c, h, w = x.shape
assert self.config.in_channels == c, "in_channels != 输入图像channel"
p1 = self.patch_size[0]
p2 = self.patch_size[1]
if h % p1 != 0:
print("请重新输入img size 参数 必须整除")
os._exit(0)
if w % p2 != 0:
print("请重新输入img size 参数 必须整除")
os._exit(0)
hh = h // p1 ## 分成几块
ww = w // p2 ##
x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p1, p2 = p2)
encode_x = self.bert_model(x)[-1] # 取出来最后一层
if not self.is_segmentation:
return encode_x
x = self.final_dense(encode_x)
x = rearrange(x, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", p1 = self.hh, p2 = self.ww, h = hh, w = ww, c = self.config.hidden_size)
# print(self.hh)
# print('**********************************')
return encode_x, x
class PreTrainModel(nn.Module):
def __init__(self, patch_size,
in_channels,
out_class,
hidden_size=1024,
num_hidden_layers=8,
num_attention_heads=16,
decode_features=[512, 256, 128, 64]):
super().__init__()
config = TransConfig(patch_size=patch_size,
in_channels=in_channels,
out_channels=0,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads)
self.encoder_2d = Encoder2D(config, is_segmentation=False)
self.cls = nn.Linear(hidden_size, out_class)
def forward(self, x):
encode_img = self.encoder_2d(x)
encode_pool = encode_img.mean(dim=1)
out = self.cls(encode_pool)
return out
class Vit(nn.Module):
def __init__(self, patch_size,
in_channels,
out_class,
hidden_size=1024,
num_hidden_layers=8,
num_attention_heads=16,
sample_rate=4,
):
super().__init__()
config = TransConfig(patch_size=patch_size,
in_channels=in_channels,
out_channels=0,
sample_rate=sample_rate,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads)
self.encoder_2d = Encoder2D(config, is_segmentation=False)
self.cls = nn.Linear(hidden_size, out_class)
def forward(self, x):
encode_img = self.encoder_2d(x)
encode_pool = encode_img.mean(dim=1)
out = self.cls(encode_pool)
return out
class Decoder2D(nn.Module):
def __init__(self, in_channels, out_channels, features=[512, 256, 128, 64]):
super().__init__()
self.decoder_1 = nn.Sequential(
nn.Conv2d(in_channels, features[0], 3, padding=1),
nn.BatchNorm2d(features[0]),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.decoder_2 = nn.Sequential(
nn.Conv2d(features[0], features[1], 3, padding=1),
nn.BatchNorm2d(features[1]),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.decoder_3 = nn.Sequential(
nn.Conv2d(features[1], features[2], 3, padding=1),
nn.BatchNorm2d(features[2]),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.decoder_4 = nn.Sequential(
nn.Conv2d(features[2], features[3], 3, padding=1),
nn.BatchNorm2d(features[3]),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.final_out = nn.Conv2d(features[-1], out_channels, 3, padding=1)
def forward(self, x):
x = self.decoder_1(x)
x = self.decoder_2(x)
x = self.decoder_3(x)
x = self.decoder_4(x)
x = self.final_out(x)
return x
class SETRModel(nn.Module):
def __init__(self, patch_size=(32, 32),
in_channels=3,
out_channels=1,
hidden_size=1024,
num_hidden_layers=8,
num_attention_heads=16,
decode_features=[512, 256, 128, 64],
sample_rate=4,):
super().__init__()
config = TransConfig(patch_size=patch_size,
in_channels=in_channels,
out_channels=out_channels,
sample_rate=sample_rate,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads)
self.encoder_2d = Encoder2D(config)
self.decoder_2d = Decoder2D(in_channels=config.hidden_size, out_channels=config.out_channels, features=decode_features)
def forward(self, x):
_, final_x = self.encoder_2d(x)
x = self.decoder_2d(final_x)
return x
if __name__ == "__main__":
net = SETRModel(patch_size=(32, 32), ## 每多少个像素为一组
in_channels=3, ## 输入通道
out_channels=1, ## 输出通道
hidden_size=1024, ## 中间层分布数
sample_rate=5, ## 不知道。。。
num_hidden_layers=1, ## 有多少个transform
num_attention_heads=16, ## 多头
decode_features=[512, 256, 128, 64]) ## 输出通道卷积解码器的通道数
t1 = torch.rand(2, 3, 512, 512)
print("input: " + str(t1.shape))
print("output: " + str(net(t1).shape))
transformer_model
import logging
import math
import os
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from einops import rearrange
def swish(x):
return x * torch.sigmoid(x)
def gelu(x):
"""
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "mish": mish}
class TransConfig(object):
def __init__(
self,
patch_size,
in_channels,
out_channels,
sample_rate=4,
hidden_size=768,
num_hidden_layers=8,
num_attention_heads=6,
intermediate_size=1024,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
initializer_range=0.02,
layer_norm_eps=1e-12,
):
self.sample_rate = sample_rate
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
class TransLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(TransLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(hidden_size))
self.beta = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.gamma * x + self.beta
class TransEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super().__init__()
## nn.Embedding(词的维度,表示词的向量)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids):
input_shape = input_ids.size()
seq_length = input_shape[1]
device = input_ids.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape[:2])
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_ids + position_embeddings ## + 位置信息
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
## 多头注意力机制
class TransSelfAttention(nn.Module):
def __init__(self, config: TransConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
## 切分和移位作用
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
## 最后xshape (batch_size, num_attention_heads, seq_len, head_size)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states
):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# 注意力加权
context_layer = torch.matmul(attention_probs, value_layer)
# 把加权后的V reshape, 得到[batch_size, length, embedding_dimension]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class TransSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
## 多头注意力 + 残差
class TransAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = TransSelfAttention(config)
self.output = TransSelfOutput(config)
def forward(
self,
hidden_states,
):
self_outputs = self.self(hidden_states) ## 经过多头注意力
attention_output = self.output(self_outputs, hidden_states) ## 残差模块 + 标准化模块
return attention_output
class TransIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] ## relu
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TransOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
## trans模块
class TransLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = TransAttention(config)
self.intermediate = TransIntermediate(config)
self.output = TransOutput(config)
def forward(
self,
hidden_states
):
attention_output = self.attention(hidden_states) ## 多头注意力 + 残差
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class TransEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.layer = nn.ModuleList([TransLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
self,
hidden_states,
output_all_encoded_layers=True,
):
all_encoder_layers = []
for i, layer_module in enumerate(self.layer):
layer_output = layer_module(hidden_states)
hidden_states = layer_output
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class InputDense2d(nn.Module):
def __init__(self, config):
super(InputDense2d, self).__init__()
self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.in_channels, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act]
self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states) ## 激活函数
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class InputDense3d(nn.Module):
def __init__(self, config):
super(InputDense3d, self).__init__()
self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.patch_size[2] * config.in_channels, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act]
self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class TransModel2d(nn.Module):
def __init__(self, config):
super(TransModel2d, self).__init__()
self.config = config
self.dense = InputDense2d(config)
self.embeddings = TransEmbeddings(config)
self.encoder = TransEncoder(config)
def forward(
self,
input_ids, ## 输入的是经过位置分割之后的数据(b,hhww,p1p2c).输出为(b,hhww,p1p2c)
output_all_encoded_layers=True,
):
dense_out = self.dense(input_ids) ## 投影 + 标准化
# print(dense_out.shape)
embedding_output = self.embeddings(dense_out) ## 加上位置编码
encoder_layers = self.encoder(embedding_output,output_all_encoded_layers=output_all_encoded_layers,) ## transform模块
sequence_output = encoder_layers[-1]
if not output_all_encoded_layers:
# 如果不用输出所有encoder层
encoder_layers = encoder_layers[-1]
return encoder_layers
class TransModel3d(nn.Module):
def __init__(self, config):
super(TransModel3d, self).__init__()
self.config = config
self.dense = InputDense3d(config)
self.embeddings = TransEmbeddings(config)
self.encoder = TransEncoder(config)
def forward(
self,
input_ids,
output_all_encoded_layers=True,
):
dense_out = self.dense(input_ids)
embedding_output = self.embeddings(
input_ids=dense_out
)
encoder_layers = self.encoder(
embedding_output,
output_all_encoded_layers=output_all_encoded_layers,
)
sequence_output = encoder_layers[-1]
if not output_all_encoded_layers:
# 如果不用输出所有encoder层
encoder_layers = encoder_layers[-1]
return encoder_layers
汽车分割的例子
数据集可以去这下载:
# data_url : https://www.kaggle.com/c/carvana-image-masking-challenge/data
import torch
import numpy as np
from SETR.transformer_seg import SETRModel
from PIL import Image
import glob
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torchvision.utils import save_image
import os
import xlwt
img_url = sorted(glob.glob("./data/train/*"))
mask_url = sorted(glob.glob("./data/train_masks/*"))
# print(img_url)
train_size = int(len(img_url) * 0.8)
train_img_url = img_url[:train_size]
train_mask_url = mask_url[:train_size]
val_img_url = img_url[train_size:]
val_mask_url = mask_url[train_size:]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device is " + str(device))
epoches = 100
out_channels = 1
def build_model():
model = SETRModel(patch_size=(16, 16),
in_channels=3,
out_channels=1,
hidden_size=1024,
num_hidden_layers=6,
num_attention_heads=16,
decode_features=[512, 256, 128, 64])
return model
class CarDataset(Dataset):
def __init__(self, img_url, mask_url):
super(CarDataset, self).__init__()
self.img_url = img_url
self.mask_url = mask_url
def __getitem__(self, idx):
img = Image.open(self.img_url[idx])
img = img.resize((256, 256))
img_array = np.array(img, dtype=np.float32) / 255
mask = Image.open(self.mask_url[idx])
mask = mask.resize((256, 256))
mask = np.array(mask, dtype=np.float32)
img_array = img_array.transpose(2, 0, 1)
return torch.tensor(img_array.copy()), torch.tensor(mask.copy())
def __len__(self):
return len(self.img_url)
def compute_dice(input, target):
eps = 0.0001
# input 是经过了sigmoid 之后的输出。
input = (input > 0.5).float()
target = (target > 0.5).float()
# inter = torch.dot(input.view(-1), target.view(-1)) + eps
inter = torch.sum(target.view(-1) * input.view(-1)) + eps
# print(self.inter)
union = torch.sum(input) + torch.sum(target) + eps
t = (2 * inter.float()) / union.float()
return t
def predict():
model = build_model()
model.load_state_dict(torch.load("./SETR_car.pth", map_location="cpu"))
print(model)
import matplotlib.pyplot as plt
val_dataset = CarDataset(val_img_url, val_mask_url)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
with torch.no_grad():
for img, mask in val_loader:
pred = torch.sigmoid(model(img))
pred = (pred > 0.5).int()
plt.subplot(1, 3, 1)
print(img.shape)
img = img.permute(0, 2, 3, 1)
plt.imshow(img[0])
plt.subplot(1, 3, 2)
plt.imshow(pred[0].squeeze(0), cmap="gray")
plt.subplot(1, 3, 3)
plt.imshow(mask[0], cmap="gray")
plt.show()
def data_write(file_path, epoch, datas1, datas2): # datas是列表
# print(datas)
f = xlwt.Workbook(encoding='utf-8') # 设置一个workbook,其编码是utf-8
sheet1 = f.add_sheet(u'阿强的表1', cell_overwrite_ok=True) # 创建sheet
sheet1.write(0, 0, label='epoch') # 将‘列1’作为标题
sheet1.write(0, 1, label='train_Loss') # 将‘列1’作为标题
sheet1.write(0, 2, label='Val_Loss') # 将‘列2’作为标题
# 将数据写入第 i 行,第 j 列
for j in range(len(datas1)):
sheet1.write(j + 1, 0, epoch[j])
sheet1.write(j + 1, 1, datas1[j])
sheet1.write(j + 1, 2, datas2[j])
f.save(file_path) # 保存文件
if __name__ == "__main__":
model = build_model()
model.to(device)
train_dataset = CarDataset(train_img_url, train_mask_url)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataset = CarDataset(val_img_url, val_mask_url)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
loss_func = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)
## 加载参数
model.load_state_dict(torch.load('./SETR_car.pth'))
step = 0
report_loss = 0.0
train_men = []
test_men = []
epo = []
for epoch in range(epoches):
print("epoch is " + str(epoch))
epo.append(epoch)
train_loss = 0
test_loss = 0
print("进行------训练------测试:")
for img, mask in tqdm(train_loader, total=len(train_loader)):
optimizer.zero_grad()
step += 1
img = img.to(device)
mask = mask.to(device)
pred_img = model(img) ## pred_img (batch, len, channel, W, H)
# print('***********************')
# print('输出结果为', pred_img.shape)
if out_channels == 1:
pred_img = pred_img.squeeze(1) # 去掉通道维度
loss = loss_func(pred_img, mask)
train_loss += loss.item()
loss.backward()
optimizer.step()
## 测试
model.eval()
with torch.no_grad():
print("进行------验证------测试:")
for val_img, val_mask in tqdm(val_loader, total=len(val_loader)):
val_img = val_img.to(device)
val_mask = val_mask.to(device)
pred_img = torch.sigmoid(model(val_img))
if out_channels == 1:
pred_img = pred_img.squeeze(1)
cur_dice = compute_dice(pred_img, val_mask)
test_loss += cur_dice.item()
if (step % 50 == 0):
# 输入的图像,取第一张
# x = img[0]
# 标签,取第一张
x_ = val_mask[0]
# 标签的图像,取第一张
y = pred_img[0]
# 三张图,从第0轴拼接起来,再保存
img = torch.stack([x_, y], 0)
if not os.path.exists('./outputs'):
os.mkdir('outputs')
save_image(img.cpu(), f"./outputs/{step}.png")
torch.save(model.state_dict(), "./SETR_car.pth")
model.train()
train_men.append(train_loss / len(train_loader))
test_men.append(1 - (test_loss / len(val_loader)))
data_write("Car_loss.xls", epo, train_men, test_men)
print("train_loss is " + str(train_men[epoch]))
print("Val_dice is " + str(test_men[epoch]))
SETR + VIT 的一些改动
原版的SETR结构有些复杂,阅读起来有些困难,但是写的很详细。而VIT写的简单易懂,却不适合分割网络。所以本人集大家之所长与一家。
transformer_seg
import logging
import math
import os
import numpy as np
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from einops import rearrange
from vit import ViT
class Decoder2D(nn.Module):
def __init__(self, in_channels, out_channels, features=[512, 256, 128, 64]):
super().__init__()
self.decoder_1 = nn.Sequential(
nn.Conv2d(in_channels, features[0], 3, padding=1),
nn.BatchNorm2d(features[0]),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.decoder_2 = nn.Sequential(
nn.Conv2d(features[0], features[1], 3, padding=1),
nn.BatchNorm2d(features[1]),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.decoder_3 = nn.Sequential(
nn.Conv2d(features[1], features[2], 3, padding=1),
nn.BatchNorm2d(features[2]),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.decoder_4 = nn.Sequential(
nn.Conv2d(features[2], features[3], 3, padding=1),
nn.BatchNorm2d(features[3]),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.final_out = nn.Conv2d(features[-1], out_channels, 3, padding=1)
def forward(self, x):
x = self.decoder_1(x)
x = self.decoder_2(x)
x = self.decoder_3(x)
x = self.decoder_4(x)
x = self.final_out(x)
return x
class SETRModel(nn.Module):
def __init__(self, patch_size=(32, 32),
image_size=512,
in_channels=3,
out_channels=1,
hidden_size=1024,
num_hidden_layers=8,
num_attention_heads=16,
decode_features=[512, 256, 128, 64],
sample_rate=4,):
super().__init__()
# self.encoder_2d = Encoder2D(config)
self.encoder_2d = ViT(
image_size = image_size,
in_channels = in_channels,
patch_size = patch_size,
hidden_size = hidden_size, # 每个向量的维度
num_hidden_layers = num_hidden_layers, # 就是上右图的L,就是用了几次这个Transformer Encoder
num_attention_heads = num_attention_heads, # 多头注意力机制的 多头
sample_rate = sample_rate
)
self.decoder_2d = Decoder2D(in_channels=hidden_size, out_channels=out_channels, features=decode_features)
def forward(self, x):
final_x = self.encoder_2d(x)
x = self.decoder_2d(final_x)
return x
if __name__ == "__main__":
net = SETRModel(
image_size = 512,
patch_size=(32, 32), ## 每多少个像素为一组
in_channels=3, ## 输入通道
out_channels=1, ## 输出通道
hidden_size=1024, ## 中间层分布数
sample_rate=5, ## 不知道。。。
num_hidden_layers=1, ## 有多少个transform
num_attention_heads=16, ## 多头
decode_features=[512, 256, 128, 64]) ## 输出通道卷积解码器的通道数
t1 = torch.rand(2, 3, 512, 512)
print("input: " + str(t1.shape))
sample_rate = 5
# sample_v = int(math.pow(2, sample_rate))
# print(sample_v)
print("output: " + str(net(t1).shape))
VIT
import torch
import math
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def swish(x):
return x * torch.sigmoid(x)
def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "mish": mish}
# helpers
## 判断t是否是元组,如果是,直接返回t;如果不是,则将t复制为元组(t, t)再返回。
## 用来处理当给出的图像尺寸或块尺寸是int类型(如224)时,直接返回为同值元组(如(224, 224))
def pair(t):
return t if isinstance(t, tuple) else (t, t)
# classes
##
class PreNorm(nn.Module):
def __init__(self, dim, fn, eps=1e-12):
super().__init__()
self.fn = fn ## 这个函数可能是多头注意力函数,或者是 MLP 函数
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
self.variance_epsilon = eps
def forward(self, x, **kwargs):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
x = self.gamma * x + self.beta ## y = [(x - Ex) / (Varx - e) ] * γ + β
return self.fn(x, **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 16, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads ## 1024
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5 ## 4
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale ## q乘以k的装置
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
# print(attn(x).shape)
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, hidden_size, num_hidden_layers, num_attention_heads, in_channels = 3,
mlp_dim =2048, act = 'gelu', dim_head = 64, dropout = 0.1, emb_dropout = 0.1, sample_rate = 4): ## 内部改的参数
super().__init__()
image_height, image_width = pair(image_size) ## 图片大小:256, 256
patch_height, patch_width = pair(patch_size) ## 图块大小:32, 32
dim = hidden_size
depth = num_hidden_layers
heads = num_attention_heads
channels = in_channels
sample_v = int(math.pow(2, sample_rate))
assert patch_height * patch_width * num_hidden_layers % (sample_v ** 2) == 0, "不能除尽"
self.hh = patch_size[0] // sample_v
self.ww = patch_size[1] // sample_v
self.h = image_height // patch_height
self.w = image_width // patch_width
self.hidden_size = hidden_size
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width) ## 64
patch_dim = channels * patch_height * patch_width ## # 图块拉成 3 * 32 * 32 变成一维的长度
assert act in {'gelu', 'relu', 'swish', 'mish'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
self.transform_act_fn = ACT2FN[act]
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim), # 通过线性函数 把32*32*3 -> 1024
) ## 1,(8,8,)(3,32,32)
## 分成了64块图片,加入位置信息,并且,多加了一个class维度,用来做分类,
## 我的理解是,它可以整合我这64块图片的信息,最终判断这是个什么类
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) ## 1,65,1024
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) ## 1,1,1024
self.dropout = nn.Dropout(emb_dropout)
## dim=1024,depth=6, head=16, dim_head=64, mlp_dim=2048, dropout=0.1
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
## 上面的操作都是为了让数据能够进入到transform这个结构模型中
def forward(self, img):
x = self.to_patch_embedding(img) ## 将图片展平压缩投影至dim维度
x = self.transform_act_fn(x) ## 选择一个激活函数激活一下
'''
从这里开始 是按照VIT的格式来的
'''
x += self.pos_embedding[:, :] ## 加上位置信息
x = self.dropout(x)
## 上面的操作都是为了让数据能够进入到transform这个结构模型中
x = self.transformer(x) ## 1, 64, 1024
x = rearrange(x, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
p1=self.hh, p2=self.ww, h=self.h, w=self.w, c=self.hidden_size)
return x
总结:
VIT和SETR的结构有很多细节上的区别:
区别 | VIT | SETR |
结构 | 位置编码 —> 标准化 —> 多头注意力 | 标准化—> 位置编码 —> 多头注意力机制 |
QKV | 1024经过一个线性层变成3072,然后分成三分,这三份设置成QKV | 1024经过三个线性层变成三个1024的QKV |
位置编码 | 随机生成数编码 | 查表编码 |
其中这个样VIT和SETR的效果也是有明显的区别的,但是对于数据上差别不算大,效果图却很明显。
对比 | VIT | SETR |
准确度 | 略差(20个epoch:0.975) | 略好(20个epoch:0.980) |
速度 | 快了将近一半吧 train + Val = 5分20秒 | train + Val = 7分55秒 |
效果图 | 有点拉胯,但是能接受 |
直接调包Monai
from monai.networks.nets import ViT
self.vit = ViT(
in_channels=in_channels, ## 输入通道
img_size=img_size, ## 图像大小
patch_size=self.patch_size, ## 采样块大小
hidden_size=hidden_size, ## 隐藏层线性大小
mlp_dim=mlp_dim, ## MLP线性大小
num_layers=self.num_layers, ## 多少个VIM
num_heads=num_heads, ## 多头
pos_embed=pos_embed, ## 编码
classification=self.classification, ## 是否分类
dropout_rate=dropout_rate,
)