设置块的大小为sz,当更新的层的节点数小于sz时直接暴力更新,大于sz时给这一行打上更新标记,查询的时候对于一个点x实际上就是x的子树中小于sz的层的贡献+x的子树中大于sz的层的贡献
小于sz的时候直接用树状数组+dfs序就可以维护
大于sz的时候实际上就是求子树中每层的节点个数*这一层增加的值,然后对这些层求和,每层增加的值直接用一个数组就可以维护,每层节点的个数我采用了一种log级的算法,首先按dfs序顺序用vector保存每层节点的dfs序,若要求x的子树某一层的节点个数,实际上就是拿着x子树的起始dfs序s和终止dfs序t在这一层的所有节点中二分查找,即upperbound(t) - lowerbound(s)
复杂度为sz*logn+(n/sz)*logn(粗略计算),当sz取sqrt(n)时最优
细节的话看代码吧
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MAXN = 1e5 + 10;
int N, Q, sz;
int mem, head[MAXN];
struct Edge { int to, next; }edges[MAXN<<1];
void init_edges() {
mem = 0;
for (int i = 0; i <= N; i++) head[i] = -1;
}
void addedge(int from, int to) {
edges[mem].to = to, edges[mem].next = head[from], head[from] = mem++;
}
int s[MAXN], t[MAXN], dfs_clock, maxh;
vector<int> node[MAXN];
vector<int> big;
void dfs(int fa, int u, int depth) {
maxh = max(maxh, depth);
s[u] = ++dfs_clock;
node[depth].push_back(dfs_clock);
for (int i = head[u]; ~i; i = edges[i].next) {
int v = edges[i].to;
if (v == fa) continue;
dfs(u, v, depth+1);
}
t[u] = dfs_clock;
}
LL C[MAXN];
int lowbit(int x) { return x & (-x); }
LL sum(int x) {
LL ans = 0;
while (x > 0) { ans += C[x]; x -= lowbit(x); }
return ans;
}
void add(int x, int d) {
while (x <= N) { C[x] += d; x += lowbit(x); }
}
LL update[MAXN];
void init() {
sz = sqrt(N);
init_edges();
dfs_clock = 0;
maxh = 0;
for (int i = 0; i <= N; i++) node[i].clear();
big.clear();
for (int i = 0; i <= N; i++) C[i] = 0;
for (int i = 0; i <= N; i++) update[i] = 0;
}
int main() {
scanf("%d%d", &N, &Q);
init();
for (int i = 1; i <= N-1; i++) {
int u, v; scanf("%d%d", &u, &v);
addedge(u, v); addedge(v, u);
}
dfs(-1, 1, 0);
for (int i = 0; i <= maxh; i++) {
if (node[i].size() > sz) big.push_back(i);
}
while (Q--) {
int op;
scanf("%d", &op);
if (op == 1) {
int x, y;
scanf("%d%d", &x, &y);
if (node[x].size() <= sz) {
for (int i = 0; i < node[x].size(); i++) add(node[x][i], y);
}
else update[x] += y;
}
else {
int x;
scanf("%d", &x);
LL ans = sum(t[x]) - sum(s[x]-1);
for (int i = 0; i < big.size(); i++) {
int idx = big[i];
ans += (upper_bound(node[idx].begin(), node[idx].end(), t[x]) - lower_bound(node[idx].begin(), node[idx].end(), s[x]))*update[idx];
}
printf("%lld\n", ans);
}
}
return 0;
}