http://www.elijahqi.win/archives/3009
题意:给定n个环再给定k个人 求把k个人放到n个环上每个环至少一个人的方案数
直接暴力dp 65分 那么考虑正解
设A(x) 为当这个环大小为size时的生成函数 那么x^i前面的系数
A(x)=∑i=1size(sizei)
A
(
x
)
=
∑
i
=
1
s
i
z
e
(
s
i
z
e
i
)
然后直接合并即可 但是暴力合并复杂度是不对的 注意到合并大小为sa,sb循环的时候复杂度是(sa+sb)log(n)那么这个和哈夫曼编码的过程其实相似我们使用哈夫曼编码的顺序合并每次将最小的两个size合并一下 然后因为多项式次数不会超过n所以复杂度是
n∗log(n)2
n
∗
l
o
g
(
n
)
2
#include<queue>
#include<cmath>
#include<cstdio>
#include<cctype>
#include<cstring>
#include<algorithm>
#define pa pair<int,int>
using namespace std;
const int mod=998244353;
const int N=200000;
const int gg=3;
#define ll long long
inline char gc(){
static char now[1<<16],*S,*T;
if(T==S){T=(S=now)+fread(now,1,1<<16,stdin);if (T==S) return EOF;}
return *S++;
}
inline int read(){
int x=0,f=1;char ch=gc();
while(!isdigit(ch)) {if (ch=='-') f=-1;ch=gc();}
while(isdigit(ch)) x=x*10+ch-'0',ch=gc();
return x*f;
}
inline int ksm(int b,int t){static ll tmp;
tmp=1;for (;t;b=(ll)b*b%mod,t>>=1) if (t&1) tmp=tmp*b%mod;return tmp;
}
int R[N<<2],g[N],inv[N],fa[N],pend[N],k,m;bool flag[N];
inline int calc(int n,int m){
return (ll)g[n]*inv[m]%mod*inv[n-m]%mod;
}
inline void ntt(vector<int> &x,int f,const int &n){
for (int i=0;i<n;++i) if (i<R[i]) swap(x[i],x[R[i]]);
for (int i=1;i<n;i<<=1){
ll wn=ksm(gg,f==1?(mod-1)/(i<<1):mod-1-(mod-1)/(i<<1));
for (int j=0;j<n;j+=i<<1){ll w=1;
for (int k=0;k<i;++k,w*=wn,w%=mod){
int t1=x[j+k],t2=w*x[i+j+k]%mod;
x[j+k]=t1+t2>=mod?t1+t2-mod:t1+t2;
x[i+j+k]=t1-t2<0?t1-t2+mod:t1-t2;
}
}
}
}
vector<int> p[N];
inline void init(int n,int l){
for (int i=0;i<n;++i) R[i]=(R[i>>1]>>1)|(i&1)<<l-1;
}
int n;
priority_queue<pa,vector<pa>,greater<pa> >q;
int main(){
//freopen("a.in","r",stdin);
int T=read();g[0]=1;
for (int i=1;i<=180000;++i) g[i]=(ll)g[i-1]*i%mod;
inv[180000]=ksm(g[180000],mod-2);
for (int i=180000-1;~i;--i) inv[i]=(ll)inv[i+1]*(i+1)%mod;
while(T--){static int top,nn;
nn=n=read();k=read();memset(flag,0,sizeof(flag));
for (int i=1;i<=n;++i) fa[i]=read();top=0;
for (int i=1;i<=n;++i) {static int cnt,x;
if (flag[i]) continue;cnt=0;x=i;
while(!flag[x]) ++cnt,flag[x]=1,x=fa[x];
pend[++top]=cnt;
}
for (int i=1;i<=top;++i) {
static int size;size=pend[i];p[i].reserve(size+1);p[i].push_back(0);
for (int j=1;j<=size;++j) p[i].push_back(calc(size,j));
q.push(make_pair(size,i));
}static int id1,id2;
while(q.size()>=2){static int invv;
id1=q.top().second;m=q.top().first;q.pop();
id2=q.top().second;m+=q.top().first;q.pop();
for (n=1;n<=m;n<<=1);init(n,log2(n));
p[id1].resize(n<<1);p[id2].resize(n<<1);
ntt(p[id1],1,n);ntt(p[id2],1,n);invv=ksm(n,mod-2);
for (int i=0;i<n;++i) p[id1][i]=(ll)p[id1][i]*p[id2][i]%mod;
ntt(p[id1],-1,n);p[id2].clear();
for (int i=0;i<n;++i) p[id1][i]=(ll)invv*p[id1][i]%mod;
while(p[id1][m+1])++m;p[id1].resize(m+1);q.push(make_pair(m,id1));
}printf("%d\n",(ll)p[q.top().second][k]*ksm(calc(nn,k),mod-2)%mod);p[q.top().second].clear();q.pop();
}
return 0;
}