线段树的初级操作

线段树的初级操作

简介:

问题背景:

把问题建模成数轴上的问题或者数列的问题。一般是每次对数轴或者数列的一个区间进行相同的处理。

线段树的结构:

一棵平衡的二叉树。
举例说明:
这里写图片描述
区间:处理前闭后开的区间 [ a , b ) [a,b) [a,b)
线段树结点T(a,b):维护原序列中 [ a , b ) [a,b) [a,b)的信息
内部结点:对于结点T(a,b),有 b − a > 1 b-a>1 ba>1,那么T(a,b)的左孩子是T(a,(a+b)/2),右孩子是T((a+b)/2,b)
叶结点: 对于结点T(a,b),有 b − a = 1 b-a=1 ba=1
因此,假设一个序列有n个结点,那么根结点是T(1,n+1),第k个叶子结点是T(k,k+1)

性质:

结点数:小于等于2n,n是序列中元素的个数
深度:线段树去除最后一层后,是满二叉树, h = 1 + log ⁡ 2 ( n − 1 ) h=1+\log_2(n-1) h=1+log2(n1)
线段分解数量级:可以把任意的长度是 L L L的线段分解成不超过 2 log ⁡ 2 L 2\log_2L 2log2L条的子线段。可以让绝大多数查询在 O ( log ⁡ 2 n ) O(\log_2n) O(log2n)内解决
存储空间 O ( n ) O(n) O(n)

实现方式:

这里,以修改元素的值为例。修改某个或某个区间的元素的值,然后查询某个区间的元素的和。

修改单个元素的值

问题1:

长度位n的数列,初始化全是0.现在执行m次操作,每次执行下面两种操作之一:

  1. 数列中某个数加上某个值
  2. 询问给定区间中所有数的和

解析:
朴素算法的复杂度是 O ( m n ) O(mn) O(mn),因为传统的线性查找时间是 O ( n ) O(n) O(n)。引入线段树,查询的复杂度是 O ( log ⁡ 2 n ) O(\log_2n) O(log2n)。在使用修改算法时,最好是人为地添加下标x的范围合法性判断;否则,如果x小于最小下标,delta会累加到第一个元素上,如果大于等于最大下标,delta会累加到最后一个元素上!!!

#include <iostream>
#include <utility>
#include <memory>

struct Node {
    int l, r, val;
    std::shared_ptr<Node> lc, rc;
    Node(int _l = 0, int _r = 0, int _v = 0):
        l(_l), r(_r), val(_v) {}
};

void build(std::shared_ptr<Node>& cur, int l, int r) {
    cur  = std::make_shared<Node>(l, r);
    if (l + 1 < r) {
        int mid = l + (r - l) / 2;
        build(cur->lc, l, mid);
        build(cur->rc, mid, r);
    }
}

int query(const std::shared_ptr<Node>& cur, int l, int r) {
    if (l <= cur->l && cur->r <= r) {
        return cur->val;
    }
    int ans =  0;
    int mid = cur->l + (cur->r - cur->l) / 2;
    if (l < mid) {
        ans += query(cur->lc, l, r);
    }
    if (r > mid) {
        ans += query(cur->rc, l, r);
    }
    return ans;
}

void change(std::shared_ptr<Node>& cur, int x, int delta) {
    if (cur->l + 1 == cur->r) {
        cur->val += delta;
        return;
    }
    int mid = cur->l + (cur->r - cur->l) / 2;
    if (x < mid) {
        change(cur->lc, x, delta);
    }
    if (x >= mid) {
        change(cur->rc, x, delta);
    }
    cur->val = cur->lc->val + cur->rc->val;
}

int main() {
    std::shared_ptr<Node> root;
    build(root, 1, 11);
    for (int i = 1; i <= 10; ++i) {
        change(root, i, 1);
    }
        for(int i = 1; i <= 10; ++i) {
        std::cout << i << "th: " << query(root, i, i + 1) << std::endl;
    }
    return 0;
}

修改整个区间的值

问题2:

长度位n的数列,初始化全是0.现在执行m次操作,每次执行下面两种操作之一:

  1. 数列中某个区间的所有数加上某个值
  2. 询问给定区间中所有数的和

解析:
如果修改的是一个区间的值,假设区间长度是k,使用上面的算法,由于每次查询的时间是 O ( log ⁡ 2 n ) O(\log_2n) O(log2n),所以处理一个区间的复杂度是 O ( k log ⁡ 2 n ) O(k\log_2n) O(klog2n),如果k很大,复杂度甚至会超过朴素的模拟算法,因此引入下面的改进算法。
算法的核心在于不直接计算叶子结点的值,而是每个结点增加一个delta域,用于记录当前结点的延迟修改量。只有当前结点需要继续向下查询或者更改当前结点的子区间时,才把当前结点的延迟修改量传递给子区间,同时当前结点的修改量清零,如果不清零,会导致重复计算!
这种算法,保证了不会有过多的递归下降而浪费时间。只有需要向下时,才会根据父结点累积的增量,计算子结点有关的值,保证了时间复杂度较低,减少不必要的递归过程。

#include <bits/stdc++.h>
using namespace std;

struct Node {
    int l, r, sum, delta;
    struct Node *lc, *rc;
    Node(): l(0), r(0), sum(0), delta(0), lc(nullptr), rc(nullptr) {}
};

void build(Node* &cur, int l, int r) { // 建立算法和单个元素的一样
    cur = new Node;
    cur->l = l;
    cur->r = r;
    if(l + 1 < r) {
        build(cur->lc, l, (l + r) / 2);
        build(cur->rc, (l + r) / 2, r);
    }
}

void update(Node* cur) {   // 更新算法,处理累计状态
    // 向下传递累积和,等效成后计算的,注意是累加
    cur->lc->sum += cur->delta * (cur->lc->r - cur->lc->l);
    cur->rc->sum += cur->delta * (cur->rc->r - cur->rc->l);
    // 孩子的delta状态进行累计,注意是累加
    cur->lc->delta += cur->delta;
    cur->rc->delta += cur->delta;
    // 一定要把父结点的清零
    cur->delta = 0;
}

void change(Node* cur, int l, int r, int delta) {
    if(l <= cur->l && cur->r <= r) {
        cur->sum += delta * (cur->r - cur->l);
        cur->delta += delta;
    } else {
        if(cur->delta != 0) {  // 先检查当前结点是否有孩子的累计状态,有的话向下传递
            update(cur);
        }
        if(l < (cur->l + cur->r) / 2) {
            change(cur->lc, l, r, delta);
        }
        if(r > (cur->l + cur->r) / 2) {  // 注意这里没有等号!!!!
            change(cur->rc, l, r, delta);
        }
        cur->sum = cur->lc->sum + cur->rc->sum;
    }
}

int query(Node* cur, int l, int r) {
    if(l <= cur->l && cur->r <= r) {
        return cur->sum;
    } else {
        if(cur->delta != 0) {  // 检查是否有孩子结点的累计状态
            update(cur);       // 计算之前延迟的累积和
        }
        int ans = 0;
        if(l < (cur->l + cur->r) / 2) {
            ans += query(cur->lc, l, r);
        }
        if(r > (cur->l + cur->r) / 2) {
            ans += query(cur->rc, l, r);
        }
        return ans;
    }
}

int main() {
    Node* root = nullptr;
    build(root, 1, 11);
    for(int i = 1; i <= 10; ++i) {
        change(root, i, i + 3, 1);
    }
    for(int i = 1; i <= 10; ++i) {
        cout << i << "th:" << query(root, i, i + 1) << endl;
    }
    cout << "sum 1~10:" << query(root, 1, 11) << endl;
    return 0;
}

更一般的方法:

对于当前区间[l,r)
if 达到某种边界条件(比如叶子结点或整个区间被完全包含)
	then 对维护或者询问进行相应的处理
else
	将第二类标记传递下去(注意,查询的过程也要处理)
	根据区间的关系,对两个孩子递归地处理
	利用递推关系,根据孩子结点的情况维护第一类信息

根据一般方法改进的问题:

问题3:

长度位n的数列,初始化全是0.现在执行m次操作,每次执行下面两种操作之一:

  1. 数列中某个区间的所有数加上某个值
  2. 数列中某个区间的所有数改成某个值
  3. 询问给定区间中所有数的和
  4. 询问给定区间的最值
#include <bits/stdc++.h>
using namespace std;
const int INF = 10000000;
struct Node {
    int l, r, value, sum, maxm, minm, delta;
    bool tag;
    struct Node *lc, *rc;
    Node(): tag(false), l(0), r(0), maxm(0), minm(0),
        delta(0), value(0), sum(0), lc(nullptr), rc(nullptr) {}
};

void build(Node* &cur, int l, int r) {
    cur = new Node;
    cur->l = l;
    cur->r = r;
    if(l + 1 < r) {
        build(cur->lc, l, (l + r) / 2);
        build(cur->rc, (l + r) / 2, r);
    }
}

// 统一更新
void update(Node* cur) {
    // 更新值和最值
    cur->lc->value = cur->rc->value = cur->value;
    cur->lc->maxm = cur->rc->value = cur->value;
    cur->lc->minm = cur->rc->minm = cur->value;
    cur->lc->tag = cur->rc->tag = true;
    cur->tag = false;
    cur->lc->sum += cur->delta * (cur->lc->r - cur->lc->l);
    cur->rc->sum += cur->delta * (cur->rc->r - cur->rc->r);
    cur->lc->delta += cur->delta;
    cur->rc->delta += cur->delta;
    cur->delta = 0;
}

// 把区间的值改成value
void change_to(Node* cur, int l, int r, int value) {
    if(l <= cur->l && cur->r <= r) {
        cur->value = value;
        cur->maxm = cur->minm = value;
        cur->tag = true;
    } else {
        if(cur->tag) {
            update(cur);
        }
        if(l < (cur->l + cur->r) / 2) {
            change_to(cur->lc, l, r, value);
        }
        if(r > (cur->l + cur->r) / 2) {
            change_to(cur->rc, l, r, value);
        }
        cur->maxm = max(cur->lc->maxm, cur->rc->maxm);
        cur->minm = min(cur->lc->minm, cur->rc->minm);
    }
}

// 更改累积和
void change_sum(Node* cur, int l, int r, int delta) {
    if(l <= cur->l && cur->r <= r) {
        cur->sum += delta * (cur->r - cur->l);
        cur->delta += delta;
    } else {
        if(cur->delta != 0) {
            update(cur);
        }
        if(l < (cur->l + cur->r) / 2) {
            change_sum(cur->lc, l, r, delta);
        }
        if(r > (cur->l + cur->r) / 2) {
            change_sum(cur->rc, l, r, delta);
        }
        cur->sum = cur->lc->sum + cur->rc->sum;
    }
}

// 查询最大值
int query_max(Node* cur, int l, int r) {
    if(l <= cur->l && cur->r <= r) {
        return cur->maxm;
    } else {
        if(cur->tag) {
            update(cur);
        }
        int ml = -INF, mr = -INF;
        if(l < (cur->l + cur->r) / 2) {
            ml = query_max(cur->lc, l, r);
        }
        if(r > (cur->l + cur->r) / 2) {
            mr = query_max(cur->rc, l, r);
        }
        return max(ml, mr);
    }
}

// 查询最小值
int query_min(Node* cur, int l, int r) {
    if(l <= cur->l && cur->r <= r) {
        return cur->minm;
    } else {
        if(cur->tag) {
            update(cur);
        }
        int ml = INF, mr = INF;
        if(l < (cur->l + cur->r) / 2) {
            ml = query_max(cur->lc, l, r);
        }
        if(r > (cur->l + cur->r) / 2) {
            mr = query_max(cur->rc, l, r);
        }
        return min(ml, mr);
    }
}

// 查询和
int query_sum(Node* cur, int l, int r) {
    if(l <= cur->l && cur->r <= r) {
        return cur->sum;
    } else {
        if(cur->delta != 0) {
            update(cur);
        }
        int ans = 0;
        if(l < (cur->l + cur->r) / 2) {
            ans += query_sum(cur->lc, l, r);
        }
        if(r > (cur->l + cur->r) / 2) {
            ans += query_sum(cur->rc, l, r);
        }
        return ans;
    }
}

int main() {
    srand(time(unsigned(0)));
    Node* root = nullptr;
    build(root, 1, 11);
    cout << "rand res:" << endl;
    for(int i = 1; i <= 10; ++i) {
        int t = rand() % 30;
        cout << i << "th:" << t << endl;
        change_to(root, i, i + 1, t);
        change_sum(root, i, i + 1, t);
    }
    cout << "max:" << query_max(root, 1, 11) << endl;
    cout << "min:" << query_min(root, 1, 11) << endl;
    cout << "sum:" << query_sum(root, 1, 11) << endl;
    return 0;
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值