题目大意
对于一个长度为 n n n的整数序列 X = ( x 1 , x 2 , … x n ) X=(x_1,x_2,\dots x_n) X=(x1,x2,…xn),每个元素都在 1 1 1到 n n n之间,令 f ( X ) f(X) f(X)表示以下问题的答案:
- 有一个 n n n个顶点 n n n条边的无向图(可能有重边和自环),第 i i i条边连接 i i i和 X i X_i Xi,求联通块的数量
给一个正整数 n n n和一个长度为 n n n的序列 A = ( a 1 , a 2 … a n ) A=(a_1,a_2\dots a_n) A=(a1,a2…an),其每一个元素都在 1 1 1到 n n n之间,或者为 − 1 -1 −1。
你可以将每个值为 − 1 -1 −1的 a i a_i ai变为任意一个 1 1 1到 n n n之间的数,求所有情况下 f ( A ) f(A) f(A)的和。输出答案对 998244353 998244353 998244353取模。
题解
令 k k k表示 a i = − 1 a_i=-1 ai=−1的元素的个数。
我们可以先将 a i ≠ − 1 a_i\neq -1 ai=−1的边连上,那么现在图上的每一个连通块都是树或环或基环树。
如果是树的话,则这个连通块有且只有一个 a i = − 1 a_i=-1 ai=−1的点
如果是环或基环树的话,则这个连通块没有 a i = − 1 a_i=-1 ai=−1的点
我们可以先把环和基环树的贡献算出来,每个环或基环树的贡献为 n k n^k nk,因为不管怎么连,环或基环树都会有 1 1 1的贡献。那么如果有树向环或基环树连边,则这棵树不计算贡献。
树与环或基环树连边的贡献不需计算,那么我们只需要求树与树连边的贡献了。
因为每棵树只有一条边连出去,所以我们可以将每棵树看成一个点。
如果不连向环和基环树,那么这些树一定会形成一个环。对于一个顺序已确定的环,形成这样的环的方案数为 ∏ s i z i \prod siz_i ∏sizi。
我们考虑DP。设 f i f_i fi表示形成长度为 i i i的环的方案数,那么对于每个点 j j j,有转移式
f i = f i + f i − 1 × s i z k f_i=f_i+f_{i-1}\times siz_k fi=fi+fi−1×sizk
求出 f f f后我们考虑如何计算答案。对于所有长度为 i i i的环的贡献为 f i × ( i − 1 ) ! × n k − i f_i\times (i-1)!\times n^{k-i} fi×(i−1)!×nk−i。其中 ( i − 1 ) ! (i-1)! (i−1)!表示 i i i个点按不同顺序可以构成 ( i − 1 ) ! (i-1)! (i−1)!个不同的环, n k − i n^{k-i} nk−i表示其他 n − k n-k n−k个点可以任意连边。
这样问题就解决了,时间复杂度为 O ( n 2 ) O(n^2) O(n2)。
code
#include<bits/stdc++.h>
using namespace std;
int n,tot=0,vt=0,a[2005],d[5005],l[5005],r[5005],s[2005],z[2005],siz[2005];
long long ans,f[2005],jc[2005],mi[2005];
long long mod=998244353;
void add(int xx,int yy){
l[++tot]=r[xx];d[tot]=yy;r[xx]=tot;
}
void dfs(int u){
z[u]=1;siz[u]=1;
for(int i=r[u];i;i=l[i]){
if(!z[d[i]]){
dfs(d[i]);siz[u]+=siz[d[i]];
}
}
}
int main()
{
scanf("%d",&n);
jc[0]=mi[0]=1;
for(int i=1;i<=n;i++){
jc[i]=jc[i-1]*i%mod;
mi[i]=mi[i-1]*n%mod;
}
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
if(a[i]==-1) continue;
add(i,a[i]);add(a[i],i);
}
for(int i=1;i<=n;i++){
if(a[i]==-1){
dfs(i);s[++vt]=siz[i];
}
}
for(int i=1;i<=n;i++){
if(!z[i]){
dfs(i);ans=(ans+mi[vt])%mod;
}
}
f[0]=1;
for(int i=1;i<=vt;i++){
for(int j=i;j>=1;j--) f[j]=(f[j]+f[j-1]*s[i]%mod)%mod;
}
for(int i=1;i<=vt;i++){
ans=(ans+f[i]*jc[i-1]%mod*mi[vt-i]%mod)%mod;
}
printf("%lld",ans);
return 0;
}