给出如下定义:
-
子矩阵:从一个矩阵当中选取某些行和某些列交叉位置所组成的新矩阵(保持行与 列的相对顺序)被称为原矩阵的一个子矩阵。
例如,下面左图中选取第 2、4 行和第 2、4、5 列交叉位置的元素得到一个 2*3 的子矩阵如右图所示。
-
相邻的元素:矩阵中的某个元素与其上下左右四个元素(如果存在的话)是相邻的。
-
矩阵的分值:矩阵中每一对相邻元素之差的绝对值之和。
本题任务:给定一个 n 行 m 列的正整数矩阵,请你从这个矩阵中选出一个 r 行 c 列的 子矩阵,使得这个子矩阵的分值最小,并输出这个分值。
![](http://codevs.cn/media/image/problem/3904.png)
第一行包含用空格隔开的四个整数 n,m,r,c,意义如问题描述中所述,每两个整数之间用一个空格隔开。
接下来的 n 行,每行包含 m 个用空格隔开的整数,用来表示问题描述中那个 n 行 m 列的矩阵。
输出共 1 行,包含 1 个整数,表示满足题目描述的子矩阵的最小分值。
样例输入1
5 5 2 3
9 3 3 3 9
9 4 8 7 4
1 7 4 6 6
6 8 5 6 9
7 4 5 6 1
样例输入2
7 7 3 3
7 7 7 6 2 10 5
5 8 8 2 1 6 2
2 9 5 5 6 1 7
7 9 3 6 1 7 8
1 9 1 4 7 8 8
10 5 9 1 1 8 10
1 3 1 5 4 8 6
样例输出1
6
样例输出2
16
对于 50%的数据,1 ≤ n ≤ 12, 1 ≤ m ≤ 12, 矩阵中的每个元素 1 ≤ a[i][j] ≤20;
对于 100%的数据,1 ≤ n ≤ 16, 1 ≤ m ≤ 16, 矩阵中的每个元素 1 ≤ a[i][j] ≤1000,1 ≤ r ≤ n, 1 ≤ c ≤ m。
时间限制:每一组测试数据1s。
【输入输出样例 1 说明】
该矩阵中分值最小的 2 行 3 列的子矩阵由原矩阵的第 4 行、第 5 行与第 1 列、第 3 列、 第 4 列交叉位置的元素组成,为
6 5 6
7 5 6
,其分值为 |6 − 5| + |5 − 6| + |7 − 5| + |5 − 6| + |6 − 7| + |5 − 5| + |6 − 6| = 6。
【输入输出样例 2 说明】
该矩阵中分值最小的 3 行 3 列的子矩阵由原矩阵的第 4 行、第 5 行、第 6 行与第 2 列、第 6 列、第 7 列交叉位置的元素组成,选取的分值最小的子矩阵为
9 7 8
9 8 8
5 8 10
题解:dfs+dp
先用dfs搜索出所有行的选择状态,然后枚举状态,预处理c1[i][j]表示选这两列相邻中选定的行的贡
献(左右),sum[i]表示选这一列中选定行的贡献(上下)。
然后就可以dp,f[i][j]表示第i列选j的最小值。
f[i][j]=min(f[i][j],f[i-1][k]+c[k][j]+sum[j])
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 20
using namespace std;
int n,m,r,c,ans,cnt;
int f[N][N],a[N][N],c1[N][N],b[N],sta[70000],sum[N],v[N];
void dfs(int x,int num,int ans)
{
if (x==n+1) {
if (num==r) sta[++cnt]=ans;
return;
}
dfs(x+1,num+1,ans+(1<<(x-1)));
dfs(x+1,num,ans);
}
int main()
{
freopen("a.in","r",stdin);
// freopen("my.out","w",stdout);
scanf("%d%d%d%d",&n,&m,&r,&c);
for (int i=1;i<=n;i++)
for (int j=1;j<=m;j++) scanf("%d",&a[i][j]);
dfs(1,0,0);
ans=1000000000;
for (int i=1;i<=cnt;i++) {
memset(b,0,sizeof(b));
memset(sum,0,sizeof(sum));
memset(c1,0,sizeof(c1));
for (int j=1;j<=n;j++)
if ((sta[i]>>(j-1))&1) b[j]=1;
for (int l=1;l<=m;l++)
for (int j=1;j<=m;j++)
{
if (l==j) continue;
for (int k=1;k<=n;k++)
if (b[k]) c1[l][j]+=abs(a[k][l]-a[k][j]);
}
for (int j=1;j<=m;j++) {
int t=0;
for (int k=1;k<=n;k++)
if (b[k]) v[++t]=a[k][j];
for (int k=2;k<=r;k++) sum[j]+=abs(v[k-1]-v[k]);
}
memset(f,127,sizeof(f));
f[0][0]=0;
for (int j=1;j<=c;j++)
for (int k=j;k<=m;k++)
for (int l=0;l<k;l++) {
int t=f[j-1][l]+c1[l][k]+sum[k];
f[j][k]=min(f[j][k],t);
if (j==c) ans=min(ans,f[j][k]);
}
}
printf("%d\n",ans);
}