一、概述
最近在利用业余时间学习机器学习算法,由于笔者是嵌入式软件工程师,想将机器学习算法在单片机端实现,KNN算法(k-Nearest Neighbor,K最近邻算法)是为数不多的可在单片机端实现的机器学习算法。
通过检索发现,在单片机端实现KNN算法的例子较少,仅有几个用单片机实现手写数字识别的。
本例程硬件使用的是STM32F103C8T6最小系统板,IAR/MDK开发环境,STM32CubeMX进行配置并生成工程文件,鸢尾花数据集是从UCI机器学习官网https://archive.ics.uci.edu/ml/index.php下载,笔者已上传到资源,也可移步https://download.csdn.net/download/wanglong3713/26849636直接下载笔者整理好的。程序完整工程已上传至[面包多]:[001]STM32单片机使用KNN算法实现鸢尾花分类(IAR开发)
[002]STM32单片机使用KNN算法实现鸢尾花分类(Keil_MDK版)
欢迎下载。
鸢尾花Iris数据集共150组数据,分3类,分别是Iris Setosa,Iris Versicolour,Iris Virginica,每组数据有4个特征,分别是花萼长度,花萼宽度,花瓣长度,花瓣宽度。本例程选取了3个分类的各前30组共90组作为训练集,剩余20组共60组的数据作为测试集。
二、程序流程
KNN算法原理不再详细介绍,读者可自行检索,本篇博文的目的在于介绍算法的C语言实现。根据算法的原理可得程序流程:
① 计算一个测试样本,与所有训练样本的距离,距离一般选用欧氏距离,也可以选择其他距离,如切比雪夫距离、曼哈顿距离,这几种距离的定义等内容可自行检索;
② 对上述求得的距离从小到大进行排序,可使用最简单的冒泡排序法;
③ 取前k个距离,统计这k个距离中,对应的每种样本分类出现的个数,个数最大的分类,即为测试样本的分类,此处k即为KNN算法中的k;
④ 重复以上步骤,计算其他测试样本的分类。
三、主要代码
1. 计算欧式距离
根据不同的距离公式,程序稍作修改即可得到其他距离,常用距离计算公式可参考常用距离计算单片机C语言程序。
/*******************************************************************************
* 函数名:EuclideanDistance
* 功 能:计算一个样本与一个训练样本的欧几里得距离
* 参 数:*u16DataA测试样本
*u16DataB训练样本
u8Size数据维度
* 返回值:u16Dist距离
* 说 明:无
*******************************************************************************/
uint16_t EuclideanDistance(uint16_t *u16DataA, uint16_t *u16DataB, uint8_t u8Size)
{
uint16_t u16Dist = 0;
int16_t s16Temp = 0;
uint8_t i;
for (i = 0; i < u8Size; i++)
{
s16Temp = ((int16_t)*(u16DataA + i) - ((int16_t)*(u16DataB + i)));
s16Temp = s16Temp * s16Temp;
u16Dist += (uint16_t)s16Temp;
}
u16Dist = (uint16_t)sqrt(u16Dist);
return u16Dist;
}
2. 分类
排好序的数据,统计前k个数据中每个分类出现的个数,个数最大的结果即为分类结果;
通过printf函数打印分类情况到串口。
/*******************************************************************************
* 函数名:KNN_Classify
* 功 能:分类
* 参 数:无
* 返回值:无
* 说 明:无
*******************************************************************************/
void KNN_Classify(void)
{
uint8_t u8SetosaCnt = 0;
uint8_t u8VersiColorCnt = 0;
uint8_t u8VirginicaCnt = 0;
uint8_t u8Max = 0;
uint8_t i, j, m;
uint16_t *pTest, *pTrain;
Result_ts sIrisResult;
for (i = 0; i < TEST_ROW; i++)
{
memset(&sIrisResult, 0, sizeof(sIrisResult));
for (j = 0; j < TRAIN_ROW; j++)//
{
pTest = (uint16_t *)&u16TestSet[i];
pTrain = (uint16_t *)&u16TrainSet[j];
sIrisResult.u16Distance[j][0] = EuclideanDistance(pTest, pTrain, TRAIN_COLUMN - 1);//
//sIrisResult.u16Distance[j][0] = ChebyshevDistance(pTest, pTrain, TRAIN_COLUMN - 1);//
//sIrisResult.u16Distance[j][0] = ManhattanDistance(pTest, pTrain, TRAIN_COLUMN - 1);
sIrisResult.u16Distance[j][1] = u16TrainSet[j][4];
}
BubbleSort(sIrisResult.u16Distance, TRAIN_ROW);//第i个测试集的数据,排序
u8SetosaCnt = 0;
u8VersiColorCnt = 0;
u8VirginicaCnt = 0;
HAL_IWDG_Refresh(&hiwdg);
for (m = 0; m < K_VALUE; m++)//前k个数据
{
switch (sIrisResult.u16Distance[m][1])
{
case SETOSA: u8SetosaCnt++; break;
case VERSICOLOR: u8VersiColorCnt++; break;
case VIRGINICA: u8VirginicaCnt++; break;
default:break;
}
}
u8Max = max(max(u8SetosaCnt, u8VersiColorCnt), u8VirginicaCnt);
if(u8Max == u8SetosaCnt)
{
u8Max = SETOSA;
}else
{
if (u8Max == u8VersiColorCnt)
{
u8Max = VERSICOLOR;
}else
{
if (u8Max == u8VirginicaCnt)
{
u8Max = VIRGINICA;
}
}
}
sIrisResult.u8Class = u8Max;//保存分类结果
printf(" %.1f,%.1f,%.1f,%.1f ",(float)u16TestSet[i][0]/10,(float)u16TestSet[i][1]/10,(float)u16TestSet[i][2]/10,(float)u16TestSet[i][3]/10);//
switch(u8Max)
{
case SETOSA:
{
printf("class: Iris-setosa ");//输出分类结果
if (sIrisResult.u8Class == u16TestSet[i][4])//分类正确
{
printf(" Success\n");//
}else
{
printf(" Fail\n");
}
}break;
case VERSICOLOR:
{
printf("class: Iris-versicolor ");//输出分类结果
if (sIrisResult.u8Class == u16TestSet[i][4])//分类正确
{
printf(" Success\n");//
}else
{
printf(" Fail\n");
}
}break;
case VIRGINICA:
{
printf("class: Iris-virginica ");
if (sIrisResult.u8Class == u16TestSet[i][4])//分类正确
{
printf(" Success\n");//
}else
{
printf(" Fail\n");
}
}break;
default:break;
}
HAL_IWDG_Refresh(&hiwdg);
}
}
三、运行效果
可看出,60组测试集,有2组分类错误,58组正确,准确率为96.7%。其实用肉眼观察分析这两组数据,也可以看出确实不太好分类。另外,如果训练集和测试集选择的合适,准确率可以达到100%。
四、总结
1. 关于距离公式
在相关文献看到,切比雪夫距离的效果优于其他距离,但实际测试发现并非如此,可能与训练集、测试集的选取有关;
2.关于k的取值
在分类只有2种的情况下,建议k取奇数,防止两种分类出现平局的现象;但在分类有2种以上的时候,无论k是奇数还是偶数,都可能出现两种甚至多种分类出现平局,此时要考虑其他方式,选择出最佳的分类;
3. 关于训练集和测试集
分类效果和训练集、测试集的选取关系很大。本例程只是按照鸢尾花的数据集的顺序,选取了3个分类的各前30组共90组作为训练集,剩余20组共60组的数据作为测试集,如果改变训练集和测试集,最终的分类准确率不同。