采用的样本是非常经典的鸢尾花数据。
鸢尾花分为三类:
1、setosa
2、versicolor
3、virginica
已知数据集的特征分为:
1、萼片长度
2、萼片宽度
3、花瓣长度
4、花瓣宽度
在讲算法之前,我们首先应该给距离下定义:
当然代码采用的是“欧式距离”
算法原理:
1、读取训练集和测试集数据(包括特征、类型)
2、计算两者距离:
相对应特征的距离。
3、选取排名前k个最小距离进行投票,投票多的就是判定的鸢尾花类型
4、利用混淆矩阵、评测该方法的合理性(理论上95%正确率是很好的模型了)。
代码:
/***********************************
*说明:总数据有150条
根据1:1分类原则:把数据分为:
75条测试数据
和75条训练数据
利用四种特征区分鸢尾花类型
************************************/
float testdata[75][4]; //测试集特征数据
int testclass[75]; //测试集类型数据
float traindata[75][4]; //训练集特征数据
int trainclass[75]; //训练集类型数据
float distances[75]; //距离
int results[75]; //类型结果
void sort(float *dist, int *class_)
{
int N = 75;
float temp = 0;
for (int pass = 0; pass<N; pass++)//冒泡排序成从小到大
{
for (int i = pass + 1; i<N; i++)
{
if (dist[pass] >= dist[i])
{
temp = dist[pass];
dist[pass] = dist[i];
dist[i] = temp;
int k = class_[pass];
class_[pass] = class_[i];
class_[i] = k;
}
}
}
}
//选取前k个最小距离投票
int classify(float *dist, int *res, int k)
{
int i, c[4];
c[1] = c[2] = c[3] = 0;
for (i = 0; i< k; i+