iris 数据分类

  转载:

#coding=utf-8
"""
#演示目的:利用鸢尾花数据集画出P-R曲线,mooc
"""
print(__doc__)

import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm, datasets
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split

#from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
#以iris数据为例,画出P-R曲线
iris =datasets.load_iris()
x =iris.data
y =iris.target
#print(x)
print(x.shape,y.shape)
y=label_binarize(y,classes=[0,1,2])
n_classes=y.shape[1]
print(y.shape)
'''
#print(y)
print(x.shape,y.shape)
'''
random_state=np.random.RandomState(0)

#print(random_state)
n_samples,n_features=x.shape# 分别是行数和列数
print("n_samples:",n_samples,"n_features:",n_features)
x=np.c_[x,random_state.randn(n_samples,200*n_features)]  # The different of randn and random 
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=.5,random_state=random_state)
#print("x_train",x_train,"x_test:",x_test)
classifier = OneVsRestClassifier(svm.SVC(kernel='linear',probability=True,random_state=random_state))
y_socre = classifier.fit(x_train,y_train).decision_function(x_test)
#print("classifier:",classifier)
#print("y_socre:",y_socre)
precision=dict()
recall=dict()
average_precision=dict()

for i in range(n_classes):
    precision[i],recall[i],_= precision_recall_curve(y_test[:,i],y_socre[:,i])
    average_precision[i]= average_precision_score(y_test[:,i],y_socre[:,i])

precision["micro"],recall["micro"],_=precision_recall_curve(y_test.ravel(),y_socre.ravel())
average_precision["micro"] =average_precision_score(y_test,y_socre,average="micro")

plt.clf()
plt.plot(recall["micro"],precision["micro"],label='micro-average Precision-recall'.format(average_precision["micro"]))

for i in range(n_classes):
    plt.plot(recall[i], precision[i],
             label='Precision-recall curve of class {0} (area = {1:0.2f})'.format(i, average_precision[i]))


plt.legend(loc="lower right")#legend 是用于设置图例的函数
plt.show()

 

 

 

 

 

 

 

 

 

 

https://www.icourse163.org/learn/HIT-1206320802?tid=1206635203#/learn/content?type=detail&id=1212706083&sm=1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值