题目描述
master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的kk 次方和,而且每次的kk 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
输入格式
第一行包含一个正整数\(n\),表示树的节点数。
之后\(n-1\) 行每行两个空格隔开的正整数\(i, j\),表示树上的一条连接点\(i\)和点\(j\)的边。
之后一行一个正整数\(m\)表示询问的数量。
之后每行三个空格隔开的正整数\(i, j, k\),表示询问从点ii 到点jj 的路径上所有节点深度的\(k\) 次方和。由于这个结果可能非常大,输出其对\(998244353\) 取模的结果。
树的节点从\(1\) 开始标号,其中\(1\)号节点为树的根。
输出格式
对于每组数据输出一行一个正整数表示取模后的结果。
输入输出样例
输入 #1复制
5
1 2
1 3
2 4
2 5
2
1 4 5
5 4 45
输出 #1复制
33
503245989
说明/提示
样例解释
以下用\(d (i)\) 表示第ii 个节点的深度。
对于样例中的树,有\(d (1) = 0, d (2) = 1, d (3) = 1, d (4) = 2, d (5) = 2\)
因此第一个询问答案为\((2^5 + 1^5 + 0^5)\ mod\ 998244353\),第二个询问答案为\((2^{45} + 1^{45} + 2^{45})\ mod\ 998244353 = 503245989\)
数据范围
对于\(30\%\) 的数据,\(1 \leq n,m \leq 100\)
对于\(60\%\) 的数据,\(1 \leq n,m \leq 1000\)
对于\(100\%\) 的数据,\(1 \leq n,m \leq 300000, 1 \leq k \leq 50\)
另外存在5个不计分的hack数据
提示
数据规模较大,请注意使用较快速的输入输出方式。
敲完树剖求lca华丽走人
我们可以发现,lca的情况无非就是三种
1.\(lca==a\)
2.\(lca==b\)
3.\(lca\)在\(a\)和\(b\)的上面
1,2情况直接暴力跳就行,
3.情况分别从\(a\)向\(lca\)和从\(b\)向\(lca\)跳,然后我们发现\(lca\)算了两次,然后再减去一次\(lca\)的贡献就行
#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
const int M=400100;
const int N=400100;
int ne[M],head[M],ver[M],idx;
int dep[N],fa[N],son[N],sz[N],top[N];
long long ans;
int n,m;
inline void add(int u,int v)
{
ne[idx]=head[u];
ver[idx]=v;
head[u]=idx;
idx++;
}
inline void dfs1(int u,int father,int depth)
{
fa[u]=father;
sz[u]=1;
dep[u]=depth;
for(int i=head[u]; i!=-1; i=ne[i])
{
int j=ver[i];
if(j==father)continue;
dfs1(j,u,depth+1);
sz[u]+=sz[j];
if(sz[son[u]]<sz[j]) son[u]=j;
}
}
inline void dfs2(int u,int t)
{
top[u]=t;
if(!son[u]) return ;
dfs2(son[u],t);
for(int i=head[u]; i!=-1; i=ne[i])
{
int j=ver[i];
if(j==fa[u]||j==son[u])continue;
dfs2(j,j);
}
}
inline int lca(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])
swap(u,v);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
return v;
}
inline int qmi(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans=(long long)ans*a%mod;
a=(long long)a*a%mod;
b>>=1;
}
return ans;
}
inline int read()
{
int x=0;
int f=1;
char ch;
ch=getchar();
while(ch>'9'||ch<'0')
{
if(ch=='-')f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=x*10,x=x+ch-'0';
ch=getchar();
}
return x*f;
}
int main()
{
memset(head,-1,sizeof(head));
n=read();
for(register int i=1; i<n; i++)
{
int u,v;
u=read();
v=read();
add(u,v);
add(v,u);
}
dfs1(1,0,0);
dfs2(1,1);
m=read();
for(register int i=1; i<=m; i++)
{
int a,b,k;
ans=0;
a=read();
b=read();
k=read();
int LCA=lca(a,b);
if(LCA==a)
{
for(register int j=dep[a]; j<=dep[b]; j++)
{
ans=(ans+qmi(j,k)+mod)%mod;
}
}
else if(LCA==b)
{
for(register int j=dep[b]; j<=dep[a]; j++)
{
ans=(ans+qmi(j,k)+mod)%mod;
}
}
else
{
for(register int j=dep[LCA]; j<=dep[b]; j++)
{
ans=(ans+qmi(j,k)%mod+mod)%mod;
}
for(register int j=dep[LCA]; j<=dep[a]; j++)
{
ans=(ans+qmi(j,k)%mod+mod)%mod;
}
ans=(ans-qmi(dep[LCA],k)%mod+mod)%mod;
}
printf("%lld\n",ans);
}
return 0;
}