《Unpaired Unsupervised CT Metal ArtifactReduction》代码讲解

论文讲解见上篇博客
        这篇论文的标题是《Unpaired Unsupervised CT Metal Artifact Reduction》,作者是Bo-Yuan Chen和Chu-Song Chen。这篇论文主要研究了如何使用深度学习技术来减少医学成像中由于金属植入物引起的CT图像伪影。

项目给出了几个不同的unet网络的实验,以pytorch_Net.py举例

train

1、参数如下

batch_size = 8 
num_epoch = 25
lr = 2e-5
channels = 3
img_size = 320
lmda_g = 0.05
lmda_dnn = 0.1
input_shape = (channels, img_size, img_size)

居然是3通道的,大家要用记者修改

2、获得患者信息

    train_patient_info_noise, train_patient_info_clear, train_noise_num, train_clear_num = get_patient_info(CT_dir, OMA_dir, patients_id_list_train, semi=True)
    test_patient_info_noise, test_patient_info_clear, test_noise_num, test_clear_num = get_patient_info(CT_dir, OMA_dir, patients_id_list_test, semi=True)
def get_patient_info(root, patients_id_list):
    patient_info_clear = list()
    patient_info_clear = pd.DataFrame(patient_info_clear, columns = ['name', 'path', 'class']) # clear : 0
    patient_info_noise = list()
    patient_info_noise = pd.DataFrame(patient_info_noise, columns = ['name', 'path', 'class']) # noise : 1
    noise_num = 0
    clear_num = 0
    for i, patient_id in enumerate(patients_id_list):
        patient_id_path = os.path.join(root, patient_id)
        f = open(os.path.join(patient_id_path, 'MA_slice_num.txt'))
        noisy_patients_No = list()
        for line in f.read().splitlines():
            noisy_patients_No.append(line)
        for item in os.listdir(patient_id_path):
            if ('.jpg' in item and item.split('_')[0] in noisy_patients_No):
                patient_info_noise = patient_info_noise.append({'name':item,'path': patient_id_path, 'class': 1}, ignore_index = True)
                noise_num += 1
            elif ('.jpg' in item and item.split('_')[0] not in noisy_patients_No):
                patient_info_clear = patient_info_clear.append({'name':item,'path': patient_id_path, 'class': 0}, ignore_index = True)
                clear_num += 1
    return patient_info_noise, patient_info_clear, noise_num, clear_num

包括CT是否是干净的,CT名,CT路径等

3、根据id划分训练、测试集


    test_transform = transforms.Compose([  
        transforms.Resize((img_size, img_size)),                                 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
    ])

    train_set_noise1 = CTImg(transform = train_transform, patient_info = train_patient_info_noise,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)

    train_set_noise = ConcatDataset([train_set_noise1, train_set_noise1, train_set_noise1, train_set_noise1])
    train_set_noise = ConcatDataset([train_set_noise,train_set_noise])

    train_set_clear = CTImg(transform = train_transform, patient_info = train_patient_info_clear,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)
    test_set_noise =  CTImg(transform = test_transform, patient_info = test_patient_info_noise,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)
    test_set_clear =  CTImg(transform = test_transform, patient_info = test_patient_info_clear,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)


    train_noise_loader = DataLoader(train_set_noise, batch_size = batch_size, shuffle=True)
    train_clear_loader = DataLoader(train_set_clear, batch_size = batch_size, shuffle=True)
    test_noise_loader = DataLoader(test_set_noise, batch_size = batch_size, shuffle=False)
    test_clear_loader = DataLoader(test_set_clear, batch_size = batch_size, shuffle=False)

有CT也有noise 的数据

4、加载损失函数


    g_loss = torch.nn.BCEWithLogitsLoss()
    g_r_loss = torch.nn.MSELoss()
    d_loss = torch.nn.BCEWithLogitsLoss()
    dnn_loss = torch.nn.MSELoss()
    dnn_r_loss = torch.nn.MSELoss()

5、两个生成器一个鉴别器

    Gen = Generator(input_shape)
    Dis = Discriminator(input_shape)
    Dnn = Denoiser_UNet(input_shape)

6、放入cuda,初始化权重、优化函数


    if cuda:
        Gen = Gen.cuda()
        Dis = Dis.cuda()
        Dnn = Dnn.cuda()
        g_loss.cuda()
        d_loss.cuda()
        dnn_loss.cuda()


    # Initialize weights
    Gen.apply(weights_init_normal)
    Dis.apply(weights_init_normal)
    Dnn.apply(weights_init_normal)


    # Optimizers
    optimizer_Gen = torch.optim.Adam(Gen.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_Dis = torch.optim.Adam(Dis.parameters(), lr=lr/2, betas=(0.5, 0.999))
    optimizer_Dnn = torch.optim.Adam(Dnn.parameters(), lr=lr, betas=(0.5, 0.999))

    # Input tensor type
    Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
    fix_batch_sample_z = Tensor(get_random_sample(([batch_size] + list(input_shape)), method = 'uniform'))

7、开始训练,训练鉴别器,先生成个噪音g_noise,然后再与干净数据结合,提取特征DIS,计算损失real_loss、fake_loss,返回梯度。

            """ Train D """
            optimizer_Dis.zero_grad()
            batch_sample_z = Tensor(get_random_sample(([len(clear_img)] + list(input_shape)), method = 'uniform'))
            g_noise = Gen(torch.cat((Variable(batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))
            g_img = g_noise + Variable(clear_img).cuda()

            noisy_real = diff(Variable(noise_img).cuda())
            noisy_fake = diff(g_img)
            #if i ==0:
            #    print(f"shape of noisy_real: {noisy_real.shape}, shape of noisy_fake: {noisy_fake.shape}")
            real_logit = Dis(noisy_real.detach())
            fake_logit = Dis(noisy_fake.detach())
            
            real_label = Variable(noise_cls.float().cuda()) #1
            fake_label = Variable(clear_cls.float().cuda()) #0
            
            real_loss = d_loss(real_logit, real_label)
            fake_loss = d_loss(fake_logit, fake_label)
            loss_D = (real_loss + fake_loss) / 2
            
            loss_D.backward()
            optimizer_Dis.step()

训练生成器,

            optimizer_Gen.zero_grad()
            optimizer_Dnn.zero_grad()
            batch_sample_z = Tensor(get_random_sample(([len(clear_img)] + list(input_shape)), method = 'uniform'))
            g_noise = Gen(torch.cat((Variable(batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))
            
            # semi-part
            
            loss_g_r, loss_dnn_r = 0, 0
            spl = 0
            for li, (ni,s,nl) in enumerate(zip(noise_img, supervised, noise_label)):
                b_s_z = Tensor(get_random_sample(([1] + list(input_shape)), method = 'uniform'))
                if s:
                    spl += 1
                    g_n_GT = Gen(torch.cat((Variable(b_s_z).cuda(),Variable(nl[None]).cuda()), 1))
                    loss_g_r += g_r_loss(g_n_GT, Variable(ni[None]).cuda() -Variable(nl)[None].cuda())
                    dnn_p_GT = Dnn(g_n_GT.detach())
                    loss_dnn_r = dnn_r_loss(dnn_p_GT, Variable(ni[None]).cuda() -Variable(nl[None]).cuda())
            if spl != 0:
                loss_g_r /= spl
                loss_dnn_r /= spl
            g_img = g_noise + Variable(clear_img).cuda()
            noisy_fake = diff(g_img)
            fake_logit = Dis(noisy_fake)         
            loss_G = g_loss(fake_logit, torch.ones((len(clear_img))).cuda()) + lmda_g * loss_g_r 
            loss_G.backward()
            optimizer_Gen.step()
                                  
            dnn_pred = Dnn(g_noise.detach())                
            out = g_img.detach() - dnn_pred               
            loss_Dnn = dnn_loss(out,Variable(clear_img).cuda()) + lmda_dnn * loss_dnn_r 
            loss_Dnn.backward()
            optimizer_Dnn.step()      

8、验证+保存

        with torch.no_grad():
            psnr = PSNR()
            mae = MAE()
            N_GT_psnr, DN_GT_psnr, N_GT_mae, DN_GT_mae, N_GT_ssim, DN_GT_ssim = 0, 0, 0, 0, 0, 0
            for i, ((noise_img, _,_,noise_label,_), (clear_img,_,_,clear_label,_)) in enumerate(zip(test_noise_loader, test_clear_loader)):
                '''Gen'''
                g_noise = Gen(torch.cat((Variable(fix_batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))            
                g_img = g_noise + Variable(clear_img).cuda()
                '''Dnn'''
                dnn_pred = Dnn(Variable(noise_img).cuda())
                out = Variable(noise_img).cuda() - dnn_pred
                batch_len = len(out)
                for (noise,label) in zip(Variable(noise_img).cuda(),Variable(noise_label).cuda()): 
                    N_GT_psnr += psnr(noise, label)/batch_len
                    #N_GT_ssim += compare_ssim(noise,label)/batch_len
                    N_GT_mae += mae(noise,label)/batch_len

                for (denoise,label) in zip(out,Variable(noise_label).cuda()): 
                    DN_GT_psnr += psnr(clp(denoise), label)/batch_len
                    #DN_GT_ssim += compare_ssim(denoise,label)/batch_len
                    DN_GT_mae += mae(clp(denoise), label)/batch_len

                if  i == 0:                
                    fig = plt.figure(figsize=[8*6,8*4])
                    axes = [fig.add_subplot(6, 1, r+1 ) for r in range(0, 6)]
                    for ax in axes:
                        ax.axis('off')
                    plt.gca().xaxis.set_major_locator(plt.NullLocator())
                    plt.gca().yaxis.set_major_locator(plt.NullLocator())
                    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
                    plt.margins(0,0) 
                    axes[0].imshow(torchvision.utils.make_grid(clear_img.cpu(), nrow=8).permute(1, 2, 0))
                    #torchvision.utils.save_image(clear_img.cpu(), './samples/origin_clear_ep{:02d}-{:04d}.png'.format(epoch, i))               
                    axes[1].imshow(torchvision.utils.make_grid(g_noise.cpu(), nrow=8).permute(1, 2, 0))
                    #torchvision.utils.save_image(g_noise.cpu(), './samples/gen_noise_ep{:02d}-{:04d}.png'.format(epoch, i))                
                    axes[2].imshow(torchvision.utils.make_grid(g_img.cpu(), nrow=8).permute(1, 2, 0))
                    #torchvision.utils.save_image(g_img.cpu(), './samples/gen_img_ep{:02d}-{:04d}.png'.format(epoch, i))                                                         
                    axes[3].imshow(torchvision.utils.make_grid(noise_img.cpu(), nrow=8).permute(1, 2, 0))
                    #torchvision.utils.save_image(noise_img.cpu(), './samples/origin_noise_ep{:02d}-{:04d}.png'.format(epoch, i))
                    axes[4].imshow(torchvision.utils.make_grid(dnn_pred.cpu(), nrow=8).permute(1, 2, 0))
                    #torchvision.utils.save_image(dnn_pred.cpu(),  './samples/dnn_noise_ep{:02d}-{:04d}.png'.format(epoch, i))
                    axes[5].imshow(torchvision.utils.make_grid(out.cpu(), nrow=8).permute(1, 2, 0))
                    #torchvision.utils.save_image(out.cpu(), './samples/denoised_img_ep{:02d}-{:04d}.png'.format(epoch, i))
                    fig.savefig("results/SS_DNN2UNet/cv{:02d}ep{:02d}.png".format(idx+1,epoch),bbox_inches = 'tight',pad_inches = 0)
                    plt.close(fig)
                    print("saving...")

model

class Generator(nn.Module):
    def __init__(self, input_shape, cat=True):
        super(Generator, self).__init__()
        
        channels, _, _ = input_shape
        if cat:
            channels*=2 
        self.down1 = G_Down(channels, 32, normalize=False) 
        self.down2 = G_Down(32, 32) 
        self.down3 = G_Down(32, 64, pooling=True, dropout=0.5) 
        self.down4 = G_Down(64, 64)         
        self.down5 = G_Down(64, 128, pooling=True, dropout=0.5) 
        self.down6 = G_Down(128, 128, normalize=False) 

        self.up1 = G_Up(256, 64, uppooling=True, dropout=0.5)
        self.up2 = G_Up(64, 64)
        self.up3 = G_Up(128, 32, uppooling=True, dropout=0.5)
        self.up4 = G_Up(32, 32)
        self.up5 = G_Up(32, 3)

        self.final = nn.Sequential(
            nn.Conv2d(3, 3, kernel_size = 3,stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):               #[batchsize,   6, 64, 64]
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)              #[batchsize,  32, 64, 64]
        d2 = self.down2(d1)             #[batchsize,  32, 64, 64]
        d3 = self.down3(d2)             #[batchsize,  64, 32, 32]
        d4 = self.down4(d3)             #[batchsize,  64, 32, 32]
        d5 = self.down5(d4)             #[batchsize, 128, 16, 16]
        d6 = self.down6(d5)             #[batchsize, 128, 16, 16]
        cat1 = torch.cat((d6, d5), 1)   #[batchsize, 256, 16, 16]
        u1 = self.up1(cat1)             #[batchsize,  64, 32, 32]
        u2 = self.up2(u1)               #[batchsize,  64, 32, 32]
        cat2 = torch.cat((u2, d4), 1)   #[batchsize, 128, 32, 32]
        u3 = self.up3(cat2)             #[batchsize,  32, 64, 64]    
        u4 = self.up4(u3)               #[batchsize,  32, 64, 64]
        u5 = self.up5(u4)               #[batchsize,   3, 64, 64]
        return self.final(u5)           #[batchsize,   3, 64, 64]

 

class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape
        self.input_shape = (channels*2, height, width)                        #[batchsize,   3, 64, 64]
        # Calculate output of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 3, width // 2 ** 3)

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, 2, 1)]
            if normalization:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels*2, 16, normalization=False),      #[batchsize,   64, 32, 32]
            *discriminator_block(16, 32),                                  #[batchsize,  128, 16, 16]
            *discriminator_block(32, 128),                                 #[batchsize,  256,  8,  8]
            *discriminator_block(128, 128),                                 #[batchsize,  512,  4,  4]
        )
        
        self.final = nn.Sequential(
            nn.Linear(128 * 20 * 20, 1),
            nn.Sigmoid(),
        )


    def forward(self, img):
        # Concatenate image and condition image by channels to produce input
        conv = self.model(img)
        conv = conv.view(conv.shape[0], -1)
        return self.final(conv).view(-1)

 

综上,与论文框架描述一致,没有弯弯绕绕

  • 7
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

请站在我身后

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值