分类精度计算细节

    preds = torch.tensor(preds).to('cuda') if not isinstance(preds,torch.Tensor) else preds
    targets = torch.tensor(targets).to('cuda') if not isinstance(targets, torch.Tensor) else targets
    probs = torch.tensor(probs).to('cuda') if not isinstance(probs, torch.Tensor) else probs

    accuracy = torchmetrics.Accuracy(num_classes=2, average='micro').to('cuda')
    lesion_acc = accuracy(preds, targets)
        
    cm = confusion_matrix(targets.squeeze().cpu().numpy(), preds.squeeze().cpu().numpy())

   在计算acc的时候,pred和target需要是torch.Tensor类型,且在cuda上,在计算混淆矩阵的时候,则需要转成numpy,且在cpu上

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值