题目链接:点击这里
题目大意:
给定
A
,
B
,
K
,
W
A,B,K,W
A,B,K,W 四个整数,求满足如下条件的二元组
(
x
,
y
)
(x,y)
(x,y) 的个数:
- x , y x,y x,y 是整数
- x ≤ A , y ≤ B x\le A,y\le B x≤A,y≤B
- ∣ x − y ∣ ≤ K |x-y|\le K ∣x−y∣≤K
- x x o r y ≤ W x\ xor\ y\le W x xor y≤W
题目分析:
数位
d
p
dp
dp ,定义状态
d
p
[
x
−
y
+
k
≤
0
]
[
y
−
x
+
k
≤
0
]
[
x
≤
A
]
[
y
≤
B
]
[
x
x
o
r
y
≤
W
]
dp[x-y+k\le 0][y-x+k\le 0][x\le A][y\le B][x\ xor\ y\le W]
dp[x−y+k≤0][y−x+k≤0][x≤A][y≤B][x xor y≤W]
后三个状态直接按照数位
d
p
dp
dp 的常规处理方式直接处理即可。
主要解释前两个状态(为了满足
∣
x
−
y
∣
≤
K
|x-y|\le K
∣x−y∣≤K )的处理:
设
v
1
,
v
2
v1,v2
v1,v2 分别记录
x
−
y
+
k
,
y
−
x
+
k
x-y+k,y-x+k
x−y+k,y−x+k ,取值为
−
1
,
0
,
1
-1,0,1
−1,0,1
在此主要解释
v
1
,
v
2
v1,v2
v1,v2 为何只需定义这三种取值,由于我们是全部转化为
2
2
2 进制来进行数位
d
p
dp
dp 的,所以
x
,
y
,
K
x,y,K
x,y,K 只可能取
0
0
0 或
1
1
1 ,那么有
−
1
≤
x
−
y
+
k
≤
2
,
−
1
≤
y
−
x
+
k
≤
2
-1\le x-y+k \le 2,-1\le y-x+k\le 2
−1≤x−y+k≤2,−1≤y−x+k≤2 ,而我们转移是按二进制从高位向低位转移的,所以转移时要乘进制
2
2
2 ,接下来我们对状态
v
1
v1
v1 分情况讨论一下(
v
2
v2
v2 同理):
- 若上个状态的 v 1 v1 v1 小于等于 − 2 -2 −2 ,那么乘 2 2 2 后,有的 v 1 ≤ − 4 v1\le -4 v1≤−4 ,那么由于 − 1 ≤ x − y + k ≤ 2 -1\le x-y+k \le 2 −1≤x−y+k≤2 显然之后无论如何也到不了合法状态 v 1 = 0 v1=0 v1=0 了
- 若上个状态的 v 1 v1 v1 等于 − 1 -1 −1 ,则乘 2 2 2 后有 v 1 = − 2 v1=-2 v1=−2 ,是有机会在之后达到 v 1 = 0 v1=0 v1=0 的
- 若上个状态的 v 1 v1 v1 等于 0 0 0 ,此时就是合法状态
- 若上个状态的 v 1 v1 v1 大于等于 1 1 1 ,那么乘 2 2 2 后有 v 1 ≥ 2 v1 \ge 2 v1≥2 ,那么之后无论如何都能满足条件了,故将其状态压缩为 1 1 1 即可
具体细节见代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<set>
#include<map>
#define ll long long
#define inf 0x3f3f3f3f
using namespace std;
int read()
{
int res = 0,flag = 1;
char ch = getchar();
while(ch<'0' || ch>'9')
{
if(ch == '-') flag = -1;
ch = getchar();
}
while(ch>='0' && ch<='9')
{
res = (res<<3)+(res<<1)+(ch^48);//res*10+ch-'0';
ch = getchar();
}
return res*flag;
}
const int maxn = 1e5+5;
const int mod = 1e9+7;
const double pi = acos(-1);
const double eps = 1e-8;
int A[maxn],B[maxn],K[maxn],W[maxn];
ll dp[32][3][3][2][2][2];
ll dfs(int pos,int v1,int v2,int f1,int f2,int f3)
{
v1 = min(v1,1); v2 = min(v2,1);
if(v1 < -1 || v2 < -1) return 0;
if(pos == -1) return v1>=0 && v2>=0;
if(dp[pos][v1+1][v2+1][f1][f2][f3] != -1)
return dp[pos][v1+1][v2+1][f1][f2][f3];
int upa = f1 ? A[pos] : 1;
int upb = f2 ? B[pos] : 1;
int upw = f3 ? W[pos] : 1;
ll res = 0;
for(int i = 0;i <= upa;i++)
for(int j = 0;j <= upb;j++)
{
if((i^j) > upw) continue;
res += dfs(pos-1,v1*2+i-j+K[pos],v2*2+j-i+K[pos],f1&&i==A[pos],f2&&j==B[pos],f3&&(i^j)==W[pos]);
}
return dp[pos][v1+1][v2+1][f1][f2][f3] = res;
}
ll solve(int a,int b,int k,int w)
{
memset(dp,-1,sizeof(dp));
for(int i = 0;i <= 30;i++)
{
A[i] = a&1;a >>= 1;
B[i] = b&1;b >>= 1;
K[i] = k&1;k >>= 1;
W[i] = w&1;w >>= 1;
}
return dfs(30,0,0,1,1,1);
}
int main()
{
int t = read();
while(t--)
{
int a = read(),b = read(),k = read(),w = read();
printf("%lld\n",solve(a,b,k,w));
}
return 0;
}