train_policy.py
Create data loader
test_loader = create_data_loader()
def test(args, recon_model):
# Create data loader
test_loader = create_data_loader(policy_args, 'test', shuffle=False)
def create_data_loader(args, partition, shuffle=False, display=False):
dataset = create_fastmri_dataset(args, partition)
def create_fastmri_dataset(args, partition):
elif partition == 'test':
path = args.data_path / f'singlecoil_test'
use_seed = True
mask = MaskFunc(args.center_fractions, args.accelerations)
class MaskFunc:
if len(center_fractions) != len(accelerations):
raise ValueError('Number of center fractions should match number of accelerations')
self.center_fractions = center_fractions
self.accelerations = accelerations
self.rng = np.random.RandomState()
dataset = SliceData(
root=path,
transform=DataTransform(mask, args.resolution, use_seed=use_seed),
dataset=args.dataset,
sample_rate=args.sample_rate,
acquisition=args.acquisition,
center_volume=args.center_volume
)
class SliceData(Dataset):
self.transform = transform
self.examples = []
self.dataset = dataset
assert dataset in ['knee', 'brain'], f"Dataset must be 'knee'' or 'brain'', not {dataset}"
# Using rss for Brain data
self.recons_key = 'reconstruction_esc' if self.dataset == 'knee' \
else 'reconstruction_rss'
data_path = pathlib.Path(root)
files = sorted(list(data_path.iterdir()))
for fname in sorted(files):
target = h5py.File(fname, 'r')[self.recons_key]
num_slices = target.shape[0]
if center_volume: # Only use the slices in the center half of the volume
else:
self.examples += [(fname, slice) for slice in range(num_slices // 4, (num_slices // 4) + 2)]
print(f'{partition.capitalize()} slices: {len(dataset)}')
return dataset
[output]
Test slices: 4
elif partition.lower() in ['val', 'test']:
batch_size = args.val_batch_size #64
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=args.num_workers,
pin_memory=True,
)
return loader
test_data_range_dict = create_data_range_dict(policy_args, test_loader)
do_and_log_evaluation(policy_args, -1, recon_model, model, test_loader, writer, 'Test', test_data_range_dict)
writer.close()
test_data_range_dict = create_data_range_dict()
test_data_range_dict = create_data_range_dict(policy_args, test_loader)
–>跳到policy_model_utils.py
def create_data_range_dict(args, loader):
# Locate ground truths of a volume
gt_vol_dict = {} #{dict:0}{}
for it, data in enumerate(loader):
kspace, masked_kspace, mask, zf, gt, gt_mean, gt_std, fname, slice = data
for i, vol in enumerate(fname): #i=0, vol='file_brain_AXFLAIR_200_6002492.h5'
if vol not in gt_vol_dict:
gt_vol_dict[vol] = []
gt_vol_dict[vol].append(gt[i] * gt_std[i] + gt_mean[i])
# Find max of a volume
data_range_dict = {}
for vol, gts in gt_vol_dict.items(): #vol='file_brain_AXFLAIR_200_6002492.h5'
# Shape 1 x 1 x 1 x 1
data_range_dict[vol] = torch.stack(gts).max().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(args.device)
del gt_vol_dict
return data_range_dict
do_and_log_evaluation
do_and_log_evaluation(policy_args, -1, recon_model, model, test_loader, writer, 'Test', test_data_range_dict)
def do_and_log_evaluation(args, epoch, recon_model, model, loader, writer, partition, data_range_dict):
"""
Helper function for logging.
"""
ssims, psnrs, score_time = evaluate(args, epoch, recon_model, model, loader, writer, partition, data_range_dict)
图为下面代码enumerate(loader)中的loader
def evaluate(args, epoch, recon_model, model, loader, writer, partition, data_range_dict):
"""
Evaluates the policy on all slices in a validation or test dataset on the SSIM and PSNR metrics.
:param args: Argument object, containing hyperparameters for model evaluation.
:param epoch: int, current training epoch.
:param recon_model: reconstruction model object.
:param model: policy model object.
:param loader: training data loader.
:param writer: Tensorboard writer.
:param partition: str, dataset partition to evaluate on ('val' or 'test')
:param data_range_dict: dictionary containing the dynamic range of every volume in the validation or test data.
:return: (dict: average SSIMS per time step, dict: average PSNR per time step, float: evaluation duration)
"""
model.eval()
ssims, psnrs = 0, 0
tbs = 0 # data set size counter
start = time.perf_counter()
with torch.no_grad():
for it, data in enumerate(loader):
kspace, masked_kspace, mask, zf, gt, gt_mean, gt_std, fname, _ = data
# shape after unsqueeze = batch x channel x columns x rows x complex
kspace = kspace.unsqueeze(1).to(args.device) #tensor(4,256,256,2)---->(4,1,256,256,2)
masked_kspace = masked_kspace.unsqueeze(1).to(args.device)#tensor(4,256,256,2)---->(4,1,256,256,2)
mask = mask.unsqueeze(1).to(args.device)#tensor(4,1,256,1)---->(4,1,1,256,1)
# shape after unsqueeze = batch x channel x columns x rows
zf = zf.unsqueeze(1).to(args.device)#tensor(4,256,256)---->(4,1,256,256)
gt = gt.unsqueeze(1).to(args.device)#tensor(4,256,256)---->(4,1,256,256)
gt_mean = gt_mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)#tensor:(4,)--->(4,1,1,1)
gt_std = gt_std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)#tensor:(4,)--->(4,1,1,1)
unnorm_gt = gt * gt_std + gt_mean #(4,1,256,256)
data_range = torch.stack([data_range_dict[vol] for vol in fname])#(4,1,1,1)
tbs += mask.size(0)#4
# Base reconstruction model forward pass
recons = recon_model(zf) #(4,1,256,256)
unnorm_recons = recons[:, :, :, :] * gt_std + gt_mean #(4,1,256,256)
init_ssim_val = compute_ssim(unnorm_recons, unnorm_gt, size_average=False,
data_range=data_range).mean(dim=(-1, -2)).sum()
init_psnr_val = compute_psnr(args, unnorm_recons, unnorm_gt, data_range).sum()
def compute_psnr(args, unnorm_recons, gt_exp, data_range):
# Have to reshape to batch . trajectories x res x res and then reshape back to batch x trajectories x res x res
# because of psnr implementation
psnr_recons = torch.clamp(unnorm_recons, 0., 10.).reshape(gt_exp.size(0) * gt_exp.size(1), 1, args.resolution, args.resolution) #没执行这句之前unnorm_recons[0,0,:,:].min()=7.0674e-06, max=0.0002;执行完这句后(4,1,256,256) psnr_recons[0,0,:,:].min()=7.0674e-06, max=0.0002
psnr_gt = gt_exp.reshape(gt_exp.size(0) * gt_exp.size(1), 1, args.resolution, args.resolution)#(4,1,256,256)-->(4,1,256,256)
# First duplicate data range over trajectories, then reshape: this to ensure alignment with recon and gt.
psnr_data_range = data_range.expand(-1, gt_exp.size(1), -1, -1)#(4,1,1,1)
psnr_data_range = psnr_data_range.reshape(gt_exp.size(0) * gt_exp.size(1), 1, 1, 1)#(4,1,1,1)
psnr_scores = psnr(psnr_recons, psnr_gt, reduction='none', data_range=psnr_data_range) #(4,) tensor([24.5812, 24.8475, 26.9927, 27.7892], device='cuda:0')
psnr_scores = psnr_scores.reshape(gt_exp.size(0), gt_exp.size(1)) #tensor(4,1)
return psnr_scores
batch_ssims = [init_ssim_val.item()]#[2.8300583362579346]
batch_psnrs = [init_psnr_val.item()]#[104.21051788330078]
for step in range(args.acquisition_steps):#args.acquisition_steps=16
policy, probs = get_policy_probs(model, recons, mask)
DataParallel(
(module): PolicyModel(
(channel_layer): ConvBlock(in_chans=1, out_chans=8, drop_prob=0, max_pool_size=1)
(down_sample_layers): ModuleList(
(0): ConvBlock(in_chans=8, out_chans=16, drop_prob=0, max_pool_size=2)
(1): ConvBlock(in_chans=16, out_chans=32, drop_prob=0, max_pool_size=2)
(2): ConvBlock(in_chans=32, out_chans=64, drop_prob=0, max_pool_size=2)
(3): ConvBlock(in_chans=64, out_chans=128, drop_prob=0, max_pool_size=2)
(4): ConvBlock(in_chans=128, out_chans=256, drop_prob=0, max_pool_size=2)
)
(fc_out): Sequential(
(0): Linear(in_features=16384, out_features=256, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=256, out_features=256, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=256, out_features=256, bias=True)
)
)
)
def get_policy_probs(model, recons, mask):#mask:(4,1,1,256,1), recons(4,1,256,256)
channel_size = mask.shape[1]#1
res = mask.size(-2)#256
# Reshape trajectory dimension into batch dimension for parallel forward pass
recons = recons.view(mask.size(0) * channel_size, 1, res, res)#(4,1,256,256)
# Obtain policy model logits
output = model(recons)
class PolicyModel(nn.Module):
def forward(self, image):
"""
Args:
image (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
mask (torch.Tensor): Input tensor of shape [resolution], containing 0s and 1s
Returns:
(torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
"""
# Image embedding
# Initial block
image_emb = self.channel_layer(image)#(4,1,256,256)---->(4,8,256,256)
# Apply down-sampling layers
for layer in self.down_sample_layers:
image_emb = layer(image_emb)#(4,8,256,256)--->(4,16,256,256)-->(4,32,256,256)-->64-->128--->256
image_emb = self.fc_out(image_emb.flatten(start_dim=1)) # flatten all but batch dimension #(4,256)
assert len(image_emb.shape) == 2
return image_emb
# Reshape trajectories back into their own dimension
output = output.view(mask.size(0), channel_size, res)#(4,1,256)
# Mask already acquired rows by setting logits to very negative numbers
loss_mask = (mask == 0).squeeze(-1).squeeze(-2).float() #(4,1,256),除了中间32个0,其余全是1
logits = torch.where(loss_mask.byte(), output, -1e7 * torch.ones_like(output))#(4,1,256),
# Softmax over 'logits' representing row scores
probs = torch.nn.functional.softmax(logits - logits.max(dim=-1, keepdim=True)[0], dim=-1)#(4,1,256)
# Also need this for sampling the next row at the end of this loop
policy = torch.distributions.Categorical(probs)
return policy, probs
一直往mask中加竖线,for循环16遍
step=0
if step == 0:
actions = torch.multinomial(probs.squeeze(1), args.num_test_trajectories, replacement=True)#tensor(4,8)值全是145
# Samples trajectories in parallel
# For evaluation we can treat greedy and non-greedy the same: in both cases we just simulate
# num_test_trajectories acquisition trajectories in parallel for each slice in the batch, and store
# the average SSIM score every time step.
mask, masked_kspace, zf, recons = compute_next_step_reconstruction(recon_model, kspace, masked_kspace, mask, actions) #kspace(4,1,256,256,2), masked_kspace(4,1,256,256,2), mask(4,1,1,256,1)
对于probs.squeeze(1):(4,256),除了第145列,其余值都是0
def compute_next_step_reconstruction(recon_model, kspace, masked_kspace, mask, next_rows): #next_rows=actions
# This computation is done by reshaping the masked k-space tensor to (batch . num_trajectories x 1 x res x res)
# and then reshaping back after performing a reconstruction.
mask, masked_kspace = acquire_rows_in_batch_parallel(kspace, masked_kspace, mask, next_rows)
def acquire_rows_in_batch_parallel(k, mk, mask, to_acquire):# k(4,1,256,256,2)
if mask.size(1) == mk.size(1) == to_acquire.size(1):
# Two cases:
# 1) We are only requesting a single k-space column to acquire per batch.
# 2) We are requesting multiple k-space columns per batch, and we are already in a trajectory of the non-greedy
# model: every column in to_acquire corresponds to an existing trajectory that we have sampled the next
# column for.
m_exp = mask
mk_exp = mk
else:
# We have to initialise trajectories: every row in to_acquire corresponds to a trajectory.
m_exp = mask.repeat(1, to_acquire.size(1), 1, 1, 1) #(4,1,256,256,2),中间32个1,其余全是0
mk_exp = mk.repeat(1, to_acquire.size(1), 1, 1, 1)#(4,8,256,256,2)和kspace一样,中间白条,两边黑
# Loop over slices in batch
for sl, rows in enumerate(to_acquire):#sl=0, rows=tensor([145, 145, 145, 145, 145, 145, 145, 145], device='cuda:0'), 有4个图片我这里设置的,所以sl遍历4次,然后每个图片都有8个trajectory图片模版
# Loop over indices to acquire
for index, row in enumerate(rows): # Will only be single index if first case (see comment above)#index=0,row=145
m_exp[sl, index, :, row.item(), :] = 1. #(48,1,256,1)除了中间32个(121~143)+第145个的值是1,其余全是0
mk_exp[sl, index, :, row.item(), :] = k[sl, 0, :, row.item(), :]#取kspace第145列的值给masked_kspace, 一次for之后,index遍历8次,把kspace第145列的值给这8个模版
return m_exp, mk_exp
经过这段代码后,得到新的under-sampled kspace
channel_size = masked_kspace.shape[1]#masked_kspace(4,8,256,256,2),channel_size=8
res = masked_kspace.size(-2)#256
# Combine batch and channel dimension for parallel computation if necessary
masked_kspace = masked_kspace.view(mask.size(0) * channel_size, 1, res, res, 2)#(32,1,256,256,2)
zf, _, _ = get_new_zf(masked_kspace) #(32,1,256,256)
def get_new_zf(masked_kspace_batch):#(32,1,256,256,2)
# Inverse Fourier Transform to get zero filled solution
image_batch = transforms.ifft2(masked_kspace_batch)#(32,1,256,256,2)
# Absolute value
image_batch = transforms.complex_abs(image_batch)#(32,1,256,256)
# Normalize input
image_batch, means, stds = transforms.normalize(image_batch, dim=(-2, -1), eps=1e-11)
image_batch = image_batch.clamp(-6, 6)
return image_batch, means, stds
recon = recon_model(zf)#(32,1,256,256)
# Reshape back to B X C (=parallel acquisitions) x H x W
recon = recon.view(mask.size(0), channel_size, res, res)#recon(4,8,256,256)。mask(4,8,1,256,1)
zf = zf.view(mask.size(0), channel_size, res, res)#(4,8,256,256)
masked_kspace = masked_kspace.view(mask.size(0), channel_size, res, res, 2)#(32,1,256,256,2)-->(4,8,256,256,2)
return mask, masked_kspace, zf, recon
ssim_scores, psnr_scores = compute_scores(args, recons, gt_mean, gt_std, unnorm_gt, data_range, comp_psnr=True)
def compute_scores(args, recons, gt_mean, gt_std, unnorm_gt, data_range, comp_psnr=True):
# For every slice in the batch, and every acquired action per slice, compute the resulting SSIM (and PSNR) scores
# in parallel.
# Unnormalise reconstructions
unnorm_recons = recons * gt_std + gt_mean #(4,8,256,256)
# Reshape targets if necessary (for parallel computation of multiple acquisitions)
gt_exp = unnorm_gt.expand(-1, recons.shape[1], -1, -1)#(4,8,256,256)
# SSIM scores = batch x k (channels)
ssim_scores = compute_ssim(unnorm_recons, gt_exp, size_average=False, data_range=data_range).mean(-1).mean(-1)#(4,8)
# Also compute PSNR
if comp_psnr:
psnr_scores = compute_psnr(args, unnorm_recons, gt_exp, data_range)
return ssim_scores, psnr_scores
return ssim_scores
assert len(ssim_scores.shape) == 2
ssim_scores = ssim_scores.mean(-1).sum()#tensor(2.8645, device='cuda:0')
psnr_scores = psnr_scores.mean(-1).sum()#tensor(105.1553, device='cuda:0')
# eventually shape = al_steps
batch_ssims.append(ssim_scores.item())
batch_psnrs.append(psnr_scores.item())
一个for循环之后,这里0是初始值,0是第一个for循环即mask加了一条线之后的结果。然后跳到for step in range(args.acquisition_steps)
继续遍历range(0, 16)要遍历16次。
step=1
for step in range(args.acquisition_steps):
policy, probs = get_policy_probs(model, recons, mask) #probs只有第146列有值,四个文件的值分别是0.96784,0.93953,0.92923,0.86803
if step == 0:
else:
actions = policy.sample()#(4,8)值全是146
# Samples trajectories in parallel
# For evaluation we can treat greedy and non-greedy the same: in both cases we just simulate
# num_test_trajectories acquisition trajectories in parallel for each slice in the batch, and store
# the average SSIM score every time step.
mask, masked_kspace, zf, recons = compute_next_step_reconstruction(recon_model, kspace, masked_kspace, mask, actions)
def compute_next_step_reconstruction(recon_model, kspace, masked_kspace, mask, next_rows):
# This computation is done by reshaping the masked k-space tensor to (batch . num_trajectories x 1 x res x res)
# and then reshaping back after performing a reconstruction.
mask, masked_kspace = acquire_rows_in_batch_parallel(kspace, masked_kspace, mask, next_rows)
channel_size = masked_kspace.shape[1]
res = masked_kspace.size(-2)
# Combine batch and channel dimension for parallel computation if necessary
masked_kspace = masked_kspace.view(mask.size(0) * channel_size, 1, res, res, 2)
zf, _, _ = get_new_zf(masked_kspace)
recon = recon_model(zf)
# Reshape back to B X C (=parallel acquisitions) x H x W
recon = recon.view(mask.size(0), channel_size, res, res)
zf = zf.view(mask.size(0), channel_size, res, res)
masked_kspace = masked_kspace.view(mask.size(0), channel_size, res, res, 2)
return mask, masked_kspace, zf, recon
ssim_scores, psnr_scores = compute_scores(args, recons, gt_mean, gt_std, unnorm_gt, data_range, comp_psnr=True)
assert len(ssim_scores.shape) == 2
ssim_scores = ssim_scores.mean(-1).sum()#tensor(2.8925, device='cuda:0')
psnr_scores = psnr_scores.mean(-1).sum()#tensor(105.9600, device='cuda:0')
# eventually shape = al_steps
batch_ssims.append(ssim_scores.item())
batch_psnrs.append(psnr_scores.item())
step=2,3…15
看一下step=15
# shape of al_steps
ssims += np.array(batch_ssims)
psnrs += np.array(batch_psnrs)
然后跳到for it, data in enumerate(loader):
循环下一个batch
循环完所有batch计算ssim和psnr
ssims /= tbs
psnrs /= tbs
# Logging
elif partition == 'Test':
# Only computed once, so loop over all epochs for wandb logging
if args.wandb:#False
return ssims, psnrs, time.perf_counter() - start
ssims_str = ", ".join(["{}: {:.4f}".format(i, l) for i, l in enumerate(ssims)])
psnrs_str = ", ".join(["{}: {:.3f}".format(i, l) for i, l in enumerate(psnrs)])
logging.info(f'{partition}SSIM = [{ssims_str}]')
logging.info(f'{partition}PSNR = [{psnrs_str}]')
logging.info(f'{partition}ScoreTime = {score_time:.2f}s')