template <typename T>
class segment_tree
{
struct record
{
int left;
int right;
T value;
T sum;
};
int m_size;
std::vector<record> m_records;
friend class getter;
public:
typedef std::pair<int, int> range;
class getter
{
friend class segment_tree;
record * m_record, m_reserve;
std::vector<record*> m_ancestors;
int m_left, m_right;
getter( segment_tree & tree, int value_index ) : m_left( 1 ), m_right( tree.m_size )
{
assert( value_index >= m_left && value_index <= m_right );
int my_index = 1;
m_record = &tree.m_records[ my_index ];
for ( ;; ) {
if ( m_record->left <= 0 ) {
m_record->left = m_left;
m_record->right = m_right;
}
if ( m_record->left == value_index && m_record->right == value_index ) {
break;
}
int mid = ( m_left + m_right ) / 2;
my_index <<= 1;
if ( value_index <= mid ) {
m_right = mid;
} else {
m_left = mid + 1;
++my_index;
}
m_ancestors.push_back( m_record );
m_record = &tree.m_records[ my_index ];
}
}
getter( segment_tree & tree, int left, int right ) : m_left( 1 ), m_right( tree.m_size )
{
assert( left >= m_left && right <= m_right && left <= right );
int my_left = 1, my_right = tree.m_size, my_index = 1;
record * current = &tree.m_records[ my_index ];
bool is_left = true;
T left_sum = 0;
for ( ; ; ) {
if ( current->left <= 0 ) {
current->left = my_left;
current->right = my_right;
}
if ( current->left == left && current->right == left ) {
break;
}
int mid = ( my_left + my_right ) / 2;
my_index <<= 1;
if ( left <= mid ) {
my_right = mid;
is_left = true;
} else {
my_left = mid + 1;
++my_index;
is_left = false;
}
if ( !is_left ) {
left_sum += tree.m_records[ my_index - 1 ].sum;
}
current = &tree.m_records[ my_index ];
}
my_left = 1, my_right = tree.m_size, my_index = 1;
current = &tree.m_records[ my_index ];
is_left = true;
T right_sum = 0;
for ( ;; ) {
if ( current->left <= 0 ) {
current->left = my_left;
current->right = my_right;
}
if ( current->left == right && current->right == right ) {
break;
}
int mid = ( my_left + my_right ) / 2;
my_index <<= 1;
if ( right <= mid ) {
my_right = mid;
is_left = true;
} else {
my_left = mid + 1;
++my_index;
is_left = false;
}
if ( is_left ) {
right_sum += tree.m_records[ my_index + 1 ].sum;
}
current = &tree.m_records[ my_index ];
}
m_reserve.left = left;
m_reserve.right = right;
m_reserve.sum = tree.m_records[1].sum - (left_sum + right_sum);
m_reserve.value = m_reserve.sum;
m_record = &m_reserve;
}
public:
operator T() { return m_record->value; }
operator const T() const { return m_record->value; }
getter & operator=( typename boost::call_traits<T>::param_type value )
{
T diff = value - m_record->value;
m_record->value = value;
m_record->sum = value;
std::for_each(m_ancestors.begin(), m_ancestors.end(), [diff](record * r) { r->sum += diff; });
return *this;
}
};
segment_tree( int n ) : m_size( n ), m_records( m_size * 2, record{ -1, -1, 0, 0 } )
{
assert(n > 0);
}
getter operator[](int i)
{
if (i < 0 || i >= m_size) {
throw std::out_of_range("bad index");
}
return getter( *this, i + 1 );
}
const getter operator[]( int i ) const
{
if ( i < 0 || i >= m_size ) {
throw std::out_of_range( "bad index" );
}
return getter( *this, i + 1 );
}
const T operator[]( range r ) const
{
if ( r.first < 0 || r.second >= m_size || r.first > r.second ) {
throw std::out_of_range( "bad index" );
}
if (r.first != r.second) {
return static_cast<const T>( getter( const_cast<segment_tree<T> &>( *this ), r.first + 1, r.second + 1 ) );
}
else {
return static_cast<const T>( getter( const_cast<segment_tree<T> &>(*this), r.first + 1 ) );
}
}
};
We can use the class in this way:
segment_tree<int> st(10);
typedef segment_tree<int>::range range;
st[2] = 10;
st[0] = 5;
st[1] = 5;
verify( st[ 1 ] == 5, stderr, "Failed to get the specified item\n" );
st[ 3 ] = 15;
st[8] = 10;
st[9] = 20;
st[1] = 10;
verify( st[ 1 ] == 10, stderr, "Failed to get the specified item\n" );
st[ 2 ] = 30;
verify(st[range{1, 2}] == 40, stderr, "Failed to get the sum of the specified range\n");