【transform】SETF + VIT 的一些个人笔记

首先声明:代码是参考以下作者大佬的,如有侵权马上删。

本人只是在原有基础上加了点自己的笔记,改了点结构

https://github.com/920232796/SETR-pytorch

https://github.com/lucidrains/vit-pytorch

transform的本质其实就是把图像切块然后靠线性层映射找关系

transform_seg

import logging
import math
import os
import numpy as np 

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from einops import rearrange
from transformer_model import TransModel2d, TransConfig
import math 

class Encoder2D(nn.Module):
    def __init__(self, config: TransConfig, is_segmentation=True):
        super().__init__()
        self.config = config
        self.out_channels = config.out_channels
        self.bert_model = TransModel2d(config)      ## 经过transform计算了
        sample_rate = config.sample_rate
        sample_v = int(math.pow(2, sample_rate))
        assert config.patch_size[0] * config.patch_size[1] * config.hidden_size % (sample_v**2) == 0, "不能除尽"
        self.final_dense = nn.Linear(config.hidden_size, config.patch_size[0] * config.patch_size[1] * config.hidden_size // (sample_v**2))
        self.patch_size = config.patch_size
        self.hh = self.patch_size[0] // sample_v
        self.ww = self.patch_size[1] // sample_v

        self.is_segmentation = is_segmentation
    def forward(self, x):
        ## x:(b, c, w, h)
        b, c, h, w = x.shape
        assert self.config.in_channels == c, "in_channels != 输入图像channel"
        p1 = self.patch_size[0]
        p2 = self.patch_size[1]

        if h % p1 != 0:
            print("请重新输入img size 参数 必须整除")
            os._exit(0)
        if w % p2 != 0:
            print("请重新输入img size 参数 必须整除")
            os._exit(0)
        hh = h // p1        ## 分成几块
        ww = w // p2        ##

        x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p1, p2 = p2)
        
        encode_x = self.bert_model(x)[-1] # 取出来最后一层
        if not self.is_segmentation:
            return encode_x

        x = self.final_dense(encode_x)
        x = rearrange(x, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", p1 = self.hh, p2 = self.ww, h = hh, w = ww, c = self.config.hidden_size)
        # print(self.hh)
        # print('**********************************')
        return encode_x, x 


class PreTrainModel(nn.Module):
    def __init__(self, patch_size, 
                        in_channels, 
                        out_class, 
                        hidden_size=1024, 
                        num_hidden_layers=8, 
                        num_attention_heads=16,
                        decode_features=[512, 256, 128, 64]):
        super().__init__()
        config = TransConfig(patch_size=patch_size, 
                            in_channels=in_channels, 
                            out_channels=0, 
                            hidden_size=hidden_size, 
                            num_hidden_layers=num_hidden_layers, 
                            num_attention_heads=num_attention_heads)
        self.encoder_2d = Encoder2D(config, is_segmentation=False)
        self.cls = nn.Linear(hidden_size, out_class)

    def forward(self, x):
        encode_img = self.encoder_2d(x)
        encode_pool = encode_img.mean(dim=1)
        out = self.cls(encode_pool)
        return out 

class Vit(nn.Module):
    def __init__(self, patch_size, 
                        in_channels, 
                        out_class, 
                        hidden_size=1024, 
                        num_hidden_layers=8, 
                        num_attention_heads=16,
                        sample_rate=4,
                        ):
        super().__init__()
        config = TransConfig(patch_size=patch_size, 
                            in_channels=in_channels, 
                            out_channels=0, 
                            sample_rate=sample_rate,
                            hidden_size=hidden_size, 
                            num_hidden_layers=num_hidden_layers, 
                            num_attention_heads=num_attention_heads)
        self.encoder_2d = Encoder2D(config, is_segmentation=False)
        self.cls = nn.Linear(hidden_size, out_class)

    def forward(self, x):
        encode_img = self.encoder_2d(x)
        
        encode_pool = encode_img.mean(dim=1)
        out = self.cls(encode_pool)
        return out 

class Decoder2D(nn.Module):
    def __init__(self, in_channels, out_channels, features=[512, 256, 128, 64]):
        super().__init__()
        self.decoder_1 = nn.Sequential(
                    nn.Conv2d(in_channels, features[0], 3, padding=1),
                    nn.BatchNorm2d(features[0]),
                    nn.ReLU(inplace=True),
                    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
                )
        self.decoder_2 = nn.Sequential(
                    nn.Conv2d(features[0], features[1], 3, padding=1),
                    nn.BatchNorm2d(features[1]),
                    nn.ReLU(inplace=True),
                    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
                )
        self.decoder_3 = nn.Sequential(
            nn.Conv2d(features[1], features[2], 3, padding=1),
            nn.BatchNorm2d(features[2]),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        )
        self.decoder_4 = nn.Sequential(
            nn.Conv2d(features[2], features[3], 3, padding=1),
            nn.BatchNorm2d(features[3]),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        )

        self.final_out = nn.Conv2d(features[-1], out_channels, 3, padding=1)

    def forward(self, x):
        x = self.decoder_1(x)
        x = self.decoder_2(x)
        x = self.decoder_3(x)
        x = self.decoder_4(x)
        x = self.final_out(x)
        return x

class SETRModel(nn.Module):
    def __init__(self, patch_size=(32, 32), 
                        in_channels=3, 
                        out_channels=1, 
                        hidden_size=1024, 
                        num_hidden_layers=8, 
                        num_attention_heads=16,
                        decode_features=[512, 256, 128, 64],
                        sample_rate=4,):
        super().__init__()
        config = TransConfig(patch_size=patch_size, 
                            in_channels=in_channels, 
                            out_channels=out_channels, 
                            sample_rate=sample_rate,
                            hidden_size=hidden_size, 
                            num_hidden_layers=num_hidden_layers, 
                            num_attention_heads=num_attention_heads)
        self.encoder_2d = Encoder2D(config)
        self.decoder_2d = Decoder2D(in_channels=config.hidden_size, out_channels=config.out_channels, features=decode_features)

    def forward(self, x):
        _, final_x = self.encoder_2d(x)
        x = self.decoder_2d(final_x)
        return x 


if __name__ == "__main__":
    net = SETRModel(patch_size=(32, 32),        ## 每多少个像素为一组
                    in_channels=3,              ## 输入通道
                    out_channels=1,             ## 输出通道
                    hidden_size=1024,           ## 中间层分布数
                    sample_rate=5,              ## 不知道。。。
                    num_hidden_layers=1,        ## 有多少个transform
                    num_attention_heads=16,     ## 多头
                    decode_features=[512, 256, 128, 64])    ## 输出通道卷积解码器的通道数
    t1 = torch.rand(2, 3, 512, 512)
    print("input: " + str(t1.shape))

    print("output: " + str(net(t1).shape))

transformer_model

import logging
import math
import os

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from einops import rearrange

def swish(x):
    return x * torch.sigmoid(x)

def gelu(x):
    """ 
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

def mish(x):
    return x * torch.tanh(nn.functional.softplus(x))

ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "mish": mish}

class TransConfig(object):
    
    def __init__(
        self,
        patch_size,
        in_channels,
        out_channels,
        sample_rate=4,
        hidden_size=768,
        num_hidden_layers=8,
        num_attention_heads=6,
        intermediate_size=1024,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
    ):  
        self.sample_rate = sample_rate
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps

class TransLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(TransLayerNorm, self).__init__()

        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps
       

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta
      
class TransEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super().__init__()
        ## nn.Embedding(词的维度,表示词的向量)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids):
        input_shape = input_ids.size()
    
        seq_length = input_shape[1]
        device = input_ids.device
        
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).expand(input_shape[:2])

        position_embeddings = self.position_embeddings(position_ids)

        embeddings = input_ids + position_embeddings        ## + 位置信息
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

## 多头注意力机制
class TransSelfAttention(nn.Module):
    def __init__(self, config: TransConfig):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
        
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    ## 切分和移位作用
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        ## 最后xshape (batch_size, num_attention_heads, seq_len, head_size)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states
    ):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # 注意力加权
        context_layer = torch.matmul(attention_probs, value_layer)
        # 把加权后的V reshape, 得到[batch_size, length, embedding_dimension]
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)

        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


class TransSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

## 多头注意力 + 残差
class TransAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self = TransSelfAttention(config)
        self.output = TransSelfOutput(config)

    def forward(
        self,
        hidden_states,
    ):
        self_outputs = self.self(hidden_states)     ## 经过多头注意力
        attention_output = self.output(self_outputs, hidden_states)     ## 残差模块 + 标准化模块
        
        return attention_output


class TransIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = ACT2FN[config.hidden_act] ## relu 

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

class TransOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

## trans模块
class TransLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = TransAttention(config)
        self.intermediate = TransIntermediate(config)
        self.output = TransOutput(config)

    def forward(
        self,
        hidden_states
    ):
        attention_output = self.attention(hidden_states)    ## 多头注意力 + 残差
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output


class TransEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer = nn.ModuleList([TransLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(
        self,
        hidden_states,
        output_all_encoded_layers=True,
    ):
        all_encoder_layers = []
        
        for i, layer_module in enumerate(self.layer):
            layer_output = layer_module(hidden_states)
            hidden_states = layer_output
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
            
        return all_encoder_layers

class InputDense2d(nn.Module):
    def __init__(self, config):
        super(InputDense2d, self).__init__()
        self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.in_channels, config.hidden_size)
        self.transform_act_fn = ACT2FN[config.hidden_act]
        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)        ## 激活函数
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

class InputDense3d(nn.Module):
    def __init__(self, config):
        super(InputDense3d, self).__init__()
        self.dense = nn.Linear(config.patch_size[0] * config.patch_size[1] * config.patch_size[2] * config.in_channels, config.hidden_size)
        self.transform_act_fn = ACT2FN[config.hidden_act]
        self.LayerNorm = TransLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

class TransModel2d(nn.Module):

    def __init__(self, config):
        super(TransModel2d, self).__init__()
        self.config = config
        self.dense = InputDense2d(config)
        self.embeddings = TransEmbeddings(config)
        self.encoder = TransEncoder(config)

    def forward(
        self,
        input_ids,          ## 输入的是经过位置分割之后的数据(b,hhww,p1p2c).输出为(b,hhww,p1p2c)
        output_all_encoded_layers=True,
       
    ):
        dense_out = self.dense(input_ids)       ## 投影 + 标准化
        # print(dense_out.shape)
        embedding_output = self.embeddings(dense_out)   ## 加上位置编码
        encoder_layers = self.encoder(embedding_output,output_all_encoded_layers=output_all_encoded_layers,) ## transform模块
        sequence_output = encoder_layers[-1]

        if not output_all_encoded_layers:
            # 如果不用输出所有encoder层
            encoder_layers = encoder_layers[-1]
        return encoder_layers


class TransModel3d(nn.Module):

    def __init__(self, config):
        super(TransModel3d, self).__init__()
        self.config = config
        self.dense = InputDense3d(config)
        self.embeddings = TransEmbeddings(config)
        self.encoder = TransEncoder(config)

    def forward(
        self,
        input_ids,
        output_all_encoded_layers=True,
       
    ):  
        dense_out = self.dense(input_ids)
        embedding_output = self.embeddings(
            input_ids=dense_out
        )
        encoder_layers = self.encoder(
            embedding_output,
            output_all_encoded_layers=output_all_encoded_layers,
        )
        sequence_output = encoder_layers[-1]
        
        if not output_all_encoded_layers:
            # 如果不用输出所有encoder层
            encoder_layers = encoder_layers[-1]
        return encoder_layers

汽车分割的例子

数据集可以去这下载:

Carvana Image Masking Challenge | Kaggle

# data_url : https://www.kaggle.com/c/carvana-image-masking-challenge/data
import torch
import numpy as np
from SETR.transformer_seg import SETRModel
from PIL import Image
import glob
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torchvision.utils import save_image
import os
import xlwt

img_url = sorted(glob.glob("./data/train/*"))
mask_url = sorted(glob.glob("./data/train_masks/*"))
# print(img_url)
train_size = int(len(img_url) * 0.8)
train_img_url = img_url[:train_size]
train_mask_url = mask_url[:train_size]
val_img_url = img_url[train_size:]
val_mask_url = mask_url[train_size:]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device is " + str(device))
epoches = 100
out_channels = 1


def build_model():
    model = SETRModel(patch_size=(16, 16),
                      in_channels=3,
                      out_channels=1,
                      hidden_size=1024,
                      num_hidden_layers=6,
                      num_attention_heads=16,
                      decode_features=[512, 256, 128, 64])
    return model


class CarDataset(Dataset):
    def __init__(self, img_url, mask_url):
        super(CarDataset, self).__init__()
        self.img_url = img_url
        self.mask_url = mask_url

    def __getitem__(self, idx):
        img = Image.open(self.img_url[idx])
        img = img.resize((256, 256))
        img_array = np.array(img, dtype=np.float32) / 255
        mask = Image.open(self.mask_url[idx])
        mask = mask.resize((256, 256))
        mask = np.array(mask, dtype=np.float32)
        img_array = img_array.transpose(2, 0, 1)

        return torch.tensor(img_array.copy()), torch.tensor(mask.copy())

    def __len__(self):
        return len(self.img_url)


def compute_dice(input, target):
    eps = 0.0001
    # input 是经过了sigmoid 之后的输出。
    input = (input > 0.5).float()
    target = (target > 0.5).float()

    # inter = torch.dot(input.view(-1), target.view(-1)) + eps
    inter = torch.sum(target.view(-1) * input.view(-1)) + eps

    # print(self.inter)
    union = torch.sum(input) + torch.sum(target) + eps

    t = (2 * inter.float()) / union.float()
    return t


def predict():
    model = build_model()
    model.load_state_dict(torch.load("./SETR_car.pth", map_location="cpu"))
    print(model)

    import matplotlib.pyplot as plt
    val_dataset = CarDataset(val_img_url, val_mask_url)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    with torch.no_grad():
        for img, mask in val_loader:
            pred = torch.sigmoid(model(img))
            pred = (pred > 0.5).int()
            plt.subplot(1, 3, 1)
            print(img.shape)
            img = img.permute(0, 2, 3, 1)
            plt.imshow(img[0])
            plt.subplot(1, 3, 2)
            plt.imshow(pred[0].squeeze(0), cmap="gray")
            plt.subplot(1, 3, 3)
            plt.imshow(mask[0], cmap="gray")
            plt.show()


def data_write(file_path, epoch, datas1, datas2):  # datas是列表
    # print(datas)
    f = xlwt.Workbook(encoding='utf-8')  # 设置一个workbook,其编码是utf-8
    sheet1 = f.add_sheet(u'阿强的表1', cell_overwrite_ok=True)  # 创建sheet
    sheet1.write(0, 0, label='epoch')  # 将‘列1’作为标题
    sheet1.write(0, 1, label='train_Loss')  # 将‘列1’作为标题
    sheet1.write(0, 2, label='Val_Loss')  # 将‘列2’作为标题
    # 将数据写入第 i 行,第 j 列
    for j in range(len(datas1)):
        sheet1.write(j + 1, 0, epoch[j])
        sheet1.write(j + 1, 1, datas1[j])
        sheet1.write(j + 1, 2, datas2[j])

    f.save(file_path)  # 保存文件


if __name__ == "__main__":

    model = build_model()
    model.to(device)

    train_dataset = CarDataset(train_img_url, train_mask_url)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

    val_dataset = CarDataset(val_img_url, val_mask_url)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    loss_func = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)

    ## 加载参数
    model.load_state_dict(torch.load('./SETR_car.pth'))

    step = 0
    report_loss = 0.0

    train_men = []
    test_men = []
    epo = []
    for epoch in range(epoches):
        print("epoch is " + str(epoch))
        epo.append(epoch)
        train_loss = 0
        test_loss = 0
        print("进行------训练------测试:")
        for img, mask in tqdm(train_loader, total=len(train_loader)):
            optimizer.zero_grad()
            step += 1
            img = img.to(device)
            mask = mask.to(device)

            pred_img = model(img)  ## pred_img (batch, len, channel, W, H)
            # print('***********************')
            # print('输出结果为', pred_img.shape)
            if out_channels == 1:
                pred_img = pred_img.squeeze(1)  # 去掉通道维度

            loss = loss_func(pred_img, mask)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()

        ## 测试
        model.eval()
        with torch.no_grad():
            print("进行------验证------测试:")
            for val_img, val_mask in tqdm(val_loader, total=len(val_loader)):
                val_img = val_img.to(device)
                val_mask = val_mask.to(device)
                pred_img = torch.sigmoid(model(val_img))
                if out_channels == 1:
                    pred_img = pred_img.squeeze(1)
                cur_dice = compute_dice(pred_img, val_mask)
                test_loss += cur_dice.item()

                if (step % 50 == 0):
                    # 输入的图像,取第一张
                    # x = img[0]
                    # 标签,取第一张
                    x_ = val_mask[0]
                    # 标签的图像,取第一张
                    y = pred_img[0]
                    # 三张图,从第0轴拼接起来,再保存
                    img = torch.stack([x_, y], 0)
                    if not os.path.exists('./outputs'):
                        os.mkdir('outputs')
                    save_image(img.cpu(), f"./outputs/{step}.png")

            torch.save(model.state_dict(), "./SETR_car.pth")
            model.train()

        train_men.append(train_loss / len(train_loader))
        test_men.append(1 - (test_loss / len(val_loader)))
        data_write("Car_loss.xls", epo, train_men, test_men)

        print("train_loss is " + str(train_men[epoch]))
        print("Val_dice is " + str(test_men[epoch]))

SETR + VIT 的一些改动

原版的SETR结构有些复杂,阅读起来有些困难,但是写的很详细。而VIT写的简单易懂,却不适合分割网络。所以本人集大家之所长与一家。

transformer_seg

import logging
import math
import os
import numpy as np 

import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from einops import rearrange
from vit import ViT

class Decoder2D(nn.Module):
    def __init__(self, in_channels, out_channels, features=[512, 256, 128, 64]):
        super().__init__()
        self.decoder_1 = nn.Sequential(
                    nn.Conv2d(in_channels, features[0], 3, padding=1),
                    nn.BatchNorm2d(features[0]),
                    nn.ReLU(inplace=True),
                    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
                )
        self.decoder_2 = nn.Sequential(
                    nn.Conv2d(features[0], features[1], 3, padding=1),
                    nn.BatchNorm2d(features[1]),
                    nn.ReLU(inplace=True),
                    nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
                )
        self.decoder_3 = nn.Sequential(
            nn.Conv2d(features[1], features[2], 3, padding=1),
            nn.BatchNorm2d(features[2]),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        )
        self.decoder_4 = nn.Sequential(
            nn.Conv2d(features[2], features[3], 3, padding=1),
            nn.BatchNorm2d(features[3]),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        )

        self.final_out = nn.Conv2d(features[-1], out_channels, 3, padding=1)

    def forward(self, x):
        x = self.decoder_1(x)
        x = self.decoder_2(x)
        x = self.decoder_3(x)
        x = self.decoder_4(x)
        x = self.final_out(x)
        return x

class SETRModel(nn.Module):
    def __init__(self, patch_size=(32, 32),
                        image_size=512,
                        in_channels=3, 
                        out_channels=1, 
                        hidden_size=1024, 
                        num_hidden_layers=8, 
                        num_attention_heads=16,
                        decode_features=[512, 256, 128, 64],
                        sample_rate=4,):
        super().__init__()
        # self.encoder_2d = Encoder2D(config)
        self.encoder_2d = ViT(
                            image_size = image_size,
                            in_channels = in_channels,
                            patch_size = patch_size,
                            hidden_size = hidden_size,             # 每个向量的维度
                            num_hidden_layers = num_hidden_layers,              # 就是上右图的L,就是用了几次这个Transformer Encoder
                            num_attention_heads = num_attention_heads,             # 多头注意力机制的 多头
                            sample_rate = sample_rate
                            )
        self.decoder_2d = Decoder2D(in_channels=hidden_size, out_channels=out_channels, features=decode_features)

    def forward(self, x):
        final_x = self.encoder_2d(x)
        x = self.decoder_2d(final_x)
        return x 


if __name__ == "__main__":
    net = SETRModel(
                    image_size = 512,
                    patch_size=(32, 32),        ## 每多少个像素为一组
                    in_channels=3,              ## 输入通道
                    out_channels=1,             ## 输出通道
                    hidden_size=1024,           ## 中间层分布数
                    sample_rate=5,              ## 不知道。。。
                    num_hidden_layers=1,        ## 有多少个transform
                    num_attention_heads=16,     ## 多头
                    decode_features=[512, 256, 128, 64])    ## 输出通道卷积解码器的通道数
    t1 = torch.rand(2, 3, 512, 512)
    print("input: " + str(t1.shape))
    sample_rate = 5
    # sample_v = int(math.pow(2, sample_rate))
    # print(sample_v)
    print("output: " + str(net(t1).shape))

VIT

import torch
import math
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

def swish(x):
    return x * torch.sigmoid(x)
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def mish(x):
    return x * torch.tanh(nn.functional.softplus(x))
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "mish": mish}

# helpers
## 判断t是否是元组,如果是,直接返回t;如果不是,则将t复制为元组(t, t)再返回。
## 用来处理当给出的图像尺寸或块尺寸是int类型(如224)时,直接返回为同值元组(如(224, 224))
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes
##
class PreNorm(nn.Module):
    def __init__(self, dim, fn, eps=1e-12):
        super().__init__()
        self.fn = fn    ## 这个函数可能是多头注意力函数,或者是 MLP 函数

        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        self.variance_epsilon = eps

    def forward(self, x, **kwargs):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        x = self.gamma * x + self.beta      ## y = [(x - Ex) / (Varx - e) ] * γ + β
        return self.fn(x, **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 16, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads       ## 1024
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5       ## 4

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale        ## q乘以k的装置

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            # print(attn(x).shape)
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, hidden_size, num_hidden_layers, num_attention_heads, in_channels = 3,
                 mlp_dim =2048, act = 'gelu',  dim_head = 64, dropout = 0.1, emb_dropout = 0.1, sample_rate = 4):    ## 内部改的参数
        super().__init__()
        image_height, image_width = pair(image_size)    ## 图片大小:256, 256
        patch_height, patch_width = pair(patch_size)    ## 图块大小:32, 32

        dim = hidden_size
        depth = num_hidden_layers
        heads = num_attention_heads
        channels = in_channels

        sample_v = int(math.pow(2, sample_rate))
        assert patch_height * patch_width * num_hidden_layers % (sample_v ** 2) == 0, "不能除尽"
        self.hh = patch_size[0] // sample_v
        self.ww = patch_size[1] // sample_v
        self.h = image_height // patch_height
        self.w = image_width // patch_width
        self.hidden_size = hidden_size

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)     ## 64
        patch_dim = channels * patch_height * patch_width           ## # 图块拉成 3 * 32 * 32 变成一维的长度
        assert act in {'gelu', 'relu', 'swish', 'mish'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        self.transform_act_fn = ACT2FN[act]

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),          # 通过线性函数 把32*32*3 -> 1024
        )   ## 1,(8,8,)(3,32,32)

        ## 分成了64块图片,加入位置信息,并且,多加了一个class维度,用来做分类,
        ## 我的理解是,它可以整合我这64块图片的信息,最终判断这是个什么类
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))     ## 1,65,1024
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))                       ## 1,1,1024
        self.dropout = nn.Dropout(emb_dropout)

        ## dim=1024,depth=6, head=16, dim_head=64, mlp_dim=2048, dropout=0.1
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        ## 上面的操作都是为了让数据能够进入到transform这个结构模型中

    def forward(self, img):
        x = self.to_patch_embedding(img)    ## 将图片展平压缩投影至dim维度
        x = self.transform_act_fn(x)        ## 选择一个激活函数激活一下
        '''
        从这里开始 是按照VIT的格式来的
        '''
        x += self.pos_embedding[:, :]          ## 加上位置信息
        x = self.dropout(x)

        ## 上面的操作都是为了让数据能够进入到transform这个结构模型中
        x = self.transformer(x)     ##  1, 64, 1024

        x = rearrange(x, "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
                      p1=self.hh, p2=self.ww, h=self.h, w=self.w, c=self.hidden_size)

        return x

总结:

VIT和SETR的结构有很多细节上的区别:

区别VITSETR
结构位置编码 —> 标准化 —> 多头注意力标准化—> 位置编码 —> 多头注意力机制
QKV1024经过一个线性层变成3072,然后分成三分,这三份设置成QKV1024经过三个线性层变成三个1024的QKV
位置编码随机生成数编码查表编码

其中这个样VIT和SETR的效果也是有明显的区别的,但是对于数据上差别不算大,效果图却很明显。

对比VITSETR
准确度

略差(20个epoch:0.975)

略好(20个epoch:0.980)

速度

快了将近一半吧

train + Val = 5分20秒

train + Val = 7分55秒

效果图

有点拉胯,但是能接受

直接调包Monai

from monai.networks.nets import ViT

self.vit = ViT(
            in_channels=in_channels,            ## 输入通道
            img_size=img_size,                  ## 图像大小
            patch_size=self.patch_size,         ## 采样块大小
            hidden_size=hidden_size,            ## 隐藏层线性大小
            mlp_dim=mlp_dim,                    ## MLP线性大小
            num_layers=self.num_layers,         ## 多少个VIM
            num_heads=num_heads,                ## 多头
            pos_embed=pos_embed,                ## 编码
            classification=self.classification, ## 是否分类
            dropout_rate=dropout_rate,          
        )

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值