【C++】线段树学习


今天我看了B站UP主正月点灯笼的 【数据结构】线段树(Segment Tree),决定写一篇学习笔记

一代版本(2021)

1、构建一个线段树

void build_tree(int arr[], int tree[], int node,
 int start, int end) {
/*
	arr为存储原数据的数组,tree为线段树数组,node为结点,
	start和end为数据范围,表示原数据总共为第一个到第几个
*/
	if (start == end) {	//递归出口,此时到达叶节点
		tree[node] = arr[start];
	}
	else {	//未到达叶节点,继续递归
		int mid = (start + end) / 2;	
		int left_node  = 2 * node + 1;	//线段树左节点公式
		int right_node = 2 * node + 2;	//线段树右节点公式
		
		//进入下一次递归

		//往当前节点的左边走,mid作为end进入下一次递归
		build_tree(arr, tree, left_node, start, mid);	
		
		//往当前节点的右边走,mid+1作为start进入下一次递归
		build_tree(arr, tree, right_node, mid + 1, end);//
		
		/*
			tree[node]即为当前节点的值,为左值加右值,
			当上面两个build_tree完成递归后才执行,此时
			tree[left_node]和tree[right_node]已完成赋值
		*/
		tree[node] = tree[left_node] + tree[right_node];
	}

}

代码实现

	int arr[] = { 1,3,5,7,9,11 };			//单个数据
	int size = 6;							//数据个数
	int tree[MAXI] = { 0 };					//线段树(初始化为0)

	//建立线段树,从node=0开始建立
	build_tree(arr, tree, 0, 0, size - 1);	

在这里插入图片描述

2、更新线段树的内容

void update_tree(int arr[], int tree[], 
int node, int start, int end, int idx, int val) {
	/*
		arr为存储原数据的数组,tree为线段树数组,node为结点,
		start和end为数据范围,表示原数据总共为第一个到第几个,
		idx为要修改的数据是第几个,val为要修改的值
	*/
	if (start == end) {	//递归出口,更改叶节点的值
		arr[idx]   = val;
		tree[node] = val;

	}
	else {	//未到达叶节点,继续递归
		int mid = (start + end) / 2;
		int left_node  = 2 * node + 1;	//线段树左节点公式
		int right_node = 2 * node + 2;	//线段树右节点公式
		
		//idx不在当前节点的左节点所保存的数据范围内
		if (start <= idx && idx <= mid) {	
			update_tree(arr, tree, left_node, start, mid, idx, val);
		}
		
		//idx在当前节点的左节点所保存的数据范围内
		else {
			update_tree(arr, tree, right_node, mid + 1, end, idx, val);
		}

		//完成上面的递归后对经过的节点所保留的数据继续更改
		tree[node] = tree[left_node] + tree[right_node];
	}

}

代码实现

	/*
		将位于4的数(其实是第五个数(9))改为6,因为原来的数据
		在数组中的排列是从arr[0]开始的
	*/
	update_tree(arr, tree, 0, 0, size - 1, 4, 6);	

代码实现

3、查询线段树的内容

//需要返回int类型的数值,所以不用void
//查询的结果为指定区间内的数的和
int query_tree(int arr[], int tree[], int node, int start, int end, int L, int R) {
	/*
		arr为存储原数据的数组,tree为线段树数组,node为结点,
		start和end为数据范围,表示原数据总共为第一个到第几个,
		L为要查询的数据是从第几个开始,R为要查询的数据到第几个结束
	*/
	
	//查询范围不在数据存储区间内,比如总共就6个数,却要查10到20个数的和
	if (R < start || L > end) {	
		return 0;
	}
	else if (L <= start && end <= R) {
		return tree[node];
	}
	/*
		上面那个else if是对下面这个的优化,
		上面那个可以通过节点保存了一个范围内的数据的总和,不需要查询到
		每一个叶节点才返回;
		下面这个需要查询到叶节点才返回,会进行不必要的运算,浪费性能
		else if (start == end) {
			return tree[node];
		}
	*/
	else {
		int mid = (start + end) / 2;
		int left_node  = 2 * node + 1;	//线段树左节点公式
		int right_node = 2 * node + 2;	//线段树右节点公式
		
		int sum_left  = query_tree(arr, tree, left_node, start, mid, L, R);
		int sum_right = query_tree(arr, tree, right_node, mid + 1, end, L, R);
		
		//返回左右两节点的数据总和
		return sum_left + sum_right;
	}
}

代码实现

//输出2到4区间(5,7,6)的和
cout << query_tree(arr, tree, 0, 0, size - 1, 0, 0);	

原数据arr

综合代码实现

#define _CRT_SECURE_NO_DEPRECATE
#include <iostream>
#include <cstdio>
using namespace std;

#define MAXI 10000

void build_tree(int arr[], int tree[], int node, int start, int end) {

	if (start == end) {
		tree[node] = arr[start];
	}
	else {
		int mid = (start + end) / 2;
		int left_node  = 2 * node + 1;
		int right_node = 2 * node + 2;

		build_tree(arr, tree, left_node, start, mid);
		build_tree(arr, tree, right_node, mid + 1, end);
		tree[node] = tree[left_node] + tree[right_node];
	}

}

void update_tree(int arr[], int tree[], int node, int start, int end, int idx, int val) {
	
	if (start == end) {
		arr[idx]   = val;
		tree[node] = val;

	}
	else {
		int mid = (start + end) / 2;
		int left_node  = 2 * node + 1;
		int right_node = 2 * node + 2;
		if (idx >= start && idx <= mid) {
			update_tree(arr, tree, left_node, start, mid, idx, val);
		}
		else {
			update_tree(arr, tree, right_node, mid + 1, end, idx, val);
		}
		tree[node] = tree[left_node] + tree[right_node];
	}

}

int query_tree(int arr[], int tree[], int node, int start, int end, int L, int R) {

	if (R < start || L > end) {
		return 0;
	}
	else if (L <= start && end <= R) {
		return tree[node];
	}
	else {
		int mid = (start + end) / 2;
		int left_node  = 2 * node + 1;
		int right_node = 2 * node + 2;
		int sum_left  = query_tree(arr, tree, left_node, start, mid, L, R);
		int sum_right = query_tree(arr, tree, right_node, mid + 1, end, L, R);
		return sum_left + sum_right;
	}
}

int main() {
	
	int arr[] = { 1,3,5,7,9,11 };			//单个数据
	int size = 6;							//数据个数
	int tree[MAXI] = { 0 };					//线段树(初始化为0)

	build_tree(arr, tree, 0, 0, size - 1);	//建立线段树

	update_tree(arr, tree, 0, 0, size - 1, 4, 6);	//将位于4的数(其实是第五个数(9))改为6
													
	cout << query_tree(arr, tree, 0, 0, size - 1, 2, 4);	//输出第2个数到第4个数的(5,7,6)的和

	return 0;
}

进行了部分优化,将arr[1]设为第一个数,更加直观

#include <iostream>
#include <cstdio>
using namespace std;

#define MAX 100000	

int arr[MAX];		//原始数据   
int tree[4 * MAX];	//线段树,建立线段树需要四倍空间
//(可不考虑实际需要大小,在作为参数传入函数时会根据需要拓展空间)

//arr为原始数据,tree为线段树,node为线段树的节点,start为数据开始的位置
//end为数据结束的位置(长度),L为查询区间的左值,R为查询区间的右值

void build_tree(int arr[], int tree[], int node, int start, int end) {

	if (start == end) {
		tree[node] = arr[start];
	}
	else {
		int mid = (start + end) / 2;
		int left_node = 2 * node + 1;
		int right_node = 2 * node + 2;

		build_tree(arr, tree, left_node, start, mid);
		build_tree(arr, tree, right_node, mid + 1, end);
		tree[node] = tree[left_node] + tree[right_node];
	}

}

void update_tree(int tree[], int node, int start, int end, int idx, int val) {

	if (start == end) {
		tree[node] = val;
	}
	else {
		int mid = (start + end) / 2;
		int left_node = 2 * node + 1;
		int right_node = 2 * node + 2;
		if (idx >= start && idx <= mid) {
			update_tree(tree, left_node, start, mid, idx, val);
		}
		else {
			update_tree(tree, right_node, mid + 1, end, idx, val);
		}
		tree[node] = tree[left_node] + tree[right_node];
	}

}

int query_tree(int tree[], int node, int start, int end, int L, int R) {

	if (R < start || L > end) {
		return 0;
	}
	else if (L <= start && end <= R) {
		return tree[node];
	}
	else {
		int mid = (start + end) / 2;
		int left_node = 2 * node + 1;
		int right_node = 2 * node + 2;
		int sum_left = query_tree(tree, left_node, start, mid, L, R);
		int sum_right = query_tree(tree, right_node, mid + 1, end, L, R);
		return sum_left + sum_right;
	}

}


int main() {

	memset(arr, 0, sizeof(arr));
	memset(tree, 0, sizeof(tree));

	int size = 1234;	//数据大小

	for (int i = 1; i <= size; i++) {	//将 arr[1] 作为第一个数开始建立原始数据
		arr[i] = 1;
	}

	build_tree(arr, tree, 0, 1, size);	//建立线段树

	update_tree(tree, 0, 1, size, 500, 1000);	//将第500个数的值改为1000

	cout << query_tree(tree, 0, 1, size, 1, 100) << endl;	//输出1到100个数的和
	//100

	cout << query_tree(tree, 0, 1, size, 500, 500) << endl;	//输出第500个数的值
	//1000

	return 0;
}



二代版本(2022)

1、结构体

struct Node
{
	int L, R;	//区间左、右两端
	int sum;	//区间和
} tr[N * 4];	//四倍空间

2、由子节点更新父节点

void pushup(int u)	//可不单独定义,可单独写在其他函数内
{
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
    //u<<1 = u*2 代表左子节点	u<<1|1 = u*2+1 代表右子节点
}

3、创建线段树

void build(int u, int l, int r)
{
	if (l == r)tr[u] = { l,r,w[r] };	//叶子节点
	else
	{
		tr[u] = { l,r };
		int mid = l + r >> 1;
		build(u << 1, l, mid);
		build(u << 1 | 1, mid + 1, r);
		pushup(u);
	}
}

4、查询

int query(int u, int l, int r)
{
	if (tr[u].L >= l && tr[u].R <= r)return tr[u].sum;
	//特别注意是 tr[u].L >= l && tr[u].R <= r
	int mid = tr[u].L + tr[u].R >> 1;
	int sum = 0;
	if (l <= mid)sum += query(u << 1, l, r);
	if (r > mid)sum += query(u << 1 | 1, l, r);
	return sum;
}

5、修改

void modify(int u, int x, int v)
{
	if (tr[u].L == tr[u].R)tr[u].sum += v;	//这里是+=,依情况而定
	else
	{
		int mid = tr[u].L + tr[u].R >> 1;
		if (x <= mid)modify(u << 1, x, v);
		else modify(u << 1 | 1, x, v);
		pushup(u);
	}
}

综合代码实现

原题链接:
AcWing-1264. 动态求连续区间和
题目描述:
给定 n 个数组成的一个数列,规定有两种操作,一是修改某个元素,二是求子数列 [a,b] 的连续和。

输入格式
第一行包含两个整数 n 和 m,分别表示数的个数和操作次数。

第二行包含 n 个整数,表示完整数列。

接下来 m 行,每行包含三个整数 k,a,b (k=0,表示求子数列[a,b]的和;k=1,表示第 a 个数加 b)。

数列从 1 开始计数。

输出格式
输出若干行数字,表示 k=0 时,对应的子数列 [a,b] 的连续和。

数据范围
1≤n≤100000,
1≤m≤100000,
1≤a≤b≤n,
数据保证在任何时候,数列中所有元素之和均在 int 范围内。

输入样例:
10 5
1 2 3 4 5 6 7 8 9 10
1 1 5
0 1 3
0 4 8
1 7 5
0 4 8
输出样例:
11
30
35


AC代码如下:

#define _CRT_SECURE_NO_DEPRECATE
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
using namespace std;

const int N = 100010;
int n, m;
int w[N];

struct Node
{
	int L, R;
	int sum;
} tr[N * 4];

void pushup(int u)
{
	tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void build(int u, int l, int r)
{
	if (l == r)tr[u] = { l,r,w[r] };
	else
	{
		tr[u] = { l,r };
		int mid = l + r >> 1;
		build(u << 1, l, mid);
		build(u << 1 | 1, mid + 1, r);
		pushup(u);
	}
}

int query(int u, int l, int r)
{
	if (tr[u].L >= l && tr[u].R <= r)return tr[u].sum;
	int mid = tr[u].L + tr[u].R >> 1;
	int sum = 0;
	if (l <= mid)sum += query(u << 1, l, r);
	if (r > mid)sum += query(u << 1 | 1, l, r);
	return sum;
}

void modify(int u, int x, int v)
{
	if (tr[u].L == tr[u].R)tr[u].sum += v;
	else
	{
		int mid = tr[u].L + tr[u].R >> 1;
		if (x <= mid)modify(u << 1, x, v);
		else modify(u << 1 | 1, x, v);
		pushup(u);
	}
}

int main()
{
	cin >> n >> m;
	for (int i = 1; i <= n; i++)scanf("%d", &w[i]);
	build(1, 1, n);

	int k, a, b;
	while (m--)
	{
		scanf("%d %d %d", &k, &a, &b);
		if (k == 0)printf("%d\n", query(1, a, b));
		else modify(1, a, b);
	}
	return 0;
}
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值