洛谷传送门
BZOJ传送门
题目描述
有一个
n
n
n 行
m
m
m 列的表格,行从
0
0
0 到
n
−
1
n−1
n−1 编号,列从
0
0
0 到
m
−
1
m−1
m−1 编号。每个格子都储存着能量。最初,第
i
i
i 行第
j
j
j 列的格子储存着
(
i
x
o
r
j
)
(i\ xor\ j)
(i xor j) 点能量。所以,整个表格储存的总能量是,
∑
i
=
0
n
−
1
∑
j
=
0
m
−
1
(
i
x
o
r
j
)
\sum_{i=0}^{n-1} \sum_{j=0}^{m-1} (i \mathrm{xor} j)
i=0∑n−1j=0∑m−1(ixorj)
随着时间的推移,格子中的能量会渐渐减少。一个时间单位,每个格子中的能量都会减少
1
1
1。显然,一个格子的能量减少到
0
0
0 之后就不会再减少了。
也就是说,
k
k
k 个时间单位后,整个表格储存的总能量是
∑
i
=
0
n
−
1
∑
j
=
0
m
−
1
m
a
x
(
(
i
x
o
r
j
)
−
k
,
0
)
\sum_{i=0}^{n-1} \sum_{j=0}^{m-1} \mathrm{max} ((i \mathrm{xor} j)-k,0)
i=0∑n−1j=0∑m−1max((ixorj)−k,0)
给出一个表格,求
k
k
k 个时间单位后它储存的总能量。
由于总能量可能较大,输出时对 p p p 取模。
输入输出格式
输入格式:
第一行一个整数 T T T,表示数据组数。接下来 T T T 行,每行四个整数 n n n、 m m m、 k k k、 p p p。
输出格式:
共 T T T 行,每行一个数,表示总能量对 p p p 取模后的结果
输入输出样例
输入样例#1:
3
2 2 0 100
3 3 0 100
3 3 1 100
输出样例#1:
2
12
6
数据范围
T ≤ 5000 , n ≤ 1 0 18 , m ≤ 1 0 18 , k ≤ 1 0 18 , p ≤ 1 0 9 T \le5000,n \leq 10 ^ {18},m \leq 10 ^ {18},k \leq 10 ^ {18},p \leq 10 ^ 9 T≤5000,n≤1018,m≤1018,k≤1018,p≤109 。
解题分析
太菜了, 居然看不出来这是道数位 D P DP DP…(虽然看到数据范围大概猜得到一点QAQ)
发现每一个数位的贡献是互不影响的, 所以我们可以分开计算。
这里还有一个 k k k的限制, 所以多记一维表示 k k k的限制是否还存在, 直接不算 ≤ k \le k ≤k的部分。
最后答案就是总的 > k >k >k的值的和减去个数*k。
代码如下:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <cctype>
#include <utility>
#include <algorithm>
#define R register
#define IN inline
#define W while
#define gc getchar()
#define pi std::pair<int, int>
#define ll long long
template <class T>
IN void in(T &x)
{
x = 0; R char c = gc;
for (; !isdigit(c); c = gc);
for (; isdigit(c); c = gc)
x = (x << 1) + (x << 3) + c - 48;
}
template <class T> IN T max(T a, T b) {return a > b ? a : b;}
template <class T> IN T min(T a, T b) {return a < b ? a : b;}
int T, mx, mod;
pi dp[65][2][2][2];
bool vis[65][2][2][2];
ll N, M, K;
IN void ad(int &x, R int v) {x += v; if (x >= mod) x -= mod;}
pi DFS(R int dgt, R bool n, R bool m, R bool k)
{
if (dgt > mx) return std::make_pair(1, 0);
if (vis[dgt][n][m][k]) return dp[dgt][n][m][k];
vis[dgt][n][m][k] = true;
int nlim = n ? ((N >> mx - dgt) & 1) : 1;
int mlim = m ? ((M >> mx - dgt) & 1) : 1;
int klim = k ? ((K >> mx - dgt) & 1) : 1;
pi ret;
for (R int i = 0; i <= nlim; ++i)
for (R int j = 0; j <= mlim; ++j)
{
if (k && klim > (i ^ j)) continue;
ret = DFS(dgt + 1, n && (i == nlim), m && (j == mlim), k && ((i ^ j) == klim));
ad(dp[dgt][n][m][k].first, ret.first);
ad(dp[dgt][n][m][k].second, (1ll * (1ll << mx - dgt) % mod * (i ^ j) * ret.first % mod + ret.second % mod) % mod);
}
return dp[dgt][n][m][k];
}
int main(void)
{
in(T); pi ret;
W (T--)
{
in(N), in(M), in(K), in(mod);
N--, M--, mx = 0;
ll buf = max(N, max(M, K));
mx = log2(buf) + 1;
std::memset(vis, false, sizeof(vis));
std::memset(dp, 0, sizeof(dp));
ret = DFS(1, 1, 1, 1);
printf("%d\n", (ret.second - 1ll * ret.first * (K % mod) % mod + mod) % mod);
}
}