《scikit-learn常用机器学习算法》学习笔记---KNN算法及KNN回归

一、KNN算法

具体算法原理不再细说,此处生成3类,60个点集合,并对预测值进行预测分类。

# -*- coding: utf-8 -*-
"""
Created on Wed May 16 12:09:34 2018

@author: Administrator
"""

from sklearn.datasets.samples_generator import make_blobs
from matplotlib import pyplot as plt
import numpy as np
#生成数据
#中心点
centers=[[-2,2],[2,2],[0,4]]
'''
使用sklearn.datasets.samples_generator包下的make_blobs函数来生成数据集,生成
60个训练样本,这60个训练样本分布在以centers参数指定中心点周围。
cluster_std:为标准差,来指明生成的点分布的松散程度
X:存放生成的训练集
y:存放数据的类标记
'''
X,y=make_blobs(n_samples=60,centers=centers,random_state=0,cluster_std=0.6)

print(X)
print(y)
#画出X,y数据
plt.figure(figsize=(4,4),dpi=144)
c=np.array(centers)

#画出样本
plt.scatter(X[:,0],X[:,1],c=y,s=100,cmap='cool')
#画出中心点
plt.scatter(c[:,0],c[:,1],s=100,marker='^',c='orange')


#使用KNN对算法进行训练
from sklearn.neighbors import KNeighborsClassifier
#模型训练,
k=5
clf=KNeighborsClassifier(n_neighbors=k)
clf.fit(X,y)

#进行预测,要预测的数据为[0,2]
X_sample=[0,2]
y_sample=clf.predict(X_sample)
neighbors=clf.kneighbors(X_sample,return_distance=False)


#画出示意图
plt.figure(figsize=(4,4),dpi=144)
#画出样本
plt.scatter(X[:,0],X[:,1],c=y,s=100,cmap='cool')
#画出中心点
plt.scatter(c[:,0],c[:,1],s=100,marker='^',c='k')
#画出待预测的点
plt.scatter(X_sample[0],X_sample[1],s=100,marker='x',c=y_sample,cmap='cool')

print(X_sample[0])
print(X_sample[1])

for i in neighbors[0]:
    plt.plot([X[i][0],X_sample[0]],[X[i][1],X_sample[1]],'k--',linewidth=0.6)

运行结果:




二、KNN回归

# -*- coding: utf-8 -*-
"""
Created on Thu May 17 09:33:12 2018

@author: Administrator
"""

'''
使用KNN进行回归拟合
'''


import numpy as np
from matplotlib import pyplot as plt
from numpy import *
#生成数据集,在余弦曲线的基础上加入了噪声

n_dots=40
X=5*np.random.rand(n_dots,1)
y=cos(X).ravel()

#添加噪声
y+=0.2*np.random.rand(n_dots)-0.1

#使用KNeighborsRegressor来训练模型
from sklearn.neighbors import KNeighborsRegressor
k=5
knn=KNeighborsRegressor(k)
knn.fit(X,y)

#生成足够密集的点进行预测
T=np.linspace(0,5,5000)[:,np.newaxis]
y_pred=knn.predict(T)
print(knn.score(X,y))

#把这些预测点连起来,构成拟合曲线
#画出拟合曲线
plt.figure(figsize=(4,5),dpi=144)
#画出训练曲线
plt.scatter(X,y,c='g',label='data',s=100)
#画出拟合曲线
plt.plot(T,y_pred,c='k',label='prediction',lw=4)

plt.axis('tight')
plt.title('KNeighborRegressor (k=%i)'%k)
plt.show()

运行结果:


阅读更多
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

不良信息举报

《scikit-learn常用机器学习算法》学习笔记---KNN算法及KNN回归

最多只允许输入30个字

加入CSDN,享受更精准的内容推荐,与500万程序员共同成长!
关闭
关闭