题目
描述
master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k 次方和,而且每次的k 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
输入
第一行包含一个正整数
n
n
n,表示树的节点数。
之后 n − 1 n-1 n−1行每行两个空格隔开的正整数 i , j i, j i,j表示树上的一条连接点 i i i 和点 j j j 的边。
之后一行一个正整数 m m m,表示询问的数量。
之后每行三个空格隔开的正整数 i , j , k i, j, k i,j,k表示询问从点i 到点j 的路径上所有节点深度的 k k k 次方和。由于这个结果可能非常大,输出其对 998244353 998244353 998244353 取模的结果。
树的节点从
1
1
1开始标号,其中
1
1
1号节点为树的根。
输出
对于每组数据输出一行一个正整数表示取模后的结果
样例输入
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
样例输出
33
503245989
范围
对于
100
%
100\%
100%的数据,
1
≤
n
,
m
≤
300000
,
1
≤
k
≤
50
1 \leq n,m \leq 300000, 1 \leq k \leq 50
1≤n,m≤300000,1≤k≤50。
思路
我们要求
i
i
i和
j
j
j之间所有点的深度的
k
k
k次方和,那一定会经过它们的
L
C
A
LCA
LCA。那就可以转换为两部分,一是
i
i
i点到
L
C
A
LCA
LCA之间所有点深度的
k
k
k次方和,二是
j
j
j点到
L
C
A
LCA
LCA之间所有点深度的
k
k
k次方和。而
i
i
i点到
L
C
A
LCA
LCA之间所有点深度的
k
k
k次方和就可以转换为i到根节点的所有深度的
k
k
k次方和减去
L
C
A
LCA
LCA到根节点所有深度的
k
k
k次方和,
j
j
j点也相同
因为每条路径上每种深度是唯一的,因此你暴力把每种深度的
1
1
1~
50
50
50次方都打出来,在搞个前缀和就行了。
注意,
L
C
A
LCA
LCA的深度被减光了,因此还要加上一个
L
C
A
LCA
LCA深度的
k
k
k次方
心得
若果你在洛谷上一直40分,且是后面三个点和第三个点过,其他WA,我建议你重打,且在遍历树的时候就把前缀和搞出来(反正我是这样过的)
代码
#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;
#define mod 998244353
#define M 300005
#define LL long long
vector<int>G[M];
int n, m, s, e, k;
int f[M][22], dep[M];
LL p[M][55], d[M][55];
bool v[M];
inline int maxx(int x, int y){
return x > y? x: y;
}
inline LL qkp(LL x, LL y){
LL sum = 1;
while( y ){
if( y&1 )
sum = sum*x%mod;
x = x*x%mod;
y >>= 1;
}
return sum;
}
inline void dfs(int x, int last, int depth){
v[x] = 1;
f[x][0] = last;
dep[x] = depth;
for(int i = 0; i < 51; i ++){
d[depth][i] = qkp(depth, i);
p[depth][i] = (p[maxx(0,depth-1)][i]+d[depth][i])%mod;
}
for(int i = 1; i < 19; i ++)
f[x][i] = f[f[x][i-1]][i-1];
int siz = G[x].size();
for(int i = 0; i < siz; i ++){
int son = G[x][i];
if( !v[son] )
dfs(son, x, depth+1);
}
}
inline int LCA(int x, int y){
for(int i = 18; i >= 0; i --){
if( dep[f[x][i]] >= dep[y] )
x = f[x][i];
if( dep[f[y][i]] >= dep[x] )
y = f[y][i];
}
if( x != y ){
for(int i = 18; i >= 0; i --){
if( f[x][i] != f[y][i] )
x = f[x][i], y = f[y][i];
}
x = f[x][0];
}
return x;
}
int main(){
scanf("%d", &n);
for(int i = 1; i < n; i ++){
scanf("%d%d", &e ,&s);
G[e].push_back(s);
G[s].push_back(e);
}
dfs(1, 0, 0);
scanf("%d", &m);
while( m-- ){
scanf("%d%d%d", &s, &e, &k);
int lca = LCA(s, e);
printf("%lld\n", (((p[dep[s]][k]-p[dep[lca]][k]+mod)%mod + (p[dep[e]][k]-p[dep[lca]][k]+mod)%mod)%mod + d[dep[lca]][k])%mod );
}
return 0;
}