【题目链接】
【思路要点】
- 分 b b 是否是 祖先讨论, a a 是 祖先时需要主席树来维护答案。
- 时间复杂度 O(NLogN+QLogN) O ( N L o g N + Q L o g N ) 。
【代码】
#include<bits/stdc++.h>
using namespace std;
#define MAXN 300005
#define MAXP 8000005
template <typename T> void read(T &x) {
x = 0; int f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
x *= f;
}
struct PersistentSegmentTree {
struct Node {
int lc, rc;
long long sum;
} a[MAXP];
int n, size, root[MAXN];
void build(int &root, int l, int r) {
if (root == 0) root = ++size;
if (l == r) return;
int mid = (l + r) / 2;
build(a[root].lc, l, mid);
build(a[root].rc, mid + 1, r);
}
void init(int x) {
n = x; size = 0;
build(root[0], 1, n);
}
int modify(int root, int l, int r, int pos, int delta) {
int ans = ++size;
a[ans] = a[root];
a[ans].sum += delta;
if (l == r) return ans;
int mid = (l + r) / 2;
if (mid >= pos) a[ans].lc = modify(a[root].lc, l, mid, pos, delta);
else a[ans].rc = modify(a[root].rc, mid + 1, r, pos, delta);
return ans;
}
void extend(int version, int pos, int delta) {
root[version] = modify(root[version - 1], 1, n, pos, delta);
}
long long query(int rootl, int rootr, int l, int r, int ql, int qr) {
if (l == ql && r == qr) return a[rootr].sum - a[rootl].sum;
long long ans = 0;
int mid = (l + r) / 2;
if (mid >= ql) ans += query(a[rootl].lc, a[rootr].lc, l, mid, ql, min(mid, qr));
if (mid + 1 <= qr) ans += query(a[rootl].rc, a[rootr].rc, mid + 1, r, max(mid + 1, ql), qr);
return ans;
}
long long query(int l, int r, int ql, int qr) {
return query(root[l], root[r], 1, n, ql, qr);
}
} PST;
int timer, dfn[MAXN], rit[MAXN], home[MAXN];
int n, q, size[MAXN], depth[MAXN];
vector <int> a[MAXN];
void dfs(int pos, int fa) {
size[pos] = 1;
depth[pos] = depth[fa] + 1;
dfn[pos] = ++timer;
home[timer] = pos;
for (unsigned i = 0; i < a[pos].size(); i++)
if (a[pos][i] != fa) {
dfs(a[pos][i], pos);
size[pos] += size[a[pos][i]];
}
rit[pos] = timer;
}
int main() {
read(n), read(q);
for (int i = 1; i <= n - 1; i++) {
int x, y;
read(x), read(y);
a[x].push_back(y);
a[y].push_back(x);
}
dfs(1, 0);
PST.init(n);
for (int i = 1; i <= n; i++)
PST.extend(i, depth[home[i]], size[home[i]] - 1);
for (int i = 1; i <= q; i++) {
int pos, k;
read(pos), read(k);
long long ans = PST.query(dfn[pos], rit[pos], min(depth[pos] + 1, n), min(depth[pos] + k, n));
ans += (size[pos] - 1ll) * min(depth[pos] - 1, k);
printf("%lld\n", ans);
}
return 0;
}