#coding:utf8
'''
Created on 2018年8月2日
@author: Administrator
'''
%matplotlib inline
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score,recall_score,precision_score
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def getPointInCle(r=1):
if type(r)==list:
corn=np.random.random()*180/180*np.pi
x=(np.random.random()*(r[1]-r[0])+r[0])*np.cos(corn)
y=x*np.tan(corn)*(-1)**np.random.randint(10)
return [x,y]
else:
x=np.random.random()*r*(-1)**np.random.randint(10)
y=np.sqrt((r**2-x**2)*np.random.random())*(-1)**np.random.randint(10)
return [x,y]
if __name__ == '__main__':
print("getting data ...")
aPoints=np.array([getPointInCle(r=1)+[0] for i in range(500)])
bPoints=np.array([getPointInCle(r=[2,3])+[1] for i in range(500)])
plt.scatter(aPoints[:,0],aPoints[:,1])
plt.scatter(bPoints[:,0],bPoints[:,1])
plt.show()
print("visualizing RBF ...")
sv=np.var(np.array([np.sqrt(abi[0]**2+abi[1]**2) for abi in aPoints.tolist()+bPoints.tolist()]))
aRPoints=np.array([[ai[0],ai[1],np.exp(-(ai[0]**2+ai[1]**2)/(2*sv))] for ai in aPoints.tolist()])
bRPoints=np.array([[bi[0],bi[1],np.exp(-(bi[0]**2+bi[1]**2)/(2*sv))] for bi in bPoints.tolist()])
fig=plt.figure()
ax=Axes3D(fig)
ax.scatter(aRPoints[:,0],aRPoints[:,1],aRPoints[:,2])
ax.scatter(bRPoints[:,0],bRPoints[:,1],bRPoints[:,2])
x=range(-3,4)
y=range(-3,4)
xx,yy=np.meshgrid(x,y)
z=(xx-yy.T)+0.2
plt3d = fig.gca(projection='3d')
plt3d.plot_surface(xx, yy, z, alpha=0.2)
plt.show()
print("splitting train set and test set ...")
trainX=np.array(aPoints[:400,0:2].tolist()+bPoints[:400,0:2].tolist())
testX=np.array(aPoints[400:,0:2].tolist()+bPoints[400:,0:2].tolist())
trainY=np.array(aPoints[:400,2].tolist()+bPoints[:400,2].tolist())
testY=np.array(aPoints[400:,2].tolist()+bPoints[400:,2].tolist())
print("building model ...")
myModel=SVC(kernel='rbf')
#myModel=SVC(kernel='linear')#取消这个注释,可查看线性核函数的结果
print("training model ...")
myModel.fit(trainX,trainY)
print("predicting test set ...")
yPre=myModel.predict(testX)
print("estimating ...")
print("acc:",accuracy_score(yPre,testY))
print("recall:",recall_score(yPre,testY,labels=1))
print("precision:",precision_score(yPre,testY,labels=1))
带核SVM分类器分类结果及分类边界示意图如下:
svm对该随机数据处理的结果如下:
acc: 1.0
recall: 1.0
precision: 1.0