题意
给定一颗有根带权树,记\(LCA(u,v)\)为\(u,v\)的最近公共祖先,\(dis(u)\)表示树根到\(u\)的距离
每个节点可以是黑色或白色,初始结点颜色为白色
有\(m\)次操作,操作分为两种
将结点\(x\)染成黑色
记所有黑点形成的集合为\(S\)与一个节点\(x\),求出下面式子的值
\[ \sum_{y \in S}F(dis(LCA(u,v))) \]
其中函数\(F\)定义为
\[ F(x)=\sum_{i=1}^x i^k \]
解法
朴素的暴力是求取\(x\)与\(S\)中点的\(LCA\)并进行计算
我们考虑换一种角度,考虑\(LCA\)的贡献
\(x\)与\(S\)中点的\(LCA\)一定位于由\(x\)到树根的这一条路径上
对于每个修改操作,我们把由\(x\)到根的这条路径都打上标记(即加一)
每次查询就查询\(x\)到根上的路径的标记之和即可
为了保证复杂度,我们用树链剖分来处理
由于在修改时,我们将\(x\)到根的一整条路径上的标记都加了一,但对于点\(x\)来说,真正有意义的只是\(LCA(x,S_i)\)上的那一个标记,所以由根到\(fa[LCA(x,S_i)]\)上的标记实际上是不合法的
为了结局这个问题,我们可以把每个节点的权值设为\(F(dis(x))-F(dis(fa[x]))\)
这样求出来的和就只会计算到合法标记的贡献
为什么这样做是对的呢?这实际上是一个差分的操作
把差分数组的\([l,r]\)区间均加上\(1\),在求前缀和意义下实际上就是在\(a_l\)处加上了\(1\)
还要注意\(F\)函数要用线性筛进行预处理
代码
#include <cstdio>
using namespace std;
const int N = 2e5 + 10;
const int mod = 998244353;
int read();
int n, m, k;
const int MAX_N = 1e7 + 10;
int pri[MAX_N], is[MAX_N], func[MAX_N];
int qpow(int x, int y) {
int res = 1;
for (; y; y >>= 1, x = 1LL * x * x % mod)
if (y & 1) res = 1LL * res * x % mod;
return res;
}
void sieve() {
int cnt = 0; func[1] = 1;
for (int i = 2; i < MAX_N; ++i) {
if (!is[i]) pri[++cnt] = i, func[i] = qpow(i, k);
for (int j = 1; j <= cnt; ++j) {
if (i * pri[j] >= MAX_N) break;
is[i * pri[j]] = 1;
func[i * pri[j]] = 1LL * func[i] * func[pri[j]] % mod;
if (i % pri[j] == 0) break;
}
}
for (int i = 1; i < MAX_N; ++i)
func[i] = (func[i] + func[i - 1]) % mod;
}
int cap;
int head[N], to[N], nxt[N], val[N];
inline void add(int x, int y, int z) {
to[++cap] = y, nxt[cap] = head[x], head[x] = cap, val[cap] = z;
}
int ind;
int sz[N], dep[N], fa[N];
int top[N], son[N], ver[N], id[N];
void DFS(int x) {
sz[x] = 1;
for (int i = head[x]; i; i = nxt[i]) {
dep[to[i]] = dep[x] + val[i], fa[to[i]] = x, DFS(to[i]), sz[x] += sz[to[i]];
if (sz[to[i]] > sz[son[x]]) son[x] = to[i];
}
// printf("data: %d %d\n", x, son[x]);
}
void DFS(int x, int tp) {
top[x] = tp, id[x] = ++ind, ver[ind] = func[dep[x]] - func[dep[fa[x]]];
if (son[x]) {
DFS(son[x], tp);
for (int i = head[x]; i; i = nxt[i])
if (!id[to[i]]) DFS(to[i], to[i]);
}
}
struct SegTree {
#define ls(x) x << 1
#define rs(x) x << 1 | 1
struct node {
int val, sum, tag;
node() : val(0), tag(0), sum(0) {}
} t[N << 2];
void build(int x, int l, int r) {
if (l == r)
return t[x].val = ver[l], void();
int mid = l + r >> 1;
build(ls(x), l, mid);
build(rs(x), mid + 1, r);
t[x].val = (t[ls(x)].val + t[rs(x)].val) % mod;
}
void addtag(int x, int v) {
t[x].sum = (t[x].sum + 1LL * v * t[x].val % mod + mod) % mod;
t[x].tag = (t[x].tag + v) % mod;
}
void pushdown(int x) {
if (t[x].tag) {
addtag(ls(x), t[x].tag);
addtag(rs(x), t[x].tag);
t[x].tag = 0;
}
}
void modify(int x, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr)
return addtag(x, 1), void();
int mid = l + r >> 1;
pushdown(x);
if (ql <= mid)
modify(ls(x), l, mid, ql, qr);
if (qr > mid)
modify(rs(x), mid + 1, r, ql, qr);
t[x].sum = (t[ls(x)].sum + t[rs(x)].sum) % mod;
}
int query(int x, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr)
return t[x].sum;
int mid = l + r >> 1, res = 0;
pushdown(x);
if (ql <= mid)
res = (res + query(ls(x), l, mid, ql, qr)) % mod;
if (qr > mid)
res = (res + query(rs(x), mid + 1, r, ql, qr)) % mod;
return res;
}
#undef ls
#undef rs
} tr;
void change(int x) {
while (x) {
tr.modify(1, 1, n, id[top[x]], id[x]);
x = fa[top[x]];
}
}
int ask(int x) {
int res = 0;
while (x) {
res = (res + tr.query(1, 1, n, id[top[x]], id[x])) % mod;
x = fa[top[x]];
}
return res;
}
int vis[N];
int main() {
n = read(), m = read(), k = read();
for (int i = 2; i <= n; ++i) {
int u = read(), v = read();
add(u, i, v);
}
sieve();
DFS(1);
DFS(1, 1);
tr.build(1, 1, n);
while (m--) {
int op = read(), x = read();
if (op == 1) {
if (vis[x]) continue;
vis[x] = 1;
change(x);
} else
printf("%d\n", ask(x));
}
return 0;
}
int read() {
int x = 0, c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
return x;
}