C++实现最基本的KNN手写数字识别

KNN算法概述:KNN算法的思想即,一个样本属于其在特征空间里的最近邻样本中数目最多的分类。

该算法有几个关键要素:

一.K的取值,即选区K个最近邻样本中所属最多的分类,显然算法的效果很大程度上取决于K的大小。

二.样本对象之间距离的定义,一般使用欧氏距离或曼哈顿距离:

KNN算法流程

1.计算预测样本与每个训练集样本的距离

2.将样本按距离从小到大排序

3.从前往后取K个样本,统计各个标签对应样本数目

4.找出对于数目最多的标签,则为测试样本所属分类

 

这里我们选择欧式距离来计算样本间距离,测试不同K的取值下算法的预测精度。对MNIST数据集进行一些处理,即将像素点不为0的位置全部定义为1。因为朴素的KNN算法计算复杂度非常高,仅选用部分训练数据和测试数据进行实验。KNN算法并不会训练出一个用于分类或回归的模型,所以每次进行预测时,我们都需要打包所有的训练数据。

#include <bits/stdc++.h>
using namespace std ;
vector<double>labels;
vector<vector<double> >images;//训练集
vector<double>labels1;
vector<vector<double> >images1;//测试集
const int train_number=10000;
const int test_number=500;
int a[20];
int KNN(int i,int k);
struct node
{
    int labels;
    int dis;
}q[train_number+100];
bool cmp(node a,node b)
{
    return a.dis<b.dis;
}
/**********************************/
int ReverseInt(int i)

{

	unsigned char ch1, ch2, ch3, ch4;

	ch1 = i & 255;

	ch2 = (i >> 8) & 255;

	ch3 = (i >> 16) & 255;

	ch4 = (i >> 24) & 255;

	return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;

}
void read_Mnist_Label(string filename, vector<double>&labels)

{
    ifstream file;
	file.open("train-labels.idx1-ubyte", ios::binary);

	if (file.is_open())

	{

		int magic_number = 0;

		int number_of_images = 0;

		file.read((char*)&magic_number, sizeof(magic_number));

		file.read((char*)&number_of_images, sizeof(number_of_images));

		magic_number = ReverseInt(magic_number);

		number_of_images = ReverseInt(number_of_images);

		cout << "magic number = " << magic_number << endl;

		cout << "number of images = " << number_of_images << endl;





		for (int i = 0; i < number_of_images; i++)

		{

			unsigned char label = 0;

			file.read((char*)&label, sizeof(label));

			labels.push_back((double)label);

		}



	}

}



void read_Mnist_Images(string filename, vector<vector<double> >&images)

{

	ifstream file("train-images.idx3-ubyte", ios::binary);

	if (file.is_open())

	{

		int magic_number = 0;

		int number_of_images = 0;

		int n_rows = 0;

		int n_cols = 0;

		unsigned char label;

		file.read((char*)&magic_number, sizeof(magic_number));

		file.read((char*)&number_of_images, sizeof(number_of_images));

		file.read((char*)&n_rows, sizeof(n_rows));

		file.read((char*)&n_cols, sizeof(n_cols));

		magic_number = ReverseInt(magic_number);

		number_of_images = ReverseInt(number_of_images);

		n_rows = ReverseInt(n_rows);

		n_cols = ReverseInt(n_cols);



		cout << "magic number = " << magic_number << endl;

		cout << "number of images = " << number_of_images << endl;

		cout << "rows = " << n_rows << endl;

		cout << "cols = " << n_cols << endl;



		for (int i = 0; i < number_of_images; i++)

		{

			vector<double>tp;

			for (int r = 0; r < n_rows; r++)

			{

				for (int c = 0; c < n_cols; c++)

				{

					unsigned char image = 0;

					file.read((char*)&image, sizeof(image));

					tp.push_back(image);

				}

			}

			images.push_back(tp);

		}

	}

}
void read_Mnist_Label1(string filename, vector<double>&labels)

{
    ifstream file;
	file.open("t10k-labels.idx1-ubyte", ios::binary);

	if (file.is_open())

	{

		int magic_number = 0;

		int number_of_images = 0;

		file.read((char*)&magic_number, sizeof(magic_number));

		file.read((char*)&number_of_images, sizeof(number_of_images));

		magic_number = ReverseInt(magic_number);

		number_of_images = ReverseInt(number_of_images);


		for (int i = 0; i < number_of_images; i++)

		{

			unsigned char label = 0;

			file.read((char*)&label, sizeof(label));

			labels.push_back((double)label);

		}



	}

}



void read_Mnist_Images1(string filename, vector<vector<double> >&images)
{
	ifstream file("t10k-images.idx3-ubyte", ios::binary);
	if (file.is_open())
	{
		int magic_number = 0;

		int number_of_images = 0;

		int n_rows = 0;

		int n_cols = 0;

		unsigned char label;

		file.read((char*)&magic_number, sizeof(magic_number));

		file.read((char*)&number_of_images, sizeof(number_of_images));

		file.read((char*)&n_rows, sizeof(n_rows));

		file.read((char*)&n_cols, sizeof(n_cols));

		magic_number = ReverseInt(magic_number);

		number_of_images = ReverseInt(number_of_images);

		n_rows = ReverseInt(n_rows);

		n_cols = ReverseInt(n_cols);

		for (int i = 0; i < number_of_images; i++)

		{

			vector<double>tp;

			for (int r = 0; r < n_rows; r++)

			{

				for (int c = 0; c < n_cols; c++)

				{

					unsigned char image = 0;

					file.read((char*)&image, sizeof(image));

					tp.push_back(image);

				}

			}

			images.push_back(tp);

		}

	}
}
/**************以上为MNIST数据集读取部分,下面开始KNN算法**************/
void test(int k)
{
    int sum=0;
    for(int i=0;i<test_number;i++)
    {
        int predict=KNN(i,k);
        //printf("pre:%d label:%d\n",predict,(int)labels1[i]);
        if(predict==(int)labels1[i]) sum++;
    }
    printf("k=%d    precision: %.5f\n",k,1.0*sum/test_number);
}
int KNN(int number,int k)//预测函数
{
    memset(q,0,sizeof(q));
    memset(a,0,sizeof(a));
    int dis=0;
    for(int i=0;i<train_number;i++)
    {
        for(int j=0;j<784;j++)
            dis+=(images[i][j]-images1[number][j])*(images[i][j]-images1[number][j]);
        dis=sqrt(dis);//获得欧式距离
        q[i].dis=(int)dis;
        q[i].labels=(int)labels[i];
    }
    sort(q,q+train_number,cmp);
    for(int i=0;i<k;i++)
    {
        a[q[i].labels]++;
    }
    int ans=-1,minn=-1;
    for(int i=0;i<10;i++)
    {
        if(a[i]>minn)
        {
            minn=a[i];
            ans=i;
        }
    }
    return ans;
}
int main()
{
    read_Mnist_Label("t10k-labels.idx1-ubyte", labels);
	read_Mnist_Images("t10k-images.idx3-ubyte", images);
	read_Mnist_Label1("t10k-labels.idx1-ubyte", labels1);
	read_Mnist_Images1("t10k-images.idx3-ubyte", images1);//读取mnist数据集
	for (int i = 0; i < images1.size(); i++)
	{
		for (int j = 0; j < images1[0].size(); j++)
		{
            images1[i][j]=(images1[i][j]>0)?1:0;
		}
	}
	for (int i = 0; i < images.size(); i++)
	{
		for (int j = 0; j < images[0].size(); j++)
		{
            images[i][j]=(images[i][j]>0)?1:0;
		}
	}
	test(1);
	test(2);
	test(3);
	test(4);
	return 0;
}

部分测试截图如下:

 

KNN算法存在几个主要的缺点:

1.当样本数据分布不均衡的时候,很有可能出现当输入一个未知样本时,该样本的K个邻居中大数量类的样本占多数,但是这类样本并不接近目标样本(就像下图的Y点),会被误判为蓝色分类。针对此类情况,可以采用距离加权的优化算法,即距离越近的样本对应分类获得更大的全权值。

2.数据量庞大时,计算复杂度过高。执行一次KNN算法,需要遍历一遍所有数据集。可以采用数据结构优化的办法,常见的有K-D数和球树等优化。

 

 

 

  • 0
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值