数据结构——线段树
本文主要内容
- 基础复习
- 线段树动态求区间最值
- 线段树动态求区间最大区间和
1 基础复习
- 主要功能
- 动态实现单点修改和区间查询
- 时间复杂度
单点修改和区间查询的时间复杂度都是o(logn)- 基本思路:
- 建立一个结构体数组,数组中每个元素维护一段区间(包括两个端点及区间性质),每个结构体元素代表线段树的一个节点。
- 线段树的节点维护一个区间,则将区间分平均为两半,该节点子节点维护左子区间,右子节点维护右子区间。
- 线段树的叶子节点l=r=数组元素下标,代表单个数组元素
- 注意事项:
- 下标从1开始
- 区间查询与单点修改函数中的mid与建树函数中的mid有所区别
- 线段树数组要开原数组长度的4倍
1.1 Acwing1264. 动态求连续区间和
模板题~
#include<iostream>
using namespace std;
const int N = 1e5 + 5;
//T表示线段树的节点类型
struct T
{
int l, r;
int sum; //该节点维护的区间性质为区间和
}tr[N * 4];
int n, m, a[N];
//更新函数:用两个子节点更新当前节点区间性质(这里是求和所以相加)
void push_up(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
//递归建树:建立一个根节点为u且维护区间[l, r]的线段树
void build(int u, int l, int r)
{
if(l == r) //u为叶子节点,停止递归
{
tr[u] = {l, r, a[l]};
return;
}
int mid = l + r >> 1;
tr[u] = {l, r};
build(u << 1, l, mid); //递归建立左子树
build(u << 1 | 1, mid + 1, r); //递归建立右子树
push_up(u);
}
//查询函数:查询[l, r]这段区间性质(这里是区间元素之和)
int query(int u, int l, int r)
{
if(l <= tr[u].l && r >= tr[u].r)return tr[u].sum; //l与r包含当前节点维护的区间,则直接返回当前节点维护的值
int sum = 0;
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid)sum += query(u << 1, l, r); //l在左边这段则查询左子树
if(r > mid)sum += query(u << 1 | 1, l, r);//r在右边这段则查询右子树
return sum;
}
//修改函数:(这里是x位置元素加d)
void add(int u, int x, int d)
{
if(tr[u].l == tr[u].r)
{
tr[u].sum += d;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid)add(u << 1, x, d); //x在左边这段则在左子树中更新x节点
else add(u << 1 | 1, x, d); //x在右边这段则在右子树中更新x节点
push_up(u); //更新当前节点
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)
cin >> a[i];
build(1, 1, n);
while(m--)
{
int k, l, r;
cin >> k >> l >> r;
if(k == 0)cout << query(1, l, r) << endl;
else add(1, l, r);
}
return 0;
}
2 线段树动态求区间最值
2.1 Acwing1275. 最大数
- 思路:
- 树状数组维护的区间性质为区间最大值
- 开始时创建一棵区间长度为m的元素值全为0(最小值)的线段树,插入一个元素即修改相应树状数组相应位置为该元素,利用一个变量n来维护线段树实际区间长度
#include<iostream>
#include<cstring>
using namespace std;
typedef long long LL;
const int N = 2e5 + 5;
struct
{
int l, r;
int maxn;
}tr[N * 4];
int n, m, p, w[N];
void push_up(int u)
{
tr[u].maxn = max(tr[u << 1].maxn, tr[u << 1 | 1].maxn);
}
void build(int u, int l, int r)
{
if(l == r)tr[u] = {l, r, w[l]};
else
{
int mid = l + r >> 1;
tr[u] = {l, r};
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
push_up(u);
}
}
int query(int u, int l)
{
if(l <= tr[u].l)return tr[u].maxn;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid)return max(query(u << 1, l), query(u << 1 | 1, l));
else return query(u << 1 | 1, l);
}
}
void add(int u, int x, int d)
{
if(tr[u].l == tr[u].r)tr[u].maxn += d;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid)add(u << 1, x, d);
else add(u << 1 | 1, x, d);
push_up(u);
}
}
int main()
{
LL a = 0;
cin >> m >> p;
build(1, 1, m);
while(m--)
{
char op;
int x;
cin >> op >> x;
if(op == 'A')
{
w[++n] = (LL)(x + a) % p;
add(1, n, w[n]);
}
else if(op == 'Q')
{
a = query(1, n - x + 1);
cout << a << endl;
}
}
return 0;
}
3 线段树动态求区间最大区间和
3.1 Acwing245. 你能回答这些问题吗
- 思路:
- 因为是求最大区间和,所以线段树节点设一个变量max来维护节点区间的最大区间和
- 对于任何一个区间来说,其和最大的区间可能在其左子区间,右子区间,或横跨左右子区间,对于横跨左右子区间的这部分最大区间和等于左子区间后缀区间和最大值加上右子区间的前缀区间和最大值,因此线段树节点需要设一个l_max和一个r_max分别维护本段区间的前缀区间和最大值和后缀区间和最大值
- 当前区间的max可以用左子区间的max、右子区间的max和左子区间的r_max+右子区间的l_max三者中的最大值来更新,而当前区间的l_max可以用左子区间的l_max和左子区间和+右子区间的l_max两者之中的较大值来更新,因此要设一个变量sum维护区间元素之和
- 查询操作不会改变值,不需要push_up,而查询时对于和最大的区间可能横跨两边的这种情况就需要分别查询两边,要用到左右子区间查询结果的max,l_max,r_max,因此查询函数的返回类型为线段树节点类型
#include<iostream>
using namespace std;
typedef long long LL;
const int N = 5e5 + 5;
int a[N], n, m;
struct T
{
int l, r;
LL sum, l_max, r_max, max;
}
tr[N * 4];
void push_up(int i)
{
tr[i].sum = tr[i << 1].sum + tr[i << 1 | 1].sum;
tr[i].l_max = max(tr[i << 1].l_max, tr[i << 1].sum + tr[i << 1 | 1].l_max);
tr[i].r_max = max(tr[i << 1 | 1].r_max, tr[i << 1 | 1].sum + tr[i << 1].r_max);
tr[i].max = max(max(tr[i << 1].max, tr[i << 1 | 1].max), tr[i << 1].r_max + tr[i << 1 | 1].l_max);
}
void build(int u, int l, int r)
{
tr[u].l = l, tr[u].r = r;
if(tr[u].l == tr[u].r)
{
tr[u].l_max = tr[u].r_max = tr[u].max = tr[u].sum = a[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
push_up(u);
}
void change(int u, int x, int d)
{
if(tr[u].l == tr[u].r)
{
tr[u].l_max = tr[u].r_max = tr[u].max = tr[u].sum = d;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if(x <= mid)change(u << 1, x, d);
else change(u << 1 | 1, x, d);
push_up(u);
}
T query(int u, int l, int r) //返回一个节点,这个节点存了四个信息(包含max),因为对于每个节点要用其两个子节点的信息更新当前节点max
{
if(l <= tr[u].l && r >= tr[u].r)
return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if(r <= mid)return query(u << 1, l, r);
if(l > mid)return query(u << 1 | 1, l, r);
T left = query(u << 1, l, r), right = query(u << 1 | 1, l, r);
T root;
root.sum = left.sum + right.sum;
root.l_max = max(left.l_max, left.sum + right.l_max);
root.r_max = max(right.r_max, right.sum + left.r_max);
root.max = max(max(left.max, right.max), left.r_max + right.l_max);
return root;
}
int main()
{
cin >> n >> m;
for(int i = 1; i <= n; i++)cin >> a[i];
build(1, 1, n);
while(m--)
{
int k, x, y;
cin >> k >> x >> y;
if(k == 1)
{
if(x > y)swap(x, y);
//cout << "q " << x << ' ' << y << endl;
cout << query(1, x, y).max << endl;
}
else
{
change(1, x, y);
//cout << "c " << x << ' ' << y << endl;
}
}
return 0;
}