Multi-scale Transformer Network with Edge-aware Pre-training for Cross-Modality MRImage Synthesis 代码

Multi-scale Transformer Network with Edge-aware Pre-training for Cross-Modality MRImage Synthesis 代码复现和讲解

复现

1、下载项目:GitHub - lyhkevin/MT-Net: Multi-scale Transformer Network for Cross-Modality MR Image Synthesis (IEEE TMI)

解压后打开,放入服务器中

2、 安装环境

pip install -r requirements.txt

注意torch版本

3、下载数据集

  • Download BraTS2020 dataset from kaggle. The file name should be ./data/archive.zip. Unzip the file in ./data/.
    data/MICCAI_BraTS2020_TrainingData

4、运行预处理

python utils/preprocessing.py

5、预训练

python pretrain.py

更改/options/pretrain_options.py文件中的默认设置。

权重将保存在。/weight/EdgeMAE/中。

6、微调

python Finetune.py

可以更改./options/finetune_options.py中的默认设置,特别是data_rate选项,以调整配对数据的数量以进行微调。此外,您可以增加num_workers来加速微调。

重量将保存在。./weight/finetuned/。注意,对于MT-Net,输入大小必须为256×256。

7、使用

test.py 进行测试

合成图片将保存在./snapshot/test/

代码讲解

preprocessing

读取t1\t2\flair\t1ce 四个模态的数据和gt 数据

    for i,(t1_path, t2_path, t1ce_path, flair_path, gt_path) in enumerate(zip(t1_list,t2_list,t1ce_list,flair_list,gt_list)):

        print('preprocessing the',i+1,'th subject')

        t1_img = nib.load(t1_path)  # (240,140,155)
        t2_img = nib.load(t2_path)
        flair_img = nib.load(flair_path)
        t1ce_img = nib.load(t1ce_path)
        gt_img = nib.load(gt_path)

 转化为np然后标准化

        #to numpy
        t1_data = t1_img.get_fdata()
        t2_data = t2_img.get_fdata()
        flair_data = flair_img.get_fdata()
        t1ce_data = t1ce_img.get_fdata()
        gt_data = gt_img.get_fdata()
        gt_data = gt_data.astype(np.uint8)
        gt_data[gt_data == 4] = 3 #label 3 is missing in BraTS 2020

        t1_data = normalize(t1_data) # normalize to [0,1]
        t2_data = normalize(t2_data)
        t1ce_data = normalize(t1ce_data)
        flair_data = normalize(flair_data)

压缩为一个5通道数据

        tensor = np.stack([t1_data, t2_data, t1ce_data, flair_data, gt_data])  # (4, 240, 240, 155)

截取一部分保存

        if i < train_len:
            for j in range(60):
                Tensor = tensor[:, 10:210, 25:225, 50 + j]
                np.save(train_path + str(60 * i + j + 1) + '.npy', Tensor)
        else:
            for j in range(60):
                Tensor = tensor[:, 10:210, 25:225, 50 + j]
                np.save(test_path + str(60 * (i - train_len) + j + 1) + '.npy', Tensor)

每个np为(5,200,200,60)

pertain

mae网络

mae = EdgeMAE(img_size=opt.img_size,patch_size=opt.patch_size, embed_dim=opt.dim_encoder, depth=opt.depth, num_heads=opt.num_heads, in_chans=1,
        decoder_embed_dim=opt.dim_decoder, decoder_depth=opt.decoder_depth, decoder_num_heads=opt.decoder_num_heads,
        mlp_ratio=opt.mlp_ratio,norm_pix_loss=False,patchwise_loss=opt.use_patchwise_loss)

定义了编码器 ,包括PatchEmbed,然后定义好了block块

        # MAE encoder specifics
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        self.H = int(img_size / patch_size)
        self.W = int(img_size / patch_size)
        self.patch_size = patch_size
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

 PatchEmbed 固定了图像大小为224 ,size 为16 维度为3,重要的就是经过了一层卷积

x = self.proj(x) 然后压平x = x.flatten(2).transpose(1, 2) # NCHW -> NLC

class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    output_fmt: Format

    def __init__(
            self,
            img_size: Optional[int] = 224,
            patch_size: int = 16,
            in_chans: int = 3,
            embed_dim: int = 768,
            norm_layer: Optional[Callable] = None,
            flatten: bool = True,
            output_fmt: Optional[str] = None,
            bias: bool = True,
    ):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        if img_size is not None:
            self.img_size = to_2tuple(img_size)
            self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
            self.num_patches = self.grid_size[0] * self.grid_size[1]
        else:
            self.img_size = None
            self.grid_size = None
            self.num_patches = None

        if output_fmt is not None:
            self.flatten = False
            self.output_fmt = Format(output_fmt)
        else:
            # flatten spatial dim and transpose to channels last, kept for bwd compat
            self.flatten = flatten
            self.output_fmt = Format.NCHW

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        if self.img_size is not None:
            _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
            _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")

        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC
        elif self.output_fmt != Format.NCHW:
            x = nchw_to(x, self.output_fmt)
        x = self.norm(x)
        return x

每一个block 如下 标准化=》注意力=》加权=》drop

class Block(nn.Module):

    def __init__(
            self,
            dim,
            num_heads,
            mlp_ratio=4.,
            qkv_bias=False,
            qk_norm=False,
            proj_drop=0.,
            attn_drop=0.,
            init_values=None,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
            mlp_layer=Mlp,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_norm=qk_norm,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            norm_layer=norm_layer,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'

然后定义decoder,看看decoder_blocks、rec_blocks、edge_blocks,三个任务,都是block的叠加

        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False)  # fixed sin-cos embedding

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(decoder_depth)])

        self.rec_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(rec_depth)])

        self.edge_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
            for i in range(edge_depth)])

        self.rec_norm = norm_layer(decoder_embed_dim)
        self.rec_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        self.rec_sig = nn.Sigmoid()

        self.edge_norm = norm_layer(decoder_embed_dim)
        self.edge_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)
        self.edge_sig = nn.Sigmoid()

直接看forward


    def forward(self, imgs,mask_ratio=0.75,epoch=1):
        
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        x_edge,x_rec= self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        weit = self.structure_loss(mask,epoch)
        rec_loss = self.rec_loss(imgs, x_rec, mask, weit)
        edge_loss,edge_gt = self.edge_loss(imgs, x_edge, mask, weit)
        
        return rec_loss, edge_loss,edge_gt,x_edge,x_rec, mask

显示跑encoder,嵌入patch,嵌入token,嵌入mask,进行Transformer

    def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

然后跑decoder,两个任务,跑完block后,通过两个不同的sig区分开

    def forward_decoder(self, x, ids_restore):
        # embed tokens
        x = self.decoder_embed(x)

        # append mask tokens to sequence
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # add pos embed
        x = x + self.decoder_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)

        x_rec = x
        x_edge = x

        x_rec = self.rec_norm(x_rec)
        x_rec = self.rec_sig(self.rec_pred(x_rec))

        x_edge = self.edge_norm(x_edge)
        x_edge = self.edge_sig(self.edge_pred(x_edge))

        # remove cls token
        x_edge = x_edge[:, 1:, :]
        x_rec = x_rec[:, 1:, :]

        return x_edge,x_rec
        self.rec_norm = norm_layer(decoder_embed_dim)
        self.rec_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
        self.rec_sig = nn.Sigmoid()

        self.edge_norm = norm_layer(decoder_embed_dim)
        self.edge_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)
        self.edge_sig = nn.Sigmoid()

然后分别计算损失

        rec_loss = self.rec_loss(imgs, x_rec, mask, weit)
        edge_loss,edge_gt = self.edge_loss(imgs, x_edge, mask, weit)
    def rec_loss(self, imgs, pred, mask, weit):

        target = self.patchify(imgs)
        weit = self.patchify(weit)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        
        if self.patchwise_loss == True:
            loss = loss * weit

        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss
    def edge_loss(self, imgs, pred, mask, weit):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove, 
        """
        with torch.no_grad():
            edge_gt = self.operator(imgs)
        target = self.patchify(edge_gt)
        weit = self.patchify(weit)

        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss * weit

        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss,edge_gt
    

训练部分没有特殊的,loss = rec_loss * opt.l1_loss + edge_loss

for epoch in range(1,opt.epoch):
    for i,img in enumerate(train_loader):

        adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
        
        optimizer.zero_grad()

        img = img.to(device,dtype=torch.float)

        rec_loss, edge_loss,edge_gt,x_edge,x_rec,mask = mae(img,opt.masking_ratio,epoch)
        loss = rec_loss * opt.l1_loss + edge_loss 
        
        loss.backward()
        optimizer.step()

        print(
                "[Epoch %d/%d] [Batch %d/%d] [rec_loss: %f] [edge_loss: %f] [lr: %f]"
                % (epoch, opt.epoch, i, len(train_loader), rec_loss.item(),edge_loss.item(),get_lr(optimizer))
            )

        if i % opt.save_output == 0:
            y1, im_masked1, im_paste1 = mae.MAE_visualize(img, x_rec, mask)
            y2, im_masked2, im_paste2 = mae.MAE_visualize(edge_gt, x_edge, mask)
            edge_gt,img = edge_gt.cpu(),img.cpu()
            save_image([img[0],im_masked1,im_paste1,edge_gt[0],im_masked2,im_paste2],
                 opt.img_save_path + str(epoch) + ' ' + str(i)+'.png', nrow=3,normalize=False)
            logging.info("[Epoch %d/%d] [Batch %d/%d] [rec_loss: %f] [edge_loss: %f] [lr: %f]"
                % (epoch, opt.epoch, i, len(train_loader), rec_loss.item(),edge_loss.item(),get_lr(optimizer)))

    if epoch % opt.save_weight == 0:
        torch.save(mae.state_dict(), opt.weight_save_path + str(epoch) + 'MAE.pth')

torch.save(mae.state_dict(), opt.weight_save_path + './MAE.pth')

fineturn

两个编码器,大同小异

E = MAE_finetune(img_size=opt.img_size, patch_size=opt.mae_patch_size, embed_dim=opt.encoder_dim, depth=opt.depth,
                 num_heads=opt.num_heads, in_chans=1, mlp_ratio=opt.mlp_ratio)
FC_module = MAE_finetune(img_size=opt.img_size, patch_size=opt.mae_patch_size, embed_dim=opt.encoder_dim,
                         depth=opt.fc_depth, num_heads=opt.num_heads, in_chans=1,
                         mlp_ratio=opt.mlp_ratio)  # feature consistency module
class MAE_finetune(nn.Module):

    def __init__(self, img_size=256,patch_size=8, in_chans=1, embed_dim=128, depth=24, num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm):
        super().__init__()

        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False)  # fixed sin-cos embedding
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True,norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

    def patchify(self, imgs):

        p = self.patch_embed.patch_size[0]
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
        return x

    def forward_encoder(self, x):
        # embed patches
        x = self.patch_embed(x)
        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]
        # apply Transformer blocks
        feature = [x]
        for blk in self.blocks:
            x = blk(x)
            feature.append(x)
        x = self.norm(x)
        feature.append(x)
        return feature

    def forward(self, imgs):
        latent = self.forward_encoder(imgs)
        return latent

MTNet网络

G = MTNet(img_size=opt.img_size, patch_size=opt.patch_size, in_chans=1, num_classes=1, embed_dim=opt.vit_dim,
          depths=[2, 2, 2, 2], depths_decoder=[2, 2, 2, 2], num_heads=[8, 8, 16, 32], window_size=opt.window_size,
          mlp_ratio=opt.mlp_ratio, qkv_bias=True, qk_scale=None, drop_rate=0.,
          attn_drop_rate=0., drop_path_rate=0, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
          use_checkpoint=False, final_upsample="expand_first", fine_tune=True)
    def forward(self,x1,x2):
        x2 = self.expand(x2)
        x1, x_downsample1 = self.forward_features1(x1)
        x2, x_downsample2 = self.forward_features2(x2)
        x = self.forward_up_features(x2, x_downsample1,x_downsample2)
        x = self.to_img(x)
        x = self.sig(x)
        return x

定义这一块 

for epoch in range(1, opt.epoch):
    for i, (img, gt) in enumerate(data_loader):
        it = len(data_loader) * epoch + i
        for id, param_group in enumerate(optimizer.param_groups):
            param_group["lr"] = lr_schedule[it]

        optimizer.zero_grad()
        img = img.to(device, dtype=torch.float)
        gt = gt.to(device, dtype=torch.float)

        Feature = E(img)
        f1, f2 = Feature[-1].clone(), Feature[-1].clone()
        pred = G(f1, f2)
        feature = FC_module(pred)
        feature_gt = FC_module(gt.detach())

        l1_loss = F.l1_loss(pred, gt)
        fc_loss = 0  # feature consistency loss
        for j in range(opt.fc_depth):
            fc_loss = fc_loss + F.l1_loss(feature[j], feature_gt[j])
        loss = opt.l1_loss * l1_loss + fc_loss
        loss.backward()
        optimizer.step()

特征进入采样后计算l1损失和特征一致性损失

backward

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

请站在我身后

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值