线段树模板(pushDown) - 做题用

普通线段树lazy更新版本。

#include <iostream>
#include <algorithm>
#include <vector>
#include <time.h>
using namespace std;

class A
{
public:
    A(const vector<int>& Vec) :v(Vec) {}

    int query(int left, int right)
    {
        int result = 0;
        for (int i = left; i <= right; ++i)
            result += v[i];
        return result;
    }

    void modify(int left, int right, int value)
    {
        for (int i = left; i <= right; ++i)
            v[i] += value;
    }

private:
    vector<int> v;

};

class SegmentTree {

    struct SegmentTreeNode
    {
        int start, end;
        int  _val;
        SegmentTreeNode(int start, int end, int val) {
            this->start = start;
            this->end = end;
            this->_val = val;
        }
    };

public:

    SegmentTree(vector<int>& Vec)
        :_size(1),
        _Identity_Element(0)
    {
        find_size(Vec.size());
        ST.resize(2 * _size - 1);
        cache.resize(_size - 1);
        for (int i = 0; i < _size - 1; ++i)
            cache[i] = _Identity_Element;
        build(0, 0, _size - 1, Vec);
    }


    int query(int start, int end)
    {
        return doQuery(0, start, end);
    }

    void augment(int left, int right, int value) {
        doAugment(0, left, right, value);
    }


protected:

    int _size;

    vector<SegmentTreeNode*> ST;

    vector<int> cache;

    const int _Identity_Element;

private:

    //find_size
    void find_size(int size)
    {
        while (_size < size)
        {
            _size <<= 1;
        }
    }


    //SegmentTree Initialization
    void build(int index, int start, int end, const vector<int>& Vec)
    {

        //leaf node
        if (start == end)
        {
            ST[index] = new SegmentTreeNode(start, end,
                (start < Vec.size()) ? Vec[start] : _Identity_Element);
            return;
        }

        //internal node (non-leaf)
        int mid = (start + end) / 2;


        //construct this node with initial val(_Identity_Element)
        ST[index] = new SegmentTreeNode(start, end, _Identity_Element);


        //construct left and right subTree (recursion)
        build((index << 1) + 1, start, mid, Vec);
        build((index << 1) + 2, mid + 1, end, Vec);


        //set value
        ST[index]->_val = ST[(index << 1) + 1]->_val + ST[(index << 1) + 2]->_val;

    }


    void pushDown(int index)
    {
        //cout << "cache " << index << endl;
        //for safe
        if (index >= _size - 1)return;

        const int value = cache[index];
        const int augment = value * (ST[index]->end - ST[index]->start + 1) / 2;

        if (index * 2 + 1 < _size - 1)
        {
            cache[(index << 1) + 1] += value;
            cache[(index << 1) + 2] += value;
        }

        ST[(index << 1) + 1]->_val += augment;
        ST[(index << 1) + 2]->_val += augment;

        cache[index] = 0;

    }

    //index: cur_node
    int doQuery(int index, int left, int right)
    {
        //cout << index << "\t" << ST[index]->start << "\t" << ST[index]->end << endl;

        //no segment union
        if (left > ST[index]->end || right < ST[index]->start)
            return _Identity_Element;


        //querying segment includes
        if (left <= ST[index]->start && ST[index]->end <= right)
            return ST[index]->_val;

        pushDown(index);
        //cout << "query " << index << " " << left << " " << right << endl;
        //partially coincide
        return doQuery((index << 1) + 1, left, right) + doQuery((index << 1) + 2, left, right);

    }

    void doAugment(int index, int left, int right, int value)
    {
        const int start = ST[index]->start;
        const int end = ST[index]->end;
        if (left > end || right < start)return;
        //cout << index << "\t" << start << "\t" << end << endl;
        //include
        if (left <= start && right >= end)
        {
            ST[index]->_val += value * (end - start + 1);
            if (index < _size - 1)
                cache[index] += value;
        }
        else    //some intersection
        {
            pushDown(index);
            //cout << "modify " << index << " " << left << " " << right << endl;
            doAugment((index << 1) + 1, left, right, value);
            doAugment((index << 1) + 2, left, right, value);

            ST[index]->_val = ST[(index << 1) + 1]->_val + ST[(index << 1) + 2]->_val;
        }

    }

};


int main()
{

    srand(time(0));
    const int size = 100;
    const int test = 100;
    const int num = 120;

    vector<int> a(num);
    vector<pair<int, int> >qtest(test);
    vector<pair<int, int> >mtest(test);
    for (int i = 0; i < num; ++i)
        a[i] = rand() % size;
    for (int i = 0; i < test; ++i)
    {
        if (i == test / 2)srand(time(0));
        int ql = rand() % num;
        int qr = rand() % num;
        if (ql > qr)swap(ql, qr);
        qtest[i] = { ql,qr };
        int ml = rand() % num;
        int mr = rand() % num;
        if (ml > mr)swap(ml, mr);
        mtest[i] = { ml,mr };
    }

    A t1(a);
    SegmentTree t2(a);

    int _errnum = 0;
    srand(time(0));
    for (int i = 0; i < test; ++i)
    {
        int aug = rand() % size;
        int ql = qtest[i].first;
        int qr = qtest[i].second;
        int ml = mtest[i].first;
        int mr = mtest[i].second;
        t1.modify(ml, mr, aug);
        t2.augment(ml, mr, aug);

        int ans1 = t1.query(ql, qr);
        int ans2 = t2.query(ql, qr);
        cout << "ans1: " << ans1 << "\t" << "ans2: " << ans2 << endl;
        if (ans1 != ans2)_errnum++;

    }

    cout << "erronum: " << _errnum << endl;
    system("pause");
    return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值