HDU 4347 KNN+KDTree

暂时wrong answer,求各路大神指错。。


#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>

using namespace std;

struct POINT
{
    int p[5];

    int &operator[](int x)
    {
        return p[x];
    }

    const int &operator[](int x) const
    {
        return p[x];
    }
}point[500005];

bool cmp0(const POINT &a, const POINT &b)
{
    return a[0] < b[0];
}
bool cmp1(const POINT &a, const POINT &b)
{
    return a[1] < b[1];
}
bool cmp2(const POINT &a, const POINT &b)
{
    return a[2] < b[2];
}
bool cmp3(const POINT &a, const POINT &b)
{
    return a[3] < b[3];
}
bool cmp4(const POINT &a, const POINT &b)
{
    return a[4] < b[4];
}
bool cmp5(const POINT &a, const POINT &b)
{
    return a[5] < b[5];
}
bool cmp6(const POINT &a, const POINT &b)
{
    return a[6] < b[6];
}
bool cmp7(const POINT &a, const POINT &b)
{
    return a[7] < b[7];
}
bool cmp8(const POINT &a, const POINT &b)
{
    return a[8] < b[8];
}
bool cmp9(const POINT &a, const POINT &b)
{
    return a[9] < b[9];
}
int n, k;
int q, m;
bool (*cmp[10])(const POINT &a, const POINT &b);
int tree[1000010];
int splitAxis[500005];
POINT knn[20];
POINT p;
int knnSize;

long long distance2(const POINT &a, const POINT &b)
{
    long long total = 0;
    for (int i = 0; i < k; i++)
        total += (b[i] - a[i]) * (b[i] - a[i]);
    return total;
}

bool knnCmp(const POINT &a, const POINT &b)
{
    return distance2(a, p) < distance2(b, p);
}

void addPoint(const POINT &crtPoint)
{
    if (knnSize < m)
    {
        knn[knnSize++] = crtPoint;
        sort(knn, knn + knnSize, knnCmp);
        return;
    }

    if (distance2(crtPoint, p) < distance2(knn[knnSize - 1], p))
        knn[knnSize - 1] = crtPoint;
    sort(knn, knn + knnSize, knnCmp);
}

void BuildKDTree(int l, int r, int axis, int root)
{
    if (r < l) return;
    sort(point + l, point + r, cmp[axis]);

    int mid = (l + r) / 2;
//    while (mid + 1 <= r && point[mid][axis] == point[mid + 1][axis])
//        mid++;
    tree[root] = mid;
    splitAxis[mid] = axis;
    axis = (axis + 1) % k;

    BuildKDTree(l, mid - 1, axis, root * 2 + 1);
    BuildKDTree(mid + 1, r, axis, root * 2 + 2);
}

void TestKDTree()
{
    for (int i = 0; i < n; i++)
    {
        int index = tree[i];
        for (int j = 0; j < k; j++)
            printf("%d ", point[index][j]);
        printf("\n");
    }
}

void findKnn(int root)
{
    if (tree[root] == -1)
        return;
    int index = tree[root];

    addPoint(point[index]);

    int axis = splitAxis[index];
    if (p[axis] <= point[index][axis])
    {
        findKnn(root * 2 + 1);
        if (knnSize < m)
            findKnn(root * 2 + 2);
        else
        {
            long long disToSplitPlane = (point[index][axis] - p[axis]) * (point[index][axis] - p[axis]);
            long long radius = distance2(knn[knnSize - 1], p);
            if (radius >= disToSplitPlane)
                findKnn(root * 2 + 2);
        }
    }
    else
    {
        findKnn(root * 2 + 2);
        if (knnSize < m)
            findKnn(root * 2 + 1);
        else
        {
            long long disToSplitPlane = (point[index][axis] - p[axis]) * (point[index][axis] - p[axis]);
            long long radius = distance2(knn[knnSize - 1], p);
            if (radius >= disToSplitPlane)
                findKnn(root * 2 + 1);
        }
    }
}

int main()
{
    cmp[0] = cmp0;
    cmp[1] = cmp1;
    cmp[2] = cmp2;
    cmp[3] = cmp3;
    cmp[4] = cmp4;
    cmp[5] = cmp5;
    cmp[6] = cmp6;
    cmp[7] = cmp7;
    cmp[8] = cmp8;
    cmp[9] = cmp9;

    while (scanf("%d", &n) != EOF)
    {
        scanf("%d", &k);
        memset(point, 0, sizeof(point));
        memset(tree, -1, sizeof(tree));

        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < k; j++)
                scanf("%d", &point[i][j]);
        }

        BuildKDTree(0, n - 1, 0, 0);
        //TestKDTree();

        scanf("%d", &q);

        for (int i = 0; i < q; i++)
        {
            for (int j = 0; j < k; j++)
                scanf("%d", &p[j]);
            scanf("%d", &m);
            knnSize = 0;

            findKnn(0);
            printf("the closest %d points are:\n", m);
            for (int j = 0; j < m; j++)
            {
                for (int l = 0; l < k - 1; l++)
                    printf("%d ", knn[j][l]);
                printf("%d\n", knn[j][k - 1]);
            }
        }
    }


}


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值