用sklearn.metrics 在测试集画ROC

进行了一个完整的二分类的测试集的预测,用了已经训练好的model,加载后进行的


if __name__ == '__main__':
    #from utils import read_split_data, train_one_epoch, evaluate
    from torch.utils.data import DataLoader
    import torch
    import torch.optim as optim
    from torch.utils.tensorboard import SummaryWriter
    from torchvision import transforms
    import torch.optim.lr_scheduler as lr_scheduler
    from model import efficientnet_b0 as create_model
    from my_dataset import MyDataSet
    from tqdm import tqdm
    import os
    import sys
    import json
    import pickle
    import random

    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    ls3 = []
    ls3_label = []
    for root, dirs, files in os.walk("..\\dataset\\ff++edge\\valid\\fake"):
        nums = 0
        for f in files:
            if nums%1==0:
                ls3.append(os.path.join(root, f))
                ls3_label.append(0)
            nums += 1
    ls4 = []
    ls4_label = []
    for root, dirs, files in os.walk("..\\dataset\\ff++edge\\valid\\real"):
        nums = 0
        for f in files:
            if nums%1 == 0:
                ls4.append(os.path.join(root, f))
                ls4_label.append(1)
            nums += 1

    val_images_path = ls3 + ls4
    val_images_label = ls3_label+ ls4_label

    img_size ={"B0": 224,"B1": 240,"B2": 260,"B3": 300,"B4": 380, "B5": 456,"B6": 528,"B7": 600}
    num_model = "B0"

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model]),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(img_size[num_model]),
                                   transforms.CenterCrop(img_size[num_model]),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])

    val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=128,shuffle=False,pin_memory=True,num_workers=0,collate_fn=val_dataset.collate_fn)


    model = create_model(num_classes=2).to(device)
    model_weight_path = "weightsff++edge/model-21.pth"

    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    def evaluate(model, data_loader, device):
        model.eval()
        # 验证样本总个数
        total_num = len(data_loader.dataset)
        # 用于存储预测正确的样本个数
        sum_num = torch.zeros(1).to(device)
        data_loader = tqdm(data_loader, file=sys.stdout)
        pre_label = []
        pre_label2=[]
        with torch.no_grad():#强制之后的内容不进行图构建
            for step, data in enumerate(data_loader):
                images, labels = data

                pred = model(images.to(device))
                
                output = torch.squeeze(pred)
                pre_label2.extend(output.tolist())

                pred1 = torch.max(pred, dim=1)[1]  # 最终结果就是返回最大值的索引值
                pre_label.extend(pred1.tolist())
                
                sum_num += torch.eq(pred1, labels.to(device)).sum()

            return sum_num.item() / total_num,pre_label,pre_label2
    acc,pre_label,pre_label2 = evaluate(model=model,data_loader=val_loader,device=device)
    print(len(pre_label2))
    print(pre_label2[:10])#01数组
    
    #不会在batch里用softmax,就这样改写了
    pre_label3 = [torch.softmax(torch.tensor(i), dim=0).numpy()[1] for i in pre_label2]#转tensor才能用softmax,然后再转回非tensor,再取0,1预测的1的预测值(小数)
    print(pre_label3[:10])#1的预测值数组

    print('acc',acc)

    from sklearn.metrics import roc_curve, auc
    import matplotlib.pyplot as plt
    y_label =val_images_label 
    y_pre = pre_label3
    fpr, tpr, thersholds = roc_curve(y_label, y_pre)#用1的预测

    roc_auc = auc(fpr, tpr)
    print('ROC (area = {0:.4f})'.format(roc_auc))

    plt.plot(fpr, tpr, 'k--', label='ROC (area = {0:.4f})'.format(roc_auc), lw=2)

    plt.xlim([-0.05, 1.05])  # 设置x、y轴的上下限,以免和边缘重合,更好的观察图像的整体
    plt.ylim([-0.05, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')  # 可以使用中文,但需要导入一些库即字体
    plt.title('ROC Curve')
    plt.legend(loc="lower right")
    plt.show()

后续再补充混淆矩阵和其他的

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值