题目描述
给出如下定义:
子矩阵:从一个矩阵当中选取某些行和某些列交叉位置所组成的新矩阵(保持行与列的相对顺序)被称为原矩阵的一个子矩阵。
例如,下面左图中选取第222、444行和第222、444、555列交叉位置的元素得到一个2×32 \times 32×3的子矩阵如右图所示。
9 3 3 3 9
9 4 8 7 4
1 7 4 6 6
6 8 5 6 9
7 4 5 6 1
的其中一个2×32 \times 32×3的子矩阵是
4 7 4
8 6 9
相邻的元素:矩阵中的某个元素与其上下左右四个元素(如果存在的话)是相邻的。
矩阵的分值:矩阵中每一对相邻元素之差的绝对值之和。
本题任务:给定一个nnn行mmm列的正整数矩阵,请你从这个矩阵中选出一个rrr行ccc列的子矩阵,使得这个子矩阵的分值最小,并输出这个分值。
(本题目为2014NOIP普及T4)
输入格式
第一行包含用空格隔开的四个整数n,m,r,cn,m,r,cn,m,r,c,意义如问题描述中所述,每两个整数之间用一个空格隔开。
接下来的nnn行,每行包含mmm个用空格隔开的整数,用来表示问题描述中那个nnn行mmm列的矩阵。
输出格式
一个整数,表示满足题目描述的子矩阵的最小分值。
输入输出样例
输入 #1
5 5 2 3
9 3 3 3 9
9 4 8 7 4
1 7 4 6 6
6 8 5 6 9
7 4 5 6 1
输出 #1
6
#include<bits/stdc++.h>
using namespace std;
const int N = 20;
int n, m, r, c;
int num[N][N], ch[N], gs = 1;
int lc[N], hc[N][N];
int f[N][N];
/*
具体思路:
首先想到暴力方法
枚举每个rc大小的矩阵,然后计算差值和,但复杂度过高
需要优化:
首先枚举n行中的r行
利用dfs()枚举每种可能,当dfs()进行到已经不能完成r行时停止
预处理选出来的r行每一列之间的行与行的差值和即lc[i]
在预处理c列中任意两列之间的差值和即hc[i][j] 第i列和第j列的差值和
利用线性dp处理即可
f[i][j]表示前i列中选j列的最小值(选第j列)
f[i][j] = min(f[i][j], f[k][j - 1] + lc[i] + hc[i][k])
*/
void Init()
{
for(int i = 1; i <= m; i++)
{
lc[i] = 0;
for(int j = 1; j < r; j++)
{
lc[i] += abs(num[ch[j]][i] - num[ch[j + 1]][i]);//
}
}
for(int i = 2; i <= m; i++)
{
for(int j = 1; j < i; j++)
{
hc[i][j] = 0; //第i列减第j列的中r行的差值差值
for(int k = 1; k <= r; k++)
{
hc[i][j] += abs(num[ch[k]][i] - num[ch[k]][j]);
}
}
}
}
int minn = 2e9;
int cmin;
void dp()
{
for(int i = 1; i <= m; i++) //前i行
{
cmin = min(i, c);
for(int j = 1; j <= cmin; j++)//选j行
{
if(j == 1)
f[i][j] = lc[i];
else
{
if(i == j)
{
f[i][j] = f[i - 1][j - 1] + lc[i] + hc[i][j - 1];
}
else
{
f[i][j] = 2e8;
for(int k = j - 1; k < i; k++)
{
f[i][j] = min(f[i][j], f[k][j - 1] + lc[i] + hc[i][k]);
}
}
}
if(j == c) minn = min(minn, f[i][c]);//前i列选c列
}
}
}
void dfs(int node)
{
if(node > n)
{
Init();
dp();
return;
}
if(r - gs + 1 == n - node + 1)//已经取地行数还没取得行数 优化剪枝
{
ch[gs ++] = node;
dfs(node + 1);
ch[gs --] = 0;
return;
}
dfs(node + 1);
if(gs <= r)
{
ch[gs ++] = node;
dfs(node + 1);
ch[gs --];
}
}
int main()
{
cout << 2<<20 << e
scanf("%d%d%d%d", &n, &m, &r, &c);
for(int i = 1; i <= n; i++)
for(int j = 1; j <= m; j++)
scanf("%d", &num[i][j]);
dfs(1);
printf("%d", minn);
return 0;
}