def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
if len(image.shape) == 3:
prediction = np.zeros_like(label)
for ind in range(image.shape[0]):
slice = image[ind, :, :]
x, y = slice.shape[0], slice.shape[1]
if x != patch_size[0] or y != patch_size[1]:
slice_input = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0
else:
slice_input = slice
input = torch.from_numpy(slice_input).unsqueeze(0).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
outputs = net(input)
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
out = out.cpu().detach().numpy()
if x != patch_size[0] or y != patch_size[1]:
pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
else:
pred = out
prediction[ind] = pred
if test_save_path is not None:
if np.sum(label[ind])==0:
continue
import cv2
import os
test_save_frame_path = test_save_path + "_frames224cover"
os.makedirs(test_save_frame_path, exist_ok=True)
cmap = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255],
[244, 208, 63], [234, 240, 241]]
print("test_save_frame_path", test_save_frame_path)
slice_img, slice_prd, slice_lab = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=1), out, zoom(label[ind], (patch_size[0] / x, patch_size[1] / y), order=0)
slice_img = np.rot90(slice_img, 3)
# cv2.imwrite(test_save_frame_path + '/' + case + '_img_slice' + str(ind).zfill(3) + '.png', (slice_img * 255).astype(np.uint8))
slice_prd = np.rot90(slice_prd, 3)
slice_lab = np.rot90(slice_lab, 3)
slice_prd_save = np.zeros((slice_prd.shape[0], slice_prd.shape[1], 3))
slice_lab_save = np.zeros((slice_lab.shape[0], slice_lab.shape[1], 3))
for i in range(1, classes):
for j in range(3):
slice_prd_save[slice_prd == i, j] = 0
slice_prd_save[slice_prd == i, j] = 0
slice_prd_save[slice_prd == i, j] = cmap[i - 1][j]
slice_lab_save[slice_lab == i, j] = cmap[i - 1][j]
# cv2.imwrite(test_save_frame_path + '/' + case + '_prd_slice' + str(ind).zfill(3) + '.png', slice_prd_save.astype(np.uint8))
# cv2.imwrite(test_save_frame_path + '/' + case + '_lab_slice' + str(ind).zfill(3) + '.png', slice_lab_save.astype(np.uint8))
slice_img_unsqueeze = np.reshape(slice_img, (slice_img.shape[0], slice_img.shape[1], 1))
slice_img_3ch = np.concatenate([slice_img_unsqueeze, slice_img_unsqueeze, slice_img_unsqueeze], axis=2) # np.repeat(slice_img.reshape((slice_img[0], slice_img[1], 1)), 3, axis=2)
slice_img_prd = slice_img_3ch.copy() * 255 # 0.7 * slice_img_3ch * 255 + 0.3 * slice_prd_save
slice_img_lab = slice_img_3ch.copy() * 255 # 0.7 * slice_img_3ch * 255 + 0.3 * slice_lab_save
for i in range(1, classes):
for j in range(3):
slice_img_prd[slice_prd == i, j] = cmap[i - 1][j]
slice_img_lab[slice_lab == i, j] = cmap[i - 1][j]
cv2.imwrite(test_save_frame_path + '/' + case + '_imgprd_slice' + str(ind).zfill(3) + '.png',
slice_img_prd.astype(np.uint8))
cv2.imwrite(test_save_frame_path + '/' + case + '_imglab_slice' + str(ind).zfill(3) + '.png',
slice_img_lab.astype(np.uint8))
else:
input = torch.from_numpy(image).unsqueeze(
0).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
prediction = out.cpu().detach().numpy()
metric_list = []
for i in range(1, classes):
if test_save_path is not None:
metric_list.append([0,0])
else:
metric_list.append(calculate_metric_percase(prediction == i, label == i))
if test_save_path is not None and False:
img_itk = sitk.GetImageFromArray(image.astype(np.float32))
prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
img_itk.SetSpacing((1, 1, z_spacing))
prd_itk.SetSpacing((1, 1, z_spacing))
lab_itk.SetSpacing((1, 1, z_spacing))
sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz")
sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz")
sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz")
import cv2
import os
test_save_frame_path = test_save_path + "_frames"
os.makedirs(test_save_frame_path, exist_ok=True)
cmap = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255],
[244, 208, 63], [234, 240, 241]]
print("test_save_frame_path", test_save_frame_path)
for ind in range(image.shape[0]):
slice_img, slice_prd, slice_lab = image[ind], prediction[ind], label[ind]
slice_img = np.rot90(slice_img, 3)
cv2.imwrite(test_save_frame_path + '/' + case + '_img_slice' + str(ind).zfill(3) + '.png',
(slice_img * 255).astype(np.uint8))
slice_prd = np.rot90(slice_prd, 3)
slice_lab = np.rot90(slice_lab, 3)
slice_prd_save = np.zeros((slice_prd.shape[0], slice_prd.shape[1], 3))
slice_lab_save = np.zeros((slice_lab.shape[0], slice_lab.shape[1], 3))
for i in range(1, classes):
for j in range(3):
slice_prd_save[slice_prd == i, j] = cmap[i - 1][j]
slice_lab_save[slice_lab == i, j] = cmap[i - 1][j]
cv2.imwrite(test_save_frame_path + '/' + case + '_prd_slice' + str(ind).zfill(3) + '.png',
slice_prd_save.astype(np.uint8))
cv2.imwrite(test_save_frame_path + '/' + case + '_lab_slice' + str(ind).zfill(3) + '.png',
slice_lab_save.astype(np.uint8))
return metric_list