倍增,处理一下前缀,代码如下
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cmath>
#include<vector>
using namespace std;
const int M=998244353;
int n,m;
vector<int> G[300005];
int pa[20][300005];
int dis[300005];
long long sum[55][300005];
void dfs(int v,int p,int d)
{
pa[0][v]=p;
dis[v]=d;
if(p==0) {
for(int i=1;i<=50;i++) {
sum[i][1]=0;
}
}
else{
int temp=1;
for(int i=1;i<=50;i++) {
temp=1ll*temp*d%M;
sum[i][v]=(sum[i][p]+temp)%M;
}
}
for(int i=0;i<G[v].size();i++) {
if(G[v][i]!=p) {
dfs(G[v][i],v,d+1);
}
}
}
void init()
{
dfs(1,0,0);
for(int k=0;k<20;k++) {
for(int v=1;v<=n;v++) {
if(pa[k][v]==0) pa[k+1][v]=0;
else pa[k+1][v]=pa[k][pa[k][v]];
}
}
}
int lca(int u,int v)
{
if(dis[u]>dis[v]) swap(u,v);
for(int k=0;k<20;k++) {
if((dis[v]-dis[u])>>k&1) {
v=pa[k][v];
}
}
if(u==v) {
return u;
}
for(int k=19;k>=0;k--) {
if(pa[k][u]!=pa[k][v]) {
u=pa[k][u];
v=pa[k][v];
}
}
return pa[0][u];
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++) {
int u,v;
scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
init();
scanf("%d",&m);
while(m--) {
int u,v,k;
scanf("%d%d%d",&u,&v,&k);
int lc=lca(u,v);
long long ans=(((sum[k][u]+sum[k][v])%M-sum[k][lc]+M)%M-sum[k][pa[0][lc]]+M)%M;
printf("%lld\n",ans);
}
}