KNN算法的實現

转载 2012年03月22日 18:51:40


KNN算法 基本思想
2009年02月08日 星期日 16:42

KNN(K 最近邻居)算法

该算法的基本思路是:在给定新文本后,考虑在训练文本集中与该新文本距离最近(最相似)的 K 篇文本,根据这 K 篇文本所属的类别判定新文本所属的类别,具体的算法步骤如下:

STEP ONE:根据特征项集合重新描述训练文本向量

STEP TWO:在新文本到达后,根据特征词分词新文本,确定新文本的向量表示

STEP THREE:在训练文本集中选出与新文本最相似的 K 个文本,计算公式为:

其中,K 值的确定目前没有很好的方法,一般采用先定一个初始值,然后根据实验测试的结果调整 K 值,一般初始值定为几百到几千之间。

STEP FOUR:在新文本的 K 个邻居中,依次计算每类的权重,计算公式如下:

其中, 为新文本的特征向量, 为相似度计算公式,与上一步骤的计算公式相同,而 为类别属性函数,即,如果 属于类 ,那么函数值为 1,否则为 0。

STEP FIVE:比较类的权重,将文本分到权重最大的那个类别中。

除此以外,支持向量机和神经网络算法在文本分类系统中应用得也较为广泛,支持向量机的基本思想是使用简单的线形分类器划分样本空间。对于在当前特征空间中线形不可分的模式,则使用一个核函数把样本映射到一个高维空间中,使得样本能够线形可分。

而神经网络算法采用感知算法进行分类。在这种模型中,分类知识被隐式地存储在连接的权值上,使用迭代算法来确定权值向量。当网络输出判别正确时,权值向量保持不变,否则进行增加或降低的调整,因此也称为奖惩法。


KNN算法的實現

转自 小橋流水

Knn.h

#pragma once

class Knn
{
private:
 double** trainingDataset;
 double* arithmeticMean;
 double* standardDeviation;
 int m, n;

 void RescaleDistance(double* row);
 void RescaleTrainingDataset();
 void ComputeArithmeticMean();
 void ComputeStandardDeviation();

 double Distance(double* x, double* y);
public:
 Knn(double** trainingDataset, int m, int n);
 ~Knn();
 double Vote(double* test, int k);
};

 

Knn.cpp

 

#include "Knn.h"
#include <cmath>
#include <map>

using namespace std;

Knn::Knn(double** trainingDataset, int m, int n)
{
 this->trainingDataset = trainingDataset;
 this->m = m;
 this->n = n;
 ComputeArithmeticMean();
 ComputeStandardDeviation();
 RescaleTrainingDataset();
}

void Knn::ComputeArithmeticMean()
{
 arithmeticMean = new double[n - 1];

 double sum;

 for(int i = 0; i < n - 1; i++)
 {
  sum = 0;
  for(int j = 0; j < m; j++)
  {
   sum += trainingDataset[j][i];
  }

  arithmeticMean[i] = sum / n;
 }
}

void Knn::ComputeStandardDeviation()
{
 standardDeviation = new double[n - 1];

 double sum, temp;

 for(int i = 0; i < n - 1; i++)
 {
  sum = 0;
  for(int j = 0; j < m; j++)
  {
   temp = trainingDataset[j][i] - arithmeticMean[i];
   sum += temp * temp;
  }

  standardDeviation[i] = sqrt(sum / n);
 }
}

void Knn::RescaleDistance(double* row)
{
 for(int i = 0; i < n - 1; i++)
 {
  row[i] = (row[i] - arithmeticMean[i]) / standardDeviation[i];
 }
}

void Knn::RescaleTrainingDataset()
{
 for(int i = 0; i < m; i++)
 {
  RescaleDistance(trainingDataset[i]);
 }
}

Knn::~Knn()
{
 delete[] arithmeticMean;
 delete[] standardDeviation;
}

double Knn::Distance(double* x, double* y)
{
 double sum = 0, temp;
 for(int i = 0; i < n - 1; i++)
 {
  temp = (x[i] - y[i]);
  sum += temp * temp;
 }

 return sqrt(sum);
}

double Knn::Vote(double* test, int k)
{
 RescaleDistance(test);

 double distance;

 map<int, double>::iterator max;

 map<int, double> mins;

 for(int i = 0; i < m; i++)
 {
  distance = Distance(test, trainingDataset[i]);
  if(mins.size() < k)
   mins.insert(map<int, double>::value_type(i, distance));
  else
  {
   max = mins.begin();
   for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)
   {
    if(it->second > max->second)
     max = it;
   }
   if(distance < max->second)
   {
    mins.erase(max);
    mins.insert(map<int, double>::value_type(i, distance));
   }
  }
 }

 map<double, int> votes;
 double temp;

 for(map<int, double>::iterator it = mins.begin(); it != mins.end(); it++)
 {
  temp = trainingDataset[it->first][n-1];
  map<double, int>::iterator voteIt = votes.find(temp);
  if(voteIt != votes.end())
   voteIt->second ++;
  else
   votes.insert(map<double, int>::value_type(temp, 1));
 }

 map<double, int>::iterator maxVote = votes.begin();

 for(map<double, int>::iterator it = votes.begin(); it != votes.end(); it++)
 {
  if(it->second > maxVote->second)
   maxVote = it;
 }

 test[n-1] = maxVote->first;

 return maxVote->first;
}

 

main.cpp

 

#include <iostream>
#include "Knn.h"

using namespace std;

int main(const int& argc, const char* argv[])
{
 double** train = new double* [14];
 for(int i = 0; i < 14; i ++)
  train[i] = new double[5];
 double trainArray[14][5] = 
 {
  {0, 0, 0, 0, 0},
  {0, 0, 0, 1, 0},
  {1, 0, 0, 0, 1},
  {2, 1, 0, 0, 1},
  {2, 2, 1, 0, 1},
  {2, 2, 1, 1, 0},
  {1, 2, 1, 1, 1},
  {0, 1, 0, 0, 0},
  {0, 2, 1, 0, 1},
  {2, 1, 1, 0, 1},
  {0, 1, 1, 1, 1},
  {1, 1, 0, 1, 1},
  {1, 0, 1, 0, 1},
  {2, 1, 0, 1, 0}
 };

 for(int i = 0; i < 14; i ++)
  for(int j = 0; j < 5; j ++)
   train[i][j] = trainArray[i][j];

 Knn knn(train, 14, 5);

 double test[5] = {2, 2, 0, 1, 0};
 cout<<knn.Vote(test, 3)<<endl;

 for(int i = 0; i < 14; i ++)
  delete[] train[i];

 delete[] train;

 return 0;
}


Hibernate(六)实现一对多、多对一映射关联关系

一对多、多对一这种关系在现实生活中很多,例如部门与员工的关系,学校里班级与学生的关系... 那么在具体的系统实现中如果i实现这种关联关系呢?这里以部门和员工的关系为例。 部门实体类 package t...
  • u011731233
  • u011731233
  • 2015年07月16日 09:12
  • 3436

简单日程表的实现

使用了BeautyEye优化界面 简单的日历查询功能 简单的提醒功能(倒计时向)查询界面的实现 package CalendarExp;import org.jb2011.lnf.beautyeye....
  • Len_master
  • Len_master
  • 2016年12月28日 22:05
  • 601

pagerank算法的Matlab實現

网页搜索中的网页排名PageRank( L, sigma )  基本转移矩阵M满足: M= L' * D^(-1)     上标一撇'表示转秩  L为网络图中的邻接矩阵,D为对角线上元素D(k,k...
  • DS_agent
  • DS_agent
  • 2016年01月18日 22:43
  • 1628

Java实现多线程经典问题:使用三个线程实现输出ABCABC循环

使用三个线程实现ABCABC……循环。 代码如下://标记类,用来让三个线程共享,同时也是三个线程中同步代码快的标记对象。 //之前这个标记我设置成Integer,但是发现Integer进行加法运算...
  • LeoSha
  • LeoSha
  • 2015年08月08日 16:05
  • 759

LabVIEW按钮延时自动弹起

同事在现场碰到的问题: 要求按钮被按下后,一定时间后自动弹起。 问题分析: 1、采用事件结构,记录按钮按下的时间; 2、只有当按钮按下时,才开始计时; 3、计时结束后,弹起按钮,并初始化计时器; ...
  • ap0108220
  • ap0108220
  • 2013年11月19日 15:38
  • 2205

在Linux屏幕上画框和抛物线(framebuffer,mapped)

#include #include #include #include #include #include #define FB0 "/dev/fb0" #define FBSZ (64...
  • dezhihuang
  • dezhihuang
  • 2014年08月30日 09:04
  • 1427

quick-cocos2d-x 實現遮罩并實踐運用

实现遮罩的最经常使用的两个方式
  • longolder
  • longolder
  • 2014年07月14日 23:33
  • 1704

HttpModule原理與實現

HttpModule是向實現類提供模塊初始化和處置事件。當一個HTTP請求到達HttpModule時,整個ASP.NET Framework系統還並沒有對這個HTTP請求做任何處理,也就是說此時對於H...
  • mikemiller2
  • mikemiller2
  • 2013年12月30日 22:31
  • 469

Comparable接口的實現和使用

Comparable接口的實現和使用
  • herion_123
  • herion_123
  • 2017年01月10日 00:36
  • 109

修改hosts文件实现网页的屏蔽

servlet中的GetremoteAddr是获取远程主机ip的api,可以用来获取请求方的ip地址   ①当访问本机的web应用时GetremoteAddr获取到的是0:0:0:0:0:0:0:...
  • blacktone
  • blacktone
  • 2015年04月09日 21:15
  • 491
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:KNN算法的實現
举报原因:
原因补充:

(最多只允许输入30个字)