洛谷P4067:[SDOI2016]储能表 (数位DP)

题目传送门:https://www.luogu.org/problemnew/show/P4067


题目分析:一道令我心态爆炸的数位DP。一调调一天,WA不花一分钱

先说一下我理解的数位DP是什么。数位DP本质上还是个DP,它里面有很多重复的子问题。但现在题面给了DP的下标一个上界限制,而我们不能直接枚举下标,所以要贴着这个上界限制来DFS。形象地做个比喻就是走楼梯,DFS的时候要紧贴着楼梯走,楼梯下方的部分可以通过DP算出来。而这个DP,也可以用记忆化搜索代替。

这就导致数位DP有两种写法。一种是记当前这位是否受到题面给出的上界限制(0/1),然后DFS。DFS的时候顺便算出不受限制时的DP数组(0),并重复调用。另一种方法是预处理出不受限制时的DP数组f,这样DFS的时候就只会沿着上界限制向低位走,一旦不受限制了就直接调用f数组。在多组询问,问题相同且模数相同时,这个f数组是可以重复调用的。

再回到这题。经过一些推导,我们可以发现这题要求的就是x[0,n),y[0,m),xy[0,k]的(x,y)的组数及它们的异或和,其中是异或。那么这就相当于给出了一个三维的限制,对此我们可以在DFS的时候用三个0/1记录当前x是否受到n的上界限制,y是否受到m的上界限制,以及xy是否受到k的上界限制。转移到下一位的时候,枚举x,y这一位放什么数,看一下这一位的异或值和k的限制的关系即可。数位DP的代码一般细节很多,但只要跟暴力试几组数据没有问题,就基本上没问题了。

这道题我一开始不停地TLE最后两个点,还以为是常数问题,后来发现是自己的思路有问题。我一开始写code的时候,为了方便只记录了x和y是否受到上界限制,而没有记录k。然后我令DFS的状态强行贴着k走,当x和y的异或值不受k限制的时候,就用一个O(log(n))的时间暴力算出结果。这就相当于只对其中的两维进行数位DP,第三维暴力计算。结果我的时间复杂度比别人多了一个log(n)。后来我发现在第三维的计算中引入记忆化,就可以做到均摊O(1),这样就能A了。

在写code的时候要时刻注意取模,不然很容易爆long long。tututu比赛的时候写了个跟暴力对拍没问题的程序,结果因为long long的原因只拿了20分。


CODE:

#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;

const int maxl=63;
typedef long long LL;

#define P pair<long long,long long>
#define MP(x,y) make_pair((long long)x,(long long)y)

int t;
LL n,m,k,M;

P f[maxl][2][2];
P g[maxl][2][2];
P sum,Dec;

LL Min(LL x,LL y)
{
    if (x<y) return x;
    return y;
}

LL Max(LL x,LL y)
{
    if (x>y) return x;
    return y;
}

LL Mod(LL x)
{
    if (x>=M) return x-M;
    return x;
}

void Plus(P &x,P y)
{
    x.first=Mod(x.first+y.first);
    x.second=Mod(x.second+y.second);
}

P Calc(LL len,LL x,LL y)
{
    if (len==-2)
    {
        LL val=0;
        LL s=(1LL<<61);
        while (s)
        {
            LL p=(s<<1);
            LL n0,n1,m0,m1;

            n0=(x/p)*s+Min(s-1LL,x%p)+1LL;
            n1=(x/p)*s+Max(0LL,x%p-s+1LL);
            m0=(y/p)*s+Min(s-1LL,y%p)+1LL;
            m1=(y/p)*s+Max(0LL,y%p-s+1LL);

            n0%=M; n1%=M; m0%=M; m1%=M;
            val=(val+ (n0*m1+n1*m0)%M*(s%M) )%M;
            s>>=1;
        }
        x++;
        y++;
        return MP(val, (x%M)*(y%M)%M );
    }
    else
    {
        if (len==-1) return MP(0,1);
        if (g[len][x][y].first>=0) return g[len][x][y];
        g[len][x][y]=MP(0,0);

        LL s=(1LL<<len);
        bool fn=(n&s),fm=(m&s);
        int un=1,um=1;
        if (x&&(!fn)) un=0;
        if (y&&(!fm)) um=0;

        for (int i=0; i<=un; i++)
            for (int j=0; j<=um; j++)
            {
                bool nn=(x&&(i==un));
                bool mm=(y&&(j==um));
                bool v=i^j;

                P z=Calc(len-1,nn,mm);
                z.first=(z.first+ z.second*v*(s%M)%M )%M;
                Plus(g[len][x][y],z);
            }
        return g[len][x][y];
    }
}

P Work(int x,bool ln,bool lm)
{
    if (x==-1) return MP(0,1);
    if (f[x][ln][lm].first>=0) return f[x][ln][lm];
    f[x][ln][lm]=MP(0,0);

    LL s=(1LL<<x);
    bool fk=(k&s),fn=(n&s),fm=(m&s);
    bool un,um;
    if (!ln) un=1; else un=fn;
    if (!lm) um=1; else um=fm;

    for (int i=0; i<=un; i++)
        for (int j=0; j<=um; j++)
        {
            bool nn=(ln&&(i==un));
            bool mm=(lm&&(j==um));
            bool v=i^j;

            if (v<fk)
            {
                LL a=(nn? n%s:s-1LL);
                LL b=(mm? m%s:s-1LL);
                Plus(f[x][ln][lm], Calc(x-1,nn,mm) );
            }

            if (v==fk)
            {
                P y=Work(x-1,nn,mm);
                y.first=(y.first+ y.second*v*(s%M)%M )%M;
                Plus(f[x][ln][lm],y);
            }
        }

    return f[x][ln][lm];
}

int main()
{
    freopen("B.in","r",stdin);
    freopen("B.out","w",stdout);

    memset(f,-1,sizeof(f));
    memset(g,-1,sizeof(g));
    scanf("%d",&t);
    while (t--)
    {
        scanf("%lld%lld%lld%lld",&n,&m,&k,&M);
        if ( (!n) || (!m) )
        {
            printf("0\n");
            continue;
        }
        n--;
        m--;

        LL N=1;
        int Lg=0;
        while ( N<=n+2 || N<=m+2 || N<=k+2 ) N<<=1,Lg++;
        N<<=1,Lg++;

        sum=Calc(-2,n,m);
        Dec=Work(Lg,1,1);

        for (int i=0; i<=Lg; i++)
            for (int j=0; j<2; j++)
                for (int w=0; w<2; w++)
                    f[i][j][w].first=g[i][j][w].first=-1;

        LL ans=Mod(sum.first-Dec.first+M);
        ans=Mod(ans- (sum.second-Dec.second+M)*(k%M)%M +M);
        printf("%lld\n",ans);
    }

    return 0;
}
发布了160 篇原创文章 · 获赞 76 · 访问量 10万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 技术黑板 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览