ensres = staple(th.stack(enslist,dim=0)).squeeze(0)
vutils.save_image(ensres, fp = os.path.join(args.out_dir, str(slice_ID)+'_output_ens'+".jpg"), nrow = 1, padding = 10)
ensres 的尺寸是:
print(ensres.shape)
torch.Size([8, 256, 256])
tensor.save_image 中的通道数只能是 1 或 3 ,也就是灰度图和 RBG 图像,我的这个代码此处应该保存的灰度图,但是通道数显示的 8 ,经过分析,是因为预测采样时 batch_size 设置成了 8 ,导致了最后的通道数有误,将 batch_size 设置成 1 后,问题解决,顺利运行。