正题
可怜的出题人(跟九条可怜没有关系)要给 n n n 个地方出题。
但是出题人太累了,他决定把以前给这些地方出过的题重新搬一搬。
这 n n n 个地方以 1 , … , n 1,\dots,n 1,…,n 编号。出题人总结出了他们之间的联系,是一个树形。如果出题人把以前给第 i i i 个地方出的题搬到第 j j j 个地方,那么他希望 i i i 和 j j j 在这棵树上的距离越长越好。
于是出题人需要决定一个排列
p
1
,
…
,
p
n
p_1,\dots,p_n
p1,…,pn ,要求最大化
S
=
∑
i
=
1
n
d
i
s
(
i
,
p
i
)
S=\sum_{i=1}^n \mathrm{dis}(i,p_i)
S=∑i=1ndis(i,pi)
这个
S
S
S 越大,出题人可能受到的损失(被指出是原题于是被上某乎或扣工资等)就越小。
有时候出题人还想知道在最大化
S
S
S 的前提下自己有多少种搬题的方案。即有多少排列
p
1
,
…
,
p
n
p_1,\dots,p_n
p1,…,pn,使得
S
S
S 最大。
由于这个数可能非常大,你只需要输出其对 998244353998244353 取模的结果。
首先考虑如何最大化
S
S
S。
我们只需要找到重心作为根,求出
2
∗
∑
i
=
1
n
d
e
p
i
2*\sum_{i=1}^n dep_i
2∗∑i=1ndepi即为答案,不会有更大的答案,因为可以对于每一条边考虑,贡献的权值最多为
2
∗
min
(
s
z
[
x
]
,
n
−
s
z
[
x
]
)
2*\min (sz[x],n-sz[x])
2∗min(sz[x],n−sz[x]),其中该条边所连接的远根点为
x
x
x,在上面的方案中达到了最大值。
方案数显然就是要在重心为根时,将重心的每一个儿子的子树点放到其他的儿子子树中。
显然可以容斥,记
S
S
S为选择放在当前子树的集合,
a
i
a_i
ai表示第
i
i
i棵子树有多少个选择放在当前子树,
s
i
s_i
si表示第
i
i
i棵子树有多少个点。
则有答案式子
∑
S
(
−
1
)
∣
S
∣
(
n
−
∣
S
∣
)
!
∏
i
=
1
m
C
s
i
a
i
P
s
i
a
i
\sum_{S} (-1)^{|S|} (n-|S|)!\prod_{i=1}^m C_{s_i}^{a_i}P_{s_i}^{a_i}
S∑(−1)∣S∣(n−∣S∣)!i=1∏mCsiaiPsiai
后面使用Dp或者NTT来优化即可。
#include<bits/stdc++.h>
using namespace std;
const int N=200010,mod=998244353;
struct edge{
int y,nex;
}s[N<<1];
int first[N],len=0,n,t,sz[N],mmin=1e9,pos=0,mmax[N],fac[N],inv[N];
int f[N];
void ins(int x,int y){
s[++len]=(edge){y,first[x]};first[x]=len;
}
int ksm(int x,int t){
int tot=1;
while(t){
if(t&1) tot=1ll*tot*x%mod;
x=1ll*x*x%mod;
t/=2;
}
return tot;
}
void dfs(int x,int fa){
sz[x]=1;
for(int i=first[x];i!=0;i=s[i].nex) if(s[i].y!=fa){
dfs(s[i].y,x);
sz[x]+=sz[s[i].y];
mmax[x]=max(mmax[x],sz[s[i].y]);
}
if(max(mmax[x],n-sz[x])<mmin) mmin=max(mmax[x],n-sz[x]),pos=x;
}
long long tt=0;
void dfs_2(int x,int fa,int dep){
sz[x]=1;tt+=dep;
for(int i=first[x];i!=0;i=s[i].nex) if(s[i].y!=fa){
dfs_2(s[i].y,x,dep+1);
sz[x]+=sz[s[i].y];
}
}
int down(int x,int y){
return 1ll*fac[x]*inv[x-y]%mod;
}
int main(){
scanf("%d %d",&n,&t);
int x,y;
for(int i=1;i<n;i++) scanf("%d %d",&x,&y),ins(x,y),ins(y,x);
dfs(1,0);dfs_2(pos,0,0);printf("%lld\n",tt*2);if(t==1) return 0;
fac[0]=1;for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%mod;
inv[n]=ksm(fac[n],mod-2);for(int i=n-1;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod;
int tot=0;f[0]=1;
for(int i=first[pos];i!=0;i=s[i].nex){
int m=sz[s[i].y];
for(int j=tot;j>=0;j--)
for(int k=m;k>=1;k--)
f[j+k]=(f[j+k]+1ll*f[j]*down(m,k)%mod*down(m,k)%mod*inv[k])%mod;
tot+=m;
}
int ans=0,t=1;
for(int j=0;j<=tot;j++)
ans=(ans+1ll*t*fac[n-j]%mod*f[j])%mod,t=(t==1?mod-1:1);
printf("%d\n",ans);
}