AtCoder Grand Contest 019E: Shuffle and Swap 题解

非常好的dp+组合题
这个版本的做法参考了tourist的editorial
我们不考虑两个序列的random shuffle,而是考虑这样的两个操作
1. 确定a序列和b序列的匹配方法
2. 确定这些匹配方法的出现顺序
我们考虑a序列和b序列匹配好以后,在A序列里面每个ai向bi连一条有向边
我们发现A序列的每个位置只有三种情况
1. 有某个a对应没有b对应,这样这个点只有出边
2. 由某个a对应也有某个b对应,这样这个点有入边也有出边
3. 没有a对应有某个b对应,这样这个点只有入边
设1类点有e个,2类点有m个,因为A和B中1的个数相等,所以3类点也有e个
考虑这个图本身,我们发现这个图一定是若干个环和若干个链组成的
我们发现组成环的这些位置所对应的A序列中的位置的数一定都是1,所以环当中边出现的顺序是任意的
我们发现链中匹配边出现的顺序有且只有一种,因为一条链中对应的A序列中的值,只有链尾是0,其他都是1,对应B序列中只有链头是0,其他都是1,所以边
一定要按照从后向前的顺序出现
环中的点都是2类点,链头和链尾都是1类和3类点,链中间的点是2类点
我们考虑怎样把2类点分到e条链和若干个环中
令dp[i][j]表示已经考虑到将j个点放入i条链的方案数,dp[i][j]=k=0jdp[i1][k](jk+1)!(分母上的阶乘的意义在后面解释)
最后把dp[e][0~m]的答案加起来,再乘上e!m!(e+m)!
e!指e条链的链头和链尾配对,有e!种配对方法
m!指m个2类点的连接顺序,比如说一条链的点确定了,但这条链的连法有阶乘种
我们还要考虑边的出现顺序,所有的出现顺序是(e+m)!但是每条链的出现顺序只有一种所以要除以若干个(u+1)!,这个在算dp的时候已经除过了
这样就有了一个O(n3)的做法
考虑优化
我们发现dp的转移方程是一个卷积的形式,k+(j-k+1)=j+1,所以每层的转移可以NTT优化,复杂度降到O(n2logn)
然后我们发现每次乘的多项式都是一样的,都是f(x)=i=0mxi(i+1)!,所以可以快速幂+NTT,复杂度O(nlog2n)

#include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <utility>
#include <cctype>
#include <algorithm>
#include <bitset>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <cmath>
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair<int,int>
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;

const int MOD=998244353;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;
const double pi=3.14159265;
const int G=3;

inline int getint()
{
    char ch;int res;bool f;
    while (!isdigit(ch=getchar()) && ch!='-') {}
    if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
    while (isdigit(ch=getchar())) res=res*10+ch-'0';
    return f?res:-res;
}

int inv[200048];
LL finv[200048],fac[200048];
int n;
char s1[100048],s2[100048];
int e,m;

inline void init_inv()
{
    int i;
    fac[0]=fac[1]=inv[0]=inv[1]=finv[0]=finv[1]=1;
    for (i=2;i<=100000;i++)
    {
        fac[i]=(fac[i-1]*i)%MOD;
        inv[i]=MOD-((long long)(MOD/i)*inv[MOD%i])%MOD;
        finv[i]=(finv[i-1]*inv[i])%MOD;
    }
} 

inline LL quick_pow(LL x,LL y)
{
    x%=MOD;LL res=1;
    while (y)
    {
        if (y&1) res=(res*x)%MOD,y--;
        x=(x*x)%MOD;y>>=1;
    }
    return res;
}

int len;
LL wn_pos[100048],wn_neg[100048];
inline void init_wn()
{
    for (register int clen=2;clen<=len;clen<<=1)
    {
        wn_pos[clen]=quick_pow(G,(MOD-1)/clen);
        wn_neg[clen]=quick_pow(G,(MOD-1)-(MOD-1)/clen);
    }
}

LL a[100048],b[100048];

inline void NTT(LL c[],int fl)
{
    int i,j,k,clen;
    for (i=(len>>1),j=1;j<len;j++)
    {
        if (i<j) swap(c[i],c[j]);
        for (k=(len>>1);i&k;k>>=1) i^=k;
        i^=k;
    }
    for (clen=2;clen<=len;clen<<=1)
    {
        LL wn=(fl==1?wn_pos[clen]:wn_neg[clen]);
        for (j=0;j<len;j+=clen)
        {
            LL w=1;
            for (k=j;k<j+(clen>>1);k++)
            {
                LL tmp1=c[k],tmp2=(c[k+(clen>>1)]*w)%MOD;
                c[k]=(tmp1+tmp2)%MOD;c[k+(clen>>1)]=((tmp1-tmp2)%MOD+MOD)%MOD;
                w=(w*wn)%MOD;
            }
        }
    }
    if (fl==-1)
        for (i=0;i<len;i++) c[i]=(c[i]*inv[len])%MOD;
}

inline void calc_NTT()
{
    NTT(a,1);NTT(b,1);
    for (register int i=0;i<len;i++) a[i]=(a[i]*b[i])%MOD;
    NTT(a,-1);
}

struct poly
{
    LL A[100048];
    inline poly operator * (const poly B) const
    {
        int i;poly res;
        memset(a,0,sizeof(a));memset(b,0,sizeof(b));
        for (i=0;i<=m;i++) a[i]=A[i],b[i]=B.A[i];
        calc_NTT();
        for (i=0;i<=m;i++) res.A[i]=a[i];
        return res;
    }
};

inline poly Quick_pow(poly x,LL y)
{
    int i;poly res;
    for (i=0;i<=m;i++) x.A[i]%=MOD;
    res.A[0]=1;
    while (y)
    {
        if (y&1) res=res*x,y--;
        x=x*x;y>>=1;
    }
    return res;
}

int main ()
{
    int i,j,k;
    scanf("%s%s",s1+1,s2+1);n=strlen(s1+1);
    init_inv();
    e=m=0;
    for (i=1;i<=n;i++)
    {
        if (s1[i]=='1' && s2[i]=='0') e++;
        if (s1[i]=='1' && s2[i]=='1') m++;
    }
    len=1;while (len<=m*2) len<<=1;
    init_wn();
    poly ans;for (i=0;i<=m;i++) ans.A[i]=finv[i+1];
    ans=Quick_pow(ans,e);
    LL fans=0;
    for (i=0;i<=m;i++) fans=(fans+ans.A[i])%MOD;
    fans=fans*fac[e]%MOD*fac[m]%MOD*fac[e+m]%MOD;
    printf("%lld\n",fans);
    return 0;
}
阅读更多
想对作者说点什么?

博主推荐

换一批

没有更多推荐了,返回首页