K-近邻法
距离度量标准:欧式距离或更一般的 L p L_p Lp 距离。
k 值的选择:k 值小时,k 近邻模型更复杂;k 值大时,模型更简单(当 k=N 时,最简单);用交叉验证法取得最合适的 k 值。
分类决策原则:多数表决。
构造kd树
kd树用于搜索与规定点空间距离最小的点。
kd树与线段树类似,线段树存某一区间,kd树存某一k维空间。
构造kd树的方法也是从一整个空间开始,递归往下分配空间。
每个节点存其所在维度,父子节点,以便之后搜索。
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
const int k = 2; //k维度
const int maxn = 100; //最大节点数
int cmp_dim; //排序时的维度
vector<vector<int> >dat; //存节点坐标
struct Node { //节点
int value[k]; //坐标
int dim; //节点所在维度
Node* parant=NULL;
Node* left_ch=NULL; //父节点,子节点
Node* right_ch=NULL;
}kdnode[maxn];
bool cmp(const Node& a,const Node& b) {
return a.value[cmp_dim] < b.value[cmp_dim];
}
Node* build(int l,int r,int d) {
if (l > r)return NULL;
int m;
cmp_dim = d;
if ((l + r) & 1)m = 1+(l + r) >> 1; //中位点
else m = (l + r) >> 1;
nth_element(kdnode + l, kdnode + m, kdnode + r + 1,cmp);
kdnode[m].dim = d;
Node* lch = build(l, m-1, (d + 1) % k); //递归建树
if(lch!=NULL)
lch->parant = kdnode + m;
Node* rch = build(m + 1, r, (d + 1) % k);
if(rch!=NULL)
rch->parant = kdnode + m;
return kdnode + m;
}
int main() {
int n = 6;
dat.push_back({ 2,3 });
dat.push_back({ 5,4 });
dat.push_back({ 9,6 });
dat.push_back({ 4,7 });
dat.push_back({ 8,1 });
dat.push_back({ 7,2 });
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
kdnode[i].value[j] = dat[i][j];
}
}
build(0, n - 1, 0);
return 0;
}
搜索kd树
-
找到包含目标点的叶节点
-
以此节点为当前最近节点
-
向上回退。同时检查:
-
若父节点更近,更新;
-
若兄弟节点空间可能存在更近的点(到轴距离小于当前最小距离),检查兄弟节点。
-
-
当回到根节点,搜索结束。最终最近节点即为结果。
所有代码:
#include<iostream>
#include<algorithm>
#include<vector>
#include<cmath>
using namespace std;
const int k = 2; //k维度
const int maxn = 100; //最大节点数
int cmp_dim; //排序时的维度
vector<vector<int> >dat; //存节点坐标
struct Node { //节点
int value[k]; //坐标
int dim; //节点所在维度
Node* parent=NULL;
Node* left_ch=NULL; //父节点,子节点
Node* right_ch=NULL;
}kdnode[maxn];
bool cmp(const Node& a,const Node& b) {
return a.value[cmp_dim] < b.value[cmp_dim];
}
Node* build(int l,int r,int d) {
if (l > r)return NULL;
int m;
cmp_dim = d;
if ((l + r) & 1)m = 1+(l + r) >> 1; //中位点
else m = (l + r) >> 1;
nth_element(kdnode + l, kdnode + m, kdnode + r + 1,cmp);
kdnode[m].dim = d;
Node* lch = build(l, m-1, (d + 1) % k); //递归建树
if (lch != NULL)
{
lch->parent = kdnode + m;
kdnode[m].left_ch = lch;
}
Node* rch = build(m + 1, r, (d + 1) % k);
if (rch != NULL)
{
rch->parent = kdnode + m;
kdnode[m].right_ch = rch;
}
return kdnode + m;
}
//distance返回欧式距离的平方
int distance(int* a,int* b) {
int res = 0;
for (int i = 0; i < k; i++)
res += (a[i]-b[i])*(a[i]-b[i]);
return res;
}
Node* root; //kd树的根
int min_dis; //最近点与目标点距离
Node* nearest; //最近点
//找到包含目标点的叶节点
Node* the_one(int* arr) {
int d = 0;
Node* r=root;
Node* tmp=r;
while (r != NULL) {
tmp = r;
if (r->value[d] > arr[d]) {
r = r->left_ch;
}
else
r = r->right_ch;
d = (d + 1) % k;
}
min_dis = distance(arr, tmp->value);
nearest = tmp;
return tmp;
}
//在某一节点 从上往下 找最近点
void search_son(Node* son, int* tar) {
if (son == NULL)
return;
if (distance(son->value, tar) < min_dis) { //如果当前点更近,则更新
nearest = son;
min_dis = distance(son->value, tar);
}
//若轴与目标点距离小于当前最小距离,则两边都要找
if (pow(son->value[son->dim] - tar[son->dim], 2) < min_dis) {
search_son(son->left_ch, tar);
search_son(son->right_ch, tar);
}
//只要找一边
else {
//目标点在轴左边,找左边空间
if (son->value[son->dim] > tar[son->dim])
search_son(son->left_ch, tar);
else
search_son(son->right_ch, tar);
}
}
//在某一点 从下往上 找最近点
void search_all(Node* now,int tar[k]) {
Node* par = now->parent;
while (now->parent != NULL) {
now = now->parent;
//若父节点更近,则更新
if (distance(tar, now->value) < min_dis) {
nearest = now;
min_dis = distance(tar, now->value);
}
Node* son;
//若轴与目标点距离小于当前最小距离,还要在另一边找
if (pow(now->value[now->dim], 2) < min_dis) {
//目标点在轴右边,找右边空间
if (now->value[now->dim] > tar[now->dim])
son = now->right_ch;
else
son = now->left_ch;
//确定在哪边空间后,从上往下找
search_son(son,tar);
}
}
}
int main() {
int n = 6;
dat.push_back({ 2,3 });
dat.push_back({ 5,4 });
dat.push_back({ 9,6 });
dat.push_back({ 4,7 });
dat.push_back({ 8,1 });
dat.push_back({ 7,2 });
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
kdnode[i].value[j] = dat[i][j];
}
}
root=build(0, n - 1, 0);
//测试
int a[2] = { 8,2 };
Node* now = the_one(a);
search_all(now, a);
//distance返回的是最小距离的平方,结果还应开根号
cout << distance(nearest->value, a);
return 0;
}