题目链接:点击此处查看
题目大意:
给出
A、B、C、D(0≤A、B、C、D≤1018)
四个数,求满足
a+c>b+d、a+d≥b+c
的四元组
(a,b,c,d)
的个数,其中
0≤a≤A,0≤b≤B,0≤c≤C,0≤d≤D
。
题解:
这题有三种解法。一种是大胆瞎猜推测答案是关于
A、B、C、D
的多项式,该多项式最多有
24
项,因此设出各项的系数,用小用例解方程组即可。
第二种暴力推公式,也可以得到一个答案的多项式 - -…推不出来..
第三种是数位DP,这个数位DP比较神奇,没写过这种数位DP。但是该题也有比较明显的特点,是给出特定区间,统计满足条件的数(四元组)的个数。同时对四个数进行数位DP,需要记录的状态是每位
a+c−b−d
和
a+d−b−c
的值。因为
0≤a+b≤18
,所以如果某一层的
a+b−c−d≥2
或
a+b−c−d≤−2
,我们就可以认为后面的位无法再对正负产生影响,那么只需要
{−2,−1,0,1,2}
五个状态就好了。每层递归内部的复杂度是
104
,加上字符长度20、状态25和1000组用例..爆炸了。题目的办法是将数字数字转化为2进制再进行DP,从而将每层递归的复杂度降为
24
,这样的复杂度就可以接受了。
通常在写数位DP的时候,为了记忆化数据对每一次查询都能够使用,我们会设置一个fp标记到当前位为止是否是上界的值。在这里我们需要设置4个上界。4数字个都不为上界的情况实际上是比较少的。因此如果只记忆化非上界的值,会有大量的重复运算。在这里只能放弃记忆化对所有查询的贡献,而要多开四维状态来记录每个数是否处于上界,这四维状态可以压位处理。因为DP数组只记录特定的查询的DP值,所以每次查询的时候需要对DP数组重新初始化。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
#include <cstdlib>
#include <string>
#include <cmath>
#include <set>
#include <map>
#include <bitset>
using namespace std;
typedef long long ll;
const int mod = 1000000007;
const double eps = 1e-6;
const int inf = 0x3f3f3f3f;
const ll INF = 100000000000000000ll;
const int MAXN = 100010;
const int MAXM = 300030;
int d1[67],d2[67],d3[67],d4[67];
int dp[67][4][4][16];
int dfs(int len,int s1,int s2,int mask){
if(!len){
if(s1>0&&s2>=0) return 1;
return 0;
}
if(dp[len][s1+1][s2+1][mask] != -1) return dp[len][s1+1][s2+1][mask];
int a,b,c,d;
a = (mask&8)?d1[len]:1; b = (mask&4)?d2[len]:1; c = (mask&2)?d3[len]:1; d = (mask&1)?d4[len]:1;
int ret = 0;
for(int i = 0;i <= a;i++){
for(int j = 0;j <= b;j++){
for(int k = 0;k <= c;k++){
for(int l = 0;l <= d;l++){
int t1 = s1*2,t2 = s2*2,t = 0;
t1 += i+k-j-l;
t2 += i+l-j-k;
if(t1<-1||t2<-1){
continue;
}
if(t1>1) t1 = 2;
if(t2>1) t2 = 2;
if(i==a) t|=8; if(j==b) t|=4; if(k==c) t|=2; if(l==d) t|=1;
ret += dfs(len-1,t1,t2,mask&t); ret %= mod;
}
}
}
}
return dp[len][s1+1][s2+1][mask] = ret;
}
int f(ll a,ll b,ll c,ll d){
int len = 0;
while(a||b||c||d){
d1[++len] = a&1; d2[len] = b&1; d3[len] = c&1; d4[len] = d&1;
a>>=1; b>>=1; c>>=1; d>>=1;
}
return dfs(len,0,0,15);
}
int main()
{
int T;
//freopen("1011.in","r",stdin);
//freopen("out","w",stdout);
cin>>T;
ll a,b,c,d;
while(T--){
memset(dp,-1,sizeof(dp));
scanf("%I64d%I64d%I64d%I64d",&a,&b,&c,&d);
printf("%d\n",f(a,b,c,d));
//printf("cnt = %I64d nn = %I64d\n",cnt,nn);
}
return 0;
}