pg_mri----brain policy network

该代码段涉及使用深度学习策略评估磁共振成像(MRI)数据的重建质量。它创建数据加载器,然后在测试数据集上进行循环,每次迭代时在mask中添加一行,以模拟逐步获取k空间数据的过程。通过计算结构相似度指数(SSIM)和峰值信噪比(PSNR)来评估重建效果,并记录这些指标随时间的变化。
摘要由CSDN通过智能技术生成

Create data loader

test_loader = create_data_loader()

def test(args, recon_model):


    # Create data loader
    test_loader = create_data_loader(policy_args, 'test', shuffle=False)
def create_data_loader(args, partition, shuffle=False, display=False):
    dataset = create_fastmri_dataset(args, partition)

def create_fastmri_dataset(args, partition):

    elif partition == 'test':
        path = args.data_path / f'singlecoil_test'
        use_seed = True


    mask = MaskFunc(args.center_fractions, args.accelerations)
class MaskFunc:

        if len(center_fractions) != len(accelerations):
            raise ValueError('Number of center fractions should match number of accelerations')

        self.center_fractions = center_fractions
        self.accelerations = accelerations
        self.rng = np.random.RandomState()
    dataset = SliceData(
        root=path,
        transform=DataTransform(mask, args.resolution, use_seed=use_seed),
       dataset=args.dataset,
        sample_rate=args.sample_rate,
        acquisition=args.acquisition,
        center_volume=args.center_volume
    )

class SliceData(Dataset):

        self.transform = transform

        self.examples = []

        self.dataset = dataset
        assert dataset in ['knee', 'brain'], f"Dataset must be 'knee'' or 'brain'', not {dataset}"
        # Using rss for Brain data
        self.recons_key = 'reconstruction_esc' if self.dataset == 'knee' \
            else 'reconstruction_rss'

        data_path = pathlib.Path(root)
        files = sorted(list(data_path.iterdir()))


        for fname in sorted(files):


            target = h5py.File(fname, 'r')[self.recons_key]
            num_slices = target.shape[0]

            if center_volume:  # Only use the slices in the center half of the volume
            else:
                self.examples += [(fname, slice) for slice in range(num_slices // 4, (num_slices // 4) + 2)]

在这里插入图片描述

    print(f'{partition.capitalize()} slices: {len(dataset)}')

    return dataset

[output]
Test slices: 4
    elif partition.lower() in ['val', 'test']:
        batch_size = args.val_batch_size #64

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=args.num_workers,
        pin_memory=True,
    )
    return loader
    test_data_range_dict = create_data_range_dict(policy_args, test_loader)

    do_and_log_evaluation(policy_args, -1, recon_model, model, test_loader, writer, 'Test', test_data_range_dict)

    writer.close()

test_data_range_dict = create_data_range_dict()

    test_data_range_dict = create_data_range_dict(policy_args, test_loader)

–>跳到policy_model_utils.py

def create_data_range_dict(args, loader):
    # Locate ground truths of a volume
    gt_vol_dict = {} #{dict:0}{}
    for it, data in enumerate(loader):
        kspace, masked_kspace, mask, zf, gt, gt_mean, gt_std, fname, slice = data
        for i, vol in enumerate(fname): #i=0, vol='file_brain_AXFLAIR_200_6002492.h5'
            if vol not in gt_vol_dict:
                gt_vol_dict[vol] = []
            gt_vol_dict[vol].append(gt[i] * gt_std[i] + gt_mean[i])

请添加图片描述
在这里插入图片描述

    # Find max of a volume
    data_range_dict = {}
    for vol, gts in gt_vol_dict.items(): #vol='file_brain_AXFLAIR_200_6002492.h5'
        # Shape 1 x 1 x 1 x 1
        data_range_dict[vol] = torch.stack(gts).max().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(args.device)
    del gt_vol_dict
    return data_range_dict

在这里插入图片描述
在这里插入图片描述

do_and_log_evaluation

    do_and_log_evaluation(policy_args, -1, recon_model, model, test_loader, writer, 'Test', test_data_range_dict)
def do_and_log_evaluation(args, epoch, recon_model, model, loader, writer, partition, data_range_dict):
    """
    Helper function for logging.
    """
    ssims, psnrs, score_time = evaluate(args, epoch, recon_model, model, loader, writer, partition, data_range_dict)

图为下面代码enumerate(loader)中的loader
在这里插入图片描述

def evaluate(args, epoch, recon_model, model, loader, writer, partition, data_range_dict):
    """
    Evaluates the policy on all slices in a validation or test dataset on the SSIM and PSNR metrics.

    :param args: Argument object, containing hyperparameters for model evaluation.
    :param epoch: int, current training epoch.
    :param recon_model: reconstruction model object.
    :param model: policy model object.
    :param loader: training data loader.
    :param writer: Tensorboard writer.
    :param partition: str, dataset partition to evaluate on ('val' or 'test')
    :param data_range_dict: dictionary containing the dynamic range of every volume in the validation or test data.
    :return: (dict: average SSIMS per time step, dict: average PSNR per time step, float: evaluation duration)
    """
    model.eval()
    ssims, psnrs = 0, 0
    tbs = 0  # data set size counter
    start = time.perf_counter()
    with torch.no_grad():
        for it, data in enumerate(loader):
            kspace, masked_kspace, mask, zf, gt, gt_mean, gt_std, fname, _ = data
            # shape after unsqueeze = batch x channel x columns x rows x complex
            kspace = kspace.unsqueeze(1).to(args.device) #tensor(4,256,256,2)---->(4,1,256,256,2)
            masked_kspace = masked_kspace.unsqueeze(1).to(args.device)#tensor(4,256,256,2)---->(4,1,256,256,2)
            mask = mask.unsqueeze(1).to(args.device)#tensor(4,1,256,1)---->(4,1,1,256,1)
            # shape after unsqueeze = batch x channel x columns x rows
            zf = zf.unsqueeze(1).to(args.device)#tensor(4,256,256)---->(4,1,256,256)
            gt = gt.unsqueeze(1).to(args.device)#tensor(4,256,256)---->(4,1,256,256)
            gt_mean = gt_mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)#tensor:(4,)--->(4,1,1,1)
            gt_std = gt_std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)#tensor:(4,)--->(4,1,1,1)
            unnorm_gt = gt * gt_std + gt_mean #(4,1,256,256)
            data_range = torch.stack([data_range_dict[vol] for vol in fname])#(4,1,1,1)
            tbs += mask.size(0)#4

            # Base reconstruction model forward pass
            recons = recon_model(zf) #(4,1,256,256)
            unnorm_recons = recons[:, :, :, :] * gt_std + gt_mean #(4,1,256,256)
            init_ssim_val = compute_ssim(unnorm_recons, unnorm_gt, size_average=False,
                                         data_range=data_range).mean(dim=(-1, -2)).sum()
            init_psnr_val = compute_psnr(args, unnorm_recons, unnorm_gt, data_range).sum()
def compute_psnr(args, unnorm_recons, gt_exp, data_range):
    # Have to reshape to batch . trajectories x res x res and then reshape back to batch x trajectories x res x res
    # because of psnr implementation
    psnr_recons = torch.clamp(unnorm_recons, 0., 10.).reshape(gt_exp.size(0) * gt_exp.size(1), 1, args.resolution, args.resolution) #没执行这句之前unnorm_recons[0,0,:,:].min()=7.0674e-06, max=0.0002;执行完这句后(4,1,256,256) psnr_recons[0,0,:,:].min()=7.0674e-06, max=0.0002
    psnr_gt = gt_exp.reshape(gt_exp.size(0) * gt_exp.size(1), 1, args.resolution, args.resolution)#(4,1,256,256)-->(4,1,256,256)
    # First duplicate data range over trajectories, then reshape: this to ensure alignment with recon and gt.
    psnr_data_range = data_range.expand(-1, gt_exp.size(1), -1, -1)#(4,1,1,1)
    psnr_data_range = psnr_data_range.reshape(gt_exp.size(0) * gt_exp.size(1), 1, 1, 1)#(4,1,1,1)
    psnr_scores = psnr(psnr_recons, psnr_gt, reduction='none', data_range=psnr_data_range) #(4,) tensor([24.5812, 24.8475, 26.9927, 27.7892], device='cuda:0')
    psnr_scores = psnr_scores.reshape(gt_exp.size(0), gt_exp.size(1)) #tensor(4,1)
    return psnr_scores
            batch_ssims = [init_ssim_val.item()]#[2.8300583362579346]
            batch_psnrs = [init_psnr_val.item()]#[104.21051788330078]

            for step in range(args.acquisition_steps):#args.acquisition_steps=16
                policy, probs = get_policy_probs(model, recons, mask)
DataParallel(
  (module): PolicyModel(
    (channel_layer): ConvBlock(in_chans=1, out_chans=8, drop_prob=0, max_pool_size=1)
    (down_sample_layers): ModuleList(
      (0): ConvBlock(in_chans=8, out_chans=16, drop_prob=0, max_pool_size=2)
      (1): ConvBlock(in_chans=16, out_chans=32, drop_prob=0, max_pool_size=2)
      (2): ConvBlock(in_chans=32, out_chans=64, drop_prob=0, max_pool_size=2)
      (3): ConvBlock(in_chans=64, out_chans=128, drop_prob=0, max_pool_size=2)
      (4): ConvBlock(in_chans=128, out_chans=256, drop_prob=0, max_pool_size=2)
    )
    (fc_out): Sequential(
      (0): Linear(in_features=16384, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
      (4): Linear(in_features=256, out_features=256, bias=True)
    )
  )
)
def get_policy_probs(model, recons, mask):#mask:(4,1,1,256,1), recons(4,1,256,256)
    channel_size = mask.shape[1]#1
    res = mask.size(-2)#256
    # Reshape trajectory dimension into batch dimension for parallel forward pass
    recons = recons.view(mask.size(0) * channel_size, 1, res, res)#(4,1,256,256)
    # Obtain policy model logits
    output = model(recons)
class PolicyModel(nn.Module):

    def forward(self, image):
        """
        Args:
            image (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
            mask (torch.Tensor): Input tensor of shape [resolution], containing 0s and 1s

        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """

        # Image embedding
        # Initial block
        image_emb = self.channel_layer(image)#(4,1,256,256)---->(4,8,256,256)
        # Apply down-sampling layers
        for layer in self.down_sample_layers:
            image_emb = layer(image_emb)#(4,8,256,256)--->(4,16,256,256)-->(4,32,256,256)-->64-->128--->256
        image_emb = self.fc_out(image_emb.flatten(start_dim=1))  # flatten all but batch dimension #(4,256)
        assert len(image_emb.shape) == 2
        return image_emb
    # Reshape trajectories back into their own dimension
    output = output.view(mask.size(0), channel_size, res)#(4,1,256)
    # Mask already acquired rows by setting logits to very negative numbers
    loss_mask = (mask == 0).squeeze(-1).squeeze(-2).float() #(4,1,256),除了中间32个0,其余全是1
    logits = torch.where(loss_mask.byte(), output, -1e7 * torch.ones_like(output))#(4,1,256),
    # Softmax over 'logits' representing row scores
    probs = torch.nn.functional.softmax(logits - logits.max(dim=-1, keepdim=True)[0], dim=-1)#(4,1,256)
    # Also need this for sampling the next row at the end of this loop
    policy = torch.distributions.Categorical(probs)
    return policy, probs

在这里插入图片描述

一直往mask中加竖线,for循环16遍

step=0
                if step == 0:
                    actions = torch.multinomial(probs.squeeze(1), args.num_test_trajectories, replacement=True)#tensor(4,8)值全是145

                # Samples trajectories in parallel
                # For evaluation we can treat greedy and non-greedy the same: in both cases we just simulate
                # num_test_trajectories acquisition trajectories in parallel for each slice in the batch, and store
                # the average SSIM score every time step.
                mask, masked_kspace, zf, recons = compute_next_step_reconstruction(recon_model, kspace, masked_kspace, mask, actions) #kspace(4,1,256,256,2), masked_kspace(4,1,256,256,2), mask(4,1,1,256,1)

对于probs.squeeze(1):(4,256),除了第145列,其余值都是0
在这里插入图片描述

在这里插入图片描述

def compute_next_step_reconstruction(recon_model, kspace, masked_kspace, mask, next_rows): #next_rows=actions
    # This computation is done by reshaping the masked k-space tensor to (batch . num_trajectories x 1 x res x res)
    # and then reshaping back after performing a reconstruction.
    mask, masked_kspace = acquire_rows_in_batch_parallel(kspace, masked_kspace, mask, next_rows)
def acquire_rows_in_batch_parallel(k, mk, mask, to_acquire):# k(4,1,256,256,2)
    if mask.size(1) == mk.size(1) == to_acquire.size(1):
        # Two cases:
        # 1) We are only requesting a single k-space column to acquire per batch.
        # 2) We are requesting multiple k-space columns per batch, and we are already in a trajectory of the non-greedy
        # model: every column in to_acquire corresponds to an existing trajectory that we have sampled the next
        # column for.
        m_exp = mask
        mk_exp = mk
    else:
        # We have to initialise trajectories: every row in to_acquire corresponds to a trajectory.
        m_exp = mask.repeat(1, to_acquire.size(1), 1, 1, 1) #(4,1,256,256,2),中间32个1,其余全是0
        mk_exp = mk.repeat(1, to_acquire.size(1), 1, 1, 1)#(4,8,256,256,2)和kspace一样,中间白条,两边黑
    # Loop over slices in batch
    for sl, rows in enumerate(to_acquire):#sl=0, rows=tensor([145, 145, 145, 145, 145, 145, 145, 145], device='cuda:0'), 有4个图片我这里设置的,所以sl遍历4次,然后每个图片都有8个trajectory图片模版
        # Loop over indices to acquire
        for index, row in enumerate(rows):  # Will only be single index if first case (see comment above)#index=0,row=145
            m_exp[sl, index, :, row.item(), :] = 1. #(48,1,256,1)除了中间32个(121~143)+第145个的值是1,其余全是0
            mk_exp[sl, index, :, row.item(), :] = k[sl, 0, :, row.item(), :]#取kspace第145列的值给masked_kspace, 一次for之后,index遍历8次,把kspace第145列的值给这8个模版
    return m_exp, mk_exp

经过这段代码后,得到新的under-sampled kspace
在这里插入图片描述

    channel_size = masked_kspace.shape[1]#masked_kspace(4,8,256,256,2),channel_size=8
    res = masked_kspace.size(-2)#256
    # Combine batch and channel dimension for parallel computation if necessary
    masked_kspace = masked_kspace.view(mask.size(0) * channel_size, 1, res, res, 2)#(32,1,256,256,2)
    zf, _, _ = get_new_zf(masked_kspace) #(32,1,256,256)
def get_new_zf(masked_kspace_batch):#(32,1,256,256,2)
    # Inverse Fourier Transform to get zero filled solution
    image_batch = transforms.ifft2(masked_kspace_batch)#(32,1,256,256,2)
    # Absolute value
    image_batch = transforms.complex_abs(image_batch)#(32,1,256,256)
    # Normalize input
    image_batch, means, stds = transforms.normalize(image_batch, dim=(-2, -1), eps=1e-11)
    image_batch = image_batch.clamp(-6, 6)
    return image_batch, means, stds
    recon = recon_model(zf)#(32,1,256,256)

    # Reshape back to B X C (=parallel acquisitions) x H x W
    recon = recon.view(mask.size(0), channel_size, res, res)#recon(4,8,256,256)。mask(4,8,1,256,1)
    zf = zf.view(mask.size(0), channel_size, res, res)#(4,8,256,256)
    masked_kspace = masked_kspace.view(mask.size(0), channel_size, res, res, 2)#(32,1,256,256,2)-->(4,8,256,256,2)
    return mask, masked_kspace, zf, recon

在这里插入图片描述

                ssim_scores, psnr_scores = compute_scores(args, recons, gt_mean, gt_std, unnorm_gt, data_range, comp_psnr=True)
def compute_scores(args, recons, gt_mean, gt_std, unnorm_gt, data_range, comp_psnr=True):
    # For every slice in the batch, and every acquired action per slice, compute the resulting SSIM (and PSNR) scores
    # in parallel.
    # Unnormalise reconstructions
    unnorm_recons = recons * gt_std + gt_mean #(4,8,256,256)
    # Reshape targets if necessary (for parallel computation of multiple acquisitions)
    gt_exp = unnorm_gt.expand(-1, recons.shape[1], -1, -1)#(4,8,256,256)
    # SSIM scores = batch x k (channels)
    ssim_scores = compute_ssim(unnorm_recons, gt_exp, size_average=False, data_range=data_range).mean(-1).mean(-1)#(4,8)
    # Also compute PSNR
    if comp_psnr:
        psnr_scores = compute_psnr(args, unnorm_recons, gt_exp, data_range)
        return ssim_scores, psnr_scores
    return ssim_scores

在这里插入图片描述
在这里插入图片描述

                assert len(ssim_scores.shape) == 2
                ssim_scores = ssim_scores.mean(-1).sum()#tensor(2.8645, device='cuda:0')
                psnr_scores = psnr_scores.mean(-1).sum()#tensor(105.1553, device='cuda:0')
                # eventually shape = al_steps
                batch_ssims.append(ssim_scores.item())
                batch_psnrs.append(psnr_scores.item())

一个for循环之后,这里0是初始值,0是第一个for循环即mask加了一条线之后的结果。然后跳到for step in range(args.acquisition_steps)继续遍历range(0, 16)要遍历16次。
在这里插入图片描述

step=1
            for step in range(args.acquisition_steps):
                policy, probs = get_policy_probs(model, recons, mask) #probs只有第146列有值,四个文件的值分别是0.96784,0.93953,0.92923,0.86803

在这里插入图片描述

                if step == 0:

                else:
                    actions = policy.sample()#(4,8)值全是146
                # Samples trajectories in parallel
                # For evaluation we can treat greedy and non-greedy the same: in both cases we just simulate
                # num_test_trajectories acquisition trajectories in parallel for each slice in the batch, and store
                # the average SSIM score every time step.
                mask, masked_kspace, zf, recons = compute_next_step_reconstruction(recon_model, kspace, masked_kspace, mask, actions)
def compute_next_step_reconstruction(recon_model, kspace, masked_kspace, mask, next_rows):
    # This computation is done by reshaping the masked k-space tensor to (batch . num_trajectories x 1 x res x res)
    # and then reshaping back after performing a reconstruction.
    mask, masked_kspace = acquire_rows_in_batch_parallel(kspace, masked_kspace, mask, next_rows)
    channel_size = masked_kspace.shape[1]
    res = masked_kspace.size(-2)
    # Combine batch and channel dimension for parallel computation if necessary
    masked_kspace = masked_kspace.view(mask.size(0) * channel_size, 1, res, res, 2)
    zf, _, _ = get_new_zf(masked_kspace)
    recon = recon_model(zf)

    # Reshape back to B X C (=parallel acquisitions) x H x W
    recon = recon.view(mask.size(0), channel_size, res, res)
    zf = zf.view(mask.size(0), channel_size, res, res)
    masked_kspace = masked_kspace.view(mask.size(0), channel_size, res, res, 2)
    return mask, masked_kspace, zf, recon
                ssim_scores, psnr_scores = compute_scores(args, recons, gt_mean, gt_std, unnorm_gt, data_range, comp_psnr=True)
                assert len(ssim_scores.shape) == 2
                ssim_scores = ssim_scores.mean(-1).sum()#tensor(2.8925, device='cuda:0')
                psnr_scores = psnr_scores.mean(-1).sum()#tensor(105.9600, device='cuda:0')
                # eventually shape = al_steps
                batch_ssims.append(ssim_scores.item())
                batch_psnrs.append(psnr_scores.item())

在这里插入图片描述

step=2,3…15

看一下step=15
在这里插入图片描述
在这里插入图片描述

            # shape of al_steps
            ssims += np.array(batch_ssims)
            psnrs += np.array(batch_psnrs)

在这里插入图片描述
在这里插入图片描述
然后跳到for it, data in enumerate(loader):循环下一个batch

循环完所有batch计算ssim和psnr
    ssims /= tbs
    psnrs /= tbs

在这里插入图片描述
在这里插入图片描述

    # Logging

    elif partition == 'Test':
        # Only computed once, so loop over all epochs for wandb logging
        if args.wandb:#False

    return ssims, psnrs, time.perf_counter() - start
    ssims_str = ", ".join(["{}: {:.4f}".format(i, l) for i, l in enumerate(ssims)])
    psnrs_str = ", ".join(["{}: {:.3f}".format(i, l) for i, l in enumerate(psnrs)])
    logging.info(f'{partition}SSIM = [{ssims_str}]')
    logging.info(f'{partition}PSNR = [{psnrs_str}]')
    logging.info(f'{partition}ScoreTime = {score_time:.2f}s')

在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值