题目描述
master对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k次方和,而且每次的k可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil并不会这么复杂的操作,你能帮他解决吗?
输入
第一行包含一个正整数n,表示树的节点数。
之后n−1行每行两个空格隔开的正整数i,j,表示树上的一条连接点i和点j的边。
之后一行一个正整数m,表示询问的数量。
之后每行三个空格隔开的正整数i,j,k,表示询问从点i到点j的路径上所有节点深度的k次方和。由于这个结果可能非常大,输出其对998244353取模的结果。
树的节点从1开始标号,其中1号节点为树的根。
输出
对于每组数据输出一行一个正整数表示取模后的结果。
样例输入
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
样例输出
33
503245989
提示
以下用d(i)表示第i个节点的深度。
对于样例中的树,有d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。
因此第一个询问答案为(2^5+1^5+0^5) mod 998244353=33,第二个询问答案为(2^45+1^45+2^45) mod 998244353=503245989。
对于30%的数据,1≤n,m≤100;
对于60%的数据,1≤n,m≤1000;
对于100%的数据,1≤n,m≤300000,1≤k≤50。
Solution
最近公共祖先(LCA)模板题,当然还是祭出最可爱的太监算法啦!
注意3e5个询问,还带50幂次求和,挨个儿去算很容易T,预处理打个表就好啦
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MX = 300005;
const int MOD = 998244353;
int n, x, y, Q, ans[MX], f[MX], deep[MX], que[MX], h, t;
int bz[MX], head[MX], tot, headQ[MX], totQ;
int biao[300001][51];
bool vis[MX];
struct node {
int v, w, nxt;
} e[MX << 1], q[MX << 1];
struct data {
int u, v, w;
} query[MX];
int Find(int x) {
return f[x] == x ? x : (f[x] = Find(f[x]));
}
void addEdge(int u, int v) {
tot++;
e[tot].v = v;
e[tot].nxt = head[u];
head[u] = tot;
tot++;
e[tot].v = u;
e[tot].nxt = head[v];
head[v] = tot;
}
void addEdge2(int u, int v, int w) {
totQ++;
q[totQ].v = v;
q[totQ].w = w;
q[totQ].nxt = headQ[u];
headQ[u] = totQ;
totQ++;
q[totQ].v = u;
q[totQ].w = w;
q[totQ].nxt = headQ[v];
headQ[v] = totQ;
}
void work(int x) {
bz[x] = 1;
for (int k = head[x]; k; k = e[k].nxt)
if (!bz[e[k].v]) {
work(e[k].v), f[e[k].v] = x;
}
for (int k = headQ[x]; k; k = q[k].nxt)
if (bz[q[k].v])
ans[q[k].w] = Find(q[k].v);
}
int Ans(int I) {
int a = deep[ans[I]];
int b = deep[query[I].u];
int c = deep[query[I].v];
int w = query[I].w;
int res;
if (b > c) swap(b, c);
if (a==0) res = biao[b][w];
else res=(biao[b][w]-biao[a-1][w]+MOD)%MOD;
res=(res+biao[c][w])%MOD;
res=(res-biao[a][w]+MOD)%MOD;
return res;
}
int main() {
//freopen("../in", "r", stdin);
for (int i = 1; i <= 300000; ++i) {
biao[i][0] = 1;
for (int j = 1; j <= 50; ++j)
biao[i][j] = 1ll * biao[i][j - 1] * i % MOD;
}
for (int i=1;i<=50;++i)
for (int j=1;j<=300000;++j)
biao[j][i] = (biao[j - 1][i] + biao[j][i]) % MOD;
scanf("%d", &n);
for (int i = 1; i < n; ++i) {
scanf("%d%d", &x, &y);
addEdge(x, y);
}
deep[1] = 0;
h = t = -1;
que[++t] = 1;
vis[1] = true;
while (h < t) {
x = que[++h];
for (int i = head[x]; i; i = e[i].nxt) {
y = e[i].v;
if (vis[y]) continue;
vis[y] = true;
deep[y] = deep[x] + 1;
que[++t] = y;
}
}
//for (int i=1;i<=n;++i) printf("%d %d\n",i,deep[i]);
scanf("%d", &Q);
for (int i = 0; i < Q; ++i) {
scanf("%d%d%d", &x, &y, &t);
addEdge2(x, y, i);
query[i].u = x;
query[i].v = y;
query[i].w = t;
}
for (int i = 1; i <= n; ++i) f[i] = i;
work(1);
for (int i = 0; i < Q; ++i) printf("%d\n", Ans(i));
}