Description
master 对树上的求和非常感兴趣。
他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的k 次方和,而且每次的k 可能是不同的。
此处节点深度的定义是这个节点到根的路径上的边数。
他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
Input
第一行包含一个正整数n ,表示树的节点数。
之后n-1 行每行两个空格隔开的正整数i,j ,表示树上的一条连接点i 和点j 的边。
之后一行一个正整数m ,表示询问的数量。
之后每行三个空格隔开的正整数i,j,k ,表示询问从点i 到点j 的路径上所有节点深度的k 次方和。
由于这个结果可能非常大,输出其对998244353 取模的结果。
树的节点从1 开始标号,其中1 号节点为树的根。
Output
对于每组数据输出一行一个正整数表示取模后的结果。
1≤n,m≤300000,1≤k≤50
Sample Input
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
Sample Output
33
503245989
说明
样例解释
以下用d(i) 表示第i 个节点的深度。
对于样例中的树,有d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。
因此第一个询问答案为(2^5 + 1^5 + 0^5) mod 998244353 = 33
第二个询问答案为(2^45 + 1^45 + 2^45) mod 998244353 = 503245989。
503245989
说明
样例解释
以下用d(i) 表示第i 个节点的深度。
对于样例中的树,有d(1)=0,d(2)=1,d(3)=1,d(4)=2,d(5)=2。
因此第一个询问答案为(2^5 + 1^5 + 0^5) mod 998244353 = 33
第二个询问答案为(2^45 + 1^45 + 2^45) mod 998244353 = 503245989。
题解Here!
暴力就是暴力往上跳父亲,跳到一个节点统计答案。。。
这也太暴力了吧。。。
我们发现树的形态没有变,也就是每个节点的深度没有变!
那么我们直接做一个树上前缀和,每次询问时差分一下就好了。
并且$k<=50$,非常好,$O(nk)$并不会$TLE$。。。
于是直接暴力处理就好。
这个前缀和长这个样:
for(int i=1;i<=50;i++)val[rt][i]=(val[fa[rt]][i]+mexp(deep[rt]-1,i))%MOD;
每次询问就这样:
long long ans=((val[x][k]-val[lca][k]+MOD)%MOD+(val[y][k]-val[fa[lca]][k]+MOD)%MOD)%MOD;
记得取模会有负数。。。
总复杂度的话就是$O(nk\log_2k+m\log_2n)$。
话说$BJOI$怎么$T1$这么简单,送分题啊。。。
跟$HNOI/AHOI$完全不能比啊。。。
附代码:
#include<iostream>
#include<algorithm>
#include<cstdio>
#define MAXN 300010
#define MOD 998244353LL
using namespace std;
int n,m,c=1;
int head[MAXN],deep[MAXN],son[MAXN],size[MAXN],fa[MAXN],top[MAXN];
long long val[MAXN][55];
struct Tree{
int next,to;
}a[MAXN<<1];
inline int read(){
int date=0,w=1;char c=0;
while(c<'0'||c>'9'){if(c=='-')w=-1;c=getchar();}
while(c>='0'&&c<='9'){date=date*10+c-'0';c=getchar();}
return date*w;
}
long long mexp(long long a,long long b){
long long s=1;
while(b){
if(b&1)s=s*a%MOD;
a=a*a%MOD;
b>>=1;
}
return s%MOD;
}
inline void add(int x,int y){
a[c].to=y;a[c].next=head[x];head[x]=c++;
a[c].to=x;a[c].next=head[y];head[y]=c++;
}
void dfs1(int rt){
son[rt]=0;size[rt]=1;
for(int i=1;i<=50;i++)val[rt][i]=(val[fa[rt]][i]+mexp(deep[rt]-1,i))%MOD;
for(int i=head[rt];i;i=a[i].next){
int will=a[i].to;
if(!deep[will]){
deep[will]=deep[rt]+1;
fa[will]=rt;
dfs1(will);
size[rt]+=size[will];
if(size[will]>size[son[rt]])son[rt]=will;
}
}
}
void dfs2(int rt,int f){
top[rt]=f;
if(son[rt])dfs2(son[rt],f);
for(int i=head[rt];i;i=a[i].next){
int will=a[i].to;
if(will!=fa[rt]&&will!=son[rt])dfs2(will,will);
}
}
int LCA(int x,int y){
while(top[x]!=top[y]){
if(deep[top[x]]<deep[top[y]])swap(x,y);
x=fa[top[x]];
}
if(deep[x]>deep[y])swap(x,y);
return x;
}
void work(){
int x,y,k;
while(m--){
x=read();y=read();k=read();
int lca=LCA(x,y);
long long ans=((val[x][k]-val[lca][k]+MOD)%MOD+(val[y][k]-val[fa[lca]][k]+MOD)%MOD)%MOD;
printf("%lld\n",ans);
}
}
void init(){
int x,y;
n=read();
for(int i=1;i<n;i++){
x=read();y=read();
add(x,y);
}
m=read();
deep[1]=1;
dfs1(1);
dfs2(1,1);
}
int main(){
init();
work();
return 0;
}