MachineLearningInAction_k近邻
根据书中内容,采用python3.5编写相关程序。不熟悉python,故备注详细注释及相关拓展。
k近邻理解[1]
1、算法过程
- 待测试数据与已知数据(含标签)的距离(常用欧式距离来表示)
- 根据距离值进行排序
- 选择前K个点集,统计标签的频次
- 选择频次最高的标签值作为分类值
2、K近邻缺陷
- 计算复杂度高
- 空间复杂度高
- 无法给出数据基础结构信息(即平均实例样本和典型实例样本的特征)
模型[2]
输入:
- T = { ( x 1 i , y 1 ) , ( x 2 i , y 2 ) , . . . , ( x k i , y k ) } T=\{(x_{1i},y_1),(x_{2i},y_2),...,(x_{ki},y_k)\} T={(x1i,y1),(x2i,y2),...,(xki,yk)} 其中 i= 1…m , y1…k ∈ \in ∈ {C1…n}
算法:
- 给定距离度量:
L
p
(
x
i
,
x
j
)
=
(
∑
l
=
1
n
∣
x
i
(
l
)
−
x
j
(
l
)
∣
p
)
1
p
L_p(x_i,x_j)=\big( \displaystyle \sum_{l=1}^n|x_i^{(l)} - x_j^{(l)}|^p \big)^\frac{1}{p}
Lp(xi,xj)=(l=1∑n∣xi(l)−xj(l)∣p)p1
p=2为欧氏距离
p=1为曼哈顿距离
p= ∞ \infty ∞ 为切比雪夫距离 - 寻找最邻近的K个点得到领域Nk(x),k值可通过交叉验证来选择
-
N
k
(
x
)
N_k(x)
Nk(x)根据分类决策规则决定x的类别y
y = a r g m a x c j ∑ x ∈ N k ( x ) I ( y i = c j ) y=\displaystyle argmax_{cj} \sum_{x\in{N_{k}(x)}}I(y_i=c_j) y=argmaxcjx∈Nk(x)∑I(yi=cj) i = 1 , 2 , . . . , N ; j = 1 , 2 , . . . , K i=1,2,...,N;j=1,2,...,K i=1,2,...,N;j=1,2,...,K
I为指示函数,相等时I=1,否则I=0
输出:
- 实例x所属的分类y
程序[1]
# @Time : 2019/7/4 16:46
# @Author : Belg
# @File : knn.py
# @Software: PyCharm
# @Desc:
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
from os import listdir #列出给定路径下的文件名
def creat_data_set():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
def classify0(inx, data_set, labels, k): # inx输入向量
data_set_size = data_set.shape[0] # 行向量大小
diff_mat = tile(inx, (data_set_size, 1)) - data_set # tile将inx重复行data_set次 列1次
sq_diff_mat = diff_mat**2 # 矩阵相乘或数组元素的平方
sq_distances = sq_diff_mat.sum(axis=1) # axis=0 行向量相加,axis=1 列向量元素相加。一维数组的只有0轴无1轴
# 参考 https://www.cnblogs.com/yyxayz/p/4033736.html
distances = sq_distances**0.5 # 开根号
sort_dist_indicies = argsort(distances) # 升序排列,返回为索引值,降序的话 argsort(-distances)
class_count = {}
for i in range(k):
vote_label = labels[sort_dist_indicies[i]]
class_count[vote_label] = class_count.get(vote_label, 0) + 1 # get查找vote_label 如果未找到则添加该值并赋值为0
# key=operator.itemgetter(1)按1的值来排序 True指定为升序,sorted返回排序的一个副本值。扩展y=x 指向同一数组y=x[:]副本
sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]
#######################################################################################################################
# 读取文件
def file2matrix(filename):
fr = open(filename)
array_lines = fr.readlines() # readlines()读取整个文档 readline()读取一行
number_lines = len(array_lines)
return_mat = zeros((number_lines, 3))
class_label = []
index = 0
for line in array_lines:
line = line.strip() # 移除字符串头尾的空格,或指定值。如strip('0') 移除0值
list_from_line = line.split('\t') # 根据制表符对字符串切分 参考https://www.runoob.com/python/att-string-split.html
return_mat[index, :] = list_from_line[0:3]
class_label.append(int(list_from_line[-1])) # 取最后一个元素
index += 1
return return_mat, class_label
# 归一化
def auto_norm(data_set):
min_val = data_set.min(0) # 0所有行中选出最小值
max_val = data_set.max(0)
ranges = max_val - min_val
norm_data_set = zeros(shape(data_set))
m = norm_data_set.shape[0]
norm_data_set = (data_set - tile(min_val, (m, 1))) / tile(ranges, (m, 1)) # numPy矩阵除法linalg.solve(A,B)
return norm_data_set, ranges, min_val
#######################################################################################################################
def dating_class_test():
ho_ratio = 0.10
dating_data_mat, dating_labels = file2matrix('datingTestSet2.txt')
norm_mat, ranges, min_val = auto_norm(dating_data_mat)
m = norm_mat.shape[0] # 字典{}、列表[]、元组()
num_test = int(m*ho_ratio)
error_count = 0.0
for i in range(num_test):
classifier_result = classify0(norm_mat[i, :], norm_mat[num_test:m, :], dating_labels[num_test:m], 3)
if classifier_result != dating_labels[i]:
error_count += 1
print("output classifier: %d, real classifier: %d" % (classifier_result, dating_labels[i]))
print("Total Error:%f" %(error_count/float(num_test)))
def classify_person():
result_list = ['not at all', 'in small doses', 'in large doses']
percent_tats = float(input("percentage of time spent playing video games?"))
ff_miles = float(input("frequent flier miles earned per year?"))
icecream = float(input("liters of ice cream consumed per year?"))
dating_data_mat, dating_labels = file2matrix('datingTestSet2.txt')
norm_mat, ranges, min_val = auto_norm(dating_data_mat)
in_arr = array([ff_miles, percent_tats, icecream])
classifier_result = classify0((in_arr - min_val) / ranges, norm_mat, dating_labels, 3)
print("You will probably like this person: %s" % result_list[classifier_result - 1])
#######################################################################################################################
# 图像转为向量
def img2vec(filename):
fr = open(filename)
return_vec = zeros((1, 1024))
for i in range(32):
line = fr.readline()
for j in range(32):
return_vec[0, 32*i+j] = int(line[j])
return return_vec
# 手写数字识别
def hand_writing_test():
hw_labels = []
training_file_list = listdir('trainingDigits')
m = len(training_file_list)
train_mat = zeros((m, 1024))
for i in range(m):
file_name = training_file_list[i]
hw_labels.append(int(file_name.split('_')[0]))
train_mat[i, :] = img2vec('trainingDigits/%s' % file_name)
test_file_list = listdir('testDigits')
error_count = 0
m_test = len(test_file_list)
for i in range(m_test):
file_name = test_file_list[i]
real_class = int(file_name.split('_')[0])
test_mat = img2vec('testDigits/%s' % file_name)
out_class = classify0(test_mat, train_mat, hw_labels, 3)
if out_class != real_class:
error_count += 1
print("real class:%d out class:%d" % (real_class, out_class))
print("Total Error Ratio:%f" % (error_count/float(m_test)))
if __name__ == '__main__':
# ###示例1
group, labels = creat_data_set()
re = classify0([0, 0], group, labels, 3)
# ###示例2 分析数据
# dating_data_mat, dating_labels = file2matrix('datingTestSet2.txt')
# norm_mat, ranges, min_val = auto_norm(dating_data_mat)
# fig = plt.figure()
# ax = fig.add_subplot(111)
# ax.scatter(dating_data_mat[:, 0], dating_data_mat[:, 1], 15.0*array(dating_labels), 15.0*array(dating_labels))
# plt.show()
# print(norm_mat)
# ###示例3 测试算法
# dating_class_test()
# ###示例4 使用算法
# classify_person()
# ###示例5 手写识别
# hand_writing_test()
待延伸
KD-Tree
参考资料
[1] Book_MachineLearningInAction
[2] Book_统计学习方法