SynthRAD2023 CBCT冠军论文 代码复现

1、仅学习使用,有侵权行为请务必联系我,论文题目和作者大佬贴在这

2、因为是很多尝试的网络里面的一个,粗劣复现,代码提及没有复现的地方,且代码为节选粘贴,请仅作参考

3、文字为两个训练集的联合训练,本文仅以一个为例

4、数据集:

SynthRAD2023 Grand Challenge validation dataset: synthetizing computed tomography for radiotherapy (zenodo.org)

网络部分

网络是使用segmentation_models_pytorch实现,github

import segmentation_models_pytorch as smp
def load_net(args,net=None):
    if net==None:
        net = args.net
    if net == 'unetpp':
        # model = UnetPlusPlus(num_classes=3, deep_supervision=deep_supervision)
        model = smp.UnetPlusPlus(
            encoder_name="vgg16",        # 选择解码器, 例如 mobilenet_v2 或 efficientnet-b7
            encoder_weights="imagenet",     # 使用预先训练的权重imagenet进行解码器初始化
            in_channels=args.chan_out,                  # 模型输入通道(1个用于灰度图像,3个用于RGB等)
            classes=args.chan_out,                      # 模型输出通道(数据集所分的类别总数)
            )
    print("====> model is loaded")
    return model

2.5D dataset 

本文仅展示大致2.5D内容,切割和归一化请自行调整

# 请按自己喜好创建pytorch 结构
from dataset_tool.dataset_nii import  SliceData_nii
train_set = SliceData_nii(datapath,patch_size=64)
#arg. 内容可以自行调整,测试可以全部为1
training_data_loader = DataLoader(dataset=train_set, num_workers=args.threads, batch_size=args.batchSize, shuffle=True)

import SimpleITK as sitk
import torch.utils.data as data
import numpy as np
import os
import pathlib

class SliceData_nii(data.Dataset):
    """
    A PyTorch Dataset template
    """

    def __init__(self, root, sample_rate=1, patch_size=None, add_noise_in=False):
        """
        Args:
            root (pathlib.Path): Path to the dataset.
            sample_rate (float, optional): A float between 0 and 1. This controls what fraction
                of the volumes should be loaded.
        """
        self.patch_size=patch_size
        self.add_noise_in=add_noise_in
        self.root=root
        self.example=[]
        namelist=os.listdir(root)
        for casename in namelist:
            ctpath=os.path.join(self.root, casename, "ct.nii.gz")
            cbctpath=os.path.join(self.root, casename, "cbct.nii.gz")
            maskpath=os.path.join(self.root, casename, "mask.nii.gz")

            ct=sitk.ReadImage(ctpath)
            ctarray=sitk.GetArrayFromImage(ct)
            cbct=sitk.ReadImage(cbctpath)
            cbctarray=sitk.GetArrayFromImage(cbct)
            assert cbctarray.shape == ctarray.shape
            if min(cbctarray.shape[1],cbctarray.shape[2])<=256:
                continue
            for idx in range(2, ctarray.shape[0] - 2):
                self.example.append((ctpath, cbctpath, maskpath, idx))

    def __len__(self):
        return len(self.example)

    def __getitem__(self, i):

        ctpath, cbctpath, maskpath, idx=self.example[i]

        ct=sitk.ReadImage(ctpath)
        ctarray=sitk.GetArrayFromImage(ct)[idx - 2:idx + 3, :, :]
        cbct=sitk.ReadImage(cbctpath)
        cbctarray=sitk.GetArrayFromImage(cbct)[idx - 2:idx + 3, :, :]
        mask=sitk.ReadImage(maskpath)
        maskarray=sitk.GetArrayFromImage(mask)[idx - 2:idx + 3, :, :]

        cbctnorm=self.norm(cbctarray)
        ctnorm=self.norm(ctarray)
        # print(cbctnorm.shape,ctnorm.shape,maskarray.shape)
        input, x, y=self.get_patch(img=cbctnorm)
        target, _, _=self.get_patch(img=ctnorm, x=x, y=y)
        mask_in, _, _=self.get_patch(img=maskarray, x=x, y=y)

        # print(input.shape,target.shape)
        return input, target,mask_in  # mask

    def norm(self, data):
        datamax=3000
        datamin=-1000
        data=np.clip(data, datamin, datamax)
        data=(data+1000)/2000-1
        return data

    def get_patch(self, img, x=None, y=None):
        ps=self.patch_size
        h=img.shape[1]
        w=img.shape[2]
        if x and y:
            img_patch=img[:, x:x + ps, y:y + ps].astype(np.float32)
            return img_patch, x, y

        # print(h)
        x=np.random.randint(h - ps)
        y=np.random.randint(w - ps)
        img_patch=img[:, x:x + ps, y:y + ps].astype(np.float32)
        return img_patch, x, y

训练和损失

from __future__ import print_function
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from utils.loss import VGG
from dataset_tool.dataset import  SliceData
from dataset_tool.dataset_nii import  SliceData_nii

from utils.file_utils import export_args
from option import args
from net.load_net import load_net
import time

def main(args):
    print("===> Building model")
    model = load_net(args,args.net)
    # set losses
    criterion_mae = nn.L1Loss()
    criterion_per = VGG()

    print("===> Setting GPU")
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()
        #损失部分确实了对遮盖mask 的损失,数据集我倒是给出了,感兴趣自行添加遮盖损失
        criterion_l = criterion_mae.cuda()
        criterion_per = criterion_per.cuda()

    print("===> Setting optimizer")
    params = list(model.parameters())
    optimizer = optim.Adam(params, lr=args.lr)
    # 自行修改name
    model_name = str(args.dataset)+'_'+str(args.net)+'_'+args.version
    export_args(args,os.path.join(args.project_model,model_name),'args.json')
    
    print("===> Training")
    for epoch in range(args.start_epoch, args.nEpochs + 1):        
        train(args,training_data_loader, optimizer, model, criterion_l, criterion_per, epoch)
        save_checkpoint(model, epoch, args)

def adjust_learning_rate(args, epoch): 
    lr = args.lr * (0.5 ** (epoch // args.step))
    return lr

def train(args,training_data_loader, optimizer, model, criterion_l, criterion_per, epoch):
    #    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    lr = adjust_learning_rate(args, epoch-1)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    model.train()    
    for iteration, batch in enumerate(training_data_loader):

        input, target ,mask  = Variable(batch[0]), Variable(batch[1]), Variable(batch[2])

        input = np.transpose(input,(1,0,2,3))
        target = np.transpose(target,(1,0,2,3))
        mask = np.transpose(mask,(1,0,2,3))

        if args.cuda:
            input = input.cuda()
            target = target.cuda()
            mask = mask.cuda()
            input=input.float()
            target=target.float()

            recon = model(input)

            loss_l = criterion_l(recon, target)

            if recon.shape[1] == 1:
                recon1 = torch.cat([recon, recon, recon], axis=1)
                target1 = torch.cat([target, target, target], axis=1)
            else:
                recon1 = recon
                target1 = target
            loss_per = criterion_per(recon1,target1)
        #选择了论文里提到的最好的参数
        loss_G =  10* loss_l +  loss_per
        
        optimizer.zero_grad()
        loss_G.backward() 
    #        nn.utils.clip_grad_norm(model.parameters(),args.clip) 
        optimizer.step()
        # print(iteration)
        if iteration%10 == 0:
            print("===> Epoch[{0}]({1}/{2}): Loss: {3:.7f}, Loss_L: {4:.7f},Loss_Per:{5:7f}, lr={6}, time:"
                  .format(epoch, iteration, len(training_data_loader), loss_G.item(),
                          loss_l.item(),loss_per.item(),str(optimizer.param_groups[0]["lr"])),time.asctime(time.localtime()))
        elif iteration%200 == 0:
            print("===> Epoch[{0}]({1}/{2}): Loss: {3:.7f}, Loss_L: {4:.7f},Loss_Per:{5:7f}, lr={6}, time:"
                  .format(epoch, iteration, len(training_data_loader), loss_G.item(),
                          loss_l.item(),loss_per.item(),str(optimizer.param_groups[0]["lr"])),time.asctime(time.localtime()))

def save_checkpoint(model, epoch, args):
    #args.name 还是arg 自行替换
    model_name = args.name
    model_out_path_folder = os.path.join(args.project_model,'model',model_name)
    model_out_path = os.path.join(model_out_path_folder,"model_epoch_{}.pth".format(epoch))
    state = {"epoch": epoch ,"model": model}
    if not os.path.exists(model_out_path_folder):
        os.makedirs(model_out_path_folder)
    if epoch%args.step_save == 0:   
        torch.save(state, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

if __name__ == '__main__':
    main(args)

推理test部分是针对项目的,就不贴了

参数来源

实验效果还不错,实际使用略差,有改进建议请留言

以上内容,写的很粗略,仅供参考,如有问题请留言

  • 8
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

请站在我身后

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值