在数据挖掘学习札记:KNN算法(一)里,使用sklearn模块对例子进行了求解,但是并不清楚k的取值。
下面是我写的一个Python代码,程序采用“小题大做”的方式,一方面可以熟悉算法,另一方面练习Python编程,可以看到,当k取1,2,3,4,5时,knn算法预测未知电影的类型都是R,即Romance。
说明:
1. 距离使用欧氏距离;
2. k近邻搜索使用线性扫描;
3. 未知电影对象调用label方法,得到预测类型;
from math import sqrt
class Movie:
''' Represents a movie '''
lib=[]
total=0 # Number of movies
def __init__(self,nk,nf,tg):
''' Initialize a movie '''
self.nkiss=nk
self.nfight=nf
self.tag=tg
self.index=len(self.lib)+1
Movie.lib.append(self)# initialize a movie and add it to the library
Movie.total+=1
self.print()
def distance(self,movie):
''' Distance to another movie '''
if type(movie)!=type(self):
raise TypeError('requires a %s, given a %s'% (type(self),type(movie)))
else:
dis=(self.nkiss-movie.nkiss)**2
dis=dis+(self.nfight-movie.nfight)**2
dis=sqrt(dis)
return dis
def get(self,k):
''' Get the kth movie'''
if k>self.total:
raise IndexError('out of range')
else:
for m in self.lib:
if k==m.index:
return m
def neighbors(self,k):
'''Find its k neighbors '''
dis=[]
movie_many=[]
for movie in self.lib:
dis.append((movie.index,self.distance(movie)))
dis.sort(key=lambda dis:dis[1]) # sort according to distances
for i in range(1,k+1):
movie_many.append(self.get(dis[i][0]))
self.print(movie_many)
return movie_many
def print(self,movies=None):
''' Print the information of a movie or a set of movies '''
if movies==None:
print((self.index,self.nkiss,self.nfight,self.tag))
else:
for m in movies:
m.print()
def label(self,k=1):
'''From its k nearest neihbors to determin its tag: R or A ?'''
if self.tag=='unknown':
movie_many=self.neighbors(k)
nR=0
nA=0
for movie in movie_many:
if movie.tag=='R':
nR+=1
elif movie.tag=='A':
nA+=1
else:
raise TypeError('The movie with label %d is not a training data'%movie.label())
else:
if nR>nA:
tag='R'
elif nR<nR:
tag='A'
else:
tag='unknown'
return tag
Movie(3,104,'R')
Movie(2,100,'R')
Movie(1,81,'R')
Movie(101,10,'A')
Movie(99,5,'A')
Movie(98,2,'A')
test=Movie(18,90,'unknown')
计算结果如下:
>>> test.label(1)
(2, 2, 100, 'R')
'R'
>>> test.label(2)
(2, 2, 100, 'R')
(3, 1, 81, 'R')
'R'
>>> test.label(3)
(2, 2, 100, 'R')
(3, 1, 81, 'R')
(1, 3, 104, 'R')
'R'
>>> test.label(4)
(2, 2, 100, 'R')
(3, 1, 81, 'R')
(1, 3, 104, 'R')
(4, 101, 10, 'A')
'R'
>>> test.label(5)
(2, 2, 100, 'R')
(3, 1, 81, 'R')
(1, 3, 104, 'R')
(4, 101, 10, 'A')
(5, 99, 5, 'A')
'R'
>>> test.label(6)
(2, 2, 100, 'R')
(3, 1, 81, 'R')
(1, 3, 104, 'R')
(4, 101, 10, 'A')
(5, 99, 5, 'A')
(6, 98, 2, 'A')
'unknown'
>>>
上面的运行如下解释,
首先,初始化6个电影对象,代表6个实例,第7个是未知电影test,它的类型tag未知。然后调用label方法, test.label(2),表示利用test的最近的两个邻居来预测,test电影的类型。在运行中,我们打印出了,该未知电影的k个近邻。
知识点:
1. 使用sort进行排序时,可以指定key,使得按指定方式排序,上面的代码key=lambda dis:dis[1]使用每个元素(二元组)的第二个为键进行排序,因为第二元组的第一个是电影的序号,第二个是距离;
2. 虽然Python没有函数重载,但是可以使用默认参数达到同样的目的,比如print函数,既可以打印一个电影,也可以打印一个电影列表。关于Python的函数重载问题,网上有评论,比如这里。