# Machine Learning In Action -- kNN (k Nearest Neighbors)

### k最近邻分类算法：k Nearest Neighbors

k最近邻分类算法是最简单的机器学习算法之一，主要应用在对未知事物的识别。

#### 算法优点

• 算法准确度较高
• 对数据不作假设
• 适用于交叉或重叠较多的待分样本集

• 计算量大
• 内存消耗大
• 样本数量不平衡时易受影响

#### 代码 Python

Note: 运行代码之前，请安装好matplotlib。

import numpy as np
import operator
from os import listdir

def knn_classify(vec_in, data_set, labels, k):
rows = data_set.shape[0]
diffs = np.tile(vec_in, (rows, 1)) - data_set
sq_diffs = diffs ** 2
sq_distances = sq_diffs.sum(axis = 1)
distances = sq_distances ** 0.5
sorted_dist_indices = distances.argsort()

class_cnt = {}
for i in range(k):
vote_label = labels[sorted_dist_indices[i]]
class_cnt[vote_label] = class_cnt.get(vote_label, 0) + 1
sorted_class_cnt = sorted(class_cnt.iteritems(), key=operator.itemgetter(1), reverse=True)
return sorted_class_cnt[0][0]

#### 实例：手写数字0-9的识别

def img2vector(filename):
ret_vec = np.zeros((1, 1024))
fp = open(filename)
for i in range(32):
for j in range(32):
ret_vec[0, 32*i + j] = int(line_str[j])
fp.close()
return ret_vec

def hand_writing_test():
training_path = './digits/trainingDigits'
test_path = './digits/testDigits'
training_files = listdir(training_path)
m = len(training_files)
training_mat = np.zeros((m, 1024))
labels = []
for i in range(m):
filename = training_files[i]
class_num_str = filename.split('_')[0]
labels.append(class_num_str)
training_mat[i, :] = img2vector(training_path + '/%s' % filename)
# print training_mat
test_files = listdir(test_path)
m = len(test_files)
err_cnt = 0.0
for i in range(m):
filename = test_files[i]
class_num_str = filename.split('_')[0]
vec_in = img2vector(test_path + '/%s' % filename)
ret = knn_classify(vec_in, training_mat, labels, 3)
if str(ret) != class_num_str:
print "file %s, classifier result: %s, real ans: %s." % (filename, ret, class_num_str)
err_cnt += 1.0
print "The total number of error is %d." % err_cnt
print "The failure rate is %f." % (err_cnt / float(m))

if __name__ == '__main__':
hand_writing_test()

• 本文已收录于以下专栏：

举报原因： 您举报文章：Machine Learning In Action -- kNN (k Nearest Neighbors) 色情 政治 抄袭 广告 招聘 骂人 其他 (最多只允许输入30个字)