KNN算法代码Python实例(鸢尾花数据集)
#鸢尾花数据集knn分类器代码实例
import numpy as np
import matplotlib.pyplot as plt
import SHUJUCHULI as S
import KNN as K
dataset,ClssLable = S.data_matrix('iris.csv',4)
NormDataSet = S.autoNorm(dataset)
#画图观察
title = ['Sepal.Length','Sepal.Width','Petal.Length','Petal.Width']
for a in range(3):
for b in range(3):
if a >= (b+1):
continue
else:
plt.scatter(dataset[:,a],dataset[:,b+1],c=ClssLable)
plt.xlabel(title[a])
plt.ylabel(title[b+1])
plt.show()
#knn分类器设置,训练集比例为m
m = 0.9
numberofdataset = NormDataSet.shape[0]
train_size = int(m*numberofdataset)
test_size = int((1-m)*numberofdataset)
print(train_size,test_size)
error = 0
for i in range(test_size):
result = K.KNN(NormDataSet[train_size+i],NormDataSet[0:train_size],ClssLable[0:train_size],5)
if result != ClssLable[train_size+i]:
error += 1
print('错误率为:',error/test_size)
###数据处理模块:SHUJUCHULI
*#读取文件
import numpy as np
#file_path 为文件路径,k为特征矩阵特征数
def data_matrix(file_path,k):
file = open(file_path)
numberoflines = len(file.readlines())
matrix = np.zeros((numberoflines-1,k))
index = 0
classlable = []
file = open(file_path)
for i,line in enumerate(file.readlines()):
if i == 0:
continue
else:
#去掉每行的前后空格
line = line.strip()
#中间的字符以”空格“分隔开
listFromline = line.split(',')
#从第二列到倒数第二列赋值
matrix[index,:] = listFromline[1:k+1]
#改标签为1,2,3;方便画散点图
if listFromline[-1]== 'setosa':
classlable.append(1)
elif listFromline[-1]== 'versicolor':
classlable.append(2)
elif listFromline[-1]== 'virginica':
classlable.append(3)
index += 1
file.close()
return matrix,classlable
#数据归一化操作
def autoNorm(dataSet):
maxval = dataSet.max(0)
minval = dataSet.min(0)
NormdataSet = np.zeros(dataSet.shape)
#加这一行目的是不报错
np.seterr(divide='ignore', invalid='ignore')
NormdataSet = (dataSet - minval)/(maxval - minval)
return NormdataSet*