HDU7191 Count Set 解题报告
题目大意
n n n 个数构成排列 p i ( 1 ≤ i ≤ n ) p_i(1\le i\le n) pi(1≤i≤n),计算基数为 k k k 的集合 T T T 的个数,其中 T T T 需满足 T ∩ P ( T ) = ∅ T\cap P(T)=\varnothing T∩P(T)=∅,其中 P ( T ) = { p i ∣ i ∈ T } P(T)=\{p_i|i\in T\} P(T)={pi∣i∈T}.
其中
1 ≤ n ≤ 5 × 1 0 5 , ∑ n ≤ 5 × 1 0 6 , 0 ≤ k ≤ n 1\le n\le 5\times 10^5,\sum n\le 5\times 10^6,0\le k\le n 1≤n≤5×105,∑n≤5×106,0≤k≤n
解题思路
首先,这样的 i → p i i\to p_i i→pi 的关系建成一个图,图中必然出现若干个大大小小的环,记它们的大小分别为 c 1 , c 2 , … , c m c_1,c_2,\dots, c_m c1,c2,…,cm。我们要做的事情就是在这些环中选取恰 k k k 个点。
容易发现,不同的环之间是独立的。在一个环中,要满足 T ∩ P ( T ) = ∅ T\cap P(T)=\varnothing T∩P(T)=∅,也就是一个环中选取的点不可相邻。这个问题只跟 c c c(环的大小)有关。
一个大小为 n n n 的环上选取 m m m 个点且不相邻,方案数如何计算呢?
这里有一个博客写得很好:不相邻问题 by hht2005。
先考虑链上的情况。 我们先把这不相邻的 k k k 个点放下来,再在最前、每两个球之间、最尾插入一些球,依次为 x 0 , x 1 , … x k x_0,x_1,\dots x_k x0,x1,…xk 共 k k k 个数,且:
x 0 + x 1 + ⋯ + x k = n − k x 0 ≥ 0 , x k ≥ 0 , x i ≥ 1 ( 1 ≤ i ≤ k − 1 ) x_0+x_1+\dots+x_k=n-k \\ x_0\ge 0, x_k\ge 0, x_i\ge 1(1\le i\le k-1) x0+x1+⋯+xk=n−kx0≥0,xk≥0,xi≥1(1≤i≤k−1)
变为
x 0 + x 1 + ⋯ + x k = n − 2 k + 1 x i ≥ 0 ( 0 ≤ i ≤ k ) x_0+x_1+\dots+x_k=n-2k+1 \\ x_i\ge 0(0\le i\le k) x0+x1+⋯+xk=n−2k+1xi≥0(0≤i≤k)
由插板法可知,方案数为 f ( n , k ) = ( n − 2 k + 1 + k + 1 − 1 k + 1 − 1 ) = ( n − k + 1 k ) f(n,k)=\binom{n-2k+1+k+1-1}{k+1-1}=\binom{n-k+1}{k} f(n,k)=(k+1−1n−2k+1+k+1−1)=(kn−k+1)
然后考虑环上的情况。 对于环上一点,若它选,则在剩下 n − 3 n-3 n−3 个点中选 k − 1 k-1 k−1 个,方案数为 f ( n − 3 , m − 1 ) = ( n − k − 1 k − 1 ) f(n-3,m-1)=\binom{n-k-1}{k-1} f(n−3,m−1)=(k−1n−k−1);若它不选,则在剩下 n − 1 n-1 n−1 个点中选择 k k k 个,方案数为 f ( n − 1 , k ) = ( n − k k ) f(n-1,k)=\binom{n-k}{k} f(n−1,k)=(kn−k)。总方案数为 g ( n , k ) = ( n − k − 1 k − 1 ) + ( n − k k ) g(n,k)=\binom{n-k-1}{k-1}+\binom{n-k}{k} g(n,k)=(k−1n−k−1)+(kn−k)。
于是,我们可以算出,大小为 c c c 的环上,选取 i ( 0 ≤ i ≤ c ) i(0\le i\le c) i(0≤i≤c) 个点时方案数为多少。
现在我们要合并答案,这 m m m 个环共取 k k k 个点方案数为多少?容易想到 OGF。
每个环的 OGF 就为
f i ( x ) = ∑ i = 0 c i g ( c i , i ) x i f_i(x)=\sum_{i=0}^{c_i}g(c_i,i)x^i fi(x)=i=0∑cig(ci,i)xi
最终答案就是
[ x k ] ∏ i = 1 m f i ( x ) [x^k]\prod_{i=1}^mf_i(x) [xk]i=1∏mfi(x)
如何将这 m m m 个多项式乘起来呢?容易发现,这些多项式总长度为 c 1 + c 2 + ⋯ + c m = n c_1+c_2+\dots+c_m=n c1+c2+⋯+cm=n,每次合并的代价为 O ( l e n log l e n ) ≤ O ( l e n log n ) O(len\log len)\le O(len\log n) O(lenloglen)≤O(lenlogn)(其中 l e n = c i + c j len=c_i+c_j len=ci+cj),故按照哈夫曼树合并即可,总时间复杂度 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
当然,用线段树一样的分治结构进行合并也是可以的,时间复杂度 O ( n log 2 n ) O(n\log^2n) O(nlog2n)。
/*
* @Author: rijuyuezhu
* @Date: 2022-08-02 14:42:15
* @Description: http://acm.hdu.edu.cn/contest/problem?cid=1048&pid=1007
* @Tag: 多项式,图论,生成函数
*/
#include<cstring>
#include<vector>
#include<queue>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
char In[1 << 20], *ss = In, *tt = In;
#define getchar() (ss == tt && (tt = (ss = In) + fread(In, 1, 1 << 20, stdin), tt == ss) ? EOF : *ss++)
ll read() {
ll x = 0, f = 1; char ch = getchar();
for(; ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') f = -1;
for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + int(ch - '0');
return x * f;
}
const int MAXN = 5e5 + 5;
const int P = 998244353;
namespace MINT {
struct mint {
int v;
mint(int v = 0) : v(v) {}
};
int MOD(int v) {return v >= P ? v - P : v;}
mint operator + (mint a, mint b) {return MOD(a.v + b.v);}
mint operator - (mint a, mint b) {return MOD(a.v - b.v + P);}
mint operator * (mint a, mint b) {return 1ll * a.v * b.v % P;}
mint qpow(mint a, int n=P-2) {mint ret = 1; for(; n; n >>= 1, a = a * a) if(n & 1) ret = ret * a; return ret;}
mint operator += (mint& a, mint b) {return a = a + b;}
mint operator -= (mint& a, mint b) {return a = a - b;}
mint operator *= (mint& a, mint b) {return a = a * b;}
} using namespace MINT;
namespace Poly {
const int MAXL = (1 << 19) + 5, Bas = 1 << 19;
typedef vector<mint> poly;
mint inv[MAXL], fac[MAXL], ifac[MAXL], _g[MAXL];
int tr[MAXL];
void init() {
inv[1] = 1; for(int i = 2; i < MAXL; i++) inv[i] = (P - P / i) * inv[P % i];
fac[0] = ifac[0] = 1;
for(int i = 1; i < MAXL; i++) fac[i] = fac[i-1] * i, ifac[i] = ifac[i-1] * inv[i];
_g[0] = 1; _g[1] = qpow(3, (P-1) / Bas);
for(int i = 2; i < Bas; i++) _g[i] = _g[i-1] * _g[1];
}
int glim(int n) {
int lim = 1; for(; lim < n; lim <<= 1);
return lim;
}
void DFT(poly& f, int lim) {
if((int)f.size() < lim) f.resize(lim);
for(int i = 0; i < lim; i++) tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
for(int i = 0; i < lim; i++) if(i < tr[i]) swap(f[i], f[tr[i]]);
for(int l = 2, k = 1; l <= lim; l <<= 1, k <<= 1)
for(int i = 0; i < lim; i += l)
for(int j = i; j < i+k; j++) {
mint tt = f[j+k] * _g[Bas / l * (j-i)];
f[j+k] = f[j] - tt;
f[j] = f[j] + tt;
}
}
void IDFT(poly& f, int lim) {
DFT(f, lim); reverse(f.begin()+1, f.begin()+lim);
for(int i = 1; i < lim; i++) f[i] *= inv[lim];
}
poly Mul(poly f, poly g) {
int n = f.size() + g.size() - 1, lim = glim(n);
DFT(f, lim), DFT(g, lim);
for(int i = 0; i < lim; i++) f[i] *= g[i];
IDFT(f, lim); f.resize(n); return f;
}
} using namespace Poly;
typedef pair<int, int> pr;
int upto[MAXN], dist[MAXN], n, k, p[MAXN], cnt[MAXN], _cnt;
poly F[MAXN], f, g;
int getup(int x) {
if(upto[x] == x) return x;
int f = getup(upto[x]);
dist[x] += dist[upto[x]];
return upto[x] = f;
}
mint C(int n, int m) {
if(n < 0 || m < 0 || n < m) return 0;
return fac[n] * ifac[m] * ifac[n-m];
}
void work() {
n = read(), k = read();
_cnt = 0;
for(int i = 1; i <= n; i++) p[i] = read(), upto[i] = i, dist[i] = 0;
for(int i = 1; i <= n; i++) {
int fx = getup(i), fy = getup(p[i]);
if(fx == fy) {
cnt[++_cnt] = dist[i] + dist[p[i]] + 1;
continue;
}
upto[fx] = fy;
dist[fx] = dist[p[i]] + 1;
}
priority_queue<pr, vector<pr>, greater<pr> > pq;
for(int i = 1; i <= _cnt; i++) {
int n = cnt[i];
F[i].resize(n+1);
for(int j = 0; j <= n; j++)
F[i][j] = C(n-j, j) + C(n-j-1,j-1);
pq.push(pr(F[i].size(), i));
}
while(pq.size() >= 2) {
int i = pq.top().second; pq.pop();
int j = pq.top().second; pq.pop();
F[i] = Mul(F[i], F[j]);
pq.push(pr(F[i].size(), i));
}
int t = pq.top().second;
printf("%d\n", F[t][k].v);
}
int main() {
init();
int T = read();
while(T--) work();
return 0;
}