绘制ROC曲线

ROC曲线绘制


适用于二分类任务。
前提条件:预训练好的模型和二分类的数据集(代码中的casia1_t就是数据集,最好和预训练模型使用的数据集同源,casia1下有两个文件夹authentic和tampered,代表两个类别)。注意label_mapping = {‘authentic’: 1, ‘tampered’: 0}这和预训练时的分类标签要一样。

import torch
import os
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from sklearn.metrics import roc_curve, auc


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
main_folder = 'casia1_t'
subfolders = [f.path for f in os.scandir(main_folder) if f.is_dir()]
class_labels = {}

for i, subfolder in enumerate(subfolders):
    class_label = os.path.basename(subfolder)  
    class_labels[i] = class_label
    
print(class_labels)
y_true = []

for root, dirs, files in os.walk(main_folder):
    class_label = os.path.basename(root)
    for file in files:
        y_true.append(class_label)
label_mapping = {'authentic': 1, 'tampered': 0}
y_true_binary = [label_mapping[label] for label in y_true]       
print(f"y_true_binary:{y_true_binary}")


DATA_DIR = "casia1_t"
transform = transforms.Compose([transforms.Resize([128, 128]), transforms.ToTensor()])
dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
batch_size = 32 
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

model = torch.load('checkpoints-model/srm_casia1_100_128_005.pth')
model.eval() 
model = model.to(device)
predictions = []
for batch in data_loader:
    images, _ = batch 
    images = images.to(device)
    with torch.no_grad():
        output = model(images)
        probabilities = F.softmax(output, dim=1)
        predicted_classes = torch.argmax(probabilities, dim=1).cpu().numpy()
        predictions.extend(predicted_classes)
    
# y_score = probabilities[0].cpu().detach().numpy()
print(f"predictions:{predictions}")
fpr, tpr, _ = roc_curve(y_true_binary,predictions)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = {:.2f})'.format(roc_auc))
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc='lower right')
plt.show()
plt.savefig('picture/AUC-ROC.png')
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值