5021 Revenge of kNN II
Time Limit: 8000/5000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others)Total Submission(s): 196 Accepted Submission(s): 56
Problem Description
In pattern recognition, the k-Nearest Neighbors algorithm (or k-NN for short) is a non-parametric method used for classification and regression. In both cases, the input consists of the k closest training examples in the feature space.
In k-NN regression, the output is the property value for the object. This value is the average of the values of its k nearest neighbors.
---Wikipedia
Today, kNN takes revenge on you, again. You have to handle a kNN case in one-dimensional coordinate system. There are N points with a position Xi and value Vi. Then there are M kNN queries for point with index i, recalculate its value by averaging the values its k-Nearest Neighbors. Note you have to replace the value of i-th point with the new calculated value. And if there is a tie while choosing k-Nearest Neighbor, choose the one with the minimal index first.
(Have you ever tried the problem “Revenge of kNN”? They are twin problems!)
In k-NN regression, the output is the property value for the object. This value is the average of the values of its k nearest neighbors.
---Wikipedia
Today, kNN takes revenge on you, again. You have to handle a kNN case in one-dimensional coordinate system. There are N points with a position Xi and value Vi. Then there are M kNN queries for point with index i, recalculate its value by averaging the values its k-Nearest Neighbors. Note you have to replace the value of i-th point with the new calculated value. And if there is a tie while choosing k-Nearest Neighbor, choose the one with the minimal index first.
(Have you ever tried the problem “Revenge of kNN”? They are twin problems!)
Input
The first line contains a single integer T, indicating the number of test cases.
Each test case begins with two integers N and M. Then N lines follows, each line contains two integers Xi and Vi. Then M lines with the queried index Qi and Ki follows, in which Ki indicating the number of k-Nearest Neighbors
[Technical Specification]
1. 1 <= T <= 5
2. 2 <= N <= 100 000
3. 1 <= M <= 100 000
4. 1 <= Vi <= 1 000
5. 1 <= Xi <= 1 000 000 000, and no two Xi are identical.
6. 1 <= Qi <= N
7. 1 <= Ki <= N - 1
Each test case begins with two integers N and M. Then N lines follows, each line contains two integers Xi and Vi. Then M lines with the queried index Qi and Ki follows, in which Ki indicating the number of k-Nearest Neighbors
[Technical Specification]
1. 1 <= T <= 5
2. 2 <= N <= 100 000
3. 1 <= M <= 100 000
4. 1 <= Vi <= 1 000
5. 1 <= Xi <= 1 000 000 000, and no two Xi are identical.
6. 1 <= Qi <= N
7. 1 <= Ki <= N - 1
Output
For each test case, output sum of all queries rounded to three fractional digits.
Sample Input
1 5 3 1 2 2 3 3 6 4 8 5 8 2 2 3 2 4 2
Sample Output
17.000HintFor the first query, the 2-NN for point 2 is point 1 and 3, so the new value is (2 + 6) / 2 = 4. For the second query, the 2-NN for point 3 is point 2 and 4, and the value of point 2 is changed to 4 by the last query, so the new value is (4 + 8) / 2 = 6. Huge input, faster I/O method is recommended.
官方思路:
考虑如何快速求出距离最近的k个点的权值之和,这里的距离具有明显的二分性。这样可以在log(MAXX)的时间内求出k个点的坐标范围。求出之后的问题是,区间求和,单点更新,树状数组足够解决这个问题了。
在二分的时候注意K和K+1可能都是符合条件的,如果算出K+1被舍弃的话,减小Distance可能得到的是K-1,并不连续,所以要判断一下这种情况。
在二分的时候注意K和K+1可能都是符合条件的,如果算出K+1被舍弃的话,减小Distance可能得到的是K-1,并不连续,所以要判断一下这种情况。
代码如下:(二分搜索有些难写,汗!!)
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int MAXN = 100005;
typedef struct node{
int id;
int x;
int v;
}node;
node a[MAXN];
int n, m;
int index[MAXN];
double C[MAXN];
int L, R;
int q, k;
bool cmp(node a, node b){
return a.x<b.x;
}
int lowbit(int x){
return x&(-x);
}
void add(int i, double v){
while(i<=n){
C[i] += v;
i += lowbit(i);
}
}
double Sum(int i){
double res = 0;
while(i>0){
res += C[i];
i -= lowbit(i);
}
return res;
}
int findL(int x){
int l = 1, r = n, m, res1;
while(l<=r){
m = (l+r)>>1;
if(a[m].x>=x){
res1 = m;
r = m-1;
}
else l = m+1;
}
return res1;
}
int findR(int x){
int l = 1, r = n, m, res2;
while(l<=r){
m = (l+r)>>1;
if(a[m].x<=x){
res2 = m;
l = m+1;
}
else r = m-1;
}
return res2;
}
void findLR(){
int l, r, mid;
l = a[1].x;
r = a[n].x;
while(l<=r){
mid = (l+r)>>1;
L = findL(a[q].x-mid);
R = findR(a[q].x+mid);
if(R-L<k) l = mid+1;
else if(R-L>k+1) r = mid-1;
else if(R-L==k){
return;
}else if(R-L==k+1){
if(a[q].x-a[L].x == a[R].x-a[q].x){
if(a[L].id<a[R].id)
R--;
else
L++;
}else if(a[q].x-a[L].x<a[R].x-a[q].x)
R--;
else
L++;
return;
}
}
}
int main(){
int T;
double ans;
scanf("%d", &T);
while(T--){
scanf("%d %d", &n, &m);
for(int i=1;i<=n;i++){
scanf("%d %d", &a[i].x, &a[i].v);
a[i].id = i;
}
sort(a+1, a+n+1, cmp);
memset(C, 0, sizeof(C));
for(int i=1;i<=n;i++){
index[a[i].id] = i;
add(i, a[i].v);
}
ans = 0.0;
while(m--){
scanf("%d %d", &q, &k);
q = index[q];
findLR();
double s = Sum(R) - Sum(L-1);
double t = Sum(q) - Sum(q-1);
add(q, (s-t)/k-t);
ans += (s-t)/k;
}
printf("%.3f\n", ans);
}
return 0;
}
另一种二分搜索代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int MAXN = 100005;
typedef struct node{
int id;
int x;
int v;
}node;
node a[MAXN];
int n, m;
int index[MAXN];
double C[MAXN];
int L, R;
int q, k;
bool cmp(node a, node b){
return a.x<b.x;
}
int lowbit(int x){
return x&(-x);
}
void add(int i, double v){
while(i<=n){
C[i] += v;
i += lowbit(i);
}
}
double Sum(int i){
double res = 0;
while(i>0){
res += C[i];
i -= lowbit(i);
}
return res;
}
int findL(int x){
int l = 1, r = n, m;
while(l<r){
m = (l+r)>>1;
if(a[m].x>=x) r = m;
else l = m+1;
}
return l;
}
int findR(int x){
int l = 1, r = n, m;
while(l<r){
m = (l+r+1)>>1;
if(a[m].x<=x) l = m;
else r = m-1;
}
return r;
}
void findLR(){
int l, r, mid;
l = a[1].x;
r = a[n].x;
while(l<=r){
mid = (l+r)>>1;
L = findL(a[q].x-mid);
R = findR(a[q].x+mid);
if(R-L<k) l = mid+1;
else if(R-L>k+1) r = mid-1;
else if(R-L==k){
return;
}else if(R-L==k+1){
if(a[q].x-a[L].x == a[R].x-a[q].x){
if(a[L].id<a[R].id)
R--;
else
L++;
}else if(a[q].x-a[L].x<a[R].x-a[q].x)
R--;
else
L++;
return;
}
}
}
int main(){
int T;
double ans;
scanf("%d", &T);
while(T--){
scanf("%d %d", &n, &m);
for(int i=1;i<=n;i++){
scanf("%d %d", &a[i].x, &a[i].v);
a[i].id = i;
}
sort(a+1, a+n+1, cmp);
memset(C, 0, sizeof(C));
for(int i=1;i<=n;i++){
index[a[i].id] = i;
add(i, a[i].v);
}
ans = 0.0;
while(m--){
scanf("%d %d", &q, &k);
q = index[q];
findLR();
double s = Sum(R) - Sum(L-1);
double t = Sum(q) - Sum(q-1);
add(q, (s-t)/k-t);
ans += (s-t)/k;
}
printf("%.3f\n", ans);
}
return 0;
}