题意:
给出 N 个矩阵(A1,A2,…,An),求完全括号化方案,使得计算乘积(A1A2…An)所需乘法次数最少。并输出方案。
思路:经典区间DP。
要求的是【0,n-1】的最小代价。且大区间的决策依赖于小区间。矩阵连乘的最后一定有一个最后一次乘法,假设最后一个乘号在第 k 个矩阵后,也就是P = A1 x A2 x … Ak 和Q = A(k+1) x A(k+2) x … x A(n)。只需分别求出P,Q的最优方案(最优子结构)。为了计算P的最优方案,需要继续枚举P = A1 x A2 x … Ak的最后一次乘法,把它分成两部分。由此发现,这个问题的子问题都是“把Ai,A(i+1),,Aj”乘起来需要多少次乘法。定义状态:d[ i ][ j ]:从第 i 个矩阵连续乘到第 j 个矩阵的最小乘法次数。
转移方程为:d[ i ][ j ] = min( d[ i ][ k ] + d[ k+1 ][ j ] + 最后一次乘法的代价 ).
记忆化搜索:
int f(int i, int j){
int& ans = d[i][j];
if(ans < INF) return ans;
for(int k = i; k < j; ++k){ // 没有等于!!
if( ans > f(i,k) + f(k+1,j) + p[i]*q[k]*q[j] ){
ans = f(i,k) + f(k+1,j) + p[i]*q[k]*q[j];
mark[i][j] = k;
}
}
return ans;
}
递推:
for(int len = 1; len <= n; ++len){// 枚举链乘长度(区间长度)
for(int i = 0; i <= n-len; ++i) // 枚举区间起点
for(int k = i; k < i+len-1; ++k){ // 枚举最后乘法的位置
int t = d[i][k] + d[k+1][i+len-1] + p[i]*q[k]*q[i+len-1];
if(d[i][i+len-1] > t){
d[i][i+len-1] = t;
mark[i][i+len-1] = k; // 标记每次的决策,便于递归输出路径
}
}
}
完整代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 20;
int d[N][N], p[N], q[N], mark[N][N];
void print_ans(int i, int j){
if(i == j){ printf("A%d",i+1); return; }
printf("(");
print_ans(i, mark[i][j]);
printf(" x ");
print_ans(mark[i][j]+1, j);
printf(")");
}
int f(int i, int j){
int& ans = d[i][j];
if(ans < INF) return ans;
for(int k = i; k < j; ++k){ // 没有等于!!
if( ans > f(i,k) + f(k+1,j) + p[i]*q[k]*q[j] ){
ans = f(i,k) + f(k+1,j) + p[i]*q[k]*q[j];
mark[i][j] = k;
}
}
return ans;
}
int main()
{
//freopen("in.txt","r",stdin);
int n, kase = 1;
while(scanf("%d",&n) == 1&&n){
memset(d,0x3f,sizeof(d));
memset(mark,0,sizeof(mark));
for(int i = 0; i < n; ++i){
scanf("%d %d",&p[i],&q[i]);
d[i][i] = 0;
}
// 递推
for(int len = 1; len <= n; ++len){// 枚举链乘长度(区间长度)
for(int i = 0; i <= n-len; ++i) // 枚举区间起点
for(int k = i; k < i+len-1; ++k){ // 枚举区间内最后乘法位置
int t = d[i][k] + d[k+1][i+len-1] + p[i]*q[k]*q[i+len-1];
if(d[i][i+len-1] > t){
d[i][i+len-1] = t;
mark[i][i+len-1] = k;
}
}
}
//int ans = f(0,n-1);
//printf("%d\n",ans);
printf("Case %d: ",kase++);
print_ans(0,n-1);
printf("\n");
}
return 0;
}