题意:给一棵树,两个参数 k , L k,L k,L,需要选择 k k k 个连通块,使得这 k k k 个连通块存在一个公共点,且该公共点到 k k k 个连通块内的任意一点的距离不超过 L L L,求方案数 模 998244353 998244353 998244353。两种方案不同当且仅当连通块的集合不同。
n ≤ 1 0 6 , k ≤ 10 n\leq 10^6,k\leq 10 n≤106,k≤10
已经写绝望了
对于一种连通块的集合,合法的钦定的点一定是一个连通块。所以可以通过 点 数 − 边 数 = 1 点数-边数=1 点数−边数=1 来容斥。即考虑每个点的贡献,再减去对于每条边,两个端点都合法的方案。
然后考虑暴力 dp。
设 f ( u , L ) f(u,L) f(u,L) 为以 u u u 为根的子树内,包含 u u u,到的距离不超过 L L L 的连通块个数 + 1 +1 +1(为了方便转移,也可理解为允许为空)。
g ( u , L ) g(u,L) g(u,L) 表示 u u u 往上走,即必须包含 u u u,且不能包含 u u u 子树内其他结点,到 u u u 的距离不超过 L L L 的连通块个数。(注意不 + 1 +1 +1,即不能为空。)
得到转移
f ( u , L ) = ∏ v ∈ s o n ( u ) f ( v , L − 1 ) + 1 f(u,L)=\prod_{v\in son(u)}f(v,L-1)+1 f(u,L)=v∈son(u)∏f(v,L−1)+1
边界 f ( u , 0 ) = 1 f(u,0)=1 f(u,0)=1
g ( u , L ) = g ( f a u , L − 1 ) ∏ v ∈ s o n ( f a u ) , v ≠ u f ( v , L − 2 ) + 1 g(u,L)=g(fa_u,L-1)\prod_{v\in son(fa_u),v\neq u}f(v,L-2)+1 g(u,L)=g(fau,L−1)v∈son(fau),v=u∏f(v,L−2)+1
边界 f ( u , 0 ) = f ( u , − 1 ) = 1 f(u,0)=f(u,-1)=1 f(u,0)=f(u,−1)=1。后面这个 + 1 +1 +1 表示 { u } \{u\} {u} 这个连通块。
最终答案为
∑ u = 1 n ( f ( u , L ) − 1 ) k g ( u , L ) k − [ u ≠ r t ] ( f ( u , L − 1 ) − 1 ) k ( g ( u , L ) − 1 ) k \sum_{u=1}^n(f(u,L)-1)^kg(u,L)^k-[u\neq rt](f(u,L-1)-1)^k(g(u,L)-1)^k u=1∑n(f(u,L)−1)kg(u,L)k−[u=rt](f(u,L−1)−1)k(g(u,L)−1)k
发现状态和深度有关,考虑长链剖分。以下设 m x u mx_u mxu 表示 u u u 到子树内最远点经过的 点数,简称深度。
f ( u , L ) = ∏ v ∈ s o n ( u ) f ( v , L − 1 ) + 1 f(u,L)=\prod_{v\in son(u)}f(v,L-1)+1 f(u,L)=v∈son(u)∏f(v,L−1)+1
这个是经典的长链剖分的形式,直接继承长儿子的信息,短儿子暴力转移。
然后状态定义的是不超过,所以你维护的只是 DP 数组里 [ 0 , m x u ) [0,mx_u) [0,mxu) 的信息, [ m x u , + ∞ ) [mx_u,+\infin) [mxu,+∞) 也是有值的。如果暴力到长儿子的深度会让复杂度退化。
不过注意到 [ m x u , + ∞ ) [mx_u,+\infin) [mxu,+∞) 内的值都是 f ( u , m x u − 1 ) f(u,mx_u-1) f(u,mxu−1),所以相当于是个后缀乘法。然后 DP 式子后面还有个 +1 ,相当于要维护以下操作:
- 单点修改,要求 O ( 1 ) O(1) O(1)
- 全局加,要求 O ( 1 ) O(1) O(1)
- [ x , + ∞ ) [x,+\infin) [x,+∞) 乘,要求 O ( x ) O(x) O(x)
这可以通过打全局标记来实现。具体来讲,我们对当前点 u u u 维护两个标记 m u l u , p l s u mul_u,pls_u mulu,plsu,表示存储的一个数 x x x 表示的真实值为 m u l u x + p l s u mul_u x+pls_u mulux+plsu。
2 操作直接改标记,3操作修改 m u l u mul_u mulu 后把 [ 0 , x ) [0,x) [0,x) 乘上逆元,1 操作改完后倒着把存储的值算出来放进去,就可以做到 O ( n ) O(n) O(n)。
你以为这就完了?奶义务!
乘上的这个数可能在模意义下为 0 0 0,是没有逆元的,并且不像一年后的某道莫反矩阵树缝合怪题,这个东西非常好构造,直接连长度分别为 2 , 2 , ⋯ , 2 ⏟ 23 , 6 , 16 \begin{matrix} \underbrace{ 2,2,\cdots,2 } \\ 23\end{matrix},6,16 2,2,⋯,223,6,16 的链就可以了。
所以我为什么没在 CSP 前看到这个东西
所以我们需要再开两个标记 l i m u , v a l u lim_u,val_u limu,valu,表示 [ l i m u , + ∞ ) [lim_u,+\infin) [limu,+∞) 这一段的存储的值是 v a l u val_u valu。如果这个数是 0 0 0,相当于后缀赋值,把 l i m u lim_u limu 赋值为 x x x, v a l u val_u valu 赋值为真实值为 0 0 0 时对应的存储值。
Q:为什么不能定义为" l i m u lim_u limu 及之后的数都是 0 0 0",还可以少开个标记?
A:因为这里只是暂时为 0 0 0,之后的全局加对这里是有影响的。
这样做到了 O ( n log P ) O(n\log P) O(nlogP)。注意到每次求逆元的都是 f ( v , m x v − 1 ) f(v,mx_v-1) f(v,mxv−1) ,即不限制距离的方案数,所以可以先做一个简单的 DP 算出来,然后 O ( n ) O(n) O(n) 离线求逆元,注意要跳过为 0 0 0 的。维护 m u l u mul_u mulu 标记的时候顺便维护一下它的逆元,就可以 O ( n ) O(n) O(n) 了。
g ( u , L ) = g ( f a u , L − 1 ) ∏ v ∈ s o n ( f a u ) , v ≠ u f ( v , L − 2 ) + 1 g(u,L)=g(fa_u,L-1)\prod_{v\in son(fa_u),v\neq u}f(v,L-2)+1 g(u,L)=g(fau,L−1)v∈son(fau),v=u∏f(v,L−2)+1
大家可能会觉得很奇怪,这个往上走的 DP 怎么能用长链剖分优化呢?
注意到我们答案需要的只有 g ( u , L ) g(u,L) g(u,L),所以对于一个叶子结点,它没有儿子需要它的其他信息,所以只需要维护 g ( u , L ) g(u,L) g(u,L) 这一个位置。类似的,对于点 u u u ,我们只需要维护 [ L − m x u + 1 , L ] [L-mx_u+1,L] [L−mxu+1,L] 中的值。
也就是说我们规定 g ( u , … ) g(u,\dots) g(u,…) 的定义域只有 [ max ( L − m x u + 1 , 0 ) , L ] [\max(L-mx_u+1,0),L] [max(L−mxu+1,0),L],这样状态数就和深度正相关了。
把信息直接继承给长儿子,短儿子暴力转移,再乘上一个 f ( u , L − 1 ) − 1 f ( v , L − 2 ) \frac{f(u,L-1)-1}{f(v,L-2)} f(v,L−2)f(u,L−1)−1
然后你又错了,因为 f ( v , L − 2 ) f(v,L-2) f(v,L−2) 可能没有逆元。所以我们只能算前缀积和后缀积了。
前缀积在遍历的时候可以顺便维护。为了方便实现,可以把每个结点的轻儿子按深度从小到大排序,这样你只需要记 3 3 3 个标记。严格意义上需要桶排保证复杂度,不过直接 sort 也能过。之后假装这个排序是 O ( n ) O(n) O(n) 的。
然后开一个数组 p r e pre pre ,用 p r e i pre_i prei 记录 f ( v , i ) f(v,i) f(v,i) 的前缀积就可以了,配合后缀赋值标记就可以维护整个前缀积。
对于后缀积是不能跑一遍记下来的,因为开不下……
但我们在计算 f f f 的时候做了一遍这东西,怎么能浪费了呢?
我们在计算 f f f 的时候倒着做,即按轻儿子深度从大到小排序。对于 d p dp dp 值和 5 5 5 个标记的修改,把它修改的过程记录下来,对就是可撤销并查集的那个东西。
然后在算 g g g 的时候不断把修改撤销,这样 f ( u , L − 1 ) f(u,L-1) f(u,L−1) 维护的就是后缀积。
这样只能算出轻儿子,重儿子因为撤回不了,所以需要再利用之前你算的前缀积单独搞一下。
因为还是有乘法和全局加操作,所以你还是得维护一堆标记。并且尽管定义域很有限,为了保证复杂度,你还是得维护后缀赋值标记。注意这个标记和前缀积的标记没有关系。
需要注意的细节:
- 因为有边界情况,需要手动把 f ( u , 0 ) f(u,0) f(u,0) 改成 1 1 1。注意因为定义不同, f f f 需要先改 0 0 0 再全局加,而 g g g 是全局加了再改 0 0 0。并且 g g g 还要判断 0 0 0 在不在定义域内。
- 边界情况 l i m u lim_u limu 需要维护准确值,或者用其他一些骚操作,不然之前的值会出问题。
- 因为定义域的问题,不能偷懒把前缀积挂在 g ( u , L − 1 ) g(u,L-1) g(u,L−1) 上。
- 回退时的 f ( u , L − 1 ) f(u,L-1) f(u,L−1) 实际上维护的是 f ( v , L − 2 ) f(v,L-2) f(v,L−2) 的后缀积,在 L = 1 L=1 L=1 的时候是未定义的,需要特判。
复杂度 O ( n log k ) O(n\log k) O(nlogk)
用尽各种毒瘤方法把一个不可做的计数题做到线性,最后却因为一个 10 10 10 的快速幂无法把复杂度写成 O ( n ) O(n) O(n),真是悲壮……
代码中的迷惑部分都有注释。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <utility>
#include <list>
#include <algorithm>
#define MAXN 1000005
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;}
inline int dec(const int& x,const int& y){return x<y? x-y+MOD:x-y;}
inline int qpow(int a,int p)
{
int ans=1;
while (p)
{
if (p&1) ans=(ll)ans*a%MOD;
a=(ll)a*a%MOD,p>>=1;
}
return ans;
}
vector<int> T[MAXN],e[MAXN];//T 是所有相邻的点,e 是所有轻儿子
int fa[MAXN],son[MAXN],mx[MAXN],s[MAXN],sinv[MAXN],n,L,k;
void dfs(int u,int f)
{
fa[u]=f,s[u]=1;
for (int i=0;i<(int)T[u].size();i++)
if (T[u][i]!=f)
{
dfs(T[u][i],u);
if (mx[T[u][i]]>mx[son[u]]) son[u]=T[u][i];
s[u]=(ll)s[u]*s[T[u][i]]%MOD;
}
mx[u]=mx[son[u]]+1;
s[u]=add(s[u],1);
}
int fac[MAXN],finv[MAXN];
inline bool cmp(const int& x,const int& y){return mx[x]>mx[y];}
inline void init()
{
fac[0]=1;
for (int i=1;i<=n;i++)
if (s[i]) fac[i]=(ll)fac[i-1]*s[i]%MOD;
else fac[i]=fac[i-1];
finv[n]=qpow(fac[n],MOD-2);
for (int i=n-1;i>=1;i--)
if (s[i+1]) finv[i]=(ll)finv[i+1]*s[i+1]%MOD;//跳过 0,后同
else finv[i]=finv[i+1];
for (int i=1;i<=n;i++) if (s[i]) sinv[i]=(ll)finv[i]*fac[i-1]%MOD;
for (int i=1;i<=n;i++) stable_sort(e[i].begin(),e[i].end(),cmp);//stable 是为了方便调试
}
void dfs(int u)
{
if (son[u]) dfs(son[u]);
for (int i=0;i<(int)T[u].size();i++)
if (T[u][i]!=fa[u]&&T[u][i]!=son[u])
e[u].push_back(T[u][i]),dfs(T[u][i]);
}
int F1[MAXN],F2[MAXN],G1[MAXN];
struct BackDS
{
typedef pair<int*,int> pi;
list<pi> his;
inline void modify(int& x,int v){his.push_back(make_pair(&x,x)),x=v;}
inline void undo(){while (!his.empty()) *his.back().first=his.back().second,his.pop_back();}
}q[MAXN];
namespace F
{
int buf[MAXN],*cur=buf;
int* dp[MAXN];
inline int* newbuf(int x){int* p=cur;cur+=x;return p;}
int mul[MAXN],inv[MAXN],pls[MAXN],lim[MAXN],val[MAXN];
inline int calc(int u,int i)//计算真实值
{
if (i<lim[u]) return ((ll)mul[u]*dp[u][i]+pls[u])%MOD;
else return ((ll)mul[u]*val[u]+pls[u])%MOD;
}
inline int clac(int u,int v){return (ll)dec(v,pls[u])*inv[u]%MOD;}//根据真实值的得到应该存储的值
void dfs(int u)
{
if (son[u])
{
dp[son[u]]=dp[u]+1;
dfs(son[u]);
mul[u]=mul[son[u]],inv[u]=inv[son[u]],pls[u]=pls[son[u]];
lim[u]=lim[son[u]]+1,val[u]=val[son[u]];
dp[u][0]=clac(u,1);
}
else
{
mul[u]=inv[u]=lim[u]=1,pls[u]=F1[u]=F2[u]=2;
return;
}
int las=0;
for (int k=0;k<(int)e[u].size();k++)
{
int v=las=e[u][k];
dp[v]=newbuf(mx[v]),dfs(v);
for (int i=1;i<=mx[v];i++)
{
if (i==lim[u]) q[v].modify(dp[u][i],val[u]), q[v].modify(lim[u],lim[u]+1);
q[v].modify(dp[u][i],clac(u,(ll)calc(u,i)*calc(v,i-1)%MOD));
}
if (s[v])
{
q[v].modify(mul[u],(ll)mul[u]*s[v]%MOD);
q[v].modify(inv[u],(ll)inv[u]*sinv[v]%MOD);
q[v].modify(pls[u],(ll)pls[u]*s[v]%MOD);
for (int i=0;i<=mx[v];i++) q[v].modify(dp[u][i],clac(u,(ll)calc(u,i)*sinv[v]%MOD));
}
else q[v].modify(lim[u],mx[v]+1),q[v].modify(val[u],clac(u,0));
}
if (las) q[las].modify(pls[u],add(pls[u],1));//把全局加挂在最后一个轻儿子上,这样一来就会撤回
else pls[u]=add(pls[u],1);//没有轻儿子的话反正都没有用,随便加
F1[u]=calc(u,L),F2[u]=calc(u,L-1);
}
inline void solve(){dp[1]=newbuf(mx[1]),dfs(1);}
}
namespace G
{
int buf[MAXN],pre[MAXN],*cur=buf;
int* dp[MAXN];
inline int* newbuf(int x){int* p=cur;cur+=x;return p;}
int mul[MAXN],inv[MAXN],pls[MAXN],lim[MAXN],val[MAXN];
inline int calc(int u,int i)
{
if (i<lim[u]) return ((ll)mul[u]*dp[u][i]+pls[u])%MOD;
return ((ll)mul[u]*val[u]+pls[u])%MOD;
}
inline int clac(int u,int v){return (ll)dec(v,pls[u])*inv[u]%MOD;}
void dfs(int u)
{
G1[u]=calc(u,L);
pre[0]=1;
int pos=1,cur=1,cinv=1;
for (int k=(int)e[u].size()-1;k>=0;k--)//按深度从小到达枚举
{
int v=e[u][k];
q[v].undo();
dp[v]=newbuf(mx[v])-max(0,L-mx[v]+1);
mul[v]=inv[v]=1,lim[v]=L+1;
for (int i=max(0,L-mx[v]+1);i<=L;i++)
{
int t=1;
if (i) t=(ll)t*calc(u,i-1)%MOD;
if (i>1)
{
t=(ll)t*F::calc(u,i-1)%MOD;//见细节4
if (i-2<pos) t=(ll)t*pre[i-2]%MOD;
else t=(ll)t*cur%MOD;
}
dp[v][i]=clac(v,t);
}
pls[v]=add(pls[v],1);
if (L-mx[v]+1<=0) dp[v][0]=clac(v,1);//是否在定义域内
for (int i=0;i<=mx[v];i++)
{
if (i<pos) pre[i]=(ll)pre[i]*F::calc(v,i)%MOD;
else pre[i]=(ll)cur*F::calc(v,i)%MOD;
}
pos=mx[v]+1;
cur=(ll)cur*s[v]%MOD,cinv=(ll)cinv*sinv[v]%MOD;
}
int v=son[u];
if (v)
{
mul[v]=mul[u],inv[v]=inv[u],pls[v]=pls[u],lim[v]=lim[u]+1,val[v]=val[u];
dp[v]=dp[u]-1;
for (int i=max(2,L-mx[v]+1);i<=pos+1;i++)
{
if (i==lim[v]) dp[v][lim[v]++]=val[v];
dp[v][i]=clac(v,(ll)calc(v,i)*pre[i-2]%MOD);
}
if (cur)
{
mul[v]=(ll)mul[v]*cur%MOD;
pls[v]=(ll)pls[v]*cur%MOD;
inv[v]=(ll)inv[v]*cinv%MOD;
for (int i=max(0,L-mx[v]+1);i<=pos+1;i++) dp[v][i]=clac(v,(ll)calc(v,i)*cinv%MOD);
}
else lim[v]=pos+1,val[v]=clac(v,0);
pls[v]=add(pls[v],1);
if (L-mx[v]+1<=0) dp[v][0]=clac(v,1);
dfs(v);
}
for (int i=0;i<(int)e[u].size();i++) dfs(e[u][i]);//算完再递归,避免 pre 冲突
}
inline void solve(){dp[1]=newbuf(mx[1])-max(L-mx[1]+1,0),mul[1]=inv[1]=pls[1]=1,lim[1]=L+1,dfs(1);}
}
int main()
{
n=read(),L=read(),k=read();
if (!L) return printf("%d\n",n),0;
for (int i=1;i<n;i++)
{
int u,v;
u=read(),v=read();
T[u].push_back(v),T[v].push_back(u);
}
dfs(1,0),dfs(1);
init();
F::solve(), G::solve();
int ans=0;
for (int i=1;i<=n;i++)
{
ans=add(ans,qpow((ll)dec(F1[i],1)*G1[i]%MOD,k));
if (i>1) ans=dec(ans,qpow((ll)dec(F2[i],1)*dec(G1[i],1)%MOD,k));
}
cout<<ans;
return 0;
}