Description
给出一棵树,定义两个点的距离为其路径上的边数。
问有多少种选择k个有序连通子图的方案,满足存在一个点u,其在所有子图内,且子图中的所有点到其的距离<=L
n<=10^6
Solution
考虑枚举点u,问题变成有多少个连通子图到u的距离<=L且包含u
但是这样会算重所以我们要减去u和u父亲都被包含
设F[x][i]表示以x为根的子树内,到x距离<=i的连通块个数
那么我们有
F
[
x
]
[
i
]
=
∏
y
是
x
儿
子
(
F
[
y
]
[
i
−
1
]
+
1
)
F[x][i]=\prod_{y是x儿子}(F[y][i-1]+1)
F[x][i]=∏y是x儿子(F[y][i−1]+1)
再设G[x][i]表示排除x为根的子树,到x距离<=i的连通块个数+1
我们可以推出
G
[
y
]
[
i
]
=
G
[
x
]
[
i
−
1
]
∗
∏
z
!
=
y
,
z
是
x
儿
子
(
F
[
z
]
[
i
−
2
]
+
1
)
G[y][i]=G[x][i-1]*\prod_{z!=y,z是x儿子}(F[z][i-2]+1)
G[y][i]=G[x][i−1]∗∏z!=y,z是x儿子(F[z][i−2]+1)
那么答案就是
∑
i
=
1
n
(
F
[
i
]
[
L
]
∗
G
[
i
]
[
L
]
)
k
−
∑
i
=
2
n
(
F
[
i
]
[
L
−
1
]
∗
(
G
[
i
]
[
L
]
−
1
)
)
k
\sum_{i=1}^{n}(F[i][L]*G[i][L])^k-\sum_{i=2}^{n}(F[i][L-1]*(G[i][L]-1))^k
∑i=1n(F[i][L]∗G[i][L])k−∑i=2n(F[i][L−1]∗(G[i][L]−1))k
因为到i和i父亲的距离都要<=L那么到i的距离就是<=L-1
这个东西的状态数只和深度有关,考虑长链剖分
F很好算,只需要区间乘区间加,用线段树即可
考虑如何算G,我们需要一个O(∑轻儿子深度)的做法
可以分重链和轻链的转移,重链直接算
轻链的话注意到我们的答案只和G[x][L]有关,考虑只保留和转移有关的G[x][L]
对于一个轻子树,如果其往下的深度为d,那么可以转移出去的范围只有[L-d,L]这个区间,于是我们只维护这个范围内的Dp值,同样也是可以直接做
因为算G的时候要用到F所以线段树要可持久化
总复杂度O(n log n),可以获得80分
满分留坑代填(咕咕咕
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define rep(i,a) for(int i=lst[a];i;i=nxt[i])
using namespace std;
typedef long long ll;
int read() {
char ch;
for(ch=getchar();ch<'0'||ch>'9';ch=getchar());
int x=ch-'0';
for(ch=getchar();ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x;
}
const int N=1e6+5,M=N<<6,Mo=998244353;
int pwr(int x,int y) {
int z=1;
for(;y;y>>=1,x=(ll)x*x%Mo)
if (y&1) z=(ll)z*x%Mo;
return z;
}
int t[N<<1],nxt[N<<1],lst[N],l;
void link(int x,int y) {t[++l]=y;nxt[l]=lst[x];lst[x]=l;}
int n,L,k,son[N],dep[N],mx[N];
namespace Segment_Tree{
int rt_f[N],rt_g[N],ls[M],rs[M],ad[M],ml[M],bel[M],now,tot;
int newnode(int v) {
if (now==bel[v]) return v;
bel[++tot]=now;
ad[tot]=ad[v];ml[tot]=ml[v];
ls[tot]=ls[v];rs[tot]=rs[v];
return tot;
}
void add(int &v,int z) {v=newnode(v);(ad[v]+=z)%=Mo;}
void mult(int &v,int z) {v=newnode(v);ml[v]=(ll)ml[v]*z%Mo;ad[v]=(ll)ad[v]*z%Mo;}
void down(int v) {
if (ml[v]!=1) {
mult(ls[v],ml[v]);mult(rs[v],ml[v]);
ml[v]=1;
}
if (ad[v]) {
add(ls[v],ad[v]);add(rs[v],ad[v]);
ad[v]=0;
}
}
void Mult(int &v,int l,int r,int x,int y,int z) {
if (x>y) return;
v=newnode(v);
if (x<=l&&r<=y) {mult(v,z);return;}
int mid=l+r>>1;down(v);
if (x<=mid) Mult(ls[v],l,mid,x,y,z);
if (y>mid) Mult(rs[v],mid+1,r,x,y,z);
}
void Add(int &v,int l,int r,int x,int y,int z) {
if (x>y) return;
v=newnode(v);
if (x<=l&&r<=y) {add(v,z);return;}
int mid=l+r>>1;down(v);
if (x<=mid) Add(ls[v],l,mid,x,y,z);
if (y>mid) Add(rs[v],mid+1,r,x,y,z);
}
int Query(int v,int l,int r,int x) {
if (!v) return 0;
if (l==r) return ad[v];
int mid=l+r>>1,tmp=(x<=mid)?Query(ls[v],l,mid,x):Query(rs[v],mid+1,r,x);
return ((ll)tmp*ml[v]%Mo+ad[v])%Mo;
}
};
using namespace Segment_Tree;
void dfs(int x,int y) {
dep[x]=dep[y]+1;int k=-1;
rep(i,x)
if (t[i]!=y) {
dfs(t[i],x);
if (mx[t[i]]>k) k=mx[t[i]],son[x]=t[i];
}
mx[x]=(son[x])?mx[son[x]]+1:0;
}
void dfs_f(int x,int y) {
rep(i,x) if (t[i]!=y) dfs_f(t[i],x);
rt_f[x]=rt_f[son[x]];now=x;
Add(rt_f[x],1,n<<1,dep[x],dep[x]+mx[x],1);
rep(i,x)
if (t[i]!=y&&t[i]!=son[x]) {
fo(j,0,mx[t[i]]) {
int tmp=Query(rt_f[t[i]],1,n<<1,j+dep[t[i]])+1;
Mult(rt_f[x],1,n<<1,j+1+dep[x],j+1+dep[x],tmp);
}
int tmp=Query(rt_f[t[i]],1,n<<1,mx[t[i]]+dep[t[i]])+1;
Mult(rt_f[x],1,n<<1,mx[t[i]]+1+dep[x]+1,mx[x]+dep[x],tmp);
}
}
int a[N];
void solve(int x) {
int rt=0,Mx=-1;now=-1;
Add(rt,1,n<<1,1,n<<1,1);
fo(i,1,a[0]) {
if (a[i]==son[x]) {
now=-1;
int tr=rt_f[a[i]];
Add(tr,1,n<<1,dep[a[i]],dep[a[i]]+mx[x],1);
fo(j,0,Mx){
now=a[i];
int tmp=Query(rt,1,n<<1,j+dep[a[i]]);
Mult(rt_g[a[i]],1,n<<1,j+2+n+1-dep[a[i]],j+2+n+1-dep[a[i]],tmp);
now=-1;
Mult(tr,1,n<<1,j+dep[a[i]],j+dep[a[i]],tmp);
}
now=a[i];
int tmp=Query(rt,1,n<<1,Mx+dep[a[i]]);
Mult(rt_g[a[i]],1,n<<1,Mx+3+n+1-dep[a[i]],(n<<1)-dep[a[i]],tmp);
now=-1;
Mult(tr,1,n<<1,Mx+dep[a[i]]+1,mx[x]+dep[a[i]],tmp);
rt=tr;
Mx=max(Mx,mx[a[i]]);
} else {
now=a[i];
fo(j,max(0,L-mx[a[i]]-2),L-2) {
int tmp=Query(rt,1,n<<1,min(j,Mx)+dep[a[i]]);
Mult(rt_g[a[i]],1,n<<1,j+2+n+1-dep[a[i]],j+2+n+1-dep[a[i]],tmp);
}
now=-1;Mx=max(Mx,mx[a[i]]);
fo(j,0,mx[a[i]]) {
int tmp=Query(rt_f[a[i]],1,n<<1,j+dep[a[i]])+1;
Mult(rt,1,n<<1,j+1+dep[x],j+1+dep[x],tmp);
}
int tmp=Query(rt_f[a[i]],1,n<<1,mx[a[i]]+dep[a[i]])+1;
Mult(rt,1,n<<1,mx[a[i]]+1+dep[x]+1,mx[x]+dep[x],tmp);
}
}
}
void dfs_g(int x,int y) {
now=x;Add(rt_g[x],1,n<<1,n+1-dep[x],L+n+1-dep[x],1);
a[0]=0;rep(i,x) if (t[i]!=y) rt_g[t[i]]=rt_g[x],a[++a[0]]=t[i];
solve(x);reverse(a+1,a+a[0]+1);solve(x);
rep(i,x) if (t[i]!=y) dfs_g(t[i],x);
}
int deg[N];
int main() {
freopen("hope.in","r",stdin);
freopen("hope.out","w",stdout);
n=read();L=read();k=read();L=min(L,n-1);
fo(i,1,n-1) {
int x=read(),y=read();
link(x,y);link(y,x);
deg[x]++;deg[y]++;
}
if (!L) {printf("%d\n",n);return 0;}
bool is_chain=1;
fo(i,1,n) if (deg[i]>2) is_chain=0;
if (is_chain) {
int ans=0;
fo(i,1,n) {
int tmp=(ll)min(i,L+1)*min(n-i+1,L+1)%Mo;
(ans+=pwr(tmp,k))%=Mo;
}
fo(i,2,n) {
int t1=min(n-i+1,L)%Mo,t2=min(i,L+1)%Mo;
(ans-=pwr((ll)t1*(t2-1)%Mo,k))%=Mo;
}
printf("%d\n",(ans+Mo)%Mo);
return 0;
}
ml[0]=1;
dfs(1,0);dfs_f(1,0);dfs_g(1,0);
int ans=0;
fo(i,1,n) {
int tmp=Query(rt_f[i],1,n<<1,dep[i]+min(L,mx[i]));
tmp=(ll)tmp*Query(rt_g[i],1,n<<1,L+n+1-dep[i])%Mo;
(ans+=pwr(tmp,k))%=Mo;
}
fo(i,2,n) {
int tmp=Query(rt_f[i],1,n<<1,dep[i]+min(L-1,mx[i]));
tmp=(ll)tmp*(Query(rt_g[i],1,n<<1,L+1+n-dep[i])-1)%Mo;
(ans-=pwr(tmp,k))%=Mo;
}
printf("%d\n",(ans+Mo)%Mo);
return 0;
}