对于一组$s_{1\cdots k}$,合法的$u$构成一个连通块,满足$\left\lvert V\right\rvert-\left\lvert E\right\rvert=1$
考虑算出算$f_{x,i}$表示$x$子树内与$x$距离$\leq i$的点构成含$x$连通块的方案数,类似定义$g$表示子树外
那么答案就是$\sum\limits_{i=1}^n\left(f_{i,L}g_{i,L}\right)^k-\sum\limits_{fa_y=x}\left(f_{y,L-1}\left(g_{y,L}-1\right)\right)^k$
$f_{x,i}=\prod\limits_{fa_y=x}\left(f_{y,i-1}+1\right)$
$g_{y,i}=1+g_{x,i-1}\prod\limits_{\substack{fa_z=y\\y\neq z}}\left(f_{z,i-2}+1\right)$
自底向上转移$f$,自顶向下转移$g$,我们得到一个$O\left(nL\right)$的暴力
设$md_x$表示以$x$为根的子树的最大深度,那么对于$i\geq md_x$有$f_{x,i}=f_{x,md_x}$,只需求出$g_{x,0\cdots md_x}$
对于$g$,因为最后我们需要$g_{x,L}$,所以只需求出$g_{x,L-md_x\cdots L}$
两部分要求的数量都是$md_x$,考虑用长链剖分优化
对于$f$和$g$的重边转移,都是重边$O(1)$,轻边暴力转移
对于$g$的轻边转移,我们需要$f$的前后缀信息
首先转移时只有后缀乘和全局加,维护一个全局$ax+b$的标记,后缀乘$v$时先把前缀乘上$v^{-1}$,然后再全局乘,对于$v=0$,再维护一个后缀赋值标记即可,显然任意时刻只会存在一个赋值标记
对于前后缀信息,因为我们在转移$f$时本质就是在求前缀,所以把整个数据结构可回退化即可
在求$g$时逆序访问所有节点(相对于求$f$时的顺序),回退的同时维护后缀信息即可
一个小问题是求逆,但要求逆的只有$f_{x,md_x}$,所以一开始$O(n)$DP求出所有$f_{x,md_x}$再$O\left(n+\log p\right)$求出所有数的逆就可以把总时间复杂度降到线性了
好像并没有想象中的那么毒瘤...
#include<stdio.h>
#include<string.h>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
const int mod=998244353,inf=2147483647;
int mul(int a,int b){return(ll)a*b%mod;}
int ad(int a,int b){return(a+=b)>=mod?a-mod:a;}
void inc(int&a,int b){(a+=b)>=mod?a-=mod:0;}
int pow(int a,int b){
int s=1;
while(b){
if(b&1)s=mul(s,a);
a=mul(a,a);
b>>=1;
}
return s;
}
const int N=1e6+10;
int n,L;
int h[N],nex[N*2],to[N*2],M;
void add(int a,int b){
M++;
to[M]=b;
nex[M]=h[a];
h[a]=M;
}
int md[N],son[N];
vi e[N];
int fL[N];
void dfs(int fa,int x){
int i,k,mx;
k=0;
mx=0;
fL[x]=1;
for(i=h[x];i;i=nex[i]){
if(to[i]!=fa){
dfs(x,to[i]);
fL[x]=mul(fL[x],fL[to[i]]+1);
if(md[to[i]]+1>mx){
mx=md[to[i]]+1;
k=to[i];
}
}
}
son[x]=k;
md[x]=mx;
for(i=h[x];i;i=nex[i]){
if(to[i]!=fa&&to[i]!=son[x])e[x].push_back(to[i]);
}
}
int pos[N];
void dfs(int x){
pos[x]=++M;
if(son[x])dfs(son[x]);
for(int y:e[x])dfs(y);
}
struct op{
int t,*p,v;
}st[N*12];
int tp;
bool flag;
void sgn(int&x,int v){
if(flag)st[++tp]={M,&x,x};
x=v;
}
void back(){
*st[tp].p=st[tp].v;
tp--;
}
struct arr{
int*f,n,a,ia,b,ti,tv;
int get(int x){
return ad(mul(a,x<ti?f[min(x,n)]:tv),b);
}
void mult(int x,int v){
if(x<0||x>n)return;
sgn(f[x],mul(mul(v,mul(a,f[x])+b)+mod-b,ia));
}
void set(int x){
if(x<=n){
sgn(ti,x);
sgn(tv,mul(mod-b,ia));
}
}
void set(int x,int v){
if(x<0||x>n)return;
f[x]=mul(v+mod-b,ia);
}
}f[N],g[N],tmp;
int pf[N];
int rf[N];
int f1[N],f2[N];
int tm[N];
void dfs1(int x){
if(son[x]){
dfs1(son[x]);
f[x]=f[son[x]];
f[x].f--;
if(f[x].ti!=inf)f[x].ti++;
f[x].n++;
}else
f[x]={pf+pos[x],0,1,1,0,inf,0};
inc(f[x].b,1);
f[x].set(0,1);
int i,t;
for(int y:e[x])dfs1(y);
for(int y:e[x]){
tm[y]=++M;
for(i=0;i<=md[y];i++)f[x].mult(i+1,f[y].get(i)+1);
t=ad(fL[y],1);
if(t){
for(i=0;i<=md[y]+1;i++)f[x].mult(i,rf[y]);
f[x].a=mul(f[x].a,t);
f[x].ia=mul(f[x].ia,rf[y]);
f[x].b=mul(f[x].b,t);
}else
f[x].set(md[y]+2);
}
f1[x]=f[x].get(L);
f2[x]=f[x].get(L-1);
}
int pg[N],pt[N];
int gL[N];
void dfs2(int x){
int i,j,n,t;
reverse(e[x].begin(),e[x].end());
n=0;
for(int y:e[x])n=max(n,md[y]);
memset(pt,0,(n+1)<<2);
tmp={pt,n,1,1,1,inf,0};
for(int y:e[x]){
t=ad(fL[y],1);
if(t){
f[x].a=mul(f[x].a,rf[y]);
f[x].ia=mul(f[x].ia,t);
f[x].b=mul(f[x].b,rf[y]);
}
while(st[tp].t==tm[y])back();
g[y]={pg+pos[y],md[y],1,1,0,inf,0};
for(i=0;i<=md[y];i++){
j=i-md[y]+L;
if(j<2){
if(j==1)g[y].f[i]=1;
continue;
}
g[y].f[i]=mul(g[x].get(j-1+md[x]-L),mul(f[x].get(j-1),tmp.get(j-2)));
}
inc(g[y].b,1);
for(i=0;i<=md[y];i++)tmp.mult(i,f[y].get(i)+1);
if(t){
for(i=0;i<=md[y];i++)tmp.mult(i,rf[y]);
tmp.a=mul(tmp.a,t);
tmp.ia=mul(tmp.ia,rf[y]);
tmp.b=mul(tmp.b,t);
}else
tmp.set(md[y]+1);
}
gL[x]=g[x].get(md[x]);
if(son[x]){
g[son[x]]=g[x];
g[son[x]].n--;
for(int y:e[x]){
for(i=0;i<=md[y];i++)g[son[x]].mult(i+2+md[son[x]]-L,f[y].get(i)+1);
j=md[y]+2+md[son[x]]-L;
if(j<md[son[x]]){
t=ad(fL[y],1);
if(t){
for(i=max(0,md[son[x]]-L);i<=j;i++)g[son[x]].mult(i,rf[y]);
g[son[x]].a=mul(g[son[x]].a,t);
g[son[x]].ia=mul(g[son[x]].ia,rf[y]);
g[son[x]].b=mul(g[son[x]].b,t);
}else
g[son[x]].set(j+1);
}
}
inc(g[son[x]].b,1);
g[son[x]].set(md[son[x]]-L,1);
g[son[x]].set(md[son[x]]-L+1,2);
}
for(int y:e[x])dfs2(y);
if(son[x])dfs2(son[x]);
}
int a[N],sa[N];
int main(){
int k,i,x,y,res;
scanf("%d%d%d",&n,&L,&k);
if(!L){
printf("%d",n);
return 0;
}
for(i=1;i<n;i++){
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(0,1);
sa[0]=1;
for(i=1;i<=n;i++){
a[i]=ad(fL[i],1);
if(!a[i])a[i]=1;
sa[i]=mul(sa[i-1],a[i]);
}
rf[n]=pow(sa[n],mod-2);
for(i=n;i>0;i--)rf[i-1]=mul(rf[i],a[i]);
for(i=1;i<=n;i++)rf[i]=mul(rf[i],sa[i-1]);
M=0;
dfs(1);
M=0;
flag=1;
dfs1(1);
g[1]={pg+1,md[1],1,1,1,inf,0};
flag=0;
dfs2(1);
res=0;
for(i=1;i<=n;i++)inc(res,pow(mul(f1[i],gL[i]),k));
for(i=2;i<=n;i++)inc(res,mod-pow(mul(f2[i],gL[i]+mod-1),k));
printf("%d\n",res);
}