Query on the subtree
题意:给你一颗有点权边权均为1的树,每次操作要么修改某点点权,要么查询距离u点不超过d距离的所有点权的和
解法:我们根据点分治找出所有的重心并且相连,用树状数组或动态开点线段树维护每个重心的子树的权值信息,每次查询或者修改顺重心树的链往上走依次查询或修改遇到的所有点的子树信息即可。
#include<bits/stdc++.h>
#define low(x) x&-x
using namespace std;
const int maxn = 1e5 + 10, inf = 1e9;
int d[maxn * 2][20], dep[maxn], id[maxn], w[maxn], cnt;
vector<int> G[maxn];
struct bit {
int n;
vector<int> c;
void init(int m) {
this->n = m;
c.clear();
for (int i = 0; i <= n; i++)
c.push_back(0);
}
void up(int x, int v) {
if (!x)
c[x] += v;
for (; x <= n && x; x += low(x))
c[x] += v;
}
int qu(int x) {
if (x > n)
x = n;
int res = c[0];
for (; x; x -= low(x))
res += c[x];
return res;
}
} T[maxn * 2];
int n, size, rt, cat, sz[maxn], f[maxn], vis[maxn];
void init() {
for (int i = 1; i <= n; i++)
G[i].clear(), vis[i] = f[i] = 0;
cnt = 0;
}
void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
d[++cnt][0] = dep[u];
id[u] = cnt;
for (auto v : G[u])
if (v != fa) {
dfs(v, u);
d[++cnt][0] = dep[u];
}
}
void rmq_init() {
for (int i = 1; i < 18; i++)
for (int j = 1; j + (1 << i) - 1 <= cnt; j++)
d[j][i] = min(d[j][i - 1], d[j + (1 << i - 1)][i - 1]);
}
int LCA(int x, int y) {
int l = id[x], r = id[y];
if (l > r)
swap(l, r);
int k = log2(r - l + 1);
return min(d[l][k], d[r + 1 - (1 << k)][k]);
}
int dis(int x, int y) {
return dep[x] + dep[y] - 2 * LCA(x, y);
}
void findrt(int u, int fa) {
sz[u] = 1;
int mx = 0;
for (auto v : G[u])
if (v != fa && !vis[v]) {
findrt(v, u);
sz[u] += sz[v];
mx = max(mx, sz[v]);
}
mx = max(mx, size - sz[u]);
if (mx < cat)
cat = mx, rt = u;
}
void divide(int u, int fa) {
f[u] = fa;
vis[u] = 1;
T[u].init(size);
T[u + n].init(size);
int tmp = size;
for (auto v : G[u])
if (v != fa && !vis[v]) {
size = (sz[v] > sz[u]) ? tmp - sz[u] : sz[v];
cat = inf;
findrt(v, 0);
divide(rt, u);
}
}
void up(int u, int v, int val) {
if (!u)
return;
int dist = dis(u, v);
T[u].up(dist, val);
if (f[u]) {
int dist2 = dis(f[u], v);
T[u + n].up(dist2, val);
}
up(f[u], v, val);
}
int gao(int u, int v, int d) {
int cat = 0, dist = dis(u, v);
if (dist <= d)
cat += T[u].qu(d - dist);
if (f[u]) {
dist = dis(f[u], v);
if (dist <= d)
cat -= T[u + n].qu(d - dist);
cat += gao(f[u], v, d);
}
return cat;
}
int main() {
int q, u, v;
char c;
while (~scanf("%d%d", &n, &q)) {
init();
for (int i = 1; i <= n; i++)
scanf("%d", &w[i]);
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 0);
rmq_init();
size = n;
cat = inf;
findrt(1, 0);
divide(rt, 0);
for (int i = 1; i <= n; i++)
up(i, i, w[i]);
while (q--) {
getchar();
scanf("%c%d%d", &c, &u, &v);
if (c == '?')
printf("%d\n", gao(u, u, v));
else {
up(u, u, v - w[u]);
w[u] = v;
}
}
}
}