一、KNN算法
具体算法原理不再细说,此处生成3类,60个点集合,并对预测值进行预测分类。
# -*- coding: utf-8 -*-
"""
Created on Wed May 16 12:09:34 2018
@author: Administrator
"""
from sklearn.datasets.samples_generator import make_blobs
from matplotlib import pyplot as plt
import numpy as np
#生成数据
#中心点
centers=[[-2,2],[2,2],[0,4]]
'''
使用sklearn.datasets.samples_generator包下的make_blobs函数来生成数据集,生成
60个训练样本,这60个训练样本分布在以centers参数指定中心点周围。
cluster_std:为标准差,来指明生成的点分布的松散程度
X:存放生成的训练集
y:存放数据的类标记
'''
X,y=make_blobs(n_samples=60,centers=centers,random_state=0,cluster_std=0.6)
print(X)
print(y)
#画出X,y数据
plt.figure(figsize=(4,4),dpi=144)
c=np.array(centers)
#画出样本
plt.scatter(X[:,0],X[:,1],c=y,s=100,cmap='cool')
#画出中心点
plt.scatter(c[:,0],c[:,1],s=100,marker='^',c='orange')
#使用KNN对算法进行训练
from sklearn.neighbors import KNeighborsClassifier
#模型训练,
k=5
clf=KNeighborsClassifier(n_neighbors=k)
clf.fit(X,y)
#进行预测,要预测的数据为[0,2]
X_sample=[0,2]
y_sample=clf.predict(X_sample)
neighbors=clf.kneighbors(X_sample,return_distance=False)
#画出示意图
plt.figure(figsize=(4,4),dpi=144)
#画出样本
plt.scatter(X[:,0],X[:,1],c=y,s=100,cmap='cool')
#画出中心点
plt.scatter(c[:,0],c[:,1],s=100,marker='^',c='k')
#画出待预测的点
plt.scatter(X_sample[0],X_sample[1],s=100,marker='x',c=y_sample,cmap='cool')
print(X_sample[0])
print(X_sample[1])
for i in neighbors[0]:
plt.plot([X[i][0],X_sample[0]],[X[i][1],X_sample[1]],'k--',linewidth=0.6)
运行结果:
二、KNN回归
# -*- coding: utf-8 -*-
"""
Created on Thu May 17 09:33:12 2018
@author: Administrator
"""
'''
使用KNN进行回归拟合
'''
import numpy as np
from matplotlib import pyplot as plt
from numpy import *
#生成数据集,在余弦曲线的基础上加入了噪声
n_dots=40
X=5*np.random.rand(n_dots,1)
y=cos(X).ravel()
#添加噪声
y+=0.2*np.random.rand(n_dots)-0.1
#使用KNeighborsRegressor来训练模型
from sklearn.neighbors import KNeighborsRegressor
k=5
knn=KNeighborsRegressor(k)
knn.fit(X,y)
#生成足够密集的点进行预测
T=np.linspace(0,5,5000)[:,np.newaxis]
y_pred=knn.predict(T)
print(knn.score(X,y))
#把这些预测点连起来,构成拟合曲线
#画出拟合曲线
plt.figure(figsize=(4,5),dpi=144)
#画出训练曲线
plt.scatter(X,y,c='g',label='data',s=100)
#画出拟合曲线
plt.plot(T,y_pred,c='k',label='prediction',lw=4)
plt.axis('tight')
plt.title('KNeighborRegressor (k=%i)'%k)
plt.show()
运行结果: