This is solved by dynamic programming, the algorithm is from lecture slide:
http://www.cs.princeton.edu/~wayne/kleinberg-tardos/pdf/IntractabilityIII-2x2.pdf
#include <iostream>
#include <vector>
#include <stack>
#include <Windows.h>
using namespace std;
struct node
{
char data;
struct node *left = nullptr;
struct node *right = nullptr;
int weight = 0;
int _s_in = 0; // subtree weight sum with this node
int _s_out = 0; // subtree weight sum without this node
node() = default;
node(char d, int w) : data(d), weight(w) {}
node(char d, int w, node *l, node *r) : data(d), weight(w), left(l), right(r) {}
};
void checkNode(node *curr, node *parent, vector<char> &set)
{
if (curr == nullptr)
return;
if (parent && parent->_s_in < parent->_s_out && curr->_s_in > curr->_s_out)
set.push_back(curr->data);
if (!parent && curr->_s_in > curr->_s_out)
set.push_back(curr->data);
checkNode(curr->left, curr, set);
checkNode(curr->right, curr, set);
}
int weightedIndeSet(node *root, vector<char> &set)
{
stack<node*> s;
node *lst = nullptr; // last visited node
node *cur = root; // current visit node
node *top = nullptr; // top node on stack
while (!s.empty() || cur) // postorder traverse
{
if (cur)
{
s.push(cur);
cur = cur->left;
}
else
{
top = s.top();
if (top->right && lst != top->right)
cur = top->right;
else
{
if (top->left == top->right)
top->_s_in = top->weight;
else
{
top->_s_in += top->weight;
if (top->left)
{
top->_s_in += top->left->_s_out;
top->_s_out += max(top->left->_s_in, top->left->_s_out);
}
if (top->right)
{
top->_s_in += top->right->_s_out;
top->_s_out += max(top->right->_s_in, top->right->_s_out);
}
}
lst = top;
s.pop();
}
}
}
checkNode(root, nullptr, set);
return max(root->_s_in, root->_s_out);
}
int main()
{
node *root = new node('F', 1);
root->left = new node('B', 10);
root->right = new node('G', 1);
root->left->left = new node('A', 1);
root->left->right = new node('D', 1);
root->left->right->left = new node('C', 1);
root->left->right->right = new node('E', 1);
root->right->right = new node('I', 6);
root->right->right->left = new node('H', 1);
vector<char> set;
cout << weightedIndeSet(root, set) << endl;
for (auto a : set)
cout << a << " ";
cout << endl;
system("PAUSE");
return 0;
}