线段树的本质是二分。查询时,都是将大区间二分成左右两个子区间分别查询,再将结果合并,比如求区间和时将结果相加,求最值时取最值等。区间内最大连续和也是同理,将大区间二分为2个子区间,连续最大和要么在左子区间,要么在右子区间,要么是左子区间的最大后缀和加上右子区间的最大前缀和。(所以query函数的返回值应为结构体,平时线段树query函数不用有返回值是因为用_sum之类的变量在每一小段都把答案维护了,但最大连续和没法这样)
查询需要什么,线段树就维护什么。为了维护这些,可能还需要多维护一些才能够维护这些。(比如这题里的sumv,即区间和)。
例题:P4513
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<algorithm>
#include<vector>
#include<string.h>
#define ll (o*2,L,m)
#define rr (o*2+1,m+1,R)
#define int long long
using namespace std;
const int maxn = 5e5 + 5;
const int INF = 1e9;
struct node
{
int pre, suf, msum, sumv;
};
struct segment_tree
{
int msum[4 * maxn];//最大连续和
int pre[4 * maxn];//最大前缀和
int suf[4 * maxn];//最大后缀和
int sumv[4 * maxn];//和
int x,v; int n; int* arr;
int y1, y2;
int ans;
void dfs(int o, int L, int R)
{
if (L == R)
{
sumv[o]=pre[o] = suf[o] = msum[o] = arr[L];
//不用和0再比一下,负的就负的,不会丢失正确答案的
}
else
{
int m = L + (R - L) / 2;
dfs ll; dfs rr;
int lc = o * 2; int rc = lc + 1;
pre[o] = max(pre[lc], sumv[lc] + pre[rc]);
suf[o] = max(suf[rc], sumv[rc] + suf[lc]);
msum[o] = max(msum[lc], msum[rc]);
msum[o] = max(msum[o], suf[lc] + pre[rc]);
sumv[o] = sumv[lc] + sumv[rc];
}
}
void build(int n, int* a)
{
arr = a; this->n = n;
dfs(1, 1, n);
}
void maintain(int o, int L, int R)
{
int m = L + (R - L) / 2;
int lc = o * 2; int rc = lc + 1;
pre[o] = max(pre[lc], sumv[lc] + pre[rc]);
suf[o] = max(suf[rc], sumv[rc] + suf[lc]);
msum[o] = max(msum[lc], msum[rc]);
msum[o] = max(msum[o], suf[lc] + pre[rc]);
sumv[o] = sumv[lc] + sumv[rc];
}
void update(int o, int L, int R)
{
if (L == R)
{
sumv[o] = pre[o] = suf[o] = msum[o] = v; return;
//一般这里不直接return,而是在else后面统一maintain,
//这边是因为已经相当于maintain了
}
else
{
int m = L + (R - L) / 2;
if (x <= m)update ll;
if (x > m)update rr;
maintain(o, L, R);
}
}
node query(int o, int L, int R)
{
if (y1 <= L && R <= y2)
{
return node{ pre[o],suf[o],msum[o],sumv[o] };
}
else
{
int m = L + (R - L) / 2;
node n1{ -INF,-INF,-INF,0 }; node n2{ -INF,-INF,-INF,0 };
//这边的-INF和0有讲究,原则就是没有左子节点时n1不会使答案出错
if (y1 <= m)n1 = query ll;
if (y2 > m)n2 = query rr;
node ret;
ret.pre = max(n1.pre, n1.sumv + n2.pre);
ret.suf = max(n2.suf, n2.sumv + n1.suf);
ret.msum = max(n1.msum, n2.msum);
ret.msum = max(ret.msum, n1.suf + n2.pre);
ret.sumv = n2.sumv + n1.sumv;
return ret;
}
}
int interval_query(int a, int b)
{
y1 = a; y2 = b;
node t = query(1, 1, n);
return t.msum;
}
void change(int x, int v)
{
this->x = x; this->v = v;
update(1, 1, n);
}
}t;
int arr[maxn];
signed main()
{
int n, m; cin >> n >> m;
for (int i = 1; i <= n; i++)cin >> arr[i];
t.build(n, arr);
for (int i = 0; i < m; i++)
{
int k; cin >> k;
if (k == 1)
{
int a, b; cin >> a >> b;
if (a > b)
{
int t = a; a = b; b = t;
}
cout<<t.interval_query(a, b)<<endl;
}
else
{
int p, s; cin >> p >> s;
t.change(p, s);
}
}
return 0;
}