ValueError: multiclass format is not supported
找到报错位置:
def compute_auc(pred, label):
if isinstance(pred, torch.Tensor):
pred = pred.cpu().detach().numpy().flatten()
if isinstance(label, torch.Tensor):
label = label.cpu().detach().numpy().flatten()
fpr, tpr, thresholds = metrics.roc_curve(label, pred)#报错位置
return metrics.auc(fpr, tpr)
报错原因:
设置断点debug看看这两个输入数据的内容和类型
报错的原因就是pos_label=None了
仔细查看label的数值,有0, 1,2
,是个多分类,但是roc曲线一般是二分类的,多分类用混淆矩阵来做。我的数据就是二分类,那肯定就是我的label数据有个标签打成了2,改为1就可以了。