转载:
#coding=utf-8
"""
#演示目的:利用鸢尾花数据集画出P-R曲线,mooc
"""
print(__doc__)
import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm, datasets
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
#from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
#以iris数据为例,画出P-R曲线
iris =datasets.load_iris()
x =iris.data
y =iris.target
#print(x)
print(x.shape,y.shape)
y=label_binarize(y,classes=[0,1,2])
n_classes=y.shape[1]
print(y.shape)
'''
#print(y)
print(x.shape,y.shape)
'''
random_state=np.random.RandomState(0)
#print(random_state)
n_samples,n_features=x.shape# 分别是行数和列数
print("n_samples:",n_samples,"n_features:",n_features)
x=np.c_[x,random_state.randn(n_samples,200*n_features)] # The different of randn and random
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=.5,random_state=random_state)
#print("x_train",x_train,"x_test:",x_test)
classifier = OneVsRestClassifier(svm.SVC(kernel='linear',probability=True,random_state=random_state))
y_socre = classifier.fit(x_train,y_train).decision_function(x_test)
#print("classifier:",classifier)
#print("y_socre:",y_socre)
precision=dict()
recall=dict()
average_precision=dict()
for i in range(n_classes):
precision[i],recall[i],_= precision_recall_curve(y_test[:,i],y_socre[:,i])
average_precision[i]= average_precision_score(y_test[:,i],y_socre[:,i])
precision["micro"],recall["micro"],_=precision_recall_curve(y_test.ravel(),y_socre.ravel())
average_precision["micro"] =average_precision_score(y_test,y_socre,average="micro")
plt.clf()
plt.plot(recall["micro"],precision["micro"],label='micro-average Precision-recall'.format(average_precision["micro"]))
for i in range(n_classes):
plt.plot(recall[i], precision[i],
label='Precision-recall curve of class {0} (area = {1:0.2f})'.format(i, average_precision[i]))
plt.legend(loc="lower right")#legend 是用于设置图例的函数
plt.show()