题目链接
题意
求一个最大点权回路
思路
插头DP模板题多个权值,还是挺模板,加强理解。
- 枚举最右点,不能存在右插头。不然之后转移状态虽然没考虑但是会一直记录,然后炸时间和空间
- 形成一个连通分量只会出现一次
()
的情况,不要再push进状态
代码
#include <bits/stdc++.h>
using namespace std;
struct hash_table {
int hash_mod = 19991;
int state[20000], ans[20000], up;
int tot, first[20000], nxt[20000], w[20000];
void init() {
memset(first, 0, sizeof(first));
tot = 0;
up = 0;
}
int ins(int sta, int val) {
int key = sta%hash_mod;
for(int i = first[key]; i; i = nxt[i]) {
if(state[w[i]] == sta) return ans[w[i]] = max(ans[w[i]],val);
}
state[++up] = sta;
ans[up] = val;
nxt[++tot] = first[key];
w[tot] = up;
first[key] = tot;
return val;
}
}dp[2];
#define prel (1<<bit[j-1])
#define prer (1<<bit[j])
int n, m, a[105][10], bit[10];
void solve() {
int cur = 0, ans = -1e9;
dp[cur].init();
dp[0].ins(0,0);
for(int i = 1; i <= n; ++i) {
for(int j = 1; j <= dp[cur].up; ++j) dp[cur].state[j] <<= 2;
for(int j = 1; j <= m; ++j) {
cur ^= 1;
dp[cur].init();
for(int k = 1; k <= dp[cur^1].up; ++k) {
int sta = dp[cur^1].state[k];
int w = dp[cur^1].ans[k];
int d = (sta>>bit[j])&3;
int r = (sta>>bit[j-1])&3;
if(!r && !d) {
dp[cur].ins(sta,w);
if(j != m) dp[cur].ins(sta+prel+2*prer,w+a[i][j]);
}
else if(!r && d) {
if(j != m) dp[cur].ins(sta,w+a[i][j]);
dp[cur].ins(sta-d*prer+d*prel,w+a[i][j]);
}
else if(r && !d) {
dp[cur].ins(sta,w+a[i][j]);
if(j != m) dp[cur].ins(sta-r*prel+r*prer,w+a[i][j]);
}
else if(r == 1 && d == 1) {
int cnt = 1;
for(int p = j+1; p <= m; ++p) {
if(((sta>>bit[p])&3) == 1) ++cnt;
if(((sta>>bit[p])&3) == 2) --cnt;
if(!cnt) {
dp[cur].ins(sta-prel-prer-(1<<bit[p]),w+a[i][j]);
break;
}
}
}
else if(r == 2 && d == 2) {
int cnt = 1;
for(int p = j-2; p >= 0; --p) {
if(((sta>>bit[p])&3) == 1) --cnt;
if(((sta>>bit[p])&3) == 2) ++cnt;
if(!cnt) {
dp[cur].ins(sta-2*prel-2*prer+(1<<bit[p]),w+a[i][j]);
break;
}
}
}
else if(r == 2 && d == 1) {
dp[cur].ins(sta-prel*r-prer*d,w+a[i][j]);
}
else {
if(sta == prel+2*prer) ans = max(ans,w+a[i][j]);
// 形成一个连通分量,不能再push进去不然存在多个回路
// else dp[cur].ins(sta-prel-2*prer,w+a[i][j]);
}
}
}
}
printf("%d\n",ans);
}
int main() {
scanf("%d%d",&n,&m);
for(int i = 1; i <= n; ++i) for(int j = 1; j <= m; ++j) scanf("%d",&a[i][j]);
for(int i = 1; i <= m; ++i) bit[i] = i<<1;
solve();
return 0;
}