D - One to One
思路
我们先来考虑先不将 a [ i ] = − 1 a[i] = -1 a[i]=−1 的边连出去。
则此时我们的图会变成多个连通块,每个连通块是基环树或环或树。
其中如果一个连通块是树,则其中一定有一个节点的 a a a 值为 − 1 -1 −1 。
如果一个连通块是基环树或环,则其中没有节点的 a a a 值为 − 1 -1 −1 。
此时我们可以发现,对于所有的基环树和环他们的贡献都可以算出来,因为如果目前我们将一个 a a a 值为 − 1 -1 −1 的点连向基环树或环,则这个点所在的连通块的贡献则不会算进去,也并不会对基环树和环的贡献产生影响。
并且此时基环树和环已经确定了,并不会再连出去,所以我们可以先将所有基环树和环的贡献算出来。
接下来我们考虑树连通块的贡献。
我们知道如果一个树连通块连向基环树和环,则它们的贡献就可以不用计算。
如果不连向基环树和环,则最后这些树连通块一定会变成一个环。
于是我们可以把所有的树连通块看成一个点,并且点的权值为连通块的大小,然后我们考虑 D P DP DP 。
我们设 g [ i ] g[i] g[i] 表示形成长度为 i i i 的环的方案数。
则我们可以对于每个点 k k k 来更新 g g g 值。
则方程式为 g [ i ] = g [ i ] + g [ i − 1 ] ∗ s i z e [ k ] g[i] = g[i] + g[i - 1] * size[k] g[i]=g[i]+g[i−1]∗size[k] 。
我们求出来 g g g 后考虑如何计算答案。
我们可以对每个环算贡献。
我们知道一个长度为 i i i 的环的贡献为 g [ i ] ∗ n c n t − i ∗ ( n − 1 ) ! g[i] * n^{cnt - i} * (n-1)! g[i]∗ncnt−i∗(n−1)! ,其中 c n t cnt cnt 是树连通块的个数。
我们来看看为什么。
首先除了这 i i i 个点,其他我们可以随便连,它们的贡献后面再算,于是就有 n c n t − i n^{cnt - i} ncnt−i 种可能出现长度为 i i i 的环。
因为我们算贡献并没有考虑顺序的问题,于是我们考虑顺序,此时所有点的全排列是 n ! n! n! 的,又因为它可以旋转,所以要除上 n n n ,也就变为 ( n − 1 ) ! (n - 1)! (n−1)! 。
到此问题就结束了。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL MOD = 998244353;
int n, m, a[2005], fa[2005], flag[2005], num;
LL g[2005], size[2005], ans, mul[2005], sum[2005];
LL Pow(LL a, LL b) {
LL s = 1;
while (b) {
if (b & 1)
s = s * a % MOD;
a = a * a % MOD, b = b >> 1;
}
return s;
}
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
int main() {
scanf("%d", &n);
mul[0] = 1;
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]), fa[i] = i, sum[i] = 1, mul[i] = mul[i - 1] * i % MOD;
//mul[i]是i的阶乘。
for (int i = 1; i <= n; i++)
if (a[i] != -1) {
int fx = find(i), fy = find(a[i]);
if (fx != fy)
fa[fy] = fx, sum[fx] += sum[fy], sum[fy] = 0;
//用并查集维护连通块。
}
for (int i = 1; i <= n; i++)
if (a[i] == -1)
flag[find(i)] = 1, size[++m] = sum[find(i)];
//如果这个连通块里有点的a值为-1,则这个连通块是一个树连通块。
for (int i = 1; i <= n; i++)
if (!flag[i] && fa[i] == i)
num++;
//num是树连通块的数量
ans = num * Pow(n, m) % MOD;
//先把每个环或基环树的贡献计算上,其中pow(n, m)是所有可能的方案数。
g[0] = 1;
for (int i = 1; i <= m; i++)
for (int j = i; j; j--)
g[j] = (g[j] + g[j - 1] * size[i] % MOD) % MOD;
for (int i = 1; i <= m; i++)
ans = (ans + g[i] * mul[i - 1] % MOD * Pow(n, m - i) % MOD) % MOD;
printf("%lld", ans);
return 0;
}