首先,我们来定义两个函数,一个是随机生成数据集的函数,用于生成我们的训练集和测试集,以及训练集的类别;另一个是将不同类别的训练集,和测试集数据,用可视化的方式呈现出来,可以用于校验knn算法的准确性。
import numpy as np
#随机生成数据集函数
np.random.seed(23) #锚定随机值,这样每次运行代码输出的随机数不变
def generate_data(num_samples,num_features=2): #2个参数分别对应随机生成的数据集的行数和列数,这里为了能够在图表中展示数据集,所以默认生成的是2维数据(列数为2)
train_data=np.random.randint(0,100,(num_samples,num_features))
labels=np.random.randint(0,2,(num_samples,1)) #这里约束了训练集的类别只有0、1两种,为训练集随机生成其所属类别
return train_data.astype(np.float32),labels #后面应用的openCV对于数据类型有些过分的讲究,所以此处强制把数据点的类型转换为np.float32
import matplotlib.pyplot as plt
#数据的可视化呈现
plt.style.use('ggplot') #这里采用比较美观的ggplot作为图表风格
def plot_data(label0_data,label1_data,test_data): #前2个参数是将训练集数据按照2个类别拆分出来的子数据集,第3个参数是测试数据集
plt.figure()
plt.scatter(label0_data[:,0],label0_data[:,1],color='blue',marker=