题意:给出一个 n × n n\times n n×n的矩阵,求所有子矩阵的 and \text{and} and值之和和 or \text{or} or值之和。
显然可以把每一位分开来求,那么对于某一位而言矩阵中的元素不是 0 0 0就是 1 1 1。一个仅由 0 0 0和 1 1 1构成的子矩阵的 and \text{and} and值是 1 1 1当且仅当这个子矩阵中全是 1 1 1,否则是 0 0 0;类似地, or \text{or} or值是 0 0 0当且仅当这个子矩阵中全是 0 0 0,否则是 1 1 1。
所以问题转化为求一个 0 − 1 0-1 0−1矩阵有多少个仅由 x x x构成的子矩阵,其中 x ∈ { 0 , 1 } x\in\{0,1\} x∈{0,1}。原题传送
以 x = 1 x=1 x=1为例:首先求出 r ( i , j ) r(i,j) r(i,j)表示第 i i i行第 j j j列这个位置开始往右最长连续 1 1 1的个数,这个递推非常简单。然后考虑 f ( i , j ) f(i,j) f(i,j)表示以 ( i , j ) (i,j) (i,j)为左下角的全 1 1 1子矩阵的个数。我们从右往左一列一列地做。对于第 i i i列第 j j j个位置,应该有 f ( j , i ) = f ( p , i ) + ( j − p ) × r ( j , i ) f(j,i)=f(p,i)+(j-p)\times r(j,i) f(j,i)=f(p,i)+(j−p)×r(j,i),其中 p p p是使得 r ( p , i ) < r ( j , i ) r(p,i)<r(j,i) r(p,i)<r(j,i)的最大的下标。这个 p p p显然可以用单调栈维护。所以这里的复杂度 O ( n 2 ) O(n^2) O(n2)。另外 f f f数组可以去掉一维,因为每一列都是独立的。
总复杂度 O ( n 2 log max { a i j } ) O(n^2\log\max\{a_{ij}\}) O(n2logmax{aij})。
#include <cctype>
#include <cstdio>
#include <climits>
#include <algorithm>
template <typename T> inline void read(T& x) {
int f = 0, c = getchar(); x = 0;
while (!isdigit(c)) f |= c == '-', c = getchar();
while (isdigit(c)) x = x * 10 + c - 48, c = getchar();
if (f) x = -x;
}
template <typename T, typename... Args>
inline void read(T& x, Args&... args) {
read(x); read(args...);
}
template <typename T> void write(T x) {
if (x < 0) x = -x, putchar('-');
if (x > 9) write(x / 10);
putchar(x % 10 + 48);
}
template <typename T> inline void writeln(T x) { write(x); puts(""); }
template <typename T> inline bool chkmin(T& x, const T& y) { return y < x ? (x = y, true) : false; }
template <typename T> inline bool chkmax(T& x, const T& y) { return x < y ? (x = y, true) : false; }
const int mod = 1e9 + 7;
const int maxn = 1007;
int a[maxn][maxn];
bool b[maxn][maxn];
int r[maxn][maxn];
int f[maxn], stk[maxn], tp;
int n, mx, as, os;
inline int count(bool x) {
for (int i = 1; i <= n; ++i)
for (int j = n; j; --j)
r[i][j] = b[i][j] == x ? r[i][j + 1] + 1 : 0;
int ans = 0;
for (int i = n; i; --i) {
stk[tp = 0] = 0;
for (int j = 1; j <= n; ++j) {
while (tp && r[stk[tp]][i] >= r[j][i]) --tp;
int pos = stk[tp];
stk[++tp] = j;
f[j] = (f[pos] + 1ll * (j - pos) * r[j][i]) % mod;
if ((ans += f[j]) >= mod) ans -= mod;
}
}
return ans;
}
int main() {
read(n);
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j)
read(a[i][j]), chkmax(mx, a[i][j]);
int all = n * (n + 1) >> 1;
all = 1ll * all * all % mod;
for (int w = 0; (1ll << w) <= mx; ++w) {
for (int i = 1; i <= n; ++i)
for (int j = 1; j <= n; ++j)
b[i][j] = a[i][j] & (1 << w);
as = (as + 1ll * ((1ll << w) % mod) * count(1) % mod) % mod;
os = (os + 1ll * ((1ll << w) % mod) * ((all - count(0) + mod) % mod) % mod) % mod;
}
write(as); putchar(' '); write(os);
return 0;
}