树状数组模板
struct fenwick {
int* data;
int n;
fenwick(int _n) : n(_n) {
data = new int[n];
for (int i = 0; i < n; i++) data[i] = 0;
}
~fenwick() {
delete[] data;
}
void add(int x) {
while (x < n) {
data[x]++;
x += (x & -x);
}
}
int get(int x) {
int ans = 0;
while (x > 0) {
ans += data[x];
x -= (x & -x);
}
return ans;
}
};
例题:LeetCode315. 计算右侧小于当前元素的个数
给定一个整数数组 nums,按要求返回一个新数组 counts。数组 counts 有该性质: counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。
示例:
输入:nums = [5,2,6,1]
输出:[2,1,1,0]
解释:
5 的右侧有 2 个更小的元素 (2 和 1)
2 的右侧仅有 1 个更小的元素 (1)
6 的右侧有 1 个更小的元素 (1)
1 的右侧有 0 个更小的元素
提示:
0 <= nums.length <= 10^5
-10^4 <= nums[i] <= 10^4
题解:树状数组+离散化
class Solution {
public:
struct fenwick {
int* data;
int n;
fenwick(int _n) : n(_n) {
data = new int[n];
for (int i = 0; i < n; i++) data[i] = 0;
}
~fenwick() {
delete[] data;
}
void add(int x) {
while (x < n) {
data[x]++;
x += (x & -x);
}
}
int get(int x) {
int ans = 0;
while (x > 0) {
ans += data[x];
x -= (x & -x);
}
return ans;
}
};
vector<int> countSmaller(vector<int>& nums) {
int sz = (int)nums.size();
vector<int> t(nums);
vector<int> r(sz, 0);
sort(t.begin(), t.end());
for (int& n : nums) {
n = lower_bound(t.begin(), t.end(), n) - t.begin() + 1;
}
fenwick* fwk = new fenwick(sz + 1);
for (int i = sz - 1; i >= 0; i--) {
r[i] = fwk->get(nums[i] - 1);
fwk->add(nums[i]);
}
return r;
}
};
线段树单点更新、区间查询 模板
template<class T>
class SegmentTree {
public:
//SegmentTree();
SegmentTree(T* arr, int n, function<T(T, T)> fun) {
this->len = n;
this->func = fun;
data = new T[n];
tree = new T[4 * n];
for (int i = 0; i < n; i++) {
data[i] = arr[i];
}
buildSegment(0, 0, n - 1);
}
T get(int index){
if (index < 0 || index >= this->len) {
throw exception();
}
return data[index];
}
int getSize(){
return this->len;
}
T query(int queryL, int queryR) {
return __query(0, 0, len - 1, queryL, queryR);
}
void modify(int index, T e) {
data[index] = e;
__modify(0, 0, len - 1, index, e);
}
virtual ~SegmentTree() {
delete[] data;
delete[] tree;
}
private:
void __modify(int treeIndex, int l, int r, int index, T e) {
if (l == r) {
tree[treeIndex] = e;
return;
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (index <= mid) {
__modify(leftTreeIndex, l, mid, index, e);
} else {
__modify(rightTreeIndex, mid + 1, r, index, e);
}
tree[treeIndex] = func(tree[leftTreeIndex], tree[rightTreeIndex]);
}
T __query(int treeIndex, int l, int r, int queryL, int queryR) {
if (l == queryL && r == queryR) {
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (queryL >= mid + 1) {
return __query(rightTreeIndex, mid + 1, r, queryL, queryR);
} else if (queryR <= mid) {
return __query(leftTreeIndex, l, mid, queryL, queryR);
} else {
T leftResult = __query(leftTreeIndex, l, mid, queryL, mid);
T rightResult = __query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
return func(leftResult, rightResult);
}
}
int leftChild(int index) {
return 2 * index + 1;
}
int rightChild(int index) {
return 2 * index + 2;
}
void buildSegment(int treeIndex, int l, int r) {
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
buildSegment(leftTreeIndex, l, mid);
buildSegment(rightTreeIndex, mid + 1, r);
tree[treeIndex] = func(tree[leftTreeIndex], tree[rightTreeIndex]);
}
private:
int len;
T* data;
T* tree;
function<T(T, T)> func;
};
给定一个整数数组 nums,求出数组从索引 i 到 j (i ≤ j) 范围内元素的总和,包含 i, j 两点。
update(i, val) 函数可以通过将下标为 i 的数值更新为 val,从而对数列进行修改。
示例:
Given nums = [1, 3, 5]
sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
说明:
数组仅可以在 update 函数下进行修改。
你可以假设 update 函数与 sumRange 函数的调用次数是均匀分布的。
题解:线段树的单点更新和区间查询
template<class T>
class SegmentTree {
public:
//SegmentTree();
SegmentTree(T* arr, int n, function<T(T, T)> fun) {
this->len = n;
this->func = fun;
data = new T[n];
tree = new T[4 * n];
for (int i = 0; i < n; i++) {
data[i] = arr[i];
}
buildSegment(0, 0, n - 1);
}
T get(int index){
if (index < 0 || index >= this->len) {
throw exception();
}
return data[index];
}
int getSize(){
return this->len;
}
T query(int queryL, int queryR) {
return __query(0, 0, len - 1, queryL, queryR);
}
void modify(int index, T e) {
data[index] = e;
__modify(0, 0, len - 1, index, e);
}
virtual ~SegmentTree() {
delete[] data;
delete[] tree;
}
private:
void __modify(int treeIndex, int l, int r, int index, T e) {
if (l == r) {
tree[treeIndex] = e;
return;
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (index <= mid) {
__modify(leftTreeIndex, l, mid, index, e);
} else {
__modify(rightTreeIndex, mid + 1, r, index, e);
}
tree[treeIndex] = func(tree[leftTreeIndex], tree[rightTreeIndex]);
}
T __query(int treeIndex, int l, int r, int queryL, int queryR) {
if (l == queryL && r == queryR) {
return tree[treeIndex];
}
int mid = l + (r - l) / 2;
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if (queryL >= mid + 1) {
return __query(rightTreeIndex, mid + 1, r, queryL, queryR);
} else if (queryR <= mid) {
return __query(leftTreeIndex, l, mid, queryL, queryR);
} else {
T leftResult = __query(leftTreeIndex, l, mid, queryL, mid);
T rightResult = __query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
return func(leftResult, rightResult);
}
}
int leftChild(int index) {
return 2 * index + 1;
}
int rightChild(int index) {
return 2 * index + 2;
}
void buildSegment(int treeIndex, int l, int r) {
if (l == r) {
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
int mid = l + (r - l) / 2;
buildSegment(leftTreeIndex, l, mid);
buildSegment(rightTreeIndex, mid + 1, r);
tree[treeIndex] = func(tree[leftTreeIndex], tree[rightTreeIndex]);
}
private:
int len;
T* data;
T* tree;
function<T(T, T)> func;
};
class NumArray {
public:
NumArray(vector<int>& nums) {
int n = (int)nums.size();
if (n > 0) {
int* a = new int[n];
for (int i = 0; i < n; i++) {
a[i] = nums[i];
}
segTree = new SegmentTree<int>(a, n, [&](int x, int y){return x + y;});
}
}
void update(int i, int val) {
if (segTree == NULL) {
return;
}
segTree->modify(i, val);
}
int sumRange(int i, int j) {
if (segTree == NULL) {
return 0;
}
return segTree->query(i, j);
}
private:
SegmentTree<int>* segTree;
};
/**
* Your NumArray object will be instantiated and called as such:
* NumArray* obj = new NumArray(nums);
* obj->update(i,val);
* int param_2 = obj->sumRange(i,j);
*/
线段树区间更新区间查询模板
class segtree {
public:
int n;
int* data;
long long* tree;
long long* lazy;
segtree(int* arr, int _n) : n (_n) {
data = new int[_n];
tree = new long long[_n * 4];
lazy = new long long[_n * 4];
fill_n(tree, _n * 4, 0);
fill_n(lazy, _n * 4, 0);
for (int i = 0; i < n; i++) {
data[i] = arr[i];
}
build(0, 0, n - 1);
}
~segtree() {
delete []data;
delete []tree;
delete []lazy;
}
void build(int tId, int l, int r) {
if (l == r) {
tree[tId] = data[l];
lazy[tId] = 0;
return;
}
int mid = (l + r) >> 1;
build((tId << 1) | 1, l, mid);
build((tId << 1) + 2, mid + 1, r);
push(tId);
}
void modify(int tId, int l, int r, int ml, int mr, int v) {
if (ml <= l && r <= mr) {
tree[tId] += (r - l + 1) * v;
lazy[tId] += v;
return;
}
if (lazy[tId] != 0) {
pull(tId, l, r);
}
int mid = (l + r) >> 1;
if (mr <= mid) {
modify((tId * 2) + 1, l, mid, ml, mr, v);
} else if (ml > mid) {
modify((tId * 2) + 2, mid + 1, r, ml, mr, v);
} else {
modify((tId * 2) + 1, l, mid, ml, mid, v);
modify((tId * 2) + 2, mid + 1, r, mid + 1, mr, v);
}
push(tId);
}
long long get(int tId, int l, int r, int gl, int gr) {
if (gl <= l && r <= gr) {
return tree[tId];
}
if (lazy[tId]) {
pull(tId, l, r);
}
int mid = (l + r) >> 1;
if (gr <= mid) {
return get((tId * 2) + 1, l, mid, gl, gr);
} else if (gl > mid) {
return get((tId * 2) + 2, mid + 1, r, gl, gr);
} else {
return get((tId * 2) + 1, l, mid, gl, mid) + get((tId * 2) + 2, mid + 1, r, mid + 1, gr);
}
}
void pull(int tId, int l, int r) {
int mid = (l + r) / 2;
tree[(tId * 2) + 1] += (mid - l + 1) * lazy[tId];
tree[(tId * 2) + 2] += (r - mid) * lazy[tId];
lazy[(tId * 2) + 1] += lazy[tId];
lazy[(tId * 2) + 2] += lazy[tId];
lazy[tId] = 0;
}
void push(int tId) {
tree[tId] = tree[(tId * 2) + 1] + tree[(tId * 2) + 2];
}
};
例题:poj3468
A Simple Problem with Integers
Description
You have N integers, A1, A2, ... , AN. You need to deal with two kinds of operations. One type of operation is to add some given number to each number in a given interval. The other is to ask for the sum of numbers in a given interval.
Input
The first line contains two numbers N and Q. 1 ≤ N,Q ≤ 100000.
The second line contains N numbers, the initial values of A1, A2, ... , AN. -1000000000 ≤ Ai ≤ 1000000000.
Each of the next Q lines represents an operation.
"C a b c" means adding c to each of Aa, Aa+1, ... , Ab. -10000 ≤ c ≤ 10000.
"Q a b" means querying the sum of Aa, Aa+1, ... , Ab.
Output
You need to answer all Q commands in order. One answer in a line.
Sample Input
10 5 1 2 3 4 5 6 7 8 9 10 Q 4 4 Q 1 10 Q 2 4 C 3 6 3 Q 2 4
Sample Output
4 55 9 15
Hint
The sums may exceed the range of 32-bit integers.
#include <iostream>
#include <cstdio>
using namespace std;
class segtree {
public:
int n;
int* data;
long long* tree;
long long* lazy;
segtree(int* arr, int _n) : n (_n) {
data = new int[_n];
tree = new long long[_n * 4];
lazy = new long long[_n * 4];
fill_n(tree, _n * 4, 0);
fill_n(lazy, _n * 4, 0);
for (int i = 0; i < n; i++) {
data[i] = arr[i];
}
build(0, 0, n - 1);
}
~segtree() {
delete []data;
delete []tree;
delete []lazy;
}
void build(int tId, int l, int r) {
if (l == r) {
tree[tId] = data[l];
lazy[tId] = 0;
return;
}
int mid = (l + r) >> 1;
build((tId << 1) | 1, l, mid);
build((tId << 1) + 2, mid + 1, r);
push(tId);
}
void modify(int tId, int l, int r, int ml, int mr, int v) {
if (ml <= l && r <= mr) {
tree[tId] += (r - l + 1) * v;
lazy[tId] += v;
return;
}
if (lazy[tId] != 0) {
pull(tId, l, r);
}
int mid = (l + r) >> 1;
if (mr <= mid) {
modify((tId * 2) + 1, l, mid, ml, mr, v);
} else if (ml > mid) {
modify((tId * 2) + 2, mid + 1, r, ml, mr, v);
} else {
modify((tId * 2) + 1, l, mid, ml, mid, v);
modify((tId * 2) + 2, mid + 1, r, mid + 1, mr, v);
}
push(tId);
}
long long get(int tId, int l, int r, int gl, int gr) {
if (gl <= l && r <= gr) {
return tree[tId];
}
if (lazy[tId]) {
pull(tId, l, r);
}
int mid = (l + r) >> 1;
if (gr <= mid) {
return get((tId * 2) + 1, l, mid, gl, gr);
} else if (gl > mid) {
return get((tId * 2) + 2, mid + 1, r, gl, gr);
} else {
return get((tId * 2) + 1, l, mid, gl, mid) + get((tId * 2) + 2, mid + 1, r, mid + 1, gr);
}
}
void pull(int tId, int l, int r) {
int mid = (l + r) / 2;
tree[(tId * 2) + 1] += (mid - l + 1) * lazy[tId];
tree[(tId * 2) + 2] += (r - mid) * lazy[tId];
lazy[(tId * 2) + 1] += lazy[tId];
lazy[(tId * 2) + 2] += lazy[tId];
lazy[tId] = 0;
}
void push(int tId) {
tree[tId] = tree[(tId * 2) + 1] + tree[(tId * 2) + 2];
}
};
int main() {
int n, q;
scanf("%d %d", &n, &q);
int* arr = new int[n];
for (int i = 0; i < n; i++) {
scanf("%d", &arr[i]);
}
segtree st(arr, n);
while (q--) {
char str[10];
scanf("%s", str);
int l, r;
scanf("%d %d", &l, &r);
l--;
r--;
if (str[0] == 'Q') {
cout << st.get(0, 0, n - 1, l, r) << "\n";
} else {
int val;
scanf("%d", &val);
st.modify(0, 0, n - 1, l, r, val);
}
}
return 0;
}