莺尾花数据(画出不同分类器的ROC曲线)

本节主要是通过著名的莺尾花数据来介绍如何画出ROC曲线。Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。
本节选取了KNN、逻辑回归、SVM三种分类器对数据集进行分类处理。代码如下:

import numpy as np
import scipy as sp
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn import metrics
from itertools import cycle


if __name__ == '__main__':
    np.random.seed(0)
    pd.set_option('display.width', 300)
    np.set_printoptions(suppress=True)
    data = pd.read_csv('iris.data', header=None)
    iris_types = data[4].unique()
    for i, iris_type in enumerate(iris_types):
        data.loc[data[4] == iris_type, 4] = i
    x = data.iloc[:, :2]
    n, features = x.shape
    print (x)
    y = data.iloc[:, -1].astype(np.int)
    c_number = np.unique(y).size
    x, x_test, y, y_test = train_test_split(x, y, train_size=0.6, random_state=0)
    y_one_hot = label_binarize(y_test, classes=np.arange(c_number))
    alpha = np.logspace(-2, 2, 20)
    models = [
        ['KNN', KNeighborsClassifier(n_neighbors=7)],
        ['LogisticRegression', LogisticRegressionCV(Cs=alpha, penalty='l2', cv=3)],
        ['SVM(Linear)', GridSearchCV(SVC(kernel='linear', decision_function_shape='ovr'), param_grid={'C': alpha})],
        ['SVM(RBF)', GridSearchCV(SVC(kernel='rbf', decision_function_shape='ovr'), param_grid={'C': alpha, 'gamma': alpha})]]
    colors = cycle('gmcr')
    mpl.rcParams['font.sans-serif'] = u'SimHei'
    mpl.rcParams['axes.unicode_minus'] = False
    plt.figure(figsize=(7, 6), facecolor='w')
    for (name, model), color in zip(models, colors):
        model.fit(x, y)
        if hasattr(model, 'C_'):
            print(model.C_)
        if hasattr(model, 'best_params_'):
            print(model.best_params_)
        if hasattr(model, 'predict_proba'):
            y_score = model.predict_proba(x_test)
        else:
            y_score = model.decision_function(x_test)
        fpr, tpr, thresholds = metrics.roc_curve(y_one_hot.ravel(), y_score.ravel())
        auc = metrics.auc(fpr, tpr)
        print(auc)
        plt.plot(fpr, tpr, c=color, lw=2, alpha=0.7, label=u'%s,AUC=%.3f' % (name, auc))
    plt.plot((0, 1), (0, 1), c='#808080', lw=2, ls='--', alpha=0.7)
    plt.xlim((-0.01, 1.02))
    plt.ylim((-0.01, 1.02))
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.yticks(np.arange(0, 1.1, 0.1))
    plt.xlabel('False Positive Rate', fontsize=13)
    plt.ylabel('True Positive Rate', fontsize=13)
    plt.grid(b=True, ls=':')
    plt.legend(loc='lower right', fancybox=True, framealpha=0.8, fontsize=12)
    # plt.legend(loc='lower right', fancybox=True, framealpha=0.8, edgecolor='#303030', fontsize=12)
    plt.title(u'鸢尾花数据不同分类器的ROC和AUC', fontsize=17)
    plt.show()

最终的结果是:

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值