2022杭电多校第五场1007(生成函数+启发式合并+ntt)

题目大意:

给定n,k以及一个长度为n的排列,让你选出一个长度为k的子集,假设子集为T,使得满足

|P(T)∩T|=0.P(T)={y|y=px,x∈T}.求方案数,答案模998244353

思路:

将问题转化,给定m个环,从m个环中选出k个元素的方案数,并且从环中选取的元素不相邻,

首先考虑从环中选取元素不相邻,设大小为n的环选取k个,

把n个数看成n个球,编号从1~n,对于编号为1的球,要么选要么不选

若选取编号为1的球,则编号为2和n的球必定不能选,则从剩下n-3个球中选取k-1个球,并且此时n-3个球不是环,将其中k-1个球看作隔板,隔开剩下n-3-(k-1)个球,方案数为C(n-k-1,k-1)

若不选取编号为1的球,同上则从n-1个球选取k个球,方案数为C(n-k,k)

总方案数为C(n-k-1,k-1)+C(n-k,k)

接下来解决从m个环选取k个的问题,首先要了解生成函数,对于一个序列,如果序列a有通项公式,那么它的普通生成函数的系数就是通项公式。

那么对于其中一个大小为n的环,它的生成函数就是

对于从m个环中选取k个,答案就是把他们的生成函数乘起来之后第k项的系数。

其中将函数乘起来的过程我用使用 ntt 来加速

我们选择用启发式合并来得到答案,具体步骤是从队列(队列中放的是环编号)选取两个长度较小的环,使它们的生成函数相乘(合并为一个),并把得到的新生成函数放回队列(放回的编号为选取的任意一个编号),循环此过程直至队列中的元素只剩下一个,输出这个元素(编号)对应的生成函数的第k项即为答案

代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define T() int tt;cin>>tt;while(tt--)
#define endl "\n"
const int mod=998244353;
const int N=2e6+10;
const int G=3,invG=332748118;//原根以及原根的逆元
int rev[N];//换位数组
/**************************************************************************************/
ll qpow(ll a,ll b)
{
    ll ans=1;
    while(b)
    {
        if(b&1)ans=ans*a%mod;
        a=a*a%mod;b>>=1;
    }
    return ans;
}
void NTT(ll *a,int op,int len)//ntt模板
{
	for(int i=0;i<len;i++)
	{
		if(i<rev[i])swap(a[i],a[rev[i]]);
	}

	for(int s=1;s<len;s<<=1)
	{
		ll wn=qpow(op==1?G:invG,(mod-1)/(s<<1));
		for(int i=0;i<len;i+=(s<<1))
		{
			ll w=1;
			for(int j=i;j<i+s;j++,w=w*wn%mod)
			{
				ll x=a[j],y=w*a[j+s]%mod;
				a[j]=(x+y)%mod;
				a[j+s]=((x-y)%mod+mod)%mod;
			}
		}
	}
	if(op==-1)
	{
		ll inv=qpow(len,mod-2);
		for(int i=0;i<len;i++)
		a[i]=a[i]*inv%mod;
	}
}
ll a[N],b[N];
ll jc[N],inv[N];
void init(int n)//预处理阶乘以及阶乘的逆元
{
    jc[0]=1;
    inv[0]=1;
    for(int i=1;i<=n;i++)jc[i]=jc[i-1]*i%mod;
    inv[n]=qpow(jc[n],mod-2);
    for(int i=n;i>=1;i--)inv[i-1]=inv[i]*i%mod;
}
vector< vector<ll> >vc;//储存生成函数的每一项
bool vis[N];
int pos[N];
ll C(int n,int m)
{
    if(n<0||n<m||m<0)return 0;
    return jc[n]*inv[m]%mod*inv[n-m]%mod;
}
struct node
{
    int pos;
    //写重载放入优先队列,使得编号对应的生成函数的数量从小到大排列
    bool operator <(const node &a)const
    {
        return vc[pos].size()>vc[a.pos].size();//由于是小根堆写大于号使得增序排列
    }
};
priority_queue<node>q;
int main()
{
    init(500005);//初始化
    vc=vector<vector<ll>>(500005);
    T()
    {
        while(q.size())q.pop();//优先队列清空防止影响答案
        int n,k;
        scanf("%d%d",&n,&k);
        memset(vis,0,sizeof(bool)*(n+1));//清空bool数组
        for(int i=1;i<=n;i++)scanf("%d",&pos[i]);

        int now=0;
        int tot=0;
        for(int i=1;i<=n;i++)
        {
            if(!vis[i])//遍历环,得到环的大小
            {
                int res=i;
                int sum=0;
                while(!vis[res])
                {
                    sum++;
                    vis[res]=1;
                    res=pos[res];
                }
                if(sum==1)continue;//若环的大小为1则不会影响答案直接跳过
                now++;
                tot+=sum/2;//记录能选的最多个数
                vc[now].clear();
                //存储生成函数的每一项
                for(int j=0;j<=sum/2;j++)vc[now].push_back((C(sum-j,j)+C(sum-j-1,j-1))%mod);
                q.push({now});//优先队列放入编号
            }
        }
        if(tot<k)
        {
            cout<<0<<endl;//如果能选取的最多个数仍然不足k,直接输出0
            for(int i=1;i<=n;i++)vc[i].clear();
            continue;
        }
        while(q.size()>1)//启发式合并,一直到队列元素只剩一个
        {
            auto t1=q.top();
            q.pop();
            auto t2=q.top();
            q.pop();
            int l1=vc[t1.pos].size()-1;
            int l2=vc[t2.pos].size()-1;
            int l=1;
            while(l<min(l1+l2,k)+1)l<<=1;
            for(int i=0;i<l;i++)
            {
                rev[i]=rev[i>>1]>>1|((i&1)?(l>>1):0);//ntt要用到的换位数组初始化
            }
            memset(a,0,sizeof (ll)*(l+1));//清空ntt要用的数组
            memset(b,0,sizeof (ll)*(l+1));
            //给ntt要用的数组赋值,没有上面的清空会影响答案
            for(int i=0;i<=min(l1,k);i++)a[i]=vc[t1.pos][i];
            for(int i=0;i<=min(l2,k);i++)b[i]=vc[t2.pos][i];
            NTT(a,1,l);NTT(b,1,l);
            for(int i=0;i<l;i++)a[i]=a[i]*b[i]%mod;
            //得到相乘后的数组
            NTT(a,-1,l);
            //将新的生成函数及编号放回优先队列和vector
            q.push(t1);
            vc[t1.pos].clear();
            for(int i=0;i<=min(l1+l2,k);i++)vc[t1.pos].push_back(a[i]);
        }
        //判定一下元素个数是否有k+1个(因为下标从0开始,访问第k个相当于vector中的第k+1个元素)
        if(vc[q.top().pos].size()>=k+1)cout<<vc[q.top().pos][k]<<endl;
        else cout<<0<<endl;
        vc[q.top().pos].clear();
        q.pop();
    }
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值