matrix_chain_multiplication, introduction to algorithm 3rd, example 15.2

算法导论 第三版 15.2动态规划问题的实现


// introduction to algorithm 3rd, example 15.2

#include <iostream>
#include <vector>
#include <algorithm>
#include <cassert>
#include <ostream>
#include <iterator>
#include <limits>
#include "stopwatch.h"

struct Matrix
{
    Matrix(size_t l, size_t r) : left(l), right(r) {}
    size_t left;
    size_t right;
};

std::ostream & operator<<(std::ostream & os, const Matrix & m)
{
    return os << "M[" << m.left << ',' << m.right << ']';
}

struct Solution
{
    Solution() : min(std::numeric_limits<size_t>::max()), pos(0) {}
    size_t min;
    size_t pos;
};

size_t split_matrix_chain(std::vector<Matrix> & matrixChain,
                          size_t startIndex,
                          size_t endIndex,
                          std::vector<std::vector<Solution>> & solution)
{
    if (startIndex == endIndex-1)
    {
        solution[startIndex][endIndex-1].min = 0;
        return 0;
    }


    size_t count = std::numeric_limits<size_t>::max();
    for (size_t index = startIndex+1; index < endIndex; ++index)
    {
        auto left  = solution[startIndex][index-1].min != std::numeric_limits<size_t>::max()
                   ? solution[startIndex][index-1].min
                   : split_matrix_chain(matrixChain, startIndex, index, solution);
        auto right = solution[index][endIndex-1].min != std::numeric_limits<size_t>::max()
                   ? solution[index][endIndex-1].min
                   : split_matrix_chain(matrixChain, index, endIndex, solution);
        auto subCount = left
                      + matrixChain[startIndex].left * matrixChain[index].left * matrixChain[endIndex-1].right
                      + right;


        if (subCount < count)
        {
            count = subCount;
            solution[startIndex][endIndex-1].min = count;
            solution[startIndex][endIndex-1].pos = index;
        }
    }


    return count;
}

size_t split_matrix_chain_none_recursive(std::vector<Matrix> & matrixChain,
                                         size_t startIndex,
                                         size_t endIndex,
                                         std::vector<std::vector<Solution>> & solution)
{
    for (auto si = startIndex+1; si < endIndex; ++si)
    {
        for (auto li = startIndex; li < endIndex-si; ++li)
        {
            size_t min = std::numeric_limits<size_t>::max();
            for (auto pi = li + 1; pi < li+si+1; ++pi)
            {
                auto submin = solution[li][pi-1].min + solution[pi][li+si].min
                            + matrixChain[li].left * matrixChain[pi].left * matrixChain[li+si].right;
                if (min > submin)
                {
                    min = submin;
                    solution[li][li+si].pos = pi;
                }
            }
            solution[li][li+si].min = min;
        }
    }

    return solution[startIndex][endIndex-1].min;
}

void print_solution(std::vector<Matrix> & matrixChain,
                    size_t startIndex,
                    size_t endIndex,
                    std::vector<std::vector<Solution>> & solution)
{
    if (startIndex == endIndex-1)
    {
        std::cout << matrixChain[startIndex];
    }
    else
    {
        auto pos = solution[startIndex][endIndex-1].pos;
        if (startIndex == pos-1)
        {
            std::cout << matrixChain[startIndex];
        }
        else
        {
            std::cout << "(";
            print_solution(matrixChain, startIndex, pos, solution);
            std::cout << ")";
        }
        if (pos == endIndex-1)
        {
            std::cout << matrixChain[pos];
        }
        else
        {
            std::cout << "(";
            print_solution(matrixChain, pos, endIndex, solution);
            std::cout << ")";
        }
    }
}

void print_solution(std::vector<Matrix> & matrixChain, std::vector<std::vector<Solution>> & solution)
{
    std::cout << "The solution is :\n    ";
    print_solution(matrixChain, 0, matrixChain.size(), solution);
    std::cout << std::endl;
}


void split_matrix_chain(std::vector<Matrix> & matrixChain)
{
    std::vector<std::vector<Solution>> solution(matrixChain.size());
    for (auto & i : solution)
    {
        i.resize(matrixChain.size());
    }


    Stopwatch watch;
    watch.start();
    auto min = split_matrix_chain(matrixChain, 0, matrixChain.size(), solution);
    watch.stop();
    assert(solution[0][matrixChain.size()-1].min == min);
    std::cout << "The minimum multiplication count is " << min << " use time " << watch.elapsed() << " ms" << std::endl;
    print_solution(matrixChain, solution);
}

void split_matrix_chain_none_recursive(std::vector<Matrix> & matrixChain)
{
    std::vector<std::vector<Solution>> solution(matrixChain.size());
    for (auto & i : solution)
    {
        i.resize(matrixChain.size());
    }
    for (size_t i = 0; i < matrixChain.size(); ++i)
    {
        solution[i][i].min = 0;
    }


    Stopwatch watch;
    watch.start();
    auto min = split_matrix_chain_none_recursive(matrixChain, 0, matrixChain.size(), solution);
    watch.stop();
    assert(solution[0][matrixChain.size()-1].min == min);
    std::cout << "(no recursive) The minimum multiplication count is " << min << " use time " << watch.elapsed() << " ms" << std::endl;
    print_solution(matrixChain, solution);
}

std::vector<Matrix> create_matrix_chain1()
{
    std::vector<Matrix> mc;
    mc.reserve(6);
    mc.push_back(Matrix(5 ,10));
    mc.push_back(Matrix(10 ,3));
    mc.push_back(Matrix(3 ,12));
    mc.push_back(Matrix(12 ,5));
    mc.push_back(Matrix(5 ,50));
    mc.push_back(Matrix(50 ,6));


    std::cout << "Input matrix is :\n    ";
    std::copy(mc.begin(), mc.end(), std::ostream_iterator<Matrix>(std::cout, " "));
    std::cout << std::endl;


    return mc;
}

std::vector<Matrix> create_matrix_chain2()
{
    std::vector<Matrix> mc;
    mc.reserve(6);
    mc.push_back(Matrix(30, 35));
    mc.push_back(Matrix(35, 15));
    mc.push_back(Matrix(15, 5));
    mc.push_back(Matrix(5,  10));
    mc.push_back(Matrix(10, 20));
    mc.push_back(Matrix(20, 25));


    std::cout << "Input matrix is :\n    ";
    std::copy(mc.begin(), mc.end(), std::ostream_iterator<Matrix>(std::cout, " "));
    std::cout << std::endl;


    return mc;
}

int main()
{
    std::cout << "start" << std::endl;
    std::cout << "test 1" << std::endl;
    auto matrixChain = create_matrix_chain1();
    split_matrix_chain(matrixChain);
    split_matrix_chain_none_recursive(matrixChain);


    std::cout << "\ntest 2" << std::endl;
    auto matrixChain2 = create_matrix_chain2();
    split_matrix_chain(matrixChain2);
    split_matrix_chain_none_recursive(matrixChain2);


    return 0;
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值