普通线段树lazy更新版本。
#include <iostream>
#include <algorithm>
#include <vector>
#include <time.h>
using namespace std;
class A
{
public:
A(const vector<int>& Vec) :v(Vec) {}
int query(int left, int right)
{
int result = 0;
for (int i = left; i <= right; ++i)
result += v[i];
return result;
}
void modify(int left, int right, int value)
{
for (int i = left; i <= right; ++i)
v[i] += value;
}
private:
vector<int> v;
};
class SegmentTree {
struct SegmentTreeNode
{
int start, end;
int _val;
SegmentTreeNode(int start, int end, int val) {
this->start = start;
this->end = end;
this->_val = val;
}
};
public:
SegmentTree(vector<int>& Vec)
:_size(1),
_Identity_Element(0)
{
find_size(Vec.size());
ST.resize(2 * _size - 1);
cache.resize(_size - 1);
for (int i = 0; i < _size - 1; ++i)
cache[i] = _Identity_Element;
build(0, 0, _size - 1, Vec);
}
int query(int start, int end)
{
return doQuery(0, start, end);
}
void augment(int left, int right, int value) {
doAugment(0, left, right, value);
}
protected:
int _size;
vector<SegmentTreeNode*> ST;
vector<int> cache;
const int _Identity_Element;
private:
//find_size
void find_size(int size)
{
while (_size < size)
{
_size <<= 1;
}
}
//SegmentTree Initialization
void build(int index, int start, int end, const vector<int>& Vec)
{
//leaf node
if (start == end)
{
ST[index] = new SegmentTreeNode(start, end,
(start < Vec.size()) ? Vec[start] : _Identity_Element);
return;
}
//internal node (non-leaf)
int mid = (start + end) / 2;
//construct this node with initial val(_Identity_Element)
ST[index] = new SegmentTreeNode(start, end, _Identity_Element);
//construct left and right subTree (recursion)
build((index << 1) + 1, start, mid, Vec);
build((index << 1) + 2, mid + 1, end, Vec);
//set value
ST[index]->_val = ST[(index << 1) + 1]->_val + ST[(index << 1) + 2]->_val;
}
void pushDown(int index)
{
//cout << "cache " << index << endl;
//for safe
if (index >= _size - 1)return;
const int value = cache[index];
const int augment = value * (ST[index]->end - ST[index]->start + 1) / 2;
if (index * 2 + 1 < _size - 1)
{
cache[(index << 1) + 1] += value;
cache[(index << 1) + 2] += value;
}
ST[(index << 1) + 1]->_val += augment;
ST[(index << 1) + 2]->_val += augment;
cache[index] = 0;
}
//index: cur_node
int doQuery(int index, int left, int right)
{
//cout << index << "\t" << ST[index]->start << "\t" << ST[index]->end << endl;
//no segment union
if (left > ST[index]->end || right < ST[index]->start)
return _Identity_Element;
//querying segment includes
if (left <= ST[index]->start && ST[index]->end <= right)
return ST[index]->_val;
pushDown(index);
//cout << "query " << index << " " << left << " " << right << endl;
//partially coincide
return doQuery((index << 1) + 1, left, right) + doQuery((index << 1) + 2, left, right);
}
void doAugment(int index, int left, int right, int value)
{
const int start = ST[index]->start;
const int end = ST[index]->end;
if (left > end || right < start)return;
//cout << index << "\t" << start << "\t" << end << endl;
//include
if (left <= start && right >= end)
{
ST[index]->_val += value * (end - start + 1);
if (index < _size - 1)
cache[index] += value;
}
else //some intersection
{
pushDown(index);
//cout << "modify " << index << " " << left << " " << right << endl;
doAugment((index << 1) + 1, left, right, value);
doAugment((index << 1) + 2, left, right, value);
ST[index]->_val = ST[(index << 1) + 1]->_val + ST[(index << 1) + 2]->_val;
}
}
};
int main()
{
srand(time(0));
const int size = 100;
const int test = 100;
const int num = 120;
vector<int> a(num);
vector<pair<int, int> >qtest(test);
vector<pair<int, int> >mtest(test);
for (int i = 0; i < num; ++i)
a[i] = rand() % size;
for (int i = 0; i < test; ++i)
{
if (i == test / 2)srand(time(0));
int ql = rand() % num;
int qr = rand() % num;
if (ql > qr)swap(ql, qr);
qtest[i] = { ql,qr };
int ml = rand() % num;
int mr = rand() % num;
if (ml > mr)swap(ml, mr);
mtest[i] = { ml,mr };
}
A t1(a);
SegmentTree t2(a);
int _errnum = 0;
srand(time(0));
for (int i = 0; i < test; ++i)
{
int aug = rand() % size;
int ql = qtest[i].first;
int qr = qtest[i].second;
int ml = mtest[i].first;
int mr = mtest[i].second;
t1.modify(ml, mr, aug);
t2.augment(ml, mr, aug);
int ans1 = t1.query(ql, qr);
int ans2 = t2.query(ql, qr);
cout << "ans1: " << ans1 << "\t" << "ans2: " << ans2 << endl;
if (ans1 != ans2)_errnum++;
}
cout << "erronum: " << _errnum << endl;
system("pause");
return 0;
}