【题目链接】
【思路要点】
- 先说这道题的正解:
- 将棋盘看做一张二分图,每一条边拆成两个点,分别属于二分图的一边。
- 我们需要做一件类似于匹配的事情,同一条边的两侧或是都没有管道,或是都有管道。
- 通过合适的建边我们能够用最小费用最大流来解决本题。
- 时间复杂度 O ( M i n C o s t F l o w ( N ∗ M , N ∗ M ) ) O(MinCostFlow(N*M,N*M)) O(MinCostFlow(N∗M,N∗M)) 。
- 但自从笔者从考场回来就一直在思考本题的插头DP做法的复杂度是不是能够被证明更优,因为我们发现实际上合法的插头数少之又少。
- 遗憾的是笔者并不能证明其更优的复杂度,但笔者将考场上的插头DP代码的状态开成了64位整型,就通过了本题,而且似乎比大部分费用流的做法要快。
- 因此笔者相信在本题中插头DP有着低于理论时间复杂度上界( O ( N M ∗ 2 N M ) O(NM*2^{\sqrt{NM}}) O(NM∗2NM) )的复杂度。
- 以下代码实现的是插头DP。
- 时间复杂度 O ( S ) O(S) O(S) ,其中 S S S 为合法的状态数。
- U p d : Upd: Upd: 更新了网络流做法。
【代码】
//Flow Version #include<bits/stdc++.h> using namespace std; const int MAXQ = 1e7 + 5; const int MAXP = 1e5 + 5; const int MAXN = 2e3 + 5; const int INF = 2e9; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } struct edge {int dest, flow, pos, cost; }; vector <edge> a[MAXP]; int n, m, s, t, tot, flow, cost, goal; int dist[MAXP], path[MAXP], home[MAXP]; int point[MAXN][MAXN][4], cnt[2]; void FlowPath() { int p = t, ans = INF; while (p != s) { ans = min(ans, a[path[p]][home[p]].flow); p = path[p]; } flow += ans; cost += ans * dist[t]; p = t; while (p != s) { a[path[p]][home[p]].flow -= ans; a[p][a[path[p]][home[p]].pos].flow += ans; p = path[p]; } } bool spfa() { static int q[MAXQ]; static bool inq[MAXP]; static int l = 0, r = 0; for (int i = 0; i <= r; i++) dist[q[i]] = INF; q[l = r = 0] = s, dist[s] = 0, inq[s] = true; while (l <= r) { int tmp = q[l]; for (unsigned i = 0; i < a[tmp].size(); i++) if (a[tmp][i].flow != 0 && dist[tmp] + a[tmp][i].cost < dist[a[tmp][i].dest]) { dist[a[tmp][i].dest] = dist[tmp] + a[tmp][i].cost; path[a[tmp][i].dest] = tmp; home[a[tmp][i].dest] = i; if (!inq[a[tmp][i].dest]) { q[++r] = a[tmp][i].dest; inq[q[r]] = true; } } l++, inq[tmp] = false; } return dist[t] != INF; } void addedge(int x, int y, int z, int c) { a[x].push_back((edge) {y, z, a[y].size(), c}); a[y].push_back((edge) {x, 0, a[x].size() - 1, -c}); } int main() { read(n), read(m); s = 0, t = tot = 1; for (int i = 1; i <= n; i++) for (int j = 1; j <= m; j++) { point[i][j][0] = ++tot; point[i][j][1] = ++tot; point[i][j][2] = ++tot; point[i][j][3] = ++tot; } //0 : up, 1 : right, 2 : down, 3 : left. for (int i = 1; i <= n; i++) for (int j = 1; j <= m; j++) { int x; read(x); int tmp = ++tot, Cnt = 0, tnp = x; int up = point[i][j][0]; int Right = point[i][j][1]; int down = point[i][j][2]; int Left = point[i][j][3]; while (tnp) {Cnt += tnp & 1, tnp >>= 1; } if ((i + j) & 1) { addedge(s, tmp, Cnt, 0); if (x & 1) addedge(tmp, up, 1, 0), cnt[0]++; if (x & 2) addedge(tmp, Right, 1, 0), cnt[0]++; if (x & 4) addedge(tmp, down, 1, 0), cnt[0]++; if (x & 8) addedge(tmp, Left, 1, 0), cnt[0]++; if (x == 1) { addedge(tmp, Right, 1, 1); addedge(tmp, down, 1, 2); addedge(tmp, Left, 1, 1); } if (x == 2) { addedge(tmp, up, 1, 1); addedge(tmp, down, 1, 1); addedge(tmp, Left, 1, 2); } if (x == 4) { addedge(tmp, up, 1, 2); addedge(tmp, Right, 1, 1); addedge(tmp, Left, 1, 1); } if (x == 8) { addedge(tmp, up, 1, 1); addedge(tmp, Right, 1, 2); addedge(tmp, down, 1, 1); } if (x == 14) { addedge(Right, up, 1, 1); addedge(down, up, 1, 2); addedge(Left, up, 1, 1); } if (x == 13) { addedge(up, Right, 1, 1); addedge(down, Right, 1, 1); addedge(Left, Right, 1, 2); } if (x == 11) { addedge(up, down, 1, 2); addedge(Right, down, 1, 1); addedge(Left, down, 1, 1); } if (x == 7) { addedge(up, Left, 1, 1); addedge(Right, Left, 1, 2); addedge(down, Left, 1, 1); } if (x == 3) { addedge(up, down, 1, 1); addedge(Right, Left, 1, 1); } if (x == 6) { addedge(down, up, 1, 1); addedge(Right, Left, 1, 1); } if (x == 12) { addedge(down, up, 1, 1); addedge(Left, Right, 1, 1); } if (x == 9) { addedge(up, down, 1, 1); addedge(Left, Right, 1, 1); } } else { addedge(tmp, t, Cnt, 0); if (x & 1) addedge(up, tmp, 1, 0), cnt[1]++; if (x & 2) addedge(Right, tmp, 1, 0), cnt[1]++; if (x & 4) addedge(down, tmp, 1, 0), cnt[1]++; if (x & 8) addedge(Left, tmp, 1, 0), cnt[1]++; if (x == 1) { addedge(Right, tmp, 1, 1); addedge(down, tmp, 1, 2); addedge(Left, tmp, 1, 1); } if (x == 2) { addedge(up, tmp, 1, 1); addedge(down, tmp, 1, 1); addedge(Left, tmp, 1, 2); } if (x == 4) { addedge(up, tmp, 1, 2); addedge(Right, tmp, 1, 1); addedge(Left, tmp, 1, 1); } if (x == 8) { addedge(up, tmp, 1, 1); addedge(Right, tmp, 1, 2); addedge(down, tmp, 1, 1); } if (x == 14) { addedge(up, Right, 1, 1); addedge(up, down, 1, 2); addedge(up, Left, 1, 1); } if (x == 13) { addedge(Right, up, 1, 1); addedge(Right, down, 1, 1); addedge(Right, Left, 1, 2); } if (x == 11) { addedge(down, up, 1, 2); addedge(down, Right, 1, 1); addedge(down, Left, 1, 1); } if (x == 7) { addedge(Left, up, 1, 1); addedge(Left, Right, 1, 2); addedge(Left, down, 1, 1); } if (x == 3) { addedge(down, up, 1, 1); addedge(Left, Right, 1, 1); } if (x == 6) { addedge(up, down, 1, 1); addedge(Left, Right, 1, 1); } if (x == 12) { addedge(up, down, 1, 1); addedge(Right, Left, 1, 1); } if (x == 9) { addedge(down, up, 1, 1); addedge(Right, Left, 1, 1); } } } for (int i = 1; i <= n; i++) for (int j = 1; j <= m; j++) { if ((i + j) & 1) { if (j != m) addedge(point[i][j][1], point[i][j + 1][3], 1, 0); if (i != n) addedge(point[i][j][2], point[i + 1][j][0], 1, 0); } else { if (j != m) addedge(point[i][j + 1][3], point[i][j][1], 1, 0); if (i != n) addedge(point[i + 1][j][0], point[i][j][2], 1, 0); } } goal = max(cnt[0], cnt[1]); for (int i = 0; i <= tot; i++) dist[i] = INF; while (spfa()) FlowPath(); if (flow != goal) printf("-1\n"); else printf("%d\n", cost); return 0; } //DP Version #include<bits/stdc++.h> using namespace std; const int MAXS = 1e7 + 5; const int MAXN = 2e3 + 5; const int P = 1e5 + 3; template <typename T> void read(T &x) { int f = 1; x = 0; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } struct Hash_Table { int val[MAXS], nxt[MAXS]; long long num[MAXS]; int size, head[P]; void init() { for (int i = 1; i <= size; i++) { int tmp = num[i] % P; head[tmp] = 0; } size = 0; } void insert(long long Num, int Val) { int tmp = Num % P; for (int p = head[tmp]; p; p = nxt[p]) if (num[p] == Num) { val[p] = min(val[p], Val); return; } size++; num[size] = Num; val[size] = Val; nxt[size] = head[tmp]; head[tmp] = size; } int query(long long Num) { int tmp = Num % P; for (int p = head[tmp]; p; p = nxt[p]) if (num[p] == Num) return val[p]; return -1; } } HT[2]; int cost[4] = {0, 1, 2, 1}; int ans, n, m, a[MAXN][MAXN]; void Extend(int x, int y, long long s, int v, int dest) { long long lft = s & (1ll << y - 1); long long upp = s & (1ll << y); long long ts = s ^ lft ^ upp; if (lft && upp) { int tmp = a[x][y]; for (int i = 0; i < 4; i++) { if ((tmp & 9) == 9) { long long ds = ts; if (tmp & 2) ds ^= 1ll << y; if (tmp & 4) ds ^= 1ll << y - 1; HT[dest].insert(ds, v + cost[i]); } if (tmp != 5 && tmp != 10) tmp = (tmp >> 1) + ((tmp & 1) << 3); } } if (lft && !upp) { int tmp = a[x][y]; for (int i = 0; i < 4; i++) { if ((tmp & 9) == 8) { long long ds = ts; if (tmp & 2) ds ^= 1ll << y; if (tmp & 4) ds ^= 1ll << y - 1; HT[dest].insert(ds, v + cost[i]); } if (tmp != 5 && tmp != 10) tmp = (tmp >> 1) + ((tmp & 1) << 3); } } if (!lft && upp) { int tmp = a[x][y]; for (int i = 0; i < 4; i++) { if ((tmp & 9) == 1) { long long ds = ts; if (tmp & 2) ds ^= 1ll << y; if (tmp & 4) ds ^= 1ll << y - 1; HT[dest].insert(ds, v + cost[i]); } if (tmp != 5 && tmp != 10) tmp = (tmp >> 1) + ((tmp & 1) << 3); } } if (!lft && !upp) { int tmp = a[x][y]; for (int i = 0; i < 4; i++) { if ((tmp & 9) == 0) { long long ds = ts; if (tmp & 2) ds ^= 1ll << y; if (tmp & 4) ds ^= 1ll << y - 1; HT[dest].insert(ds, v + cost[i]); } if (tmp != 5 && tmp != 10) tmp = (tmp >> 1) + ((tmp & 1) << 3); } } } void work(int x, int y, int now, int dest) { int tx = x, ty = y + 1; if (ty > m) ty = 1, tx++; HT[dest].init(); for (int i = 1; i <= HT[now].size; i++) if (y == 1) { long long tmp = HT[now].num[i]; if (tmp & (1ll << m)) continue; Extend(x, y, tmp << 1, HT[now].val[i], dest); } else Extend(x, y, HT[now].num[i], HT[now].val[i], dest); if (tx <= n) work(tx, ty, dest, now); } int main() { read(n), read(m); for (int i = 1; i <= n; i++) for (int j = 1; j <= m; j++) read(a[i][j]); if (n < m) { swap(n, m); for (int i = 1; i <= n; i++) for (int j = i + 1; j <= n; j++) swap(a[i][j], a[j][i]); for (int i = 1; i <= n; i++) for (int j = 1; j <= m; j++) { int tmp = a[i][j]; a[i][j] = ((tmp & 1) << 3) + ((tmp & 8) >> 3) + ((tmp & 2) << 1) + ((tmp & 4) >> 1); } } HT[0].insert(0, 0); work(1, 1, 0, 1); ans = HT[n * m % 2].query(0); printf("%d\n", ans); return 0; }