【题目链接】
【思路要点】
- 基于连通性的的动态规划。
- 时间复杂度\(O(4^N*N*M)\)。
【代码】
#include<bits/stdc++.h> using namespace std; #define MAXN 105 #define MAXM 10 #define CURR 262144 #define HASH 2000005 #define MOD 100003 int n, m, limit, value[MAXN][MAXM]; int finalans, nowx, nowy, num[MAXM]; bool ending[CURR], visited[CURR]; int tnp[CURR][MAXM], txpe[CURR][MAXM], pjr[CURR][MAXM]; void visit(int curr) { if (visited[curr]) return; visited[curr] = true; int sum = 0; for (int i = 0; i <= m; i++) { tnp[curr][i] = curr & num[i]; txpe[curr][i] = tnp[curr][i] >> i * 2; sum += txpe[curr][i]; } ending[curr] = sum == 3; static int q[MAXM]; int top = 0; for (int i = 0; i <= m; i++) { if (txpe[curr][i] == 1) q[++top] = i; if (txpe[curr][i] == 2) { pjr[curr][i] = q[top]; pjr[curr][q[top]] = i; top--; } } } struct Hash_Table { int size, head[MOD], nxt[HASH]; int curr[HASH], value[HASH]; void init() { size = 0; memset(head, 0, sizeof(head)); } void modify(int x, int val) { if (x > limit) return; int tmp = x % MOD, p = head[tmp]; while (p) { if (curr[p] == x) { value[p] = max(value[p], val); visit(x); if (ending[x]) finalans = max(finalans, value[p]); return; } p = nxt[p]; } curr[++size] = x; value[size] = val; nxt[size] = head[tmp]; head[tmp] = size; visit(x); if (ending[x]) finalans = max(finalans, val); } } ans[2]; void Extend(int x, int y, int curr, int dest, int val) { visit(curr); int *tmp = tnp[curr], *type = txpe[curr], *pir = pjr[curr]; int tcurr = curr ^ tmp[y] ^ tmp[y - 1]; val += value[x][y]; if (type[y - 1] == 0 && type[y] == 0) { ans[dest].modify(curr, val - value[x][y]); ans[dest].modify(curr ^ (3 << 2 * y), val); ans[dest].modify(curr ^ (3 << 2 * (y - 1)), val); ans[dest].modify(curr ^ (1 << 2 * (y - 1)) ^ (2 << 2 * y), val); } if (type[y - 1] == 0 && type[y] == 1) { ans[dest].modify(curr, val); ans[dest].modify(tcurr ^ (1 << 2 * (y - 1)), val); ans[dest].modify(tcurr ^ (2 << 2 * pir[y]) ^ (3 << 2 * pir[y]), val); } if (type[y - 1] == 0 && type[y] == 2) { ans[dest].modify(curr, val); ans[dest].modify(tcurr ^ (2 << 2 * (y - 1)), val); ans[dest].modify(tcurr ^ (1 << 2 * pir[y]) ^ (3 << 2 * pir[y]), val); } if (type[y - 1] == 0 && type[y] == 3) { ans[dest].modify(curr, val); ans[dest].modify(tcurr ^ (3 << 2 * (y - 1)), val); } if (type[y - 1] == 1 && type[y] == 0) { ans[dest].modify(curr, val); ans[dest].modify(tcurr ^ (1 << 2 * y), val); ans[dest].modify(tcurr ^ (2 << 2 * pir[y - 1]) ^ (3 << 2 * pir[y - 1]), val); } if (type[y - 1] == 1 && type[y] == 1) { ans[dest].modify(tcurr ^ (2 << 2 * pir[y]) ^ (1 << 2 * pir[y]), val); } if (type[y - 1] == 1 && type[y] == 2) return; if (type[y - 1] == 1 && type[y] == 3) { ans[dest].modify(tcurr ^ (2 << 2 * pir[y - 1]) ^ (3 << 2 * pir[y - 1]), val); } if (type[y - 1] == 2 && type[y] == 0) { ans[dest].modify(curr, val); ans[dest].modify(tcurr ^ (2 << 2 * y), val); ans[dest].modify(tcurr ^ (1 << 2 * pir[y - 1]) ^ (3 << 2 * pir[y - 1]), val); } if (type[y - 1] == 2 && type[y] == 1) { ans[dest].modify(tcurr, val); } if (type[y - 1] == 2 && type[y] == 2) { ans[dest].modify(tcurr ^ (1 << 2 * pir[y - 1]) ^ (2 << 2 * pir[y - 1]), val); } if (type[y - 1] == 2 && type[y] == 3) { ans[dest].modify(tcurr ^ (1 << 2 * pir[y - 1]) ^ (3 << 2 * pir[y - 1]), val); } if (type[y - 1] == 3 && type[y] == 0) { ans[dest].modify(curr, val); ans[dest].modify(tcurr ^ (3 << 2 * y), val); } if (type[y - 1] == 3 && type[y] == 1) { ans[dest].modify(tcurr ^ (2 << 2 * pir[y]) ^ (3 << 2 * pir[y]), val); } if (type[y - 1] == 3 && type[y] == 2) { ans[dest].modify(tcurr ^ (1 << 2 * pir[y]) ^ (3 << 2 * pir[y]), val); } if (type[y - 1] == 3 && type[y] == 3) { if (tcurr == 0) finalans = max(finalans, val); } } void work(int x, int y, int now) { nowx = x, nowy = y; ans[now ^ 1].init(); int tx = x, ty = y + 1; if (ty > m) tx++, ty = 1; if (y == 1) { for (int i = 1; i <= ans[now].size; i++) if (ans[now].curr[i] << 2 <= limit) Extend(x, y, ans[now].curr[i] << 2, now ^ 1, ans[now].value[i]); } else { for (int i = 1; i <= ans[now].size; i++) Extend(x, y, ans[now].curr[i], now ^ 1, ans[now].value[i]); } if (x == n && y == m) return; else work(tx, ty, now ^ 1); } int main() { cin >> n >> m; for (int i = 1; i <= n; i++) for (int j = 1; j <= m; j++) cin >> value[i][j]; num[0] = 3; for (int i = 1; i <= m; i++) num[i] = num[i - 1] << 2; limit = (1 << (m + 1) * 2) - 1; finalans = 0; ans[0].modify(0, 0); work(1, 1, 0); cout << finalans << endl; return 0; }