带有独立插头的dp,状态用4进制数表示,掌握位运算技巧还是蛮好写的,虽然很慢。。。
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<iostream>
#define maxn 100010
#define inf 1000000000
using namespace std;
int n,m,now,pre,num,ans;
int hash[2][maxn],head[10010],to[maxn],next[maxn],tot[2],f[2][maxn];
int a[110][110],bit[50];
void add(int s,int x)
{
int pos=s%10000;
for (int p=head[pos];p;p=next[p])
if (hash[now][to[p]]==s)
{
f[now][to[p]]=max(f[now][to[p]],x);
return;
}
tot[now]++;
hash[now][tot[now]]=s;
f[now][tot[now]]=x;
to[++num]=tot[now];next[num]=head[pos];head[pos]=num;
}
int find_pos(int s,int i)
{
return (s/(1<<bit[i-1]))%4;
}
int find_l(int s,int k)
{
int cnt=0;
for (int i=k;i>=1;i--)
{
int p=find_pos(s,i);
if (p==2) cnt++;
else if (p==1) cnt--;
if (!cnt) return i;
}
}
int find_r(int s,int k)
{
int cnt=0;
for (int i=k;i<=m+1;i++)
{
int p=find_pos(s,i);
if (p==1) cnt++;
else if (p==2) cnt--;
if (!cnt) return i;
}
}
void dp()
{
now=1;pre=0;
tot[now]=1;
hash[now][1]=0;
f[now][1]=0;
for (int i=1;i<=n;i++)
{
for (int j=1;j<=tot[now];j++) hash[now][j]<<=2;
for (int j=1;j<=m;j++)
{
swap(now,pre);
for (int k=1;k<=tot[now];k++) f[now][k]=-1;
num=0;tot[now]=0;
memset(head,0,sizeof(head));
for (int k=1;k<=tot[pre];k++)
{
int s=hash[pre][k],num=f[pre][k]+a[i][j];
if (s>=(1<<bit[m+1])) continue;
int p=find_pos(s,j),q=find_pos(s,j+1),e=s-(p<<bit[j-1])-(q<<bit[j]);
if (!p && !q)
{
add(e,num-a[i][j]);
add(e+(1<<bit[j-1])+(2<<bit[j]),num);
add(e+(3<<bit[j-1]),num);
add(e+(3<<bit[j]),num);
}
else if (!p)
{
if (q==1)
{
add(e+(1<<bit[j-1]),num);
add(e+(1<<bit[j]),num);
add(e^(1<<bit[find_r(s,j+1)-1]),num);
}
else if (q==2)
{
add(e+(2<<bit[j-1]),num);
add(e+(2<<bit[j]),num);
add(e^(2<<bit[find_l(s,j+1)-1]),num);
}
else
{
add(e+(3<<bit[j-1]),num);
add(e+(3<<bit[j]),num);
if (!e) ans=max(ans,num);
}
}
else if (!q)
{
if (p==1)
{
add(e+(1<<bit[j-1]),num);
add(e+(1<<bit[j]),num);
add(e^(1<<bit[find_r(s,j)-1]),num);
}
else if (p==2)
{
add(e+(2<<bit[j-1]),num);
add(e+(2<<bit[j]),num);
add(e^(2<<bit[find_l(s,j)-1]),num);
}
else
{
add(e+(3<<bit[j-1]),num);
add(e+(3<<bit[j]),num);
if (!e) ans=max(ans,num);
}
}
else if (p==1 && q==1) add(e^(3<<bit[find_r(s,j+1)-1]),num);
else if (p==2 && q==2) add(e^(3<<bit[find_l(s,j)-1]),num);
else if (p==2 && q==1) add(e,num);
else if (p==3 && q==1) add(e^(1<<bit[find_r(s,j+1)-1]),num);
else if (p==3 && q==2) add(e^(2<<bit[find_l(s,j+1)-1]),num);
else if (p==1 && q==3) add(e^(1<<bit[find_r(s,j)-1]),num);
else if (p==2 && q==3) add(e^(2<<bit[find_l(s,j)-1]),num);
else if (p==3 && q==3)
{
if (!e) ans=max(ans,num);
}
}
}
}
}
int main()
{
for (int i=1;i<=20;i++) bit[i]=i*2;
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
for (int j=1;j<=m;j++)
{
scanf("%d",&a[i][j]);
ans=max(ans,a[i][j]);
}
dp();
printf("%d\n",ans);
return 0;
}