线段树模板类设计及其应用(C++)

前言

线段树是一种实用的数据结构,在各种算法题里经常用到。但是假如每次使用 都需要重新写一遍线段树代码,不仅费时费力,还容易出错。为了能够轻松使用线段树,本人综合考虑了各类线段树的特点, 用C++的模板功能编写了一套线段树的类模板,并给出了其使用方法和解题案例。希望能给各位读者一点帮助。

线段树简介

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
对于线段树中的每一个非叶子节点 [ a , b ] [a,b] [a,b] ,它的左儿子表示的区间为 [ a , a + b 2 ] \left[a,\frac{a+b}{2}\right] [a,2a+b] ,右儿子表示的区间为 [ a + b 2 + 1 , b ] \left[\frac{a+b}{2}+1,b\right] [2a+b+1,b]。因此线段树是平衡二叉树,最后的子节点数目为 N N N,即整个线段区间的长度。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为 O ( log ⁡ N ) O(\log N) O(logN)。而未优化的空间复杂度为 2 N 2N 2N,因此有时需要离散化让空间压缩。

线段树原理

源码分享

文件:s_tree.h

#ifndef S_TREE_H
#define S_TREE_H

#include <cstring>
#define ROOT 1  //根节点下标
template <typename Node, typename Label = int>
class STree {
   public:
    STree(int start, int end, Node* arr,
          Node (*merge)(const Node& lchild, const Node& rchild),
          void (*modify)(Node& node, const Label& value) = NULL);
    virtual ~STree();
    void modify(int loc, const Node& new_data);           //单点修改
    void modify(int start, int end, const Label& value);  //区间修改
    const Node query(int start, int end);                 //区间查询
    const Node query(int loc);                            //单点查询
   private:
    Node* tree;    //节点数组
    Label* slack;  //懒标记
    const int length, loc_min, loc_max;
    Node (*f_merge)(const Node& lchild, const Node& rchild);  //区间合并函数
    void (*f_modify)(Node& node, const Label& value);  //区间修改函数

    void build_tree(int ind, int left, int right, Node* arr);
    Node query_in(int ind, int left, int right, int start, int end);
    void modify_in(int ind, int left, int right, int start, int end,
                   const Label& value);
    void down(int ind, int left, int right);  //懒标记下沉,同时更新信息
    void update(int ind);                     //更新信息
};

template <typename Node, typename Label>
STree<Node, Label>::STree(int start, int end, Node* arr,
                          Node (*merge)(const Node& lchild, const Node& rchild),
                          void (*modify)(Node& node, const Label& value))
    : length(4 * (end - start + 1)), loc_min(start), loc_max(end) {
    if (start > end) {
        throw "STree constructor error";
    }
    tree = new Node[length];
    slack = new Label[length];
    f_merge = merge;
    f_modify = modify;
    std::memset(slack, 0, length * sizeof(Label));
    build_tree(ROOT, start, end, arr);
}
template <typename Node, typename Label>
STree<Node, Label>::~STree() {
    delete[] tree;
    delete[] slack;
}
template <typename Node, typename Label>
void STree<Node, Label>::modify(int loc, const Node& new_data) {
    if (loc < loc_min || loc > loc_max) {
        throw "STree modify error";
    }
    int ind = ROOT;
    int left = loc_min;
    int right = loc_max;
    while (left < right) {//向下寻找节点
        if (slack[ind]) {
            down(ind, left, right);
        }
        int mid = (left + right) >> 1;
        if (loc <= mid) {
            right = mid;
            ind = ind << 1;
        } else {
            left = mid + 1;
            ind = ind << 1 | 1;
        }
    }
    tree[ind] = new_data;//修改节点
    int ind_p = ind >> 1;
    while (ind_p >= ROOT) {//向上更新
        Node temp = tree[ind_p];
        update(ind_p);
        if (temp == tree[ind_p]) {
            break;
        }
        ind_p = ind_p >> 1;
    }
}
template <typename Node, typename Label>
void STree<Node, Label>::modify(int start, int end, const Label& value) {
    if (start > end || start < loc_min || end > loc_max) {
        throw "STree modify error";
    }
    modify_in(ROOT, loc_min, loc_max, start, end, value);//递归区间修改
}
template <typename Node, typename Label>
const Node STree<Node, Label>::query(int loc) {
    return query_in(ROOT, loc_min, loc_max, loc, loc);//递归区间查询
}
template <typename Node, typename Label>
const Node STree<Node, Label>::query(int start, int end) {
    if (start > end || start < loc_min || end > loc_max) {
        throw "STree query error";
    }
    return query_in(ROOT, loc_min, loc_max, start, end);//递归区间查询
}

template <typename Node, typename Label>
void STree<Node, Label>::build_tree(int ind, int left, int right, Node* arr) {
    if (left == right) {
        tree[ind] = arr[left];
        return;
    }
    int mid = (left + right) >> 1;
    build_tree(ind << 1, left, mid, arr);
    build_tree(ind << 1 | 1, mid + 1, right, arr);
    update(ind);
}
template <typename Node, typename Label>
Node STree<Node, Label>::query_in(int ind, int left, int right, int start,
                                  int end) {
    if (start <= left && right <= end) {
        return tree[ind];
    }
    if (slack[ind] && left != right){
         down(ind, left, right);
    }
    //判断是否需要两边查询
    int mid = (left + right) >> 1;
    if (mid >= end) {
        return query_in(ind << 1, left, mid, start, end);
    } else if (mid < start) {
        return query_in(ind << 1 | 1, mid + 1, right, start, end);
    } else {
        return f_merge(query_in(ind << 1, left, mid, start, end),
                       query_in(ind << 1 | 1, mid + 1, right, start, end));
    }
}
template <typename Node, typename Label>
void STree<Node, Label>::modify_in(int ind, int left, int right, int start,
                                   int end, const Label& value) {
    if (value == 0 || right < start || left > end) {
        return;
    }
    if (start <= left && right <= end) {  //本区间包含在内
        f_modify(tree[ind], value);
        slack[ind] += value;
        return;
    }
    if (slack[ind] && left != right) down(ind, left, right);  //清除懒标记
    int mid = (left + right) >> 1;
    modify_in(ind << 1, left, mid, start, end, value);
    modify_in(ind << 1 | 1, mid + 1, right, start, end, value);
    update(ind);
    return;
}

template <typename Node, typename Label>
void STree<Node, Label>::down(int ind, int left, int right) {
    //需要调用者负责检查 是否应该调用该函数
    int mid = (left + right) >> 1;
    if (slack[ind << 1] && left != mid) {
        down(ind << 1, left, mid);
    }
    if (slack[ind << 1 | 1] && mid + 1 != right) {
        down(ind << 1 | 1, mid + 1, right);
    }
    f_modify(tree[ind << 1], slack[ind]);      //标记下沉
    f_modify(tree[ind << 1 | 1], slack[ind]);  //标记下沉
    update(ind);
    //懒标记更新
    slack[ind << 1] = slack[ind];
    slack[ind << 1 | 1] = slack[ind];
    slack[ind] = 0;
}

template <typename Node, typename Label>
void STree<Node, Label>::update(int ind) {
    tree[ind] = f_merge(tree[ind << 1], tree[ind << 1 | 1]);
}

#undef ROOT
#endif

代码分析

使用方法

  1. 明确线段树的类型,涉及线段树哪些操作
  2. 明确数据类型 Node 需要维护的信息
  3. 明确区间合并时如何维护 Node 信息,设计区间合并函数
  4. 如果涉及区间修改操作,需要明确 Label 的数据类型,以及区间修改函数
  5. 构造 Node 数组,明确需要构造线段树的区间的范围
  6. 使用 Node 数组指针、区间两端点下标、合并函数指针、修改函数指针 构造线段树对象
  7. 使用成员函数进行 区间查询、单点修改、区间修改、单点查询 等操作

应用案例

1. 区间最大值

1.1题目


给定一个 n 位数组和两种操作: ​操作1:修改数组中某个位置的值 操作2:查询数组中某个区间的最大值

输入
​ 第一行输入两个整数 n , m n,m n,m ( 1 ≤ n ≤ 10000 , 3 ≤ m ≤ 100000 ) (1\le n\le 10000,3\le m\le 100000) (1n100003m100000)​,分别代表数组大小和操作数。
​ 第二行包含 n n n 个整数,代表数组中相应的数字,数字大小不会超过 int 表示范围。

​接下来 m m m 行,每行三个整数 a , b , c ( a ∈ [ 1 , 2 ] ) a,b,c (a\in [1,2]) a,b,c(a[1,2])
a = 1 a=1 a=1 时,代表将数组 b b b 位置上的值修改成 c , ( 1 ≤ b ≤ n ) c,(1\le b\le n) c,(1bn) c is int 。
a = 2 a=2 a=2 时,代表询问 [ b , c ] [b,c] [b,c] 区间内的最大值 ( 1 ≤ b , c ≤ n ) (1\le b,c\le n) (1b,cn),当 b > c b>c b>c 时,输出: − 2147483648 -2147483648 2147483648

输出
​ 对于每个 a = 2 a=2 a=2 的操作,输出查询区间内的最大值。

样例输入

6 5
6 9 4 3 2 1
2 2 5
1 2 3
2 1 6
1 5 12
2 1 6

样例输出

9
6
12


1.2思路

1.3代码

#include "s_tree.h"
#include <iostream>
#include <cstdio>
using namespace std;

int arr[10005];
int func(const int& l, const int& r) { return max(l, r); }
int main()
{
	int n, m;
        cin >> n >> m;
	for (int i = 1; i <= n; i++){
		cin >> arr[i];
	}
    STree<int> stree(1, n, arr, func);
	for (int i = 0; i < m; i++){
		int a, b, c;
		cin >> a >> b >> c;
		if (a == 2){
			//查询
			if (b > c){
				cout << -2147483648 << endl;
			}
			else {
				
				cout << stree.query(b, c) << endl;
			}
		}
		else {
			//修改
			stree.modify(b, c);
			arr[b] = c;
		}
	}
	return 0;
}

2. 区间和值

2.1题目


​ 给定一个 n n n 位数组和两种操作:
操作1:数组中某个区间的所有数字加上一个值
​操作2:查询数组中某个区间的所有数字之和

输入
​ 第一行输入两个整数 n , m ( 1 ≤ n ≤ 10000 , 3 ≤ m ≤ 100000 ) n,m (1\le n\le 10000,3\le m\le 100000) n,m(1n100003m100000)​,分别代表数组大小和操作数。
​ 第二行包含 n n n 个整数,代表数组中相应的数字,数字大小不会超过 int 表示范围。
​ 接下来 m m m行,每行三个或四个整数 a , b , c , d ( a ∈ [ 1 , 2 ] ) a,b,c,d(a\in [1,2]) a,b,c,da[1,2])
a = 1 a=1 a=1​ 时,接下来输入 b , c , d b,c,d b,c,d​,代表将数组 [ b , c ] [b,c] [b,c]​区间内的数字加上 d , ( 1 ≤ b , c ≤ n ) d,(1\le b,c\le n) d(1b,cn) d d d is int​
a = 2 a=2 a=2 时,接下来输入 b , c b,c b,c,输入代表询问 [ b , c ] [b,c] [b,c] 区间内的和值 ( 1 ≤ b , c ≤ n ) (1\le b,c\le n) (1b,cn),当 b > c b>c b>c 时,输出 0 0 0

输出
​ 对于每个 a = 2 a=2 a=2 的操作,输出查询区间内数字的和值,答案不会超过64位整型(long long)的表示范围。

样例输入

6 5
6 9 4 3 2 1
2 2 5
1 2 3 5
2 1 6
1 5 12 3
2 1 6

样例输出

18
35
41


2.2思路

2.3代码

#include "s_tree.h"
#include <iostream>
#include <cstdio>
using namespace std;
struct Node{
	int length;
	long long sum;
};
Node value[10005];
Node sum(const Node& x1, const Node& x2){
	Node ans;
	ans.sum = x1.sum + x2.sum;
	ans.length = x1.length + x2.length;
	return ans;
}
void modify(Node& node,const int& value){
	node.sum += value * node.length;
}
bool operator==(const Node& n1, const Node& n2){
	return (n1.sum == n2.sum && n1.length == n2.length);
}

int main()
{
	int n, m;
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= n; i++){
		scanf("%lld",&value[i].sum);
		value[i].length = 1;
	}
	STree<Node> stree(1, n + 5, value, sum, modify);
	for (int i = 0; i < m; i++){
		int a;
		scanf("%d", &a);
		if (a == 2){
			//查询
			int b, c;
			scanf("%d%d", &b, &c);
			if (b > c){
				printf("0\n");
			}
			else {
				printf("%lld\n", stree.query(b, c).sum);
				
			}
		}
		else {
			//修改
			int b, c;
			int d;
			scanf("%d%d%d", &b, &c, &d);
			if (b > c){
				continue;
			}
			c = min(c, n);
        stree.modify(b, c, d);
		}
	}
	return 0;
}

3. 区间最大子段和

3.1题目


给定长度为 N N N 的数列 A A A,以及 M M M 条指令 ( N ≤ 500000 , M ≤ 100000 ) (N\le 500000, M\le 100000) (N500000,M100000),每条指令可能是以下两种之一:
​ “2 x y”,把 A [ x ] A[x] A[x] 改成 y y y
​ “1 x y”,查询区间 [ x , y ] [x,y] [x,y] 中的最大连续子段和,即 m a x ( ∑ i = l r A [ i ] ) ( x ≤ l ≤ r ≤ y ) max(\sum_{i = l}^r A[i]) _{(x\le l\le r\le y)} max(i=lrA[i])(xlry)
​ 对于每个询问,输出一个整数表示答案。

输入
​ 第一行两个整数 N ( 1 ≤ N ≤ 500000 ) , M ( 1 ≤ M ≤ 100000 ) N(1≤N≤500000),M(1≤M≤100000) N(1N500000),M(1M100000)
​ 第二行 N N N 个整数 A i ( 1 ≤ ∣ A i ∣ ≤ 1000 ) Ai(1≤|Ai|≤1000) Ai(1Ai1000)
​ 接下来 M M M 行每行 3 3 3 个整数 k , x , y , k = 1 k,x,y,k=1 k,x,yk=1 表示查询(此时如果 x > y x>y x>y,请交换 x , y x,y x,y), k = 2 k=2 k=2表示修改

输出
​ 对于每个询问输出一个整数表示答案。

输入样例1

5 3
1 2 -3 4 5
1 2 3
2 2 -1
1 3 2

输出样例1

2
-1


3.2思路

3.3代码

#include "s_tree.h"
#include <iostream>
#include <cstdio>
using namespace std;
struct Node {
	long long sum, in_max, l_max, r_max;
};
Node arr[500005];
Node func(const Node& l, const Node& r) {
	Node ans;
	ans.in_max = max(l.in_max, r.in_max);
	ans.in_max = max(ans.in_max, l.r_max + r.l_max);
	ans.sum = l.sum + r.sum;
	ans.l_max = max(l.l_max, l.sum + r.l_max);
	ans.r_max = max(r.r_max, r.sum + l.r_max);
	return ans; 
}
bool operator==(const Node& x1, const Node& x2){
	return (x1.sum == x2.sum && x1.in_max == x2.in_max && 
			x1.r_max == x2.r_max && x1.l_max == x2.l_max);
}
int main()
{
	int n, m;
   scanf("%d%d",&n, &m);
	for (int i = 1; i <= n; i++){
		cin >> arr[i].sum;
		arr[i].in_max = arr[i].sum;
		arr[i].l_max = arr[i].in_max;
		arr[i].r_max = arr[i].in_max;
	}
    STree<Node> stree(1, n, arr, func);
	for (int i = 0; i < m; i++){
		int a, b, c;
		scanf("%d%d%d",&a, &b, &c);
		if (a == 1){
			//查询
			if (b > c){
				swap(b, c);
			}
			printf("%lld\n", stree.query(b, c).in_max);
		}
		else {
			//修改
			arr[b].sum = c;
			arr[b].in_max = c;
			arr[b].l_max = c;
			arr[b].r_max = c;
			stree.modify(b, arr[b]);
			
		}
	}
	return 0;
}
  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值