题面
题意
给出一棵树,每次随机选择一个叶子节点(可以重复选),将其染黑,问树上不经过黑点的最长链变短的期望染色次数是多少。
做法
这题的主要思路是用总代价除以总方案数。
树的直径有一个或两个必经点,当直径长度为奇数时,直径的中点即为必经点,可以将有希望成为直径上的叶子节点根据它属于必经点的哪棵子树进行分类,得到多个叶子集合,若直径长度为偶数,则有一条必经边,根据边将有可能成为直径上的叶子结点分成两个集合。
然后就可以发现直径变短时,有且只有一个集合中含有可能成为直径上的叶子结点没有被染黑,然后可以就可以枚举哪一个集合没有被染黑来计算总代价数。
当
m
m
m个叶子节点中有
x
x
x个集合中的点已经被染黑时,再染黑一个集合中的叶子结点的期望步数为
m
s
u
m
−
x
\frac{m}{sum-x}
sum−xm,可以据此来计算某个方案的代价。
令这棵树上共有
m
m
m个叶子节点,集合中的叶子数量总和为
s
u
m
sum
sum,此时考虑的集合大小为
u
u
u,染色结束后,这个集合中有
i
(
i
<
u
)
i(i<u)
i(i<u)个叶子结点已被染黑。
则这中情况的总方案数为
(
u
i
)
{u \choose i}
(iu)(表示从
u
u
u个点中选出
i
i
i个)
∗
(
s
u
m
−
u
+
i
−
1
)
!
∗
(
s
u
m
−
u
)
*(sum-u+i-1)!*(sum-u)
∗(sum−u+i−1)!∗(sum−u)(表示这
s
u
m
−
u
+
i
sum-u+i
sum−u+i个点被染黑的顺序种数,注意最后一个染黑的点不能在这个集合中)
∗
(
u
−
i
)
!
*(u-i)!
∗(u−i)!(表示把剩余点染黑的顺序种数,因为最后除的总方案是
s
u
m
!
sum!
sum!,所以可以看作完成任务后继续把所有点染黑)
花费的代价为
∑
i
=
u
−
i
+
1
s
u
m
m
i
\sum_{i=u-i+1}^{sum}\frac{m}{i}
∑i=u−i+1sumim,根据代价计算方式,不难理解,然后把两者相乘,也就是:
inline ll calc(ll u)
{
ll i,j,t,res=0;
for(i=0;i<u;i++)
{
t=C(u,i)%M*(sum-u)%M*jc[sum-u+i-1]%M*jc[u-i]%M;
t=t*need[sum-u+i]%M;
res+=t;
res%=M;
}
return res;
}
最后再除以总方案数 s u m ! sum! sum!即可。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
#define N 500100
#define M 998244353
using namespace std;
ll n,m,bb,mx,mm,ans,sum,first[N],fa[N],cnt[N],jc[N],ds[N],need[N];
bool gg[N];
struct Bn
{
ll to,next;
}bn[N<<1];
inline ll po(ll u,ll v)
{
ll res=1;
for(;v;)
{
if(v&1) res=res*u%M;
u=u*u%M;
v>>=1;
}
return res;
}
inline void add(ll u,ll v)
{
bb++;
bn[bb].to=v;
bn[bb].next=first[u];
first[u]=bb;
}
void dfs(ll now,ll last,ll dep)
{
ll p,q;
if(dep>mx) mx=dep,mm=now;
for(p=first[now];p!=-1;p=bn[p].next)
{
q=bn[p].to;
if(q==last) continue;
fa[q]=now;
dfs(q,now,dep+1);
}
}
ll Dfs(ll now,ll last,ll ned)
{
if(ned==1) return 1;
ll p,q,res=0;
for(p=first[now];p!=-1;p=bn[p].next)
{
q=bn[p].to;
if(q==last) continue;
res+=Dfs(q,now,ned-1);
}
return res;
}
inline ll C(ll u,ll v)
{
return jc[u]*po(jc[v],M-2)%M*po(jc[u-v],M-2)%M;
}
inline ll calc(ll u)
{
ll i,j,t,res=0;
for(i=0;i<u;i++)
{
t=C(u,i)%M*(sum-u)%M*jc[sum-u+i-1]%M*jc[u-i]%M;
t=t*need[sum-u+i]%M;
res+=t;
res%=M;
}
return res;
}
int main()
{
memset(first,-1,sizeof(first));
ll i,j,p,q,t;
cin>>n;
jc[0]=1;
for(i=1;i<=n;i++) jc[i]=jc[i-1]*i%M;
for(i=1;i<n;i++)
{
scanf("%lld%lld",&p,&q);
add(p,q),add(q,p);
ds[p]++,ds[q]++;
}
for(i=1;i<=n;i++) if(ds[i]==1) m++;
dfs(1,-1,0);
mx=0;
dfs(mm,-1,0);
for(i=mm,j=1;j<=mx/2;j++,i=fa[i]);
if(mx&1)
{
p=i,q=fa[i];
t=Dfs(q,p,mx/2+1);
sum+=t;
cnt[t]++;
t=Dfs(p,q,mx/2+1);
sum+=t;
cnt[t]++;
}
else
{
p=i;
t=Dfs(fa[p],p,mx/2);
sum+=t;
cnt[t]++;
for(i=first[p];i!=-1;i=bn[i].next)
{
if(bn[i].to==fa[p]) continue;
t=Dfs(bn[i].to,p,mx/2);
sum+=t;
cnt[t]++;
}
}
for(i=1;i<=n;i++) need[i]=need[i-1]+m*po(sum-i+1,M-2)%M,need[i]%=M;
for(i=1;i<=n;i++)
{
if(!cnt[i]) continue;
ans+=calc(i)*cnt[i]%M;
ans%=M;
}
cout<<ans*po(jc[sum],M-2)%M;
}