master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 k 次方和,而且每次的 k 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。 他把这个问题交给了 pupil,但 pupil 并不会这么复杂的操作,你能帮他解决吗?
输入
第一行包含一个正整数 n,表示树的节点数。
之后 n−1行每行两个空格隔开的正整数 i,j,表示树上的一条连接点 i 和点 j的边。
之后一行一个正整数 m,表示询问的数量。
之后每行三个空格隔开的正整数 i,j,k,表示询问从点 i 到点 j 的路径上所有节点深度的 k 次方和。由于这个结果可能非常大,输出其对 998244353 取模的结果。
树的节点从 1 开始标号,其中 1 号节点为树的根。
输出
对于每组数据输出一行一个正整数表示取模后的结果。
样例输入 [复制]
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
样例输出 [复制]
33
503245989
预处理1-n的1-50次方前缀和。最后用两端的值减去lca的值即可
#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read()
{
int data=0;int w=1; char ch=0;
ch=getchar();
while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
if(ch=='-') w=-1,ch=getchar();
while(ch>='0' && ch<='9') data=(data<<3)+(data<<1)+ch-'0',ch=getchar();
return data*w;
}
const int N=511010;
const int mod=998244353;
int sum[51][N];
struct node{
int v,nxt;
}e[N<<1];
int fir[N],cnt=0;
inline void add(int u,int v){ e[++cnt]=(node){v,fir[u]};fir[u]=cnt;}
inline int quickpow(int a,int b){
int c=1;
while(b){
if(b&1) c=c*a%mod;
a=a*a%mod;
b=b>>1;
}
return c;
}
int n,m;
int dep[N],f[N][25];
inline void dfs(int u,int fa){
dep[u]=dep[fa]+1;
for(int i=0;i<=19;i++){
f[u][i+1]=f[f[u][i]][i];
}
for(int i=fir[u];i;i=e[i].nxt){
int v=e[i].v;
if(v==fa) continue;
f[v][0]=u;
dfs(v,u);
}
}
inline int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int i=20;i>=0;i--){
if(dep[f[x][i]]>=dep[y]) x=f[x][i];
if(x==y) return x;
}
for(int i=20;i>=0;i--){
if(f[x][i]!=f[y][i]){
x=f[x][i];y=f[y][i];
}
}
return f[x][0];
}
signed main(){
// int size=100<<20;//40M
//__asm__ ("movl %0, %%esp\n"::"r"((char*)malloc(size)+size));//调试用这个
// __asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));//提交用这个
//main函数代码
n=read();
for(int i=1;i<=50;i++){
for(int j=1;j<=n;j++){
sum[i][j]=(sum[i][j-1]+quickpow(j,i))%mod;
}
}
for(int i=1;i<n;i++){
int u=read(),v=read();
add(u,v);add(v,u);
}
dep[0]=-1;
dfs(1,0);
m=read();
for(int i=1;i<=m;i++){
int x=read(),y=read(),k=read();
int L=lca(x,y);
int ans=(sum[k][dep[x]]+sum[k][dep[y]]-sum[k][dep[L]]-sum[k][max(dep[f[L][0]],0ll)]+mod+mod)%mod;
printf("%lld\n",ans);
}
exit(0);//必须用exit
return 0;
}