plt知识点整理——keras分类测试

plt知识点整理,写的比较全面的博客:https://www.cnblogs.com/zhizhan/p/5615947.html(感谢作者)

*****************predict_generator和evaluate_generator 区别**************

model.predict_generator用于模型预测,常在测试时使用(返回的结果就是概率值)——不需要true label

model.evaluate_generator用于模型评估,常在训练过程中验证集上使用(或者测试过程中求准确率)——需要true label

注意:如果同时使用预测和评估(分别是predict_generator和evaluate_generator),一定要使用reset()来初始化一下predict_generator 和evaluate_generator对象;

如下:

test_generator.reset()
pred = model.predict_generator(test_generator, verbose=1)

********************************下面是keras分类测试程序***************************************

PS:分类结果使用混淆矩阵统计

*********分享是为了更好的进步*********

import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools
import os

def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          save_path='./test_confusion_matrix',
                          cmap=plt.cm.Greens,  # 这个地方设置混淆矩阵的颜色主题,这个主题看着就干净~
                          normalize=True):
    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    # plt.title(title)
    plt.title(title+'\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")

    # plt.tight_layout()

    plt.ylabel('True label')
    # plt.xlabel('Predicted label\n accuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.xlabel('Predicted label')

    # plt.text(0, 0, 'accuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))

    # 这里这个savefig是保存图片,如果想把图存在什么地方就改一下下面的路径,然后dpi设一下分辨率即可。
    aa=save_path+'/confusionmatrix.png'
    plt.savefig(save_path+'/confusionmatrix.png')
    plt.savefig('./confusionmatrix.png')

    # plt.show()

if __name__=='__main__':
    #test image directory
    dst_path = 'E:/datasets/Rcam-plusMelangerTaile/8KLSBackWindow/test'
    #model path
    model_file = "E:/datasets/Rcam-plusMelangerTaile/8KLSBackWindow/baseline_Seblock/modelpath/baseline_seblock.h5"
    title=model_file.split('/')[-3]
    save_path= "E:/datasets/Rcam-plusMelangerTaile/8KLSBackWindow/Evaluate_confusion_matrix"
    if not os.path.exists(save_path):
        os.makedirs(save_path
                    )
    batch_size = 8

    # load model
    model = load_model(model_file)
    # generator image
    test_datagen = ImageDataGenerator(rescale=1)

    test_generator = test_datagen.flow_from_directory(
        dst_path,
        target_size=(128, 128),
        batch_size=batch_size,
        shuffle=False
        )

    labels = test_generator.class_indices #查看类别的label
    #然后直接用predice_geneorator 可以进行预测
    test_generator.reset()
    pred = model.predict_generator(test_generator, verbose=1)
    # 输出每个图像的预测类别
    predicted_class_indices = np.argmax(pred, axis=1)
    #测试集的真实类别
    true_label= test_generator.classes

    #使用pd.crosstab来简单画出混淆矩阵
    import pandas as pd
    # table=pd.crosstab(predicted_class_indices,true_label,colnames=['predict'],rownames=['label'])
    table = pd.crosstab(true_label,predicted_class_indices, rownames=['label'],  colnames=['predict'])
    print(table)
    #图片化显示混淆矩阵,非常好看的说,哈哈哈
    conf_mat = confusion_matrix(y_true=true_label, y_pred=predicted_class_indices)
    plt.figure()
    plot_confusion_matrix(conf_mat, normalize=False, target_names=labels, title=title+'_Confusion Matrix',save_path=save_path)

效果图如下:

当然,也可以直接输出混淆矩阵结果:

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值