optimal_binary_search_tree introduction to algorithm 3rd, example 15.5

算法导论 第三版 15.5动态规划问题


// optimal_binary_search_tree introduction to algorithm 3rd, example 15.5


#include <iostream>
#include <vector>
#include <algorithm>
#include <cassert>
#include <ostream>
#include <iterator>
#include <limits>
#include <complex>
#include "stopwatch.h"

struct Record
{
    Record() : min(0.0), weight(0.0), root(0) {}
    double min;
    double weight;
    size_t root;
};

typedef std::vector<std::vector<Record>> solution_t;

double optimize_bst(const std::vector<double> & nodes, size_t start, size_t end, const std::vector<double> & dummies, solution_t & solution)
{
    if (start == end)
    {
        return dummies[start];
    }

    if (solution[start][end-1].weight > 1.0e-6)
    {
        return solution[start][end-1].min;
    }

    auto min = std::numeric_limits<double>::max();
    for (size_t i = start; i < end; ++i)
    {
        auto submin = optimize_bst(nodes, start, i, dummies, solution);
        auto lweight = start == i ? dummies[start] : solution[start][i-1].weight;
        submin += lweight;
        submin += nodes[i];
        submin += optimize_bst(nodes, i+1, end, dummies, solution);
        auto rweight = i+1 == end ? dummies[i+1]   : solution[i+1][end-1].weight;
        submin += rweight;

        if (min > submin)
        {
            min = submin;
            solution[start][end-1].weight = lweight + nodes[i] + rweight;
            solution[start][end-1].root = i;
            solution[start][end-1].min = submin;
        }
    }

    return min;
}

double optimize_bst_none_recursive(const std::vector<double> &  nodes, const std::vector<double> &  dummies, solution_t & solution)
{
    for (size_t il = 0; il < nodes.size(); ++il)
    {
        for (size_t is = 0; is < nodes.size()-il; ++is)
        {
            auto min = std::numeric_limits<double>::max();
            for (size_t ii = is; ii < is+il+1; ++ii)
            {
                auto subleft  = ii == is    ? dummies[ii]   : solution[is][ii-1].min;
                auto sublw    = ii == is    ? dummies[ii]   : solution[is][ii-1].weight;
                auto subright = ii == is+il ? dummies[ii+1] : solution[ii+1][is+il].min;
                auto subrw    = ii == is+il ? dummies[ii+1] : solution[ii+1][is+il].weight;

                auto submin = subleft + sublw + nodes[ii] + subright + subrw;
                if (min > submin)
                {
                    min = submin;
                    solution[is][is+il].min    = submin;
                    solution[is][is+il].weight = sublw + nodes[ii] + subrw;
                    solution[is][is+il].root   = ii;
                }
            }
        }
    }
    return solution[0][nodes.size()-1].min;
}

void print_solution(const std::vector<double> &     nodes,
                    size_t                          start,
                    size_t                          end,
                    size_t                          root,
                    const std::vector<double> &     dummies,
                    const solution_t &              solution)
{
    if (start == end)
    {
        std::cout << "D" << start << "(" << dummies[start] << ") is " << (root >= start ? "left" : "right") 
                  << " child of N" << root+1 <<"(" << nodes[root] << ")" << std::endl;
        return;
    }

    auto & record = solution[start][end-1];
    if (root == 0)
    {
        std::cout << "Root N" << record.root+1 << "(" << nodes[record.root] << ")" << std::endl;
    }
    else
    {
        std::cout << "N" <<  record.root+1 << "(" << dummies[ record.root] << ") is " 
                  << (root >=  record.root ? "left" : "right") << " child of N" << root+1 <<"(" << nodes[root] << ")" << std::endl;
    }

    print_solution(nodes, start, record.root, record.root, dummies, solution);
    print_solution(nodes, record.root+1, end, record.root, dummies, solution);
}


void print_solution(const std::vector<double> & nodes, size_t start, size_t end, size_t root, const solution_t & solution)
{
    if (start == end)
    {
        return;
    }
    auto & record = solution[start][end-1];
    std::cout << "N" << record.root+1;
    if (start != end-1)
    {
        std::cout << "(";
        print_solution(nodes, start, record.root, record.root, solution);
        std::cout << ",";
        print_solution(nodes, record.root+1, end, record.root, solution);
        std::cout << ")";
    }
}

void optimize_bst(const std::vector<double> & nodes, const std::vector<double> & dummies)
{
    solution_t solution(nodes.size());
    for (auto & i : solution)
    {
        i.resize(nodes.size());
    }

    Stopwatch watch;
    watch.start();
    auto cost = optimize_bst(nodes, 0, nodes.size(), dummies, solution);
    watch.stop();
    std::cout << "find the minimise cost " << cost << " use time " << watch.elapsed() << " ms" << std::endl;
    std::cout << "the solution is :\n";
    print_solution(nodes, 0, nodes.size(), 0, /*dummies,*/ solution);
    std::cout << std::endl;
}

void optimize_bst_none_recursive(const std::vector<double> & nodes, const std::vector<double> & dummies)
{
    solution_t solution(nodes.size());
    for (auto & i : solution)
    {
        i.resize(nodes.size());
    }

    Stopwatch watch;
    watch.start();
    auto cost = optimize_bst_none_recursive(nodes, dummies, solution);
    watch.stop();
    std::cout << "none recursive find the minimise cost " << cost << " use time " << watch.elapsed() << " ms" << std::endl;
    std::cout << "the solution is :\n";
    print_solution(nodes, 0, nodes.size(), 0, /*dummies,*/ solution);
    std::cout << std::endl;
}

void test1()
{
    std::vector<double> nodes;
    nodes.reserve(5);
    nodes.push_back(0.15);
    nodes.push_back(0.10);
    nodes.push_back(0.05);
    nodes.push_back(0.10);
    nodes.push_back(0.20);
    std::vector<double> dummies;
    dummies.reserve(6);
    dummies.push_back(0.05);
    dummies.push_back(0.10);
    dummies.push_back(0.05);
    dummies.push_back(0.05);
    dummies.push_back(0.05);
    dummies.push_back(0.10);
    
    std::cout << "test 1 : \n nodes: ";
    std::copy(nodes.begin(), nodes.end(), std::ostream_iterator<double>(std::cout, " "));
    std::cout << "  dummies: ";
    std::copy(dummies.begin(), dummies.end(), std::ostream_iterator<double>(std::cout, ""));
    std::cout << std::endl;

    optimize_bst(nodes, dummies);
    std::cout << std::endl;
    
    optimize_bst_none_recursive(nodes, dummies);
    std::cout << std::endl;
}

void test2()
{
    std::vector<double> nodes;
    nodes.reserve(7);
    nodes.push_back(0.04);
    nodes.push_back(0.06);
    nodes.push_back(0.08);
    nodes.push_back(0.02);
    nodes.push_back(0.10);
    nodes.push_back(0.12);
    nodes.push_back(0.14);
    std::vector<double> dummies;
    dummies.reserve(8);
    dummies.push_back(0.06);
    dummies.push_back(0.06);
    dummies.push_back(0.06);
    dummies.push_back(0.06);
    dummies.push_back(0.05);
    dummies.push_back(0.05);
    dummies.push_back(0.05);
    dummies.push_back(0.05);
    
    std::cout << "test 2 : \n nodes: ";
    std::copy(nodes.begin(), nodes.end(), std::ostream_iterator<double>(std::cout, " "));
    std::cout << "  dummies: ";
    std::copy(dummies.begin(), dummies.end(), std::ostream_iterator<double>(std::cout, ""));
    std::cout << std::endl;

    optimize_bst(nodes, dummies);
    std::cout << std::endl;

    optimize_bst_none_recursive(nodes, dummies);
    std::cout << std::endl;
}

int main()
{
    std::cout << "start" << std::endl;
    test1();
    test2();
    return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值