算法导论 第三版 15.5动态规划问题
// optimal_binary_search_tree introduction to algorithm 3rd, example 15.5
#include <iostream>
#include <vector>
#include <algorithm>
#include <cassert>
#include <ostream>
#include <iterator>
#include <limits>
#include <complex>
#include "stopwatch.h"
struct Record
{
Record() : min(0.0), weight(0.0), root(0) {}
double min;
double weight;
size_t root;
};
typedef std::vector<std::vector<Record>> solution_t;
double optimize_bst(const std::vector<double> & nodes, size_t start, size_t end, const std::vector<double> & dummies, solution_t & solution)
{
if (start == end)
{
return dummies[start];
}
if (solution[start][end-1].weight > 1.0e-6)
{
return solution[start][end-1].min;
}
auto min = std::numeric_limits<double>::max();
for (size_t i = start; i < end; ++i)
{
auto submin = optimize_bst(nodes, start, i, dummies, solution);
auto lweight = start == i ? dummies[start] : solution[start][i-1].weight;
submin += lweight;
submin += nodes[i];
submin += optimize_bst(nodes, i+1, end, dummies, solution);
auto rweight = i+1 == end ? dummies[i+1] : solution[i+1][end-1].weight;
submin += rweight;
if (min > submin)
{
min = submin;
solution[start][end-1].weight = lweight + nodes[i] + rweight;
solution[start][end-1].root = i;
solution[start][end-1].min = submin;
}
}
return min;
}
double optimize_bst_none_recursive(const std::vector<double> & nodes, const std::vector<double> & dummies, solution_t & solution)
{
for (size_t il = 0; il < nodes.size(); ++il)
{
for (size_t is = 0; is < nodes.size()-il; ++is)
{
auto min = std::numeric_limits<double>::max();
for (size_t ii = is; ii < is+il+1; ++ii)
{
auto subleft = ii == is ? dummies[ii] : solution[is][ii-1].min;
auto sublw = ii == is ? dummies[ii] : solution[is][ii-1].weight;
auto subright = ii == is+il ? dummies[ii+1] : solution[ii+1][is+il].min;
auto subrw = ii == is+il ? dummies[ii+1] : solution[ii+1][is+il].weight;
auto submin = subleft + sublw + nodes[ii] + subright + subrw;
if (min > submin)
{
min = submin;
solution[is][is+il].min = submin;
solution[is][is+il].weight = sublw + nodes[ii] + subrw;
solution[is][is+il].root = ii;
}
}
}
}
return solution[0][nodes.size()-1].min;
}
void print_solution(const std::vector<double> & nodes,
size_t start,
size_t end,
size_t root,
const std::vector<double> & dummies,
const solution_t & solution)
{
if (start == end)
{
std::cout << "D" << start << "(" << dummies[start] << ") is " << (root >= start ? "left" : "right")
<< " child of N" << root+1 <<"(" << nodes[root] << ")" << std::endl;
return;
}
auto & record = solution[start][end-1];
if (root == 0)
{
std::cout << "Root N" << record.root+1 << "(" << nodes[record.root] << ")" << std::endl;
}
else
{
std::cout << "N" << record.root+1 << "(" << dummies[ record.root] << ") is "
<< (root >= record.root ? "left" : "right") << " child of N" << root+1 <<"(" << nodes[root] << ")" << std::endl;
}
print_solution(nodes, start, record.root, record.root, dummies, solution);
print_solution(nodes, record.root+1, end, record.root, dummies, solution);
}
void print_solution(const std::vector<double> & nodes, size_t start, size_t end, size_t root, const solution_t & solution)
{
if (start == end)
{
return;
}
auto & record = solution[start][end-1];
std::cout << "N" << record.root+1;
if (start != end-1)
{
std::cout << "(";
print_solution(nodes, start, record.root, record.root, solution);
std::cout << ",";
print_solution(nodes, record.root+1, end, record.root, solution);
std::cout << ")";
}
}
void optimize_bst(const std::vector<double> & nodes, const std::vector<double> & dummies)
{
solution_t solution(nodes.size());
for (auto & i : solution)
{
i.resize(nodes.size());
}
Stopwatch watch;
watch.start();
auto cost = optimize_bst(nodes, 0, nodes.size(), dummies, solution);
watch.stop();
std::cout << "find the minimise cost " << cost << " use time " << watch.elapsed() << " ms" << std::endl;
std::cout << "the solution is :\n";
print_solution(nodes, 0, nodes.size(), 0, /*dummies,*/ solution);
std::cout << std::endl;
}
void optimize_bst_none_recursive(const std::vector<double> & nodes, const std::vector<double> & dummies)
{
solution_t solution(nodes.size());
for (auto & i : solution)
{
i.resize(nodes.size());
}
Stopwatch watch;
watch.start();
auto cost = optimize_bst_none_recursive(nodes, dummies, solution);
watch.stop();
std::cout << "none recursive find the minimise cost " << cost << " use time " << watch.elapsed() << " ms" << std::endl;
std::cout << "the solution is :\n";
print_solution(nodes, 0, nodes.size(), 0, /*dummies,*/ solution);
std::cout << std::endl;
}
void test1()
{
std::vector<double> nodes;
nodes.reserve(5);
nodes.push_back(0.15);
nodes.push_back(0.10);
nodes.push_back(0.05);
nodes.push_back(0.10);
nodes.push_back(0.20);
std::vector<double> dummies;
dummies.reserve(6);
dummies.push_back(0.05);
dummies.push_back(0.10);
dummies.push_back(0.05);
dummies.push_back(0.05);
dummies.push_back(0.05);
dummies.push_back(0.10);
std::cout << "test 1 : \n nodes: ";
std::copy(nodes.begin(), nodes.end(), std::ostream_iterator<double>(std::cout, " "));
std::cout << " dummies: ";
std::copy(dummies.begin(), dummies.end(), std::ostream_iterator<double>(std::cout, ""));
std::cout << std::endl;
optimize_bst(nodes, dummies);
std::cout << std::endl;
optimize_bst_none_recursive(nodes, dummies);
std::cout << std::endl;
}
void test2()
{
std::vector<double> nodes;
nodes.reserve(7);
nodes.push_back(0.04);
nodes.push_back(0.06);
nodes.push_back(0.08);
nodes.push_back(0.02);
nodes.push_back(0.10);
nodes.push_back(0.12);
nodes.push_back(0.14);
std::vector<double> dummies;
dummies.reserve(8);
dummies.push_back(0.06);
dummies.push_back(0.06);
dummies.push_back(0.06);
dummies.push_back(0.06);
dummies.push_back(0.05);
dummies.push_back(0.05);
dummies.push_back(0.05);
dummies.push_back(0.05);
std::cout << "test 2 : \n nodes: ";
std::copy(nodes.begin(), nodes.end(), std::ostream_iterator<double>(std::cout, " "));
std::cout << " dummies: ";
std::copy(dummies.begin(), dummies.end(), std::ostream_iterator<double>(std::cout, ""));
std::cout << std::endl;
optimize_bst(nodes, dummies);
std::cout << std::endl;
optimize_bst_none_recursive(nodes, dummies);
std::cout << std::endl;
}
int main()
{
std::cout << "start" << std::endl;
test1();
test2();
return 0;
}