f[i][j][S]表示转移到第i行第j列状态为S的方案数
括号表示法
S是一个3进制数
0表示没有插头,1表示左括号,2表示右括号
括号表示法
S是一个3进制数
0表示没有插头,1表示左括号,2表示右括号
还是用滚动数组来实现
讲道理的话,这种题目还是直接看代码吧。
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<iostream>
#define maxn 10010
#define inf 1000000000
using namespace std;
int tot[2],hash[2][maxn],f[2][maxn],head[210],to[maxn],next[maxn];
int a[110][110],bit[110];
int now,pre,num,ans;
int n,m;
int find_pos(int s,int i)
{
return (s/(1<<bit[i-1]))%4;
}
void add(int s,int x)
{
int pos=s%200;
for (int p=head[pos];p;p=next[p])
if (hash[now][to[p]]==s)
{
f[now][to[p]]=max(f[now][to[p]],x);
return;
}
tot[now]++;
hash[now][tot[now]]=s;
f[now][tot[now]]=x;
to[++num]=tot[now];next[num]=head[pos];head[pos]=num;
}
int find_l(int s,int k)
{
int cnt=0;
for (int i=k;i>=1;i--)
{
int p=find_pos(s,i);
if (p==2) cnt++; else if (p==1) cnt--;
if (!cnt) return i;
}
return 0;
}
int find_r(int s,int k)
{
int cnt=0;
for (int i=k;i<=m+1;i++)
{
int p=find_pos(s,i);
if (p==1) cnt++; else if (p==2) cnt--;
if (!cnt) return i;
}
}
void dp()
{
now=1;pre=0;
tot[now]=1;
hash[now][1]=0;
f[now][1]=0;
for (int i=1;i<=n;i++)
{
for (int j=1;j<=tot[now];j++) hash[now][j]<<=2;
for (int j=1;j<=m;j++)
{
swap(now,pre);
for (int k=1;k<=tot[now];k++) f[now][k]=-inf;
tot[now]=0;num=0;
memset(head,0,sizeof(head));
for (int k=1;k<=tot[pre];k++)
{
int s=hash[pre][k],num=f[pre][k]+a[i][j];
int p=find_pos(s,j),q=find_pos(s,j+1);
if (p && q)
{
if (p==1 && q==1)
{
int r=find_r(s,j+1);
s-=(1<<bit[j-1])+(1<<bit[j])+(1<<bit[r-1]+1);s+=(1<<bit[r-1]);
add(s,num);
}
else if (p==2 && q==1)
{
s-=(1<<bit[j-1]+1)+(1<<bit[j]);
add(s,num);
}
else if (p==2 && q==2)
{
int l=find_l(s,j);
s-=(1<<bit[j-1]+1)+(1<<bit[j]+1)+(1<<bit[l-1]);s+=(1<<bit[l-1]+1);
add(s,num);
}
else
{
s-=(1<<bit[j-1])+(1<<bit[j]+1);
if (!s) ans=max(ans,num);
}
}
else if (p)
{
if (p==1)
{
s-=(1<<bit[j-1]);
if (j<m) add(s+(1<<bit[j]),num);
if (i<n) add(s+(1<<bit[j-1]),num);
}
else
{
s-=(1<<bit[j-1]+1);
if (j<m) add(s+(1<<bit[j]+1),num);
if (i<n) add(s+(1<<bit[j-1]+1),num);
}
}
else if (q)
{
if (q==1)
{
s-=(1<<bit[j]);
if (j<m) add(s+(1<<bit[j]),num);
if (i<n) add(s+(1<<bit[j-1]),num);
}
else
{
s-=(1<<bit[j]+1);
if (j<m) add(s+(1<<bit[j]+1),num);
if (i<n) add(s+(1<<bit[j-1]+1),num);
}
}
else
{
add(s,num-a[i][j]);
if (i<n && j<m) add(s+(1<<bit[j-1])+(1<<bit[j]+1),num);
}
}
}
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=1;i<=100;i++) bit[i]=i*2;
for (int i=1;i<=n;i++)
for (int j=1;j<=m;j++)
scanf("%d",&a[i][j]);
ans=-inf;
dp();
printf("%d\n",ans);
return 0;
}