画轮廓曲线

def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
    if len(image.shape) == 3:
        prediction = np.zeros_like(label)
        for ind in range(image.shape[0]):

            slice = image[ind, :, :]
            x, y = slice.shape[0], slice.shape[1]
            if x != patch_size[0] or y != patch_size[1]:
                slice_input = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3)  # previous using 0
            else:
                slice_input = slice
            input = torch.from_numpy(slice_input).unsqueeze(0).unsqueeze(0).float().cuda()
            net.eval()
            with torch.no_grad():
                outputs = net(input)
                out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
                out = out.cpu().detach().numpy()
                if x != patch_size[0] or y != patch_size[1]:
                    pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
                else:
                    pred = out
                prediction[ind] = pred

            if test_save_path is not None:
                if np.sum(label[ind])==0:
                    continue
                import cv2
                import os
                test_save_frame_path = test_save_path + "_frames224cover"
                os.makedirs(test_save_frame_path, exist_ok=True)
                cmap = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255],
                        [244, 208, 63], [234, 240, 241]]
                print("test_save_frame_path", test_save_frame_path)
                slice_img, slice_prd, slice_lab = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=1), out, zoom(label[ind], (patch_size[0] / x, patch_size[1] / y), order=0)
                slice_img = np.rot90(slice_img, 3)
                # cv2.imwrite(test_save_frame_path + '/' + case + '_img_slice' + str(ind).zfill(3) + '.png', (slice_img * 255).astype(np.uint8))
                slice_prd = np.rot90(slice_prd, 3)
                slice_lab = np.rot90(slice_lab, 3)
                slice_prd_save = np.zeros((slice_prd.shape[0], slice_prd.shape[1], 3))
                slice_lab_save = np.zeros((slice_lab.shape[0], slice_lab.shape[1], 3))
                for i in range(1, classes):
                    for j in range(3):
                        slice_prd_save[slice_prd == i, j] = 0
                        slice_prd_save[slice_prd == i, j] = 0
                        slice_prd_save[slice_prd == i, j] = cmap[i - 1][j]
                        slice_lab_save[slice_lab == i, j] = cmap[i - 1][j]
                # cv2.imwrite(test_save_frame_path + '/' + case + '_prd_slice' + str(ind).zfill(3) + '.png', slice_prd_save.astype(np.uint8))
                # cv2.imwrite(test_save_frame_path + '/' + case + '_lab_slice' + str(ind).zfill(3) + '.png', slice_lab_save.astype(np.uint8))
                slice_img_unsqueeze = np.reshape(slice_img, (slice_img.shape[0], slice_img.shape[1], 1))
                slice_img_3ch = np.concatenate([slice_img_unsqueeze, slice_img_unsqueeze, slice_img_unsqueeze], axis=2) # np.repeat(slice_img.reshape((slice_img[0], slice_img[1], 1)), 3, axis=2)
                slice_img_prd = slice_img_3ch.copy() * 255 # 0.7 * slice_img_3ch * 255 + 0.3 * slice_prd_save
                slice_img_lab = slice_img_3ch.copy() * 255 # 0.7 * slice_img_3ch * 255 + 0.3 * slice_lab_save

                for i in range(1, classes):
                    for j in range(3):
                        slice_img_prd[slice_prd == i, j] = cmap[i - 1][j]
                        slice_img_lab[slice_lab == i, j] = cmap[i - 1][j]
                cv2.imwrite(test_save_frame_path + '/' + case + '_imgprd_slice' + str(ind).zfill(3) + '.png',
                            slice_img_prd.astype(np.uint8))
                cv2.imwrite(test_save_frame_path + '/' + case + '_imglab_slice' + str(ind).zfill(3) + '.png',
                            slice_img_lab.astype(np.uint8))
    else:
        input = torch.from_numpy(image).unsqueeze(
            0).unsqueeze(0).float().cuda()
        net.eval()
        with torch.no_grad():
            out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
            prediction = out.cpu().detach().numpy()
    metric_list = []
    for i in range(1, classes):
        if test_save_path is not None:
            metric_list.append([0,0])
        else:
            metric_list.append(calculate_metric_percase(prediction == i, label == i))

    if test_save_path is not None and False:
        img_itk = sitk.GetImageFromArray(image.astype(np.float32))
        prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32))
        lab_itk = sitk.GetImageFromArray(label.astype(np.float32))
        img_itk.SetSpacing((1, 1, z_spacing))
        prd_itk.SetSpacing((1, 1, z_spacing))
        lab_itk.SetSpacing((1, 1, z_spacing))
        sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz")
        sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz")
        sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz")

        import cv2
        import os
        test_save_frame_path = test_save_path + "_frames"
        os.makedirs(test_save_frame_path, exist_ok=True)
        cmap = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [255, 0, 255], [0, 255, 255],
                [244, 208, 63], [234, 240, 241]]
        print("test_save_frame_path", test_save_frame_path)
        for ind in range(image.shape[0]):
            slice_img, slice_prd, slice_lab = image[ind], prediction[ind], label[ind]
            slice_img = np.rot90(slice_img, 3)
            cv2.imwrite(test_save_frame_path + '/' + case + '_img_slice' + str(ind).zfill(3) + '.png',
                        (slice_img * 255).astype(np.uint8))
            slice_prd = np.rot90(slice_prd, 3)
            slice_lab = np.rot90(slice_lab, 3)
            slice_prd_save = np.zeros((slice_prd.shape[0], slice_prd.shape[1], 3))
            slice_lab_save = np.zeros((slice_lab.shape[0], slice_lab.shape[1], 3))
            for i in range(1, classes):
                for j in range(3):
                    slice_prd_save[slice_prd == i, j] = cmap[i - 1][j]
                    slice_lab_save[slice_lab == i, j] = cmap[i - 1][j]
            cv2.imwrite(test_save_frame_path + '/' + case + '_prd_slice' + str(ind).zfill(3) + '.png',
                        slice_prd_save.astype(np.uint8))
            cv2.imwrite(test_save_frame_path + '/' + case + '_lab_slice' + str(ind).zfill(3) + '.png',
                        slice_lab_save.astype(np.uint8))

    return metric_list

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值