【概率DP+常系数线性齐次递推+NTT】BZOJ4944 NOI2017泳池

版权声明:这是蒟蒻的BLOG,神犇转载也要吱一声哦~ https://blog.csdn.net/Dream_Lolita/article/details/82186397

【题目大意】
给定一块底边长为n,高度为1001的矩形,矩形的每个格子有q的概率是安全的,1q的概率是危险的。一个子矩形是合法的当且仅当这个子矩形的下底边贴着大矩形的底边且子矩形内所有格子都是安全的。问最大合法子矩形的面积为k的概率是多少。n1e9,k1000.

【解题思路】
听这题听说了几次了,刚好没有什么安排就来做一下。下面的推导次数界可能会有点小问题,不过不太影响。

首先题目要求的是最大安全矩形面积恰好k的概率,我们可以计算最大子矩形面积不超过k和不超过k1的答案做差。

gi,j表示高度为i,长度为j的海域都是安全的,剩下部分未知(最大子矩形面积k)的概率。
hi,j表示高度为i+1,长度为j的海域前i行都是安全的,且(i+1,j)这个位置是危险的,剩下部分未知(最大子矩形面积k)的概率。
我们有边界:

gk,1=qk(11),gi,0=1,hi,0=1

k1到1进行DP,对于(i,j)这个点,枚举i+1行的下一个危险格子在哪里,然后进行转移。
gi,j=k=0jhi,kgi+1,jkhi,j=k=0j1hi,kgi+1,jk1qi(1q)

因为第i行的宽度不会超过ki,即暴力DP的复杂度应该就是
i=1kki2=O(k2)

这部分预处理已经可以满足要求了。

下面考虑答案的计算。
fi为前i列最大子矩形k的概率,那么

fi=j=1kfijg1,j1(1q)

我们令ai=g1,i1(1q),那么这个就是一个常系数线性递推的形式
fi=j=1kajfij

然后矩阵快速幂是O(k3logk)的,我们考虑用特征多项式做到更优秀的复杂度。

我们矩阵乘法的转移矩阵为A,我们只关注An是什么而不关注中间的项。
我们将A看成变量,构造一个多项式B满足

An=i=0k1biAi

如果我们将初始矩阵设为St,我们将上面等式的两边同时乘上St,因为我们只关注最后矩阵的第0项,而上面这个等式在只取第0项的时候也是成立的,那么我们最终可以得到
Ans=i=0k1biSti

也就是说,事实上StB就是最后的答案。

于是现在我们要构造出这个B

如果我们将A写成A=Q(A)g(A)+R(A)的形式,我们钦点g(A)的次数为k,如果此时g(A)=0,则事实上R=B(因为R的次数一定比g低,我们可以将R写成上面的多个幂次的求和形式)。

现在问题就是求B=An(mod g(A)),因为是在模意义下,而我们知道Ak1次下的答案,所以这个是可以通过多项式快速幂,多项式取模来做到O(klogklogn)的,对于这题可以直接暴力取模O(k2logn)做到。

现在的问题是构造一个g出来。
根据Cayley-Hamilton定理,|λIA|是一个关于λk次多项式(I是单位矩阵),记为g(λ),且对于任意矩阵A都有g(A)=0
对于上面的这个式子,我们有一个结论:g(λ)=λki=1kaiλki,其中k是矩阵A的大小,ai就是A的第i项。
然后这个东西我们直接算就是O(k)的了。

因此整个计算答案的部分我们已经可以做到O(klogklogn),远低于题目要求的范围。
但是我们对g,hDP复杂度都已经是O(k2)的了,我们是不是很亏?所以我们要想办法优化前面的复杂度。

我们试着将两个DP结果写成生成函数的形式,设

Ai(x)=j0gi,jxjBi(x)=j0hi,jxjci=qi(1q)

那么

Ai(x)=Bi(x)Ai+1(x)Bi(x)=cixAi+1(x)Bi(x)+1Bi(x)=11cixAi+1(x)

于是对于k1到1行,我们每一行都可以用多项式求逆来计算当前DP值,所以复杂度是:

i=1kkilogki=O(klog2k)

于是最终我们得到了一个比较优秀的总做法,总复杂度是:O(klog2k+klogklogn),常数极大,未必跑得过将一个logk换成k的暴力(但实测确实复杂度优秀的算法跑得会快很多),于是我们可以自适应一下,在k比较小的时候暴力卷积,在比较大的时候用NTT应该会得到不错的效果。
当然谁会去这么无聊分开写呢。

【参考代码】

#include<bits/stdc++.h>
using namespace std;

typedef long long LL;
const int N=3e5+10,M=N*4;
const LL mod=998244353,g=3,inv2=(mod+1)>>1;

LL qpow(LL x,LL y) {LL ret=1;for(;y;y>>=1,x=x*x%mod)if(y&1)ret=ret*x%mod;return ret;}
void up(LL &x,LL y) {x+=y;if(x>=mod)x-=mod;if(x<0)x+=mod;}

namespace NTT
{
    int n,L,rev[N];
    LL w1[N],w2[N],d[N],e[N],Q[N],P[N];
    LL f[N],x[N],y[N],z[N];

    void init(int m)
    {
        for(n=1,L=0;n<m;n<<=1,++L);
        for(int i=2;i<=n;i<<=1) 
            w1[i>>1]=qpow(g,(mod-1)/i),w2[i>>1]=qpow(w1[i>>1],mod-2);
        rev[0]=0;
        for(int i=1;i<n;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
    }

    void ntt(LL *a,int f)
    {
        for(int i=0;i<n;++i) if(i>rev[i]) swap(a[i],a[rev[i]]);
        for(int i=1;i<n;i<<=1)
        {
            LL wn=(f==1?w1[i]:w2[i]);
            for(int j=0;j<n;j+=(i<<1))
            {
                LL w=1;
                for(int k=0;k<i;++k,w=w*wn%mod)
                {
                    LL x=a[j+k],y=w*a[i+j+k]%mod;
                    a[j+k]=(x+y)%mod;a[i+j+k]=(x-y+mod)%mod;
                }
            }
        }
        if(!~f) for(int i=0,inv=qpow(n,mod-2);i<n;++i) a[i]=a[i]*inv%mod;
    }

    void clear(LL *a,LL *b,int m)
    {
        for(int i=0;i<m;++i) a[i]=b[i];
        for(int i=m;i<n;++i) a[i]=0;
    }

    void inverse(LL *a,LL *b,int m)
    {
        if(m==1) {b[0]=qpow(a[0],mod-2);return;}
        inverse(a,b,m>>1); init(m<<1);
        clear(x,a,m);clear(y,b,m>>1);
        ntt(x,1);ntt(y,1);
        for(int i=0;i<n;++i) x[i]=y[i]*(2-x[i]*y[i]%mod)%mod,up(x[i],mod);
        ntt(x,-1);
        for(int i=0;i<m;++i) b[i]=x[i];
    }

    void module(LL *a,LL *b,LL *c,int n1,int n2)
    {
        int k=1; while(k<=n1-n2+1) k<<=1; k<<=1;
        for(int i=0;i<=n1;++i) d[i]=a[i];
        for(int i=0;i<=n2;++i) e[i]=b[i];
        reverse(d,d+n1+1); reverse(e,e+n2+1);
        for(int i=n1-n2+1;i<k;++i) d[i]=e[i]=0;
        inverse(e,f,k>>1);
        for(int i=n1-n2+1;i<k;++i) f[i]=0;
        init(k); ntt(d,1); ntt(f,1);
        for(int i=0;i<n;++i) e[i]=d[i]*f[i]%mod;
        ntt(e,-1);
        for(int i=0;i<=n1-n2;++i) c[i]=e[i];
        reverse(c,c+n1-n2+1);
    }

    void mul(LL *a,LL *b,LL *c,int n)
    {
        int k=1; while(k<=n) k<<=1; k<<=1; 
        for(int i=0;i<k;++i) Q[i]=P[i]=0;
        for(int i=0;i<=n;++i) Q[i]=a[i],P[i]=b[i];
        init(k);ntt(Q,1);ntt(P,1);
        for(int i=0;i<k;++i) Q[i]=Q[i]*P[i]%mod;
        ntt(Q,-1);
        for(int i=0;i<k;++i) P[i]=0;
        int n2=k-1; while(!Q[n2]) --n2;
        module(Q,c,P,n2,n);
        for(int i=0;i<n;++i) a[i]=Q[i];
        for(int i=0;i<(k>>1);++i) Q[i]=c[i];
        for(int i=(k>>1);i<k;++i) Q[i]=0;
        init(k); ntt(Q,1); ntt(P,1);
        for(int i=0;i<k;++i) Q[i]=Q[i]*P[i]%mod;
        ntt(Q,-1);
        for(int i=0;i<n;++i) up(a[i],-Q[i]); 
    }

    void powmod(LL *a,LL *b,LL *c,int m,int n)
    {
        if(!n) return;
        powmod(a,b,c,m,n>>1); 
        mul(a,a,c,m); if(n&1) mul(a,b,c,m);
    }
}

namespace SOL
{
    LL n,K,X,Y,q,q2,ans;
    LL fac1[N],fac2[N];
    LL g[2][N],h[N],fin[N];
    LL a[N],b[N],c[N],d[N],e[N],f[N];

    int DP(LL K)
    {
        int now=1,las=0;
        memset(g,0,sizeof(g));memset(h,0,sizeof(h));
        h[0]=1;g[0][0]=1;g[0][1]=q2*fac1[K]%mod;
        for(int i=K-1;i;--i,now^=1,las^=1)
        {
            LL dt=K/i,ct=q2*fac1[i]%mod,m=1;
            while(m<=dt) m<<=1;
            e[0]=1; for(int j=1;j<m;++j) e[j]=-ct*g[las][j-1];
            NTT::inverse(e,h,m); m<<=1;
            for(int j=dt+1;j<m;++j) h[j]=0;
            NTT::init(m); NTT::ntt(g[las],1); NTT::ntt(h,1);
            for(int j=0;j<m;++j) g[now][j]=g[las][j]*h[j]%mod;
            NTT::ntt(g[now],-1);
            for(int j=dt+1;j<m;++j) g[now][j]=0;
        }
        memset(a,0,sizeof(a));
        a[0]=1; for(int i=1;i<=K+1;++i) a[i]=-g[las][i-1]*q2%mod;
        return las;
    }

    LL solve(LL K)
    {
        if(K==0) return qpow(1-q+mod,n);
        NTT::init(K);int las=DP(K);

        LL ret=0,m=1,pw=n-K; while(m<=K+1) m<<=1;
        NTT::inverse(a,f,m<<1);
        if(n<=(K+1)<<1)
        {
            for(int i=0;i<=n && i<=K;++i) up(ret,f[n-i]*g[las][i]%mod);
            return ret;
        }

        memset(a,0,sizeof(a));memset(c,0,sizeof(c));memset(d,0,sizeof(d));
        a[K+1]=1; for(int i=0;i<=K;++i) a[i]=-g[las][K-i]*q2%mod,up(a[i],mod);
        if(K) c[1]=1; else c[0]=-a[0]; d[0]=1;
        NTT::powmod(d,c,a,K+1,pw); reverse(d,d+K+1);
        NTT::init(m<<2); NTT::ntt(d,1); NTT::ntt(f,1);
        for(int i=0;i<m<<2;++i) fin[i]=d[i]*f[i]%mod;
        NTT::ntt(fin,-1);
        for(int i=0,j=K<<1;i<=K;++i) up(ret,g[las][i]*fin[j-i]%mod);
        return ret;
    }   

    void Dream_Lolita()
    {
        scanf("%lld%lld%lld%lld",&n,&K,&X,&Y);
        q=X*qpow(Y,mod-2)%mod;q2=(Y-X)*qpow(Y,mod-2)%mod;
        fac1[0]=fac2[0]=1;
        for(int i=1;i<=K;++i) 
            fac1[i]=fac1[i-1]*q%mod,fac2[i]=fac2[i-1]*q2%mod;
        up(ans,solve(K));up(ans,-solve(K-1));
        printf("%lld\n",ans);
    }
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("BZOJ4944.in","r",stdin);
    freopen("BZOJ4944.out","w",stdout);
#endif
    SOL::Dream_Lolita();

    return 0;
}

【总结】
论常系数线性齐次递推中,多项式技巧是如何优化矩阵快速幂的。

展开阅读全文

没有更多推荐了,返回首页