原本是CV研究HC算法时,有一步要求对图片中每一个像素点,在lab空间上距离它最近的K个点。意识到对于哪怕小尺寸的图片(例如300*500)数据点也有足足1e5数量级,对每一个遍历全图来一次,O(N^2),哪怕Matlab矩阵优化顶破天,一幅图最起码也要3s,em,MSRA10k的图库一晚上都跑不完。
最后勉强想起来acm里面有一个东西叫KD树,专解此类问题。于是认知研究了一发。这里以HDU 4347为例,解释KD树的基本思想。
(图片引用自http://blog.csdn.net/acdreamers/article/details/44664645)
全部的算法思想分为三步:划分尺度,建树方法,查询最近K的方法。
1.划分尺度原则很简单。对于一定区间内的l~r内的n个元素,写作n*k的矩阵,对每一个列向量求解其方差,取方差最大的那一列作为metric。
可以利用一个计算技巧: D(x)=E(x^2)-E(x)*E(x) 一次for循环求解,记录下最大的那个方差对应的列数
2.数据结构的构造
又是完全二叉树精彩的方法。越来越喜欢利用数组建二叉树的方法,有着太多的好处。 类似于线段树,堆的构建思想,从1节点开始,lc下标是pa*2,rc下标是pa*2+1; 与此同时,构建一个标记数组flag,记录此位置是否为有效节点;还需要一个metric数组,描述下标为rt的根节点其采用分割的参考维度是哪一维。
造树的过程,本质上也是原始数据数组排序分治的过程。
分治到某一步,我们给出的根节点下标是rt,对应的数据区间是l~r,我们按照如下步骤操作:
a. 按照法1找出方差最大的那一维,并将此维度数存于metric[rt]中,
b.标记根节点为有效节点,flag[rt]=1;
c.利用nth_element函数,找出按照理论上sort后处于区间正中间的元素,把他放在那个位置上,并把比他小的元素散乱的放在数组左边,大于等于他的放在其右边。如此一来,我们把中间元素作为根节点,左右区间分别构建左右子树,如此分治下去。即:
tree[rt]=data[mid];
build(l,mid-1,rt*2)
build(mid+1,r,rt*2+1);
当l<r时结束递归。
3.查询最近K的方法。
上一次见到此算法还是在Tarjan算法里;在dfs搜索的过程中,利用函数回传的契机同时实现回溯搜索,的确是很高明的做法。如果是我写也许又是一边函数递归,一边向上搜索找爹了。
简单的框架如下:
f()
{
1.二叉搜索其孩子( f(lc) or f(rc))
2.更新修改存有前k近元素的优先队列
3.判定是否需要搜索另一颗子树
}
这个算法精妙在哪? 要知道,在找到叶子节点之前,函数一直在持续堆栈,每次都是执行完1后就进入下一层,直到最后到达了叶子底端,无法继续调用时,才开始执行2 3.
我们用一个优先队列记录元素信息。这个优先队列开始是空的,直到递归到叶子节点时候,才开始执行2,才开始有元素进入。对于此叶子节点来说,3是没有意义的,完成后向上回调。
我们换一个角度分析。从叶子父亲的角度看。父亲节点执行完1后,就在等待孩子的结果。孩子做了什么,他不管。他只关心此时自己如何更新队列
a.如果车上还有空座,进去再说。因为此时我们处于函数回调的过程中,因此此节点越有可能处于低级的位置,无论如何,目前看来它可能的概率至少比那些更上级(更远)的节点大。那么也因为先来先得,即使我的儿子没有满足下面的条件c,先进去再说。万一找不够k个,那还是得我们先去凑数
b.如果车已经满了,如果我比车上最垃圾的那个人更亲近,那时我可以把他赶下来,自己进去
c.在人满的情况下,我另一个儿子呢?看情况。如果它所占据的区域范围边界(即父亲自己在此维上的值)到被查询节点直线距离小于最远的k,那么有理由继续判定内部是否存在能继续更新的节点。
于是就得出来了结果。最后注意我们采用的是最大堆,堆上的元素是k个元素中距离最远的,因此如果要从小到大输出,还需要一个栈转置一下。
在这里我们是利用pair<double,node>来保存每一个节点到被查询节点的距离以及此节点的。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <string>
#include <vector>
#include <stack>
#include <bitset>
#include <set>
#include <list>
#include <deque>
#include <map>
#include <queue>
#define INF 0x3f3f3f3f
#define ll long long
#define lc (rt<<1)
#define rc ((rt<<1)+1)
using namespace std;
const int MAXN = 50010;
int k,n,dim;
struct node
{
double ele[10];
bool operator < (const node &u) const
{
return ele[k] < u.ele[k];
}
}dot[MAXN],data[MAXN*4];
int crit[4*MAXN];
bool flg[4*MAXN];
priority_queue< pair<double,node> > que;
double elucid(node x,node y)
{
double res=0;
for(int i=0;i<dim;i++)
res+=(x.ele[i]-y.ele[i])*(x.ele[i]-y.ele[i]);
return sqrt(res);
}
void build(int l,int r,int rt)
{
if(l>r) return;
flg[rt]=1;
double max=-1;
for(int i=0;i<dim;i++)
{
double Ex_2=0,Ex=0;
for(int j=l;j<=r;j++)
{
Ex_2+=dot[j].ele[i]*dot[j].ele[i];
Ex+=dot[j].ele[i];
}
Ex_2/=(r-l+1);
Ex/=(r-l+1);
if(Ex_2-Ex*Ex>max)
{
max=Ex_2-Ex*Ex;
k=i;
}
}
crit[rt]=k;
int mid=(l+r)/2;
nth_element(dot+l,dot+mid,dot+r+1);
data[rt]=dot[mid];
build(l,mid-1,lc);
build(mid+1,r,rc);
}
void query(node p,int rt,int m)
{
if(!flg[rt]) return;
double dis=elucid(p,data[rt]);
pair<double,node> cur(dis,data[rt]); //pair默认用first比较大小
bool fg=0; //是否需要搜索另一颗子树,满足条件a或c时执行
int how=crit[rt],nxt;
if(p.ele[how] >= data[rt].ele[how])
nxt=rc;
else
nxt=lc;
if(flg[nxt])
query(p,nxt,m); //1.二叉搜索孩子
if( que.size()<m ) //条件a
{
que.push(cur);
fg=1;
}
else
{
if(cur.first<que.top().first) //条件b
{
que.pop();
que.push(cur);
}
if(fabs(p.ele[how]-data[rt].ele[how])<que.top().first) //在b基础上判定条件c
fg=1;
}
if(flg[nxt^1] && fg) //执行另一颗子树更新搜索。这里运用了一个技巧,因为左右子树下标差1,xor 1即代表搜索另一颗子树
query(p,nxt^1,m);
}
int main()
{
while(~scanf("%d %d",&n,&dim))
{
memset(flg,0,sizeof(flg));
for(int i=1;i<=n;i++)
{
for(int j=0;j<dim;j++)
scanf("%lf",&dot[i].ele[j]);
}
build(1,n,1);
int t,m;
scanf("%d",&t);
while(t--)
{
node ask;
for(int i=0;i<dim;i++)
scanf("%lf",&ask.ele[i]);
scanf("%d",&m);
query(ask,1,m);
printf("the closest %d points are:\n", m);
node sta[20];
int top=0;
while(!que.empty())
{
sta[top++]=que.top().second;
que.pop();
}
while(top)
{
node tmp=sta[--top];
printf("%d", (int)tmp.ele[0]);
for(int i=1;i<dim;i++)
printf(" %d", (int)tmp.ele[i]);
printf("\n");
}
}
}
}