sklearn.metrics.roc_auc报错ValueError: unknown format is not supported

好记性不如烂笔头,之前踩的坑,不记下来,还是会掉进去爬不出来。

在使用sklearn.metrics.roc_auc绘制roc曲线时,报错ValueError: unknown format is not supported。查了好多资料,发现是输入数据的type为unknown导致的。不过这里的type不是用type函数打印的,而是sklearn.utils.multiclass.type_of_target。具体看下面代码。

import torch
import torch.nn as nn
import os
import numpy as np
import SimpleITK as sitk
from iantsen_dataset import HecktorDataset
from torch.utils.data import DataLoader
from DMCTNet_noGC import DMCTNet_noGC
from scipy import interp
from pathlib import Path
import matplotlib.pyplot as plt
from itertools import cycle
import transforms
from sklearn.metrics import roc_curve, auc, f1_score, precision_recall_curve, average_precision_score
from sklearn.utils.multiclass import type_of_target

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

resample_data_path = r''  # 测试集路径
test_weights_path = r''  # 训练好的模型参数路径
num_class = 1  # 类别数量  
gpu = "cuda:0"

def test(model, test_path):
    # 加载测试集和预训练模型参数

    path_to_imgs = Path(resample_data_path)
    # patients只有一个CHUP015
    patients = [p for p in os.listdir(path_to_imgs) if os.path.isdir(path_to_imgs / p)]

    test_paths = []
    for p in patients:
        path_to_ct = path_to_imgs / p / (p + '_ct.nii.gz')
        path_to_pt = path_to_imgs / p / (p + '_pt.nii.gz')
        path_to_gtvt = path_to_imgs / p / (p + '_gtvt.nii.gz')
        test_paths.append((path_to_ct, path_to_pt))

    val_transforms = transforms.Compose([
        transforms.NormalizeIntensity(),
        transforms.ToTensor(mode='test')
    ])
    output_transform = transforms.Compose([
        transforms.transform_back(mode='test')
    ])

    test_set = HecktorDataset(test_paths, transforms=val_transforms, mode='test')
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
    
    model = torch.nn.DataParallel(model).cuda()
    checkpoint = torch.load(test_path)
    model.load_state_dict(checkpoint, strict=True)
    model.eval()

    # 上面部分就是加载数据,加载模型

    score_list = []  # 存储预测得分
    label_list = []  # 存储真实标签
    for sample in test_loader:
        inputs = sample['input'].cuda()
        outputs = model(inputs)
	# score_array shape:[1, 1, 144,144, 72]
    score_array = outputs.detach().cpu().numpy()
    label_itk = sitk.ReadImage(str(path_to_gtvt))
    # label_array shape:[144, 144, 72]
    label_array = sitk.GetArrayFromImage(label_itk)
    score_array = np.squeeze(score_array)
    # label_array type: <class 'numpy.ndarray'>
    # score_array type: <class 'numpy.ndarray'>
    print('label_array type:', type(label_array))
    print("score_array type:", type(score_array))  
    # type of score: unknown
    # type of label: unknown
    print('type of score:', type_of_target(score_array))
    print('type of label:', type_of_target(label_array))
    
    score_array = score_array.reshape(-1, 1)
    label_array = label_array.reshape(-1, 1)
    # score_array shape: (2985984, 1)
    # label_array shape: (2985984, 1)
    print('score_array shape:', score_array.shape)
    print('label_array shape:', label_array.shape)
    
	# label_array type: <class 'numpy.ndarray'>
    # score_array type: <class 'numpy.ndarray'>
    # type of score: continuous
    # type of label: binary
    print('label_array type:', type(label_array))
    print("score_array type:", type(score_array))  
    print('type of score:', type_of_target(score_array))
    print('type of label:', type_of_target(label_array))

由于我的需求是基于已经训练好的模型,绘制测试集上的ROC曲线。以一个样本为例,首先加载测试集,加载训练模型,利用训练模型生成预测结果outputs。这个过程是在gpu上实现的,此时的数据类型是tensor。利用cpu()和numpy()函数将数据转移到cpu上,并转换成numpy格式score_array。此时,score_array的大小为1×1×144×144×144。然后加载label,label的大小为144×144×72。roc_curve的使用需要保证score_array和label_array大小一致,因此利用numpy.squeeze()将socre_array大小变为144×144×72。此时直接使用roc_curve函数会报上述错误。我们分别使用type和type_of_target打印score_array和labe_array的数据类型。可以看到type输出结果为numpy.ndarray,而type_of_target输出结果为unknown。

查了一些资料,方法大多是将输出转换为numpy或者list格式,还有就是使用astype函数将结果设置为int类型。这些方法都尝试过了,对我来说,没用。然后去看官方文档,看到说score_array和label_array的大小是样本数。心里有点儿疑惑,144×144×72的shape满足样本数的格式要求吗?然后想着,死马当做活马医,试一下吧,利用reshape函数将shape转换成了n×1的格式,再次打印type_of_target,居然变成了continuous和binary。本来不抱希望的,没想到居然有用,可见不管自己觉得多离谱,想到了就得试试。

代码还没改完就来记下这个坑,以后再也不能掉进来了。博客仅是记录以防后续踩坑,如有不严谨或者分析不对的地方,欢迎大家批评指正。(我要滚去改代码了。。。。)

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值