题目链接:https://nanti.jisuanke.com/t/41388
题解:这个题比赛最后也没改对,训练时突然想到有个地方写了emmmm
思路:我们先按照dfs序把每个点跑出来,然后vector记录下每个深度的点,很容易就可以想到,对于每个深度,距离x小于等于k的那些点肯定是在vector内的一段连续的序列,因此我们就可以二分找到最左最右的位置,求权值的话,我们维护一个前缀和就好了,但要注意一点就是,我们要先向上求,记录一下祖先向下延伸最远的点
比如:
1 2
1 3
3 4
这个图,再求距离2小于等于3的点时,直接从2向下是不行的,因为父节点1向下延伸的更长。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 10;
struct node {
int to, nex;
}e[N * 2];
int head[N], len;
vector<int> v[N];
int dep[N];
int maxdep;
int n, m;
ll val[N];
vector<ll> sum[N];
int dp[N][22];
int son[N], nex[N];
void add(int x, int y) {
e[len].to = y;
e[len].nex = head[x];
head[x] = len++;
}
int pos[N];
int lenv[N];
void dfs1(int u, int fa) {
dep[u] = dep[fa] + 1;
maxdep = max(maxdep, dep[u]);
v[dep[u]].push_back(u);
pos[u] = lenv[dep[u]]++;
dp[u][0] = fa;
for(int i = 1; i < 22; i++)
if(dp[u][i - 1])
dp[u][i] = dp[dp[u][i - 1]][i - 1];
else
break;
int to;
int flag = 0;
for(int i = head[u]; i != -1; i = e[i].nex) {
to = e[i].to;
if(to == fa) continue;
dfs1(to, u);
flag = 1;
if(son[to] > son[nex[u]]) {
nex[u] = to;
}
}
if(!flag) son[u] = 1;
else son[u] = son[nex[u]] + 1;
}
bool judge(int x, int y, int k) {
k--;
int cntx = x, cnty = y;
if(dep[x] < dep[y]) swap(x,y);
int tmp = dep[x] - dep[y];
for(int i = 0; i < 22; i++)
if(tmp & ( 1 << i))
x = dp[x][i];
if(x == y) {
return dep[cntx] + dep[cnty] - 2 * dep[x] <= k;
}
for(int i = 21; i >= 0; i--){
if(dp[x][i] != dp[y][i]) {
x = dp[x][i];
y = dp[y][i];
}
}
x = dp[x][0];
return dep[cntx] + dep[cnty] - 2 * dep[x] <= k;
}
int main() {
scanf("%d", &n);
memset(head, -1, sizeof(head));
for(int i = 1; i <= n; i++) scanf("%lld", &val[i]);
int x, y, k;
for(int i = 1; i < n; i++) {
scanf("%d %d", &x, &y);
add(x, y);
add(y, x);
}
dfs1(1, 0);
scanf("%d", &m);
ll cnt = 0;
for(int i = 1; i <= maxdep; i++) {
cnt = 0;
for(int j = 0; j < v[i].size(); j++) {
cnt += val[v[i][j]];
sum[i].push_back(cnt);
}
}
int l, r, mid;
int ansl, ansr;
ll ans;
int tmpk;
int tt;
int maxx;
int tx;
int id;
while(m--) {
scanf("%d %d", &x, &k);
k++;
y = x;
ans = 0;
tmpk = k;
maxx = son[x] - 1;
id = x;
tt = 0;
tx = 0;
while(y && tmpk) {
tmpk--;
l = 0, r = pos[y];
if(min(son[y] , k - tt ) - 1 > maxx + tt) {
maxx = min(son[y] , k - tt ) - 1 - tt;
id = y;
tx = tt;
}
tt++;
while(l <= r) {
mid = (l + r) >> 1;
if(judge(v[dep[y]][mid], x, k)) {
ansl = mid;
r = mid - 1;
} else {
l = mid + 1;
}
}
l = pos[y], r = lenv[dep[y]] - 1;
while(l <= r) {
mid = (l + r) >> 1;
if(judge(v[dep[y]][mid], x, k)) {
ansr = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
if(ansl == 0) cnt = 0;
else cnt = sum[dep[y]][ansl - 1];
ans += sum[dep[y]][ansr] - cnt;
y = dp[y][0];
}
if(tx > 0) {
while(tx) {
tx--;
id = nex[id];
}
}
y = nex[id];
tmpk = k - 1;
int maxx, id;
while(y && tmpk) {
tmpk--;
l = 0, r = pos[y];
ansl = -1;
ansr = -1;
while(l <= r) {
mid = (l + r) >> 1;
if(judge(v[dep[y]][mid], x, k)) {
ansl = mid;
r = mid - 1;
} else {
l = mid + 1;
}
}
l = pos[y], r = lenv[dep[y]] - 1;
while(l <= r) {
mid = (l + r) >> 1;
if(judge(v[dep[y]][mid], x, k)) {
ansr = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
if(ansl == -1 || ansr == -1) break;
if(ansl == 0) cnt = 0;
else cnt = sum[dep[y]][ansl - 1];
ans += sum[dep[y]][ansr] - cnt;
y = nex[y];
}
printf("%lld\n", ans);
}
return 0;
}
/*
10
1 1 1 1 1 1 1 1 1 1
1 2
1 3
2 4
3 5
3 6
3 7
4 8
8 10
6 9
5
7
1 1 1 1 1 1 1
1 2
1 3
2 4
2 5
3 6
3 7
5
1 100
1 0
2 1
4 100
4 2
10
1 1 1 1 1 1 1 1 1 1
1 2
1 3
2 4
2 5
3 6
3 7
1 8
9 10
8 9
5
6
4 5 2 1 2 1
1 4
1 3
5 3
6 2
3 6
5
*/