Multi-scale Transformer Network with Edge-aware Pre-training for Cross-Modality MRImage Synthesis 代码复现和讲解
复现
解压后打开,放入服务器中
2、 安装环境
pip install -r requirements.txt
注意torch版本
3、下载数据集
-
Download BraTS2020 dataset from kaggle. The file name should be
./data/archive.zip
. Unzip the file in./data/
.
data/MICCAI_BraTS2020_TrainingData
4、运行预处理
python utils/preprocessing.py
5、预训练
python pretrain.py
更改/options/pretrain_options.py文件中的默认设置。
权重将保存在。/weight/EdgeMAE/中。
6、微调
python Finetune.py
可以更改./options/finetune_options.py中的默认设置,特别是data_rate选项,以调整配对数据的数量以进行微调。此外,您可以增加num_workers来加速微调。
重量将保存在。./weight/finetuned/。注意,对于MT-Net,输入大小必须为256×256。
7、使用
test.py 进行测试
合成图片将保存在./snapshot/test/
代码讲解
preprocessing
读取t1\t2\flair\t1ce 四个模态的数据和gt 数据
for i,(t1_path, t2_path, t1ce_path, flair_path, gt_path) in enumerate(zip(t1_list,t2_list,t1ce_list,flair_list,gt_list)):
print('preprocessing the',i+1,'th subject')
t1_img = nib.load(t1_path) # (240,140,155)
t2_img = nib.load(t2_path)
flair_img = nib.load(flair_path)
t1ce_img = nib.load(t1ce_path)
gt_img = nib.load(gt_path)
转化为np然后标准化
#to numpy
t1_data = t1_img.get_fdata()
t2_data = t2_img.get_fdata()
flair_data = flair_img.get_fdata()
t1ce_data = t1ce_img.get_fdata()
gt_data = gt_img.get_fdata()
gt_data = gt_data.astype(np.uint8)
gt_data[gt_data == 4] = 3 #label 3 is missing in BraTS 2020
t1_data = normalize(t1_data) # normalize to [0,1]
t2_data = normalize(t2_data)
t1ce_data = normalize(t1ce_data)
flair_data = normalize(flair_data)
压缩为一个5通道数据
tensor = np.stack([t1_data, t2_data, t1ce_data, flair_data, gt_data]) # (4, 240, 240, 155)
截取一部分保存
if i < train_len:
for j in range(60):
Tensor = tensor[:, 10:210, 25:225, 50 + j]
np.save(train_path + str(60 * i + j + 1) + '.npy', Tensor)
else:
for j in range(60):
Tensor = tensor[:, 10:210, 25:225, 50 + j]
np.save(test_path + str(60 * (i - train_len) + j + 1) + '.npy', Tensor)
每个np为(5,200,200,60)
pertain
mae网络
mae = EdgeMAE(img_size=opt.img_size,patch_size=opt.patch_size, embed_dim=opt.dim_encoder, depth=opt.depth, num_heads=opt.num_heads, in_chans=1,
decoder_embed_dim=opt.dim_decoder, decoder_depth=opt.decoder_depth, decoder_num_heads=opt.decoder_num_heads,
mlp_ratio=opt.mlp_ratio,norm_pix_loss=False,patchwise_loss=opt.use_patchwise_loss)
定义了编码器 ,包括PatchEmbed,然后定义好了block块
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.H = int(img_size / patch_size)
self.W = int(img_size / patch_size)
self.patch_size = patch_size
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
PatchEmbed 固定了图像大小为224 ,size 为16 维度为3,重要的就是经过了一层卷积
x = self.proj(x) 然后压平x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
output_fmt: Format
def __init__(
self,
img_size: Optional[int] = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten: bool = True,
output_fmt: Optional[str] = None,
bias: bool = True,
):
super().__init__()
self.patch_size = to_2tuple(patch_size)
if img_size is not None:
self.img_size = to_2tuple(img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
if output_fmt is not None:
self.flatten = False
self.output_fmt = Format(output_fmt)
else:
# flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten
self.output_fmt = Format.NCHW
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
if self.img_size is not None:
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
elif self.output_fmt != Format.NCHW:
x = nchw_to(x, self.output_fmt)
x = self.norm(x)
return x
每一个block 如下 标准化=》注意力=》加权=》drop
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
proj_drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
mlp_layer=Mlp,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'
然后定义decoder,看看decoder_blocks、rec_blocks、edge_blocks,三个任务,都是block的叠加
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.rec_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(rec_depth)])
self.edge_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(edge_depth)])
self.rec_norm = norm_layer(decoder_embed_dim)
self.rec_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
self.rec_sig = nn.Sigmoid()
self.edge_norm = norm_layer(decoder_embed_dim)
self.edge_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)
self.edge_sig = nn.Sigmoid()
直接看forward
def forward(self, imgs,mask_ratio=0.75,epoch=1):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
x_edge,x_rec= self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
weit = self.structure_loss(mask,epoch)
rec_loss = self.rec_loss(imgs, x_rec, mask, weit)
edge_loss,edge_gt = self.edge_loss(imgs, x_edge, mask, weit)
return rec_loss, edge_loss,edge_gt,x_edge,x_rec, mask
显示跑encoder,嵌入patch,嵌入token,嵌入mask,进行Transformer
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
然后跑decoder,两个任务,跑完block后,通过两个不同的sig区分开
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x_rec = x
x_edge = x
x_rec = self.rec_norm(x_rec)
x_rec = self.rec_sig(self.rec_pred(x_rec))
x_edge = self.edge_norm(x_edge)
x_edge = self.edge_sig(self.edge_pred(x_edge))
# remove cls token
x_edge = x_edge[:, 1:, :]
x_rec = x_rec[:, 1:, :]
return x_edge,x_rec
self.rec_norm = norm_layer(decoder_embed_dim)
self.rec_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
self.rec_sig = nn.Sigmoid()
self.edge_norm = norm_layer(decoder_embed_dim)
self.edge_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True)
self.edge_sig = nn.Sigmoid()
然后分别计算损失
rec_loss = self.rec_loss(imgs, x_rec, mask, weit)
edge_loss,edge_gt = self.edge_loss(imgs, x_edge, mask, weit)
def rec_loss(self, imgs, pred, mask, weit):
target = self.patchify(imgs)
weit = self.patchify(weit)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
if self.patchwise_loss == True:
loss = loss * weit
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def edge_loss(self, imgs, pred, mask, weit):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
with torch.no_grad():
edge_gt = self.operator(imgs)
target = self.patchify(edge_gt)
weit = self.patchify(weit)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss * weit
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss,edge_gt
训练部分没有特殊的,loss = rec_loss * opt.l1_loss + edge_loss
for epoch in range(1,opt.epoch):
for i,img in enumerate(train_loader):
adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
optimizer.zero_grad()
img = img.to(device,dtype=torch.float)
rec_loss, edge_loss,edge_gt,x_edge,x_rec,mask = mae(img,opt.masking_ratio,epoch)
loss = rec_loss * opt.l1_loss + edge_loss
loss.backward()
optimizer.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [rec_loss: %f] [edge_loss: %f] [lr: %f]"
% (epoch, opt.epoch, i, len(train_loader), rec_loss.item(),edge_loss.item(),get_lr(optimizer))
)
if i % opt.save_output == 0:
y1, im_masked1, im_paste1 = mae.MAE_visualize(img, x_rec, mask)
y2, im_masked2, im_paste2 = mae.MAE_visualize(edge_gt, x_edge, mask)
edge_gt,img = edge_gt.cpu(),img.cpu()
save_image([img[0],im_masked1,im_paste1,edge_gt[0],im_masked2,im_paste2],
opt.img_save_path + str(epoch) + ' ' + str(i)+'.png', nrow=3,normalize=False)
logging.info("[Epoch %d/%d] [Batch %d/%d] [rec_loss: %f] [edge_loss: %f] [lr: %f]"
% (epoch, opt.epoch, i, len(train_loader), rec_loss.item(),edge_loss.item(),get_lr(optimizer)))
if epoch % opt.save_weight == 0:
torch.save(mae.state_dict(), opt.weight_save_path + str(epoch) + 'MAE.pth')
torch.save(mae.state_dict(), opt.weight_save_path + './MAE.pth')
fineturn
两个编码器,大同小异
E = MAE_finetune(img_size=opt.img_size, patch_size=opt.mae_patch_size, embed_dim=opt.encoder_dim, depth=opt.depth,
num_heads=opt.num_heads, in_chans=1, mlp_ratio=opt.mlp_ratio)
FC_module = MAE_finetune(img_size=opt.img_size, patch_size=opt.mae_patch_size, embed_dim=opt.encoder_dim,
depth=opt.fc_depth, num_heads=opt.num_heads, in_chans=1,
mlp_ratio=opt.mlp_ratio) # feature consistency module
class MAE_finetune(nn.Module):
def __init__(self, img_size=256,patch_size=8, in_chans=1, embed_dim=128, depth=24, num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True,norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
def patchify(self, imgs):
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
return x
def forward_encoder(self, x):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# apply Transformer blocks
feature = [x]
for blk in self.blocks:
x = blk(x)
feature.append(x)
x = self.norm(x)
feature.append(x)
return feature
def forward(self, imgs):
latent = self.forward_encoder(imgs)
return latent
MTNet网络
G = MTNet(img_size=opt.img_size, patch_size=opt.patch_size, in_chans=1, num_classes=1, embed_dim=opt.vit_dim,
depths=[2, 2, 2, 2], depths_decoder=[2, 2, 2, 2], num_heads=[8, 8, 16, 32], window_size=opt.window_size,
mlp_ratio=opt.mlp_ratio, qkv_bias=True, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, final_upsample="expand_first", fine_tune=True)
def forward(self,x1,x2):
x2 = self.expand(x2)
x1, x_downsample1 = self.forward_features1(x1)
x2, x_downsample2 = self.forward_features2(x2)
x = self.forward_up_features(x2, x_downsample1,x_downsample2)
x = self.to_img(x)
x = self.sig(x)
return x
定义这一块
for epoch in range(1, opt.epoch):
for i, (img, gt) in enumerate(data_loader):
it = len(data_loader) * epoch + i
for id, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedule[it]
optimizer.zero_grad()
img = img.to(device, dtype=torch.float)
gt = gt.to(device, dtype=torch.float)
Feature = E(img)
f1, f2 = Feature[-1].clone(), Feature[-1].clone()
pred = G(f1, f2)
feature = FC_module(pred)
feature_gt = FC_module(gt.detach())
l1_loss = F.l1_loss(pred, gt)
fc_loss = 0 # feature consistency loss
for j in range(opt.fc_depth):
fc_loss = fc_loss + F.l1_loss(feature[j], feature_gt[j])
loss = opt.l1_loss * l1_loss + fc_loss
loss.backward()
optimizer.step()
特征进入采样后计算l1损失和特征一致性损失
backward