How to implement segment tree

    template <typename T>
    class segment_tree
    {
        struct record
        {
            int left;
            int right;
            T value;
            T sum;
        };
        int                 m_size;
        std::vector<record> m_records;

        friend class getter;
    public:
        typedef std::pair<int, int> range;

        class getter
        {
            friend class segment_tree;

            record * m_record, m_reserve;
            std::vector<record*> m_ancestors;
            int m_left, m_right;

            getter( segment_tree & tree, int value_index ) : m_left( 1 ), m_right( tree.m_size )
            {
                assert( value_index >= m_left && value_index <= m_right );
                int my_index = 1;
                m_record = &tree.m_records[ my_index ];
                for ( ;; ) {
                    if ( m_record->left <= 0 ) {
                        m_record->left = m_left;
                        m_record->right = m_right;
                    }

                    if ( m_record->left == value_index && m_record->right == value_index ) {
                        break;
                    }

                    int mid = ( m_left + m_right ) / 2;
                    my_index <<= 1;
                    if ( value_index <= mid ) {
                        m_right = mid;
                    } else {
                        m_left = mid + 1;
                        ++my_index;
                    }
                    m_ancestors.push_back( m_record );
                    m_record = &tree.m_records[ my_index ];
                }
            }

            getter( segment_tree & tree, int left, int right ) : m_left( 1 ), m_right( tree.m_size )
            {
                assert( left >= m_left && right <= m_right && left <= right );
                int my_left = 1, my_right = tree.m_size, my_index = 1;
                record * current = &tree.m_records[ my_index ];
                bool is_left = true;
                T left_sum = 0;
                for ( ; ; ) {
                    if ( current->left <= 0 ) {
                        current->left = my_left;
                        current->right = my_right;
                    }
                    if ( current->left == left && current->right == left ) {
                        break;
                    }

                    int mid = ( my_left + my_right ) / 2;
                    my_index <<= 1;
                    if ( left <= mid ) {
                        my_right = mid;
                        is_left = true;
                    } else {
                        my_left = mid + 1;
                        ++my_index;
                        is_left = false;
                    }
                    if ( !is_left ) {
                        left_sum += tree.m_records[ my_index - 1 ].sum;
                    }
                    current = &tree.m_records[ my_index ];
                }

                my_left = 1, my_right = tree.m_size, my_index = 1;
                current = &tree.m_records[ my_index ];
                is_left = true;
                T right_sum = 0;
                for ( ;; ) {
                    if ( current->left <= 0 ) {
                        current->left = my_left;
                        current->right = my_right;
                    }
                    if ( current->left == right && current->right == right ) {
                        break;
                    }

                    int mid = ( my_left + my_right ) / 2;
                    my_index <<= 1;
                    if ( right <= mid ) {
                        my_right = mid;
                        is_left = true;
                    } else {
                        my_left = mid + 1;
                        ++my_index;
                        is_left = false;
                    }
                    if ( is_left ) {
                        right_sum += tree.m_records[ my_index + 1 ].sum;
                    }
                    current = &tree.m_records[ my_index ];
                }

                m_reserve.left = left;
                m_reserve.right = right;
                m_reserve.sum = tree.m_records[1].sum - (left_sum + right_sum);
                m_reserve.value = m_reserve.sum;

                m_record = &m_reserve;
            }
        public:
            operator T() { return m_record->value; }
            operator const T() const { return m_record->value; }

            getter & operator=( typename boost::call_traits<T>::param_type value )
            {
                T diff = value - m_record->value;
                m_record->value = value;
                m_record->sum = value;
                std::for_each(m_ancestors.begin(), m_ancestors.end(), [diff](record * r) { r->sum += diff; });
                return *this;
            }
        };

        segment_tree( int n ) : m_size( n ), m_records( m_size * 2, record{ -1, -1, 0, 0 } )
        {
            assert(n > 0);
        }

        getter operator[](int i)
        {
            if (i < 0 || i >= m_size) {
                throw std::out_of_range("bad index");
            }
            return getter( *this, i + 1 );
        }
        const getter operator[]( int i ) const
        {
            if ( i < 0 || i >= m_size ) {
                throw std::out_of_range( "bad index" );
            }
            return getter( *this, i + 1 );
        }

        const T operator[]( range r ) const
        {
            if ( r.first < 0 || r.second >= m_size || r.first > r.second ) {
                throw std::out_of_range( "bad index" );
            }
            if (r.first != r.second) {
                return static_cast<const T>( getter( const_cast<segment_tree<T> &>( *this ), r.first + 1, r.second + 1 ) );
            }
            else {
                return static_cast<const T>( getter( const_cast<segment_tree<T> &>(*this), r.first + 1 ) );
            }
        }
    };


We can use the class in this way:

        segment_tree<int> st(10);
        typedef segment_tree<int>::range range;

        st[2] = 10;
        st[0] = 5;
        st[1] = 5;
        verify( st[ 1 ] == 5, stderr, "Failed to get the specified item\n" );
        st[ 3 ] = 15;
        st[8] = 10;
        st[9] = 20;
        st[1] = 10;
        verify( st[ 1 ] == 10, stderr, "Failed to get the specified item\n" );
        st[ 2 ] = 30;
        verify(st[range{1, 2}] == 40, stderr, "Failed to get the sum of the specified range\n");


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值