最后附上java建树和c++建树完整代码
1、创建节点
线段树,可以对一个区间实现区间求和,最大值,最小值,相比较于暴力跑循环,可以将复杂度降低到logn的程度,建树大致可以是以下方法,这里使用的链表创建,首先创建节点:
//区间左边的下标L
//区间右边的下标R
//当前区间的和value
//maxvalue和minvalue分别表示区间最大值和最小值
//*left和*right左右节点指针
struct Node {
int L;
int R;
int value;
int maxvalue;
int minvalue;
int lazy = 0;
Node *left;
Node *right;
Node(int l, int r, int value) {
this->L = l;
this->R = r;
this->value = value;
}
};
2、建树
这里我们建树的数据放入的是一个数组
Node *building(int l, int r, vector<int> &value) {
//如果树为空,则返回不合法
if (l > r)
return nullptr;
//我们首先建树先把数据存放好,父节点只是用来标记区间,这里赋值是习惯问题
Node *node = new Node(l, r, value[l]);
node->maxvalue = value[l];
node->minvalue = value[l];
//如果两点相同,说明已经建到最后一个节点了,返回父节点
if (l == r)
return node;
//表明还有区间没有创建,取中间值,分成左子树和右子树递归创建
int mid = (r - l) / 2 + l;
node->left = building(l, mid, value);
node->right = building(mid + 1, r, value);
//如果该节点左子树为空,查看是否还有右子树,不然就赋值本身
//这段代码上面习惯问题已经赋值,其实可以不要这段赋默认值
if (node->left == nullptr) {
if (node->right == nullptr) {
node->value = value[l];
node->maxvalue = node->value;
node->minvalue=node->value;
} else {
node->value = node->right->value;
node->maxvalue = node->right->maxvalue;
node->minvalue=node->right->minvalue;
}
} else {
//如果左子树不为空就看右子树状态,反正谁不为空就拿谁的值,不然就默认左边下标的值
if (node->right == nullptr) {
node->value = node->left->value;
node->maxvalue = node->left->maxvalue;
node->minvalue=node->left->minvalue;
} else {
node->value = node->right->value + node->left->value;
node->maxvalue = max(node->right->maxvalue, node->left->maxvalue);
node->minvalue = min(node->right->minvalue, node->left->minvalue);
}
}
//表示当前区间已经走到尽头,返回到上一个父节点,直到头结点
return node;
}
3、区间修改值
线段树最大的优化手段就是lazy标签,可以说是灵魂所在,否则对于某些数据复杂度返回得不到好处
void alter(Node *node, int l, int r, int number) {
//判空
if (node == nullptr)
return;
//如果当前左边的下标和右边的下标刚好是这个区间,就将要修改的值直接加上原本的值
//并打上lazy标签,因为有可能重新修改当前区间,所以加上原来的值
if (node->L == l && node->R == r) {
node->value += number * (node->R - node->L + 1);
node->lazy += number;
node->maxvalue += number;
node->minvalue+=number;
return;
}
//表示区间并不是完全对应上,我们就拆开,分别找区间
int mid = (node->R - node->L) / 2 + node->L;
//首先需要判断一下当前这个节点有没有lazy标签,有的话表示下面的子节点数据还没有更新
//我们需要先去把左右子节点更新一下
if (node->lazy > 0) {
alter(node->left, node->L, mid, node->lazy);
alter(node->right, mid + 1, node->R, node->lazy);
node->lazy = 0;
}
//如果我们要加的区间整体在左边,我们就把整体传到左子树去查找
//顺便更新一下这个区间和和最大最小值
if (r <= mid) {
alter(node->left, l, r, number);
node->value+=number*(r-l+1);
node->maxvalue+=number;
node->minvalue+=number;
//如果整体在右边
} else if (mid < l) {
alter(node->right, l, r, number);
node->value+=number*(r-l+1);
node->maxvalue+=number;
node->minvalue+=number;
//否则那就是左右两边都有我们要的值
} else {
//把分别需要的值都扔给左右子树去更新
alter(node->left, l, mid, number);
alter(node->right, mid + 1, r, number);
node->value+=number*(r-l+1);
node->maxvalue+=number;
node->minvalue+=number;
}
}
4、取出最大值最小值,区间和
//剩下的就是取出线段树中的值,既然上面都看懂了,剩下这个都只是递归取值而已
int getValue(Node *node, int l, int r) {
//左右下标对应上了就返回值,否则查看属于左区间还是右区间,还是两边都有,分别去查询就好
//这里只需一定要注意lazy标签
if (node == nullptr)return -1;
if (node->L == l && node->R == r)return node->value;
int mid = (node->R - node->L) / 2 + node->L;
//如果发现这个节点有lazy标签,说明查询前下面子树还没有更新
//一定一定要先去把子树更新了再查询,只需要更新一层就好,按需去更新
if (node->lazy > 0) {
alter(node->left, node->L, mid, node->lazy);
alter(node->right, mid + 1, node->R, node->lazy);
node->lazy = 0;
}
if (mid < l)return getValue(node->right, l, r);
if (mid >= r)return getValue(node->left, l, r);
return getValue(node, mid + 1, r) + getValue(node, l, mid);
}
int getMaxvalue(Node *node, int l, int r) {
if (node == nullptr) {
return -1;
}
if (node->L == l && node->R == r) {
return node->maxvalue;
}
int mid = (node->R - node->L) / 2 + node->L;
if (node->lazy > 0) {
alter(node->left, node->L, mid, node->lazy);
alter(node->right, mid + 1, node->R, node->lazy);
node->lazy = 0;
}
if (r <= mid) {
return getMaxvalue(node->left, l, r);
}
if (mid < l) {
return getMaxvalue(node->right, l, r);
}
int leftmax = getMaxvalue(node->left, l, mid);
int rightmax = getMaxvalue(node->right, mid + 1, r);
return max(leftmax, rightmax);
}
int getMinvalue(Node *node, int l, int r) {
if (node == nullptr) {
return -1;
}
if (node->L == l && node->R == r) {
return node->minvalue;
}
int mid = (node->R - node->L) / 2 + node->L;
if (node->lazy > 0) {
alter(node->left, node->L, mid, node->lazy);
alter(node->right, mid + 1, node->R, node->lazy);
node->lazy = 0;
}
if (r <= mid) {
return getMinvalue(node->left, l, r);
}
if (mid < l) {
return getMinvalue(node->right, l, r);
}
int leftmin = getMinvalue(node->left, l, mid);
int rightmin = getMinvalue(node->right, mid + 1, r);
return min(leftmin, rightmin);
}
5、完整c++代码
#include <bits/stdc++.h>
#define int long long
using namespace std;
struct Node {
int L;
int R;
int value;
int maxvalue;
int minvalue;
int lazy = 0;
Node *left;
Node *right;
Node(int l, int r, int value) {
this->L = l;
this->R = r;
this->value = value;
}
};
Node *building(int l, int r, vector<int> &value) {
if (l > r)
return nullptr;
Node *node = new Node(l, r, value[l]);
node->maxvalue = value[l];
node->minvalue = value[l];
if (l == r)
return node;
int mid = (r - l) / 2 + l;
node->left = building(l, mid, value);
node->right = building(mid + 1, r, value);
if (node->left == nullptr) {
if (node->right == nullptr) {
node->value = value[l];
node->maxvalue = node->value;
node->minvalue=node->value;
} else {
node->value = node->right->value;
node->maxvalue = node->right->maxvalue;
node->minvalue=node->right->minvalue;
}
} else {
if (node->right == nullptr) {
node->value = node->left->value;
node->maxvalue = node->left->maxvalue;
node->minvalue=node->left->minvalue;
} else {
node->value = node->right->value + node->left->value;
node->maxvalue = max(node->right->maxvalue, node->left->maxvalue);
node->minvalue = min(node->right->minvalue, node->left->minvalue);
}
}
return node;
}
void alter(Node *node, int l, int r, int number) {
if (node == nullptr)
return;
if (node->L == l && node->R == r) {
node->value += number * (node->R - node->L + 1);
node->lazy += number;
node->maxvalue += number;
node->minvalue+=number;
return;
}
int mid = (node->R - node->L) / 2 + node->L;
if (node->lazy > 0) {
alter(node->left, node->L, mid, node->lazy);
alter(node->right, mid + 1, node->R, node->lazy);
node->lazy = 0;
}
if (r <= mid) {
alter(node->left, l, r, number);
node->value+=number*(r-l+1);
node->maxvalue+=number;
node->minvalue+=number;
} else if (mid < l) {
alter(node->right, l, r, number);
node->value+=number*(r-l+1);
node->maxvalue+=number;
node->minvalue+=number;
} else {
alter(node->left, l, mid, number);
alter(node->right, mid + 1, r, number);
node->value+=number*(r-l+1);
node->maxvalue+=number;
node->minvalue+=number;
}
}
int getValue(Node *node, int l, int r) {
if (node == nullptr)return -1;
if (node->L == l && node->R == r)return node->value;
int mid = (node->R - node->L) / 2 + node->L;
if (node->lazy > 0) {
alter(node->left, node->L, mid, node->lazy);
alter(node->right, mid + 1, node->R, node->lazy);
node->lazy = 0;
}
if (mid < l)return getValue(node->right, l, r);
if (mid >= r)return getValue(node->left, l, r);
return getValue(node, mid + 1, r) + getValue(node, l, mid);
}
int getMaxvalue(Node *node, int l, int r) {
if (node == nullptr) {
return -1;
}
if (node->L == l && node->R == r) {
return node->maxvalue;
}
int mid = (node->R - node->L) / 2 + node->L;
if (node->lazy > 0) {
alter(node->left, node->L, mid, node->lazy);
alter(node->right, mid + 1, node->R, node->lazy);
node->lazy = 0;
}
if (r <= mid) {
return getMaxvalue(node->left, l, r);
}
if (mid < l) {
return getMaxvalue(node->right, l, r);
}
int leftmax = getMaxvalue(node->left, l, mid);
int rightmax = getMaxvalue(node->right, mid + 1, r);
return max(leftmax, rightmax);
}
int getMinvalue(Node *node, int l, int r) {
if (node == nullptr) {
return -1;
}
if (node->L == l && node->R == r) {
return node->minvalue;
}
int mid = (node->R - node->L) / 2 + node->L;
if (node->lazy > 0) {
alter(node->left, node->L, mid, node->lazy);
alter(node->right, mid + 1, node->R, node->lazy);
node->lazy = 0;
}
if (r <= mid) {
return getMinvalue(node->left, l, r);
}
if (mid < l) {
return getMinvalue(node->right, l, r);
}
int leftmin = getMinvalue(node->left, l, mid);
int rightmin = getMinvalue(node->right, mid + 1, r);
return min(leftmin, rightmin);
}
signed main() {
int n;
cin >> n;
vector<int> arr(n+1);
for (int i = 0; i < n; i++) {
cin >> arr[i+1];
}
Node *head = building(1, n, arr);
return 0;
}
6、完整java代码
import java.util.*;
class Node {
int L;
int R;
int value;
int maxvalue;
int minvalue;
int lazy = 0;
Node left;
Node right;
Node(int l, int r, int value) {
this.L = l;
this.R = r;
this.value = value;
}
}
class Main{
static Node building(int l, int r, ArrayList<Integer> value) {
if (l > r)
return null;
Node node = new Node(l, r, value.get(l));
node.maxvalue = value.get(l);
node.minvalue = value.get(l);
if (l == r)
return node;
int mid = (r - l) / 2 + l;
node.left = building(l, mid, value);
node.right = building(mid + 1, r, value);
if (node.left == null) {
if (node.right == null) {
node.value = value.get(l);
node.maxvalue = node.value;
node.minvalue = node.value;
} else {
node.value = node.right.value;
node.maxvalue = node.right.maxvalue;
node.minvalue = node.right.minvalue;
}
} else {
if (node.right == null) {
node.value = node.left.value;
node.maxvalue = node.left.maxvalue;
node.minvalue = node.left.minvalue;
} else {
node.value = node.right.value + node.left.value;
node.maxvalue = Math.max(node.right.maxvalue, node.left.maxvalue);
node.minvalue = Math.min(node.right.minvalue, node.left.minvalue);
}
}
return node;
}
static void alter(Node node, int l, int r, int number) {
if (node == null)
return;
if (node.L == l && node.R == r) {
node.value += number * (node.R - node.L + 1);
node.lazy += number;
node.maxvalue += number;
node.minvalue += number;
return;
}
int mid = (node.R - node.L) / 2 + node.L;
if (node.lazy > 0) {
alter(node.left, node.L, mid, node.lazy);
alter(node.right, mid + 1, node.R, node.lazy);
node.lazy = 0;
}
if (r <= mid) {
alter(node.left, l, r, number);
node.value += number * (r - l + 1);
node.maxvalue += number;
node.minvalue += number;
} else if (mid < l) {
alter(node.right, l, r, number);
node.value += number * (r - l + 1);
node.maxvalue += number;
node.minvalue += number;
} else {
alter(node.left, l, mid, number);
alter(node.right, mid + 1, r, number);
node.value += number * (r - l + 1);
node.maxvalue += number;
node.minvalue += number;
}
}
static int getValue(Node node, int l, int r) {
if (node == null)
return -1;
if (node.L == l && node.R == r)
return node.value;
int mid = (node.R - node.L) / 2 + node.L;
if (node.lazy > 0) {
alter(node.left, node.L, mid, node.lazy);
alter(node.right, mid + 1, node.R, node.lazy);
node.lazy = 0;
}
if (mid < l)
return getValue(node.right, l, r);
if (mid >= r)
return getValue(node.left, l, r);
return getValue(node, mid + 1, r) + getValue(node, l, mid);
}
static int getMaxvalue(Node node, int l, int r) {
if (node == null)
return -1;
if (node.L == l && node.R == r)
return node.maxvalue;
int mid = (node.R - node.L) / 2 + node.L;
if (node.lazy > 0) {
alter(node.left, node.L, mid, node.lazy);
alter(node.right, mid + 1, node.R, node.lazy);
node.lazy = 0;
}
if (r <= mid)
return getMaxvalue(node.left, l, r);
if (mid < l)
return getMaxvalue(node.right, l, r);
int leftmax = getMaxvalue(node.left, l, mid);
int rightmax = getMaxvalue(node.right, mid + 1, r);
return Math.max(leftmax, rightmax);
}
static int getMinvalue(Node node, int l, int r) {
if (node == null)
return -1;
if (node.L == l && node.R == r)
return node.minvalue;
int mid = (node.R - node.L) / 2 + node.L;
if (node.lazy > 0) {
alter(node.left, node.L, mid, node.lazy);
alter(node.right, mid + 1, node.R, node.lazy);
node.lazy = 0;
}
if (r <= mid)
return getMinvalue(node.left, l, r);
if (mid < l)
return getMinvalue(node.right, l, r);
int leftmin = getMinvalue(node.left, l, mid);
int rightmin = getMinvalue(node.right, mid + 1, r);
return Math.min(leftmin, rightmin);
}
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
int m = scanner.nextInt();
ArrayList<Integer> arr = new ArrayList<Integer>();
//用掉第一个位置的值
arr.add(0);
for (int i = 1; i <= n; i++) {
int a=scanner.nextInt();
arr.add(a);
}
Node head = building(1, n, arr);
StringBuilder res = new StringBuilder();
for (int i = 0; i < m; i++) {
int a = scanner.nextInt();
if (a == 1) {
int x = scanner.nextInt();
int y = scanner.nextInt();
int k = scanner.nextInt();
int maxnum = Math.max(x, y);
int minnum = Math.min(x, y);
alter(head, minnum, maxnum, k);
} else {
int x = scanner.nextInt();
int y = scanner.nextInt();
res.append(getValue(head, x, y)).append("\n");
}
}
System.out.println(res);
}
}