数据结构笔记(5):线段树和树状数组模板

树状数组模板

struct fenwick {
    int* data;
    int n;
    fenwick(int _n) : n(_n) {
      data = new int[n];
      for (int i = 0; i < n; i++) data[i] = 0;
    }
    ~fenwick() {
      delete[] data;
    }

    void add(int x) {
      while (x < n) {
        data[x]++;
        x += (x & -x);
      }
    }

    int get(int x) {
      int ans = 0;
      while (x > 0) {
        ans += data[x];
        x -= (x & -x);
      }
      return ans;
    }
  };

例题:LeetCode315. 计算右侧小于当前元素的个数

给定一个整数数组 nums,按要求返回一个新数组 counts。数组 counts 有该性质: counts[i] 的值是  nums[i] 右侧小于 nums[i] 的元素的数量。

 

示例:

输入:nums = [5,2,6,1]
输出:[2,1,1,0] 
解释:
5 的右侧有 2 个更小的元素 (2 和 1)
2 的右侧仅有 1 个更小的元素 (1)
6 的右侧有 1 个更小的元素 (1)
1 的右侧有 0 个更小的元素
 

提示:

0 <= nums.length <= 10^5
-10^4 <= nums[i] <= 10^4

题解:树状数组+离散化

class Solution {
public:
  struct fenwick {
    int* data;
    int n;
    fenwick(int _n) : n(_n) {
      data = new int[n];
      for (int i = 0; i < n; i++) data[i] = 0;
    }
    ~fenwick() {
      delete[] data;
    }

    void add(int x) {
      while (x < n) {
        data[x]++;
        x += (x & -x);
      }
    }

    int get(int x) {
      int ans = 0;
      while (x > 0) {
        ans += data[x];
        x -= (x & -x);
      }
      return ans;
    }
  };
  vector<int> countSmaller(vector<int>& nums) {
    int sz = (int)nums.size();
    vector<int> t(nums);
    vector<int> r(sz, 0);
    sort(t.begin(), t.end());
    for (int& n : nums) {
      n = lower_bound(t.begin(), t.end(), n) - t.begin() + 1;
    }
    fenwick* fwk = new fenwick(sz + 1);
    for (int i = sz - 1; i >= 0; i--) {
      r[i] = fwk->get(nums[i] - 1);
      fwk->add(nums[i]);
    }
    return r;
  }
};

线段树单点更新、区间查询 模板

template<class T>
class SegmentTree {
public:
  //SegmentTree();
  SegmentTree(T* arr, int n, function<T(T, T)> fun) {
    this->len = n;
    this->func = fun;
    data = new T[n];
    tree = new T[4 * n];
    for (int i = 0; i < n; i++) {
      data[i] = arr[i];
    }
    buildSegment(0, 0, n - 1);
  }
  T get(int index){
    if (index < 0 || index >= this->len) {
      throw exception();
    }
    return data[index];
  }
  int getSize(){
    return this->len;
  }
  T query(int queryL, int queryR) {
    return __query(0, 0, len - 1, queryL, queryR);
  }
  void modify(int index, T e) {
    data[index] = e;
    __modify(0, 0, len - 1, index, e);
  }
  virtual ~SegmentTree() {
    delete[] data;
    delete[] tree;
  }
private:
  void __modify(int treeIndex, int l, int r, int index, T e) {
    if (l == r) {
      tree[treeIndex] = e;
      return;
    }
    int mid = l + (r - l) / 2;
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    if (index <= mid) {
      __modify(leftTreeIndex, l, mid, index, e);
    } else {
      __modify(rightTreeIndex, mid + 1, r, index, e);
    }
    tree[treeIndex] = func(tree[leftTreeIndex], tree[rightTreeIndex]);
  }
  T __query(int treeIndex, int l, int r, int queryL, int queryR) {
    if (l == queryL && r == queryR) {
      return tree[treeIndex];
    }
    int mid = l + (r - l) / 2;
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    if (queryL >= mid + 1) {
      return __query(rightTreeIndex, mid + 1, r, queryL, queryR);
    } else if (queryR <= mid) {
      return __query(leftTreeIndex, l, mid, queryL, queryR);
    } else {
      T leftResult = __query(leftTreeIndex, l, mid, queryL, mid);
      T rightResult = __query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
      return func(leftResult, rightResult);
    }
  }
  int leftChild(int index) {
    return 2 * index + 1;
  }
  int rightChild(int index) {
    return 2 * index + 2;
  }
  void buildSegment(int treeIndex, int l, int r) {
    if (l == r) {
      tree[treeIndex] = data[l];
      return;
    }
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    int mid = l + (r - l) / 2;
    buildSegment(leftTreeIndex, l, mid);
    buildSegment(rightTreeIndex, mid + 1, r);
    tree[treeIndex] = func(tree[leftTreeIndex], tree[rightTreeIndex]);
  }
private:
  int len;
  T* data;
  T* tree;
  function<T(T, T)> func;
};

例题:307. 区域和检索 - 数组可修改

给定一个整数数组  nums,求出数组从索引 i 到 j  (i ≤ j) 范围内元素的总和,包含 i,  j 两点。

update(i, val) 函数可以通过将下标为 i 的数值更新为 val,从而对数列进行修改。

示例:

Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
说明:

数组仅可以在 update 函数下进行修改。
你可以假设 update 函数与 sumRange 函数的调用次数是均匀分布的。

 

题解:线段树的单点更新和区间查询

template<class T>
class SegmentTree {
public:
  //SegmentTree();
  SegmentTree(T* arr, int n, function<T(T, T)> fun) {
    this->len = n;
    this->func = fun;
    data = new T[n];
    tree = new T[4 * n];
    for (int i = 0; i < n; i++) {
      data[i] = arr[i];
    }
    buildSegment(0, 0, n - 1);
  }
  T get(int index){
    if (index < 0 || index >= this->len) {
      throw exception();
    }
    return data[index];
  }
  int getSize(){
    return this->len;
  }
  T query(int queryL, int queryR) {
    return __query(0, 0, len - 1, queryL, queryR);
  }
  void modify(int index, T e) {
    data[index] = e;
    __modify(0, 0, len - 1, index, e);
  }
  virtual ~SegmentTree() {
    delete[] data;
    delete[] tree;
  }
private:
  void __modify(int treeIndex, int l, int r, int index, T e) {
    if (l == r) {
      tree[treeIndex] = e;
      return;
    }
    int mid = l + (r - l) / 2;
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    if (index <= mid) {
      __modify(leftTreeIndex, l, mid, index, e);
    } else {
      __modify(rightTreeIndex, mid + 1, r, index, e);
    }
    tree[treeIndex] = func(tree[leftTreeIndex], tree[rightTreeIndex]);
  }
  T __query(int treeIndex, int l, int r, int queryL, int queryR) {
    if (l == queryL && r == queryR) {
      return tree[treeIndex];
    }
    int mid = l + (r - l) / 2;
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    if (queryL >= mid + 1) {
      return __query(rightTreeIndex, mid + 1, r, queryL, queryR);
    } else if (queryR <= mid) {
      return __query(leftTreeIndex, l, mid, queryL, queryR);
    } else {
      T leftResult = __query(leftTreeIndex, l, mid, queryL, mid);
      T rightResult = __query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
      return func(leftResult, rightResult);
    }
  }
  int leftChild(int index) {
    return 2 * index + 1;
  }
  int rightChild(int index) {
    return 2 * index + 2;
  }
  void buildSegment(int treeIndex, int l, int r) {
    if (l == r) {
      tree[treeIndex] = data[l];
      return;
    }
    int leftTreeIndex = leftChild(treeIndex);
    int rightTreeIndex = rightChild(treeIndex);
    int mid = l + (r - l) / 2;
    buildSegment(leftTreeIndex, l, mid);
    buildSegment(rightTreeIndex, mid + 1, r);
    tree[treeIndex] = func(tree[leftTreeIndex], tree[rightTreeIndex]);
  }
private:
  int len;
  T* data;
  T* tree;
  function<T(T, T)> func;
};
class NumArray {
public:
    NumArray(vector<int>& nums) {
        int n = (int)nums.size();
        if (n > 0) {
            int* a = new int[n];
            for (int i = 0; i < n; i++) {
                a[i] = nums[i];
            }
            segTree = new SegmentTree<int>(a, n, [&](int x, int y){return x + y;});
        }
    }
    
    void update(int i, int val) {
        if (segTree == NULL) {
            return;
        }
        segTree->modify(i, val);
    }
    
    int sumRange(int i, int j) {
        if (segTree == NULL) {
            return 0;
        }
        return segTree->query(i, j);
    }
private:
    SegmentTree<int>* segTree;
};

/**
 * Your NumArray object will be instantiated and called as such:
 * NumArray* obj = new NumArray(nums);
 * obj->update(i,val);
 * int param_2 = obj->sumRange(i,j);
 */

线段树区间更新区间查询模板

class segtree {
public:
  int n;
  int* data;
  long long* tree;
  long long* lazy;
  segtree(int* arr, int _n) : n (_n) {
    data = new int[_n];
    tree = new long long[_n * 4];
    lazy = new long long[_n * 4];
    fill_n(tree, _n * 4, 0);
    fill_n(lazy, _n * 4, 0);
    for (int i = 0; i < n; i++) {
      data[i] = arr[i];
    }
    build(0, 0, n - 1);
  }
  ~segtree() {
    delete []data;
    delete []tree;
    delete []lazy;
  }
  void build(int tId, int l, int r) {
    if (l == r) {
      tree[tId] = data[l];
      lazy[tId] = 0;
      return;
    }
    int mid = (l + r) >> 1;
    build((tId << 1) | 1, l, mid);
    build((tId << 1) + 2, mid + 1, r);
    push(tId);
  }
  void modify(int tId, int l, int r, int ml, int mr, int v) {
    if (ml <= l && r <= mr) {
      tree[tId] += (r - l + 1) * v;
      lazy[tId] += v;
      return;
    }
    if (lazy[tId] != 0) {
      pull(tId, l, r);
    }
    int mid = (l + r) >> 1;
    if (mr <= mid) {
      modify((tId * 2) + 1, l, mid, ml, mr, v);
    } else if (ml > mid) {
      modify((tId * 2) + 2, mid + 1, r, ml, mr, v);
    } else {
      modify((tId * 2) + 1, l, mid, ml, mid, v);
      modify((tId * 2) + 2, mid + 1, r, mid + 1, mr, v);
    }
    push(tId);
  }
  long long get(int tId, int l, int r, int gl, int gr) {
    if (gl <= l && r <= gr) {
      return tree[tId];
    }
    if (lazy[tId]) {
      pull(tId, l, r);
    }
    int mid = (l + r) >> 1;
    if (gr <= mid) {
      return get((tId * 2) + 1, l, mid, gl, gr);
    } else if (gl > mid) {
      return get((tId * 2) + 2, mid + 1, r, gl, gr);
    } else {
      return get((tId * 2) + 1, l, mid, gl, mid) + get((tId * 2) + 2, mid + 1, r, mid + 1, gr);
    }
  }
  void pull(int tId, int l, int r) {
    int mid = (l + r) / 2;
    tree[(tId * 2) + 1] += (mid - l + 1) * lazy[tId];
    tree[(tId * 2) + 2] += (r - mid) * lazy[tId];
    lazy[(tId * 2) + 1] += lazy[tId];
    lazy[(tId * 2) + 2] += lazy[tId];
    lazy[tId] = 0;
  }
  void push(int tId) {
    tree[tId] = tree[(tId * 2) + 1] + tree[(tId * 2) + 2];
  }
};

例题:poj3468

A Simple Problem with Integers

Description

You have N integers, A1, A2, ... , AN. You need to deal with two kinds of operations. One type of operation is to add some given number to each number in a given interval. The other is to ask for the sum of numbers in a given interval.

Input

The first line contains two numbers N and Q. 1 ≤ N,Q ≤ 100000.
The second line contains N numbers, the initial values of A1, A2, ... , AN. -1000000000 ≤ Ai ≤ 1000000000.
Each of the next Q lines represents an operation.
"C a b c" means adding c to each of AaAa+1, ... , Ab. -10000 ≤ c ≤ 10000.
"Q a b" means querying the sum of AaAa+1, ... , Ab.

Output

You need to answer all Q commands in order. One answer in a line.

Sample Input

10 5
1 2 3 4 5 6 7 8 9 10
Q 4 4
Q 1 10
Q 2 4
C 3 6 3
Q 2 4

Sample Output

4
55
9
15

Hint

The sums may exceed the range of 32-bit integers.

#include <iostream>
#include <cstdio>

using namespace std;

class segtree {
public:
  int n;
  int* data;
  long long* tree;
  long long* lazy;
  segtree(int* arr, int _n) : n (_n) {
    data = new int[_n];
    tree = new long long[_n * 4];
    lazy = new long long[_n * 4];
    fill_n(tree, _n * 4, 0);
    fill_n(lazy, _n * 4, 0);
    for (int i = 0; i < n; i++) {
      data[i] = arr[i];
    }
    build(0, 0, n - 1);
  }
  ~segtree() {
    delete []data;
    delete []tree;
    delete []lazy;
  }
  void build(int tId, int l, int r) {
    if (l == r) {
      tree[tId] = data[l];
      lazy[tId] = 0;
      return;
    }
    int mid = (l + r) >> 1;
    build((tId << 1) | 1, l, mid);
    build((tId << 1) + 2, mid + 1, r);
    push(tId);
  }
  void modify(int tId, int l, int r, int ml, int mr, int v) {
    if (ml <= l && r <= mr) {
      tree[tId] += (r - l + 1) * v;
      lazy[tId] += v;
      return;
    }
    if (lazy[tId] != 0) {
      pull(tId, l, r);
    }
    int mid = (l + r) >> 1;
    if (mr <= mid) {
      modify((tId * 2) + 1, l, mid, ml, mr, v);
    } else if (ml > mid) {
      modify((tId * 2) + 2, mid + 1, r, ml, mr, v);
    } else {
      modify((tId * 2) + 1, l, mid, ml, mid, v);
      modify((tId * 2) + 2, mid + 1, r, mid + 1, mr, v);
    }
    push(tId);
  }
  long long get(int tId, int l, int r, int gl, int gr) {
    if (gl <= l && r <= gr) {
      return tree[tId];
    }
    if (lazy[tId]) {
      pull(tId, l, r);
    }
    int mid = (l + r) >> 1;
    if (gr <= mid) {
      return get((tId * 2) + 1, l, mid, gl, gr);
    } else if (gl > mid) {
      return get((tId * 2) + 2, mid + 1, r, gl, gr);
    } else {
      return get((tId * 2) + 1, l, mid, gl, mid) + get((tId * 2) + 2, mid + 1, r, mid + 1, gr);
    }
  }
  void pull(int tId, int l, int r) {
    int mid = (l + r) / 2;
    tree[(tId * 2) + 1] += (mid - l + 1) * lazy[tId];
    tree[(tId * 2) + 2] += (r - mid) * lazy[tId];
    lazy[(tId * 2) + 1] += lazy[tId];
    lazy[(tId * 2) + 2] += lazy[tId];
    lazy[tId] = 0;
  }
  void push(int tId) {
    tree[tId] = tree[(tId * 2) + 1] + tree[(tId * 2) + 2];
  }
};

int main() {
  int n, q;
  scanf("%d %d", &n, &q);
  int* arr = new int[n];
  for (int i = 0; i < n; i++) {
    scanf("%d", &arr[i]);
  }
  segtree st(arr, n);
  while (q--) {
    char str[10];
    scanf("%s", str);
    int l, r;
    scanf("%d %d", &l, &r);
    l--;
    r--;
    if (str[0] == 'Q') {
      cout << st.get(0, 0, n - 1, l, r) << "\n";
    } else {
      int val;
      scanf("%d", &val);
      st.modify(0, 0, n - 1, l, r, val);
    }
  }
  return 0;
}

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值