目录
当到达叶子结点后,虽然已知所有分段点所需的乘法次数为s+delt_s,但是剩余的乘法次数怎么求?
2、限界2:如果当前节点的可行解乘法次数下界大于等于best
前言
最近笔者有个大作业是“回溯法解决矩阵乘法链的最小乘法次数问题”,而且要做PPT,笔者觉得自己对此问题的分析和C++代码可能对其他人有帮助,今天就把PPT内容搬过来了。矩阵乘法链问题具体是什么就不多赘述了,不了解的小伙伴可以自行搜索。
问题分析
第一步:如何形式化表示所有可能的不同矩阵乘法次序
注:这里的“不同”指的是“括号”层面的不同,而不是“乘号顺序”上的不同。举个例子:对于A*B*C*D来说,一共三个乘号,标号为1、2、3,那么对于计算时不同的乘法顺序“1 3 2”和“3 1 2”这两者,由于先算3还是先算1是一样的,这里认为是“相同的乘法次序”。至于为什么管这个叫做“括号层面的不同”,是因为“1 3 2”和“3 1 2”的乘法顺序用含有最少括号的表达式来表示都是A*B*(C*D)。区分上边这两种表示方式,对下文的理解十分重要。
第一种:对于每个乘号排序
例如A*B*C*D,一共有三个乘号,分别编号为1、2、3。则所有可能的顺序有A_3^3=6种,分别为(用最少得括号表示):
1 2 3:A*B*C*D
1 3 2:A*B*(C*D)
2 1 3:A*(B*C)*D
2 3 1:A*((B*C)*D)
3 1 2:A*B*(C*D)
3 2 1:A*(B*(C*D))
可以发现,在上面六种情况中132和312是重复的,在这两种情况中先计算第一个乘号还是先计算第三个对所用的乘法次数以及结果一定没有影响,为了避免这种冗余,需要另一种形式来表示对总乘法次数不同的乘法次序。
第二种:分割乘法链(加括号)
观察矩阵乘法的规律:当矩阵乘法链中只有一个矩阵,所用乘法次数是固定的0;当乘法链中有两个矩阵,所用的乘法次数也是确定的,只有一种;当乘法链中有三个矩阵,可能的乘法次序不唯一,显然有两个。
想要让三个矩阵的乘法链具有确定的乘法次序,需要将其拆分成“长度为一的乘法链”和“长度为二的乘法链”的组合,并且显然有:越靠前的分割点处的乘法越后算。
例子:
A*B*C,一共有两种切割方式(切一刀即可将其分为len=1or2的乘法链的组合):(A*B)*(C) 、(A)*(B*C)。
方法描述:
对整个矩阵乘法链(长度>=3)在乘号处进行切割,当最长子链长度<=2时产生一种可能且确定的乘法次序,由此方法产生的所有乘法次序在总乘法次数方面没有冗余。
例子(融入回溯思想):
对于A*B*C*D,第一个切割点可以选在第一个乘号处,于是得到(A)*(B*C*D) ,但此时并不满足max_len==2的条件,继续分割;下一个分段点不能继续选在第一个乘号处,于是向后顺序搜索,可以选在第二个乘号处,得到(A)*((B)*(C*D)),此时得到一个满足条件的清晰的乘法次序。使用这种回溯的搜索方法可以得到一个描述所有乘法次序的树:
上面的树(数字表示在哪个矩阵前切割)的叶子结点和后面的五个乘法次序分别对应:A*(B*(C*D))、A*((B*C)*D)、A*B*(C*D)、(A*B*C)*D、(A*(B*C))*D。观察发现,和上一个方法(乘号排序)相比,此方法确实没有冗余。
第二步:如何使用回溯法找到最小的矩阵乘法次数
不难看出,只要能在“上页中能遍历“所有不同乘法总数的乘法次序”的树“中遍历每一个叶子结点对应的所需乘法次数并找到最小值,即可在遍历结束后找到最小的乘法次数(暴力搜索)。为了节约时间,可以一边生成这棵树一边更新best(最小的乘法次数)值。
如何在展开这棵树的过程中记录乘法次数?
观察路径“3”对应的加括号的方式”(A*B)*(C*D)”(在第三个矩阵前加括号)可以看出,总共需要的乘法次数可以分类为两种:
1、由于切割矩阵乘法链导致的已经确定的乘法次数
2、在没有被切割的乘号处,但是由于此乘号所在的乘法链长度为2,这个乘号的所需的已确定的乘法次数。
类似于01背包问题中的s迭代量,在此算法中可以用s表示当前已经确定了的分割点对应的所需乘法次数的和,于是在每次回溯和展开节点的时候,都要对s进行修改,以与当前路径匹配。每当到达叶子结点,可得到一个总的所需乘法次数,即为s与剩下的乘号所需乘法次数的和,这是易于计算的。为了存储最优值,使用best变量,每到达一个叶子节点就使用当前节点对应的乘法次数的值更新best。不过,在讨论如何计算每一步s的增量delt_s之前,先要讨论下面这个问题。
如何描述一个节点的路径?
由于此算法是迭代回溯,可以将路径保存在一个迭代量中,比如v[1:n],这样的数组不仅能描述一个节点对应的分割点的位置,还能表示分割点的次序,足以表征一个节点路径的有效信息,可以满足算法需要。但是在想要观察矩阵链整体分割情况的时候(比如查看最大子链长度),那就要对分割点进行排序,最快也要额外造成nlogn量级的时间复杂度,于是可以使用有序链表list来描述分段点的位置,这样一来,每次查询相关信息只要线性时间即可。
例如路径“42”对应的加括号的方式”(A*(B*C))*D”(在第三个矩阵前加括号),对应的v[1:n]为{4,2},而list={2,4}。可见list中的内容和分断点的顺序无关,只是为了快速查询
如何计算每一步的s的增量delt_s?
首先,使用数组r[1:n+1]形式化描述一个矩阵乘法链
例如:(矩阵中间为矩阵编号1~n)
这些矩阵的大小如图所示,例如矩阵2为“r2*r3”的大小。
如何描述delt_s?
观察例子:对于上图使用42的分割顺序,那么当分割矩阵2前的乘号的时候,造成的delt_s为多少?
由上图不难看出,在2前切割导致的乘法次数是r1*r2*r4,可以看出,此次分割所需要的乘法次数即为此次分割产生的两个子链中“前方子链中第一个矩阵的行数(记作a)”乘以“前方子链中最后一个矩阵的列数或后方子链中第一个矩阵的行数”再乘以“后方子链中最后一个矩阵的列数(记作b)”。
由于a和b都是r数组中的元素,且该元素的下标与之前已经选过的分段点的位置有关,所以可以通过查询该分段点前方和后方最近的分段点位置得到这个a和b的下标ra和rb。于是delt_s = r[ra] * r[rb] * r[v[k]] (v[k]是当前节点选择的分段点的位置)。但是想要得到一个可行解(叶子结点)的总共的乘法次数,如何得到?
当到达叶子结点后,虽然已知所有分段点所需的乘法次数为s+delt_s,但是剩余的乘法次数怎么求?
观察上图可知,在这个矩阵乘法链中一共有n-1=3次乘法,由于分割已经解决了其中的两个乘法,接下来只需要求得仅剩的一个乘法(即在矩阵2、3之间的乘法)的所需乘法次数即可。
实现方法:在list中添加1和n+1两个元素作为边界,然后遍历,若发现相邻两元素差为2,则找到一个需要计算的乘法,总共需要的额外的乘法次数记为rest,则每找到一个就rest+=此次乘法需要的乘法次数。
在这个例子中,list一开始是{2,4},加入1和n-1后变成了{1 2 4 5} 遍历发现2 4之间相差2,于是找到一个需要额外加入的乘法 于是rest(初始为0)对应地自增r[2]*r[3]*r[4],这一操作重复直到遍历list结束。最后把s+delt_s+rest作为当前叶子结点的乘法次数即可,并更新best的值(若小于best)。
第三步:限界剪枝的方法
1、限界1:不能重复选取已经选过的分割点位置
实现方法:在list中顺序查找,如果碰到v[k](当前分割点位置),则返回1,表示不能展开该子节点,如果碰到比v[k]大的数,则返回0,表示可以展开
2、限界2:如果当前节点的可行解乘法次数下界大于等于best
实现方法:先找到每次矩阵乘法的最小需要的乘法次数(可以通过对r数组排序实现),然后在某一节点处用这个值乘以还有几个乘法没有算在s内,相加即可得到当前节点的可行解乘法次数下界
而best要相应地在每个答案节点(叶子结点)处更新,且在每个非叶子节点处用当前节点的可行解乘法次数上界更新。
第四步:伪代码与复杂度分析
伪代码:
伪代码:
traceback(r[1:n+1],n)
if n = 1
return 0
if n = 2
return r[1] * r[2] * r[3]
minn_mult <- r[1:n+1]中任意三个元素相乘的最小值
maxx_mult <- r[1:n+1]中任意三个元素相乘的最大值
best <- 正无穷 //表示当前已经找到的最小值
k <- 1 //当前节点层数
s <- 0 //当前由于分割矩阵序列已经消耗的乘法次数
ml <- n //当前最长矩阵序列的长度
list //存放路径上分断点有序链表
v[1:n]<-{1,1,...,1} //表示当前路径的数组,并初始化为1
while k >= 1 //没有从根节点返回表示尚未结束
if v[k] < n //当前节点尚未遍历所有分支
对v[]和list进行和路径相应的修改
if 当前分割点已经选过
continue
if 当前节点的可行解乘法次数下界大于等于best
continue
if 当前分割完最大矩阵序列长度=2
ans <- s + 此次分割s增量 + 剩余乘法次数
用ans更新best
else
用当前节点可行解上界更新best
更新迭代量s和k
else
list路径回退且v[k]重置为1
回退迭代量s和k
return best
function ()
scanf(矩阵数量n)
for i <- 1 to 2n
scanf(r[i/2+1])
answer <- traceback(r,n)
print(answer)
时间、空间复杂度分析(仅供参考)
第五步:用C++实现该算法(三个文件)
main.c
#include <iostream>
#include <algorithm>
#include "list.h"
#include <windows.h>
#define LEN 100
void print(int* v, int k)
{
for (int i = 1; i <= k; ++i)
{
std::cout << v[i] << " ";
}
}
int traceback(int *r,int n)
{
//当n很小,直接计算
if (n == 1)
return 0;
if (n == 2)
return r[1] * r[2] * r[3];
//找每步乘法的最大和最少乘法次数
int r_copy[LEN] = { 0 };
for (int i = 1; i <= n + 1; ++i)
r_copy[i] = r[i];
std::sort(r_copy , r_copy + n + 2);
int maxx_mult = r_copy[n + 1] * r_copy[n] * r_copy[n - 1];
int minn_mult = r_copy[1] * r_copy[2] * r_copy[3];
//迭代量以及其他变量
unsigned int best = -1;
int k = 1;
int s = 0;
int ans;
int delt_s;
int ml = n;
int ra, rb;
List list; //List类是一个有序的链表类型
list.insert(n + 1);
list.insert(1);
int v[LEN];
for (int i = 0; i < LEN; ++i)
v[i] = 1;
//开始回溯迭代
while (k >= 1)
{
if (v[k] < n)
{
//修改v和l,相当于修改路径
if (v[k] == 1)
list.insert(2);
else
{
list.delet(v[k]);
list.insert(v[k] + 1);
}
v[k]++;
//限界1:不能在一个地方重复切割
if (list.judge())
continue;
//计算s的增量
list.ra_rb(v[k],ra,rb);
delt_s = r[ra] * r[rb] * r[v[k]];
//限界2:如果当前节点的可行解乘法次数下界大于等于best
if (s+delt_s+(n-1-k)*minn_mult >= best)
continue;
//判定答案节点条件:最短矩阵链长度为2
if (list.max_len()==2)
{
ans = delt_s + s + list.rest(r);
//print(v, k);
//l.print();
if (best > ans)
best = ans;
}
//展开正常节点
else
{
best = min((int)best, s + delt_s + (n - 1 - k) * maxx_mult);
s += delt_s;
k++;
}
}
//回溯并回退迭代量
else
{
list.delet(v[k]);
v[k] = 1;
k--;
if (k == 0)
break;
list.ra_rb(v[k], ra, rb);
s -= r[ra] * r[rb] * r[v[k]];
}
}
return best;
}
int main()
{
DWORD start, end;
int times = 1;
float sum_time = 0;
int ans = 0;
//输入数据
int n;
scanf_s("%d", &n);
int r[100] = { 0 };
for (int i = 1; i <= 2 * n; ++i)
std::cin >> r[i / 2 + 1];
/*
* 确保输入数据的格式形如:
* 8
* 9 16
* 16 4
* 4 1
* 1 7
* 7 2
* 2 11
* 11 4
* 4 16
*
* 表示有8个矩阵,第二个矩阵的行数为16,列数为4
*/
开始计算结果和时间(ms)
//int n = 12; //
//int r[LEN] = { 0,50, 35, 25, 10, 60, 70, 3, 5, 10 , 35, 25, 10, 60, 70, 3, 5, 10 };
for (int i = 0; i < times; ++i)
{
start = GetTickCount();
ans = traceback(r, n);
end = GetTickCount();
sum_time += end - start;
}
printf("n is : %d\n", n);
printf("ans is : %d\n", ans);
printf("avg_time is : %f\n", sum_time / times);
return 0;
}
//测试用例
//int n = 6; //answer:15125
//int r[100] = { 0, 30 ,35, 15, 5 ,10 ,20 ,25 };
list.c
#include "list.h"
#include <malloc.h>
#include <iostream>
#include <stack>
using namespace std;
List::List()
{
num = 0;
if (!(head = (Node*)malloc(sizeof(Node))))
exit(0);
head->data = -1;
head->next = NULL;
}
List::~List()
{
stack<Node*> sta;
sta.push(head);
Node* p = head;
while (p->next != NULL)
{
p = p->next;
sta.push(p);
}
while (!sta.empty())
{
free(sta.top());
sta.pop();
}
}
void List::insert(int a)
{
//cout << "插入:" << a << endl;
Node* I = (Node*)malloc(sizeof(Node));
I->data = a;
Node* p = head;
while (p->next != NULL)
{
if (!(p->data <= a && p->next->data >= a))
p = p->next;
else
break;
}
I->next = p->next;
p->next = I;
num++;
}
void List::print()
{
Node* p = head;
while (p->next != NULL)
{
p = p->next;
cout << p->data << " ";
}
cout << endl;
}
void List::delet(int del)
{
Node* p = head;
while (1)
{
if (p->next == NULL)
{
cout << "no del" << endl;
break;
}
if (p->next->data == del)
{
p->next = p->next->next;
break;
}
p = p->next;
}
num--;
}
int List::at(int index) //从一开始
{
Node* p = head;
for (int i = 0; i < index; ++i)
{
p = p->next;
}
return p->data;
}
bool List::judge()
{
Node* p = head;
while (1)
{
if (p->next == NULL)
break;
if (p->next->next == NULL)
break;
if (p->next->data == p->next->next->data)
return 1;
p = p->next;
}
return 0;
}
int List::max_len()
{
int maxx = 0;
Node* p = head;
while (p->next->next != NULL)
{
maxx = max(maxx, p->next->next->data - p->next->data);
p = p->next;
}
return maxx;
}
void List::ra_rb(int vk, int& ra, int& rb)
{
Node* p = head -> next;
while (p -> next -> next!= NULL)
{
if (p->next->data == vk)
{
ra = p->data;
rb = p->next->next->data;
return;
}
p = p->next;
}
cout << "error";
exit(0);
}
int List::rest(int *r)
{
int sum = 0;
Node* p = head->next;
while (p->next != NULL)
{
if (p->data + 2 == p->next->data)
{
int x = p->data;
sum += r[x] * r[x + 1] * r[x + 2];
}
p = p->next;
}
return sum;
}
list.h
#pragma once
class List
{
private:
struct Node
{
int data;
struct Node* next;
};
Node* head;
int num; //不含头结点
public:
List();
~List();
void insert(int);
void print();
int at(int);
void delet(int);
void ra_rb(int vk,int &ra,int &rb);
bool judge();
int max_len();
int rest(int *r);
};
第六步:测试数据并分析时间
实际测量运行时间方法:
可以使用window.h库中DWORD start = GetTickCount();的方式获取时间,利用这一方式可以对一个程序多次运行求得运行时间平均值
测试结果:
在附件1中链接网页(就是这个)中的数据集测试中,最多只能通过容量为8的第二个数据集,由于数据量较小,运行时间较短,测得0ms。(如下图)
通过更改数据集的容量(如下图),可以发现:随着n的增长,程序运行时间增长地越来越快,到n=14时,对于某一数据集已经需要52s,到n=16时,时间已经是不可接受的。
图不清晰的话可以看下面文字:
当有九个矩阵(n=9),运行时间为3.1ms; 当n=10,需要12.5ms; n=11,需要82.8ms; n=12,需要323.4ms; 当n=13,需要1344ms; 当n=14,需要52422ms; (n=15没测出来,因为一测电脑就暴毙了:(