计算矩阵连乘积
在科学计算中经常要计算矩阵的乘积。矩阵A和B可乘的条件是矩阵A的列数等于矩阵B的行数。若A是一个p×q的矩阵,B是一个q×r的矩阵,则其乘积C=AB是一个p×r的矩阵。由该公式知计算C=AB总共需要pqr次的数乘。其标准计算公式为:
现在的问题是,给定n个矩阵{A1,A2,…,An}。其中Ai与Ai+1是可乘的,i=1,2,…,n-1。要求计算出这n个矩阵的连乘积A1A2…An,最少的乘法次数。
递归公式:
算法参考如下:
void MatrixChain(int *p,int n,int **m,int **s)
{ for (int i = 1; i <= n; i++)
m[i][i] = 0;
for (int r = 2; r <= n; r++)
for (int i = 1; i <= n - r+1; i++) {
int j=i+r-1;
m[i][j] = m[i+1][j]+ p[i-1]*p[i]*p[j];
s[i][j] = i;
for (int k = i+1; k < j; k++) {
int t = m[i][k] + m[k+1][j] + p[i-1]*p[k]*p[j];
if (t < m[i][j]) { m[i][j] = t; s[i][j] = k;}
}
}
}
void traceback(int i,int j,int **s)
{
if(i==j)
cout<<"A"<<i;
else if (i==j-1)
cout<<"(A"<<i<<"A"<<j<<")";
else
{
cout<<"(";
traceback(i,s[i][j],s);
traceback(s[i][j]+1,j,s);
cout<<")";
}
}
具体代码:
#include<stdio.h>
#define MAX 100
void MatrixChain(int *p,int n,int **m,int **s)
{
int i,j,r;
for(i=0;i<=n;i++)
for(j=0;j<n;j++)
{
if(i==0||j==0)m[i][j]=-1,s[i][j]=-1;
if(j<i)m[i][j]=-1,s[i][j]=-1;
}
for(i=1;i<=n;i++)
m[i][i]=0,s[i][i]=0;
for(r=2;r<=n;r++)
for(i=1;i<=n-r+1;i++)
{
j=i+r-1;
m[i][j]=m[i+1][j]+p[i-1]*p[i]*p[j];
s[i][j]=i;
for(int k=i+1;k<j;k++)
{
int t=m[i][k]+m[k+1][j]+p[i-1]*p[k]*p[j];
if(t<m[i][j])
{
m[i][j]=t;
s[i][j]=k;
}
}
}
printf("最优值m[][]数组为:\n");
for(i=1;i<=n;i++)
{
for(j=1;j<=n;j++)
if(m[i][j]==-1)printf(" \t");
else printf("%d\t",m[i][j]);
printf("\n");
}
printf("最优解s[][]数组为:\n");
for(i=1;i<=n;i++)
{
for(j=1;j<=n;j++)
{
if(m[i][j]==-1)printf(" \t");
else printf("%d\t",s[i][j]);
}
printf("\n");
}
}
void traceback(int i,int j,int **s)
{
if(i==j)
printf("A%d",i);
else if(i==j-1)
printf("(A%dA%d)",i,j);
else
{
printf("(");
traceback(i,s[i][j],s);
traceback(s[i][j]+1,j,s);
printf(")");
}
}
void main()
{
int n,i;
int *p=new int[MAX];
printf("一共有n个矩阵,请输入n的值:");
scanf("%d",&n);
printf("请输入依次n个矩阵的行值和列值(列如2个矩阵30x35,35x15,则输入30 35 15):\n");
for(i=0;i<=n;i++)
scanf("%d",&p[i]);
int **m = new int*[n+1];
for(i= 0; i <=n; i++)
m[i] = new int[n+1];
int **s = new int*[n+1];
for(i= 0; i <=n; i++)
s[i] = new int[n+1];
MatrixChain(p,n,m,s);
traceback(1,n,s);
printf("\n");
}