k-近邻算法(kNN),它的工作原理是:存在若干个样本数据集,并且我们已知这些样本数据集的标签,现在输入一个新的样本数据集,标签未知。我们通过计算新的样本数据集与其他已知标签的数据集之间的距离来确定新样本数据集的标签。一般距离计算采用欧氏距离。我们通过计算欧氏距离得到新样本数据集与其他已知标签的数据集之间的距离之后,通过k值,即取前k个最小距离,再判断前k个最小距离当中,哪个标签的出现频率最高,就将新数据集的标签与之划等号。
以下是通过python实现k近邻算法来判断一部已知打斗镜头和搞笑镜头的电影到底是属于动作片还是属于喜剧片。矩阵group代表6部电影的相关数据,其中group的第一列代表每部电影的搞笑镜头,第二列代表每部电影的打斗镜头。函数classify()是用来对输入的新数据集进行距离计算,选择前k个最小距离以及进行排序等相关算法操作的函数。
import numpy as np
import operator
import matplotlib.pyplot as plt
group=np.array([[83,26],[15,98],[80,13],[70,20],[56,93],[5,105]])
labels=["comedy","action","comedy","comedy","action","action"]
def classify(data,dataSet,labels,k):
dataSetC=np.shape(dataSet)[0]
dis=np.tile(data,(dataSetC,1))-dataSet
dis=dis**2
dis=np.sum(dis,axis=1)
distanceData=dis**0.5
sortDistance=np.argsort(distanceData)
classCount={}
for i in range(k):
sortLabel=labels[sortDistance[i]]
classCount[sortLabel]=classCount.get(sortLabel,0)+1
sortClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortClassCount[0][0]
if __name__=='__main__':
data=[50,70]
sortClassCount=classify(data,group,labels,3)
plt.scatter(group[:,0],group[:,1],s=100)
for i in range(len(labels)):
plt.text(group[i][0],group[i][1],labels[i],fontsize=15)
plt.scatter(data[0],data[1],marker='^',c='red',s=100)
plt.text(data[0],data[1],sortClassCount,fontsize=15)
plt.xlabel("Number of funny scenes")
plt.ylabel("Number of fight scenes")
plt.grid()
plt.show()
以下是程序运行结果: