【题目链接】
【思路要点】
- 行的取反情况与每一列的初始元素可以看做若干个小于\(2^{N}\)的二进制数。
- 注意到行很小,我们考虑先枚举每一行是否取反,令行的取反情况为\(Mask\)。
- 那么,每列的二进制数\(x_i\)应当变成\(x_i\ xor\ Mask\)。
- 令\(bits_i\)代表\(i\)的二进制表示1的个数,\(Mask\)的最小答案应当为\(\sum_{i=1}^{M}min(bits_{x_i\ xor\ Mask},N-bits_{x_i\ xor\ Mask})\)。
- 由此,我们得到了一个\(O(M*2^N)\)的算法。
- 注意到实际上我们只关心\(bits_{x_i\ xor\ Mask}\)而不关心\(x_i\ xor\ Mask\),不妨设\(dp_{k,Mask}\)表示\(bits_{x_i\ xor\ Mask}=k\)的\(i\)的数量。
- \(dp_{0,Mask}\)就是\(x_i=Mask\)的\(i\)的数量。
- 对于\(dp_{k,Mask}\)有贡献的\(x_i\),\(x_i\)与\(Mask\)有\(k\)位不同,考虑枚举不同的位\(p\),其对\(dp_{k,Mask}\)的贡献为\(\frac{1}{k}dp_{k-1,Mask\ xor\ 2^p}\),但在这里,我们重复统计了原本与\(Mask\)有\(k-2\)位不同,但在\(p\)处与\(Mask\)相同的\(x_i\),因此,我们需要将\(\frac{1}{k}dp_{k-2,Mask}\)减去,类似地,我们要加上\(\frac{1}{k}dp_{k-3,Mask\ xor\ 2^p}\),减去\(\frac{1}{k}dp_{k-4,Mask}\)……。
- 这样,我们就得到了一个\(O(N^3*2^N)\)的DP,稍加优化即可得到\(O(N^2*2^N)\)的复杂度,可以通过本题。
- 但实际上还有一种更快,也更容易的做法。
- 我们直接考虑答案数组\(Ans_{Mask}\),令\(Cnt_k\)表示\(x_i=k\)的\(i\)的个数,\(tmp_i=min(bits_i,N-bits_i)\)则有$$Ans_{Mask}=\sum_{i\ xor\ Mask=j}cnt_i*tmp_j=\sum_{i\ xor\ j=Mask}cnt_i*tmp_j$$
- 因此,直接用FWT将\(Cnt\)和\(tmp\)卷积即可,时间复杂度\(O(N*2^N)\)。
【代码】
/*DP Version O(N ^ 2 * 2 ^ N)*/ #include<bits/stdc++.h> using namespace std; const int MAXN = 25; const int MAXM = 100005; const int MAXS = 1 << 20; template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int n, m, dp[MAXN][MAXS]; int bit[MAXN], x[MAXM]; char s[MAXN][MAXM]; int sum[2][MAXS]; void chkmin(int &x, int y) {x = min(x, y); } int main() { read(n), read(m); for (int i = 1; i <= n; i++) scanf("\n%s", s[i] + 1); bit[1] = 1; for (int i = 2; i <= n; i++) bit[i] = bit[i - 1] << 1; for (int j = 1; j <= m; j++) { for (int i = 1; i <= n; i++) if (s[i][j] == '1') x[j] += bit[i]; dp[0][x[j]]++; sum[0][x[j]]++; } int u = (1 << n) - 1; for (int k = 1; k <= n; k++) { for (int s = 0; s <= u; s++) { int tans = 0; for (int i = 1; i <= n; i++) tans += sum[(k & 1) ^ 1][s ^ bit[i]] - sum[k & 1][s]; dp[k][s] = tans / k; } for (int s = 0; s <= u; s++) sum[k & 1][s] += dp[k][s]; } int ans = n * m; for (int s = 0; s <= u; s++) { int now = 0; for (int i = 0; i <= n; i++) now += min(i, n - i) * dp[i][s]; chkmin(ans, now); } writeln(ans); return 0; } /*FWT Version O(N * 2 ^ N)*/ #include<bits/stdc++.h> using namespace std; const int MAXN = 25; const int MAXM = 100005; const int MAXS = 1 << 20; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int n, m; char s[MAXN][MAXM]; int bit[MAXN], x[MAXM]; long long cnt[MAXS], bits[MAXS], res[MAXS]; void FWT(long long *a, int N) { for (int len = 2; len <= N; len <<= 1) for (int i = 0; i < N; i += len) for (int j = i, k = i + len / 2; k < i + len; j++, k++) { long long tmp = a[j], tnp = a[k]; a[j] = tmp + tnp; a[k] = tmp - tnp; } } void UFWT(long long *a, int N) { for (int len = 2; len <= N; len <<= 1) for (int i = 0; i < N; i += len) for (int j = i, k = i + len / 2; k < i + len; j++, k++) { long long tmp = a[j], tnp = a[k]; a[j] = (tmp + tnp) / 2; a[k] = (tmp - tnp) / 2; } } int main() { read(n), read(m); for (int i = 1; i <= n; i++) scanf("\n%s", s[i] + 1); bit[1] = 1; for (int i = 2; i <= n; i++) bit[i] = bit[i - 1] << 1; for (int j = 1; j <= m; j++) { for (int i = 1; i <= n; i++) if (s[i][j] == '1') x[j] += bit[i]; cnt[x[j]]++; } int u = 1 << n; for (int i = 1; i < u; i++) bits[i] = bits[i ^ (i & -i)] + 1; for (int i = 0; i < u; i++) chkmin(bits[i], n - bits[i]); FWT(cnt, u); FWT(bits, u); for (int i = 0; i < u; i++) res[i] = cnt[i] * bits[i]; UFWT(res, u); long long ans = n * m; for (int i = 0; i < u; i++) chkmin(ans, res[i]); writeln(ans); return 0; }