题目大意
给出 n n n个数,有 m m m个操作
1 1 1 x x x w w w :将 x x x位置上的数改为 w w w
2 2 2 l l l r r r :查询区间 [ l , r ] [l,r] [l,r]内有多少个连续的不下降区间(单个数也算)
思路
我们需要维护的是区间内符合条件的个数,初步考虑的话似乎是有区间相加的性质的,如果考虑两个区间相互独立,不会对各自的递增性质产生影响。
这时我们想到线段树的 p u s h u p pushup pushup操作。
在两个区间合并成一个大区间的,有可能中间的部分会产生连续的递增区间,这个时候合并的时候需要维护一个中点向左边和右边拓展的最大长度,然后判断两个端点是否能进行合并。(详细看代码)
#include <cstdio>
#include <iostream>
#define ll long long
using namespace std;
const int maxn = 2e5 + 9;
int a[maxn], b[maxn];
struct num_node
{
int l, r;
int lm, rm;
// lm从左边这个数字开始往右最长不下降序列长度(往右端最长延伸)
// rm从右边这个数字开始往左最长不上升序列长度(往左端最长延伸)
ll sum; //表示[l,r]区间内有多少个连续递增区间
} s[maxn << 2];
int n, m;
void pushup(int k)
{
s[k].sum = s[k << 1].sum + s[k << 1 | 1].sum;
s[k].lm = s[k << 1].lm, s[k].rm = s[k << 1 | 1].rm;
if (a[s[k << 1].r] <= a[s[k << 1 | 1].l])//可以合并
{
if (s[k << 1].lm == s[k << 1].r - s[k << 1].l + 1)
s[k].lm += s[k << 1 | 1].lm;
if (s[k << 1 | 1].rm == s[k << 1 | 1].r - s[k << 1 | 1].l + 1)
s[k].rm += s[k << 1].rm;
s[k].sum += 1ll * s[k << 1].rm * s[k << 1 | 1].lm;
}
}
void build(int k, int l, int r)
{
s[k].l = l;
s[k].r = r;
if (l == r)
{
s[k].sum = 1;
s[k].lm = s[k].rm = 1;
return;
}
int mid = (l + r) >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
pushup(k);
}
void update(int k, int x, int w)
{
int l = s[k].l, r = s[k].r;
if (l == r)
{
a[x] = w;
return;
}
int mid = (l + r) >> 1;
if (mid >= x)
update(k << 1, x, w);
if (mid < x)
update(k << 1 | 1, x, w);
pushup(k);
}
ll query(int k, int L, int R)
{
int l = s[k].l, r = s[k].r;
if (l > R || r < L)
return 0;
if (L <= l && r <= R)
return s[k].sum;
pushup(k);
int mid = (l + r) >> 1;
ll ans = 0;
if (mid >= L)
ans += query(k << 1, L, R);
if (mid < R)
ans += query(k << 1 | 1, L, R);
if (a[s[k << 1].r] <= a[s[k << 1 | 1].l])
{
int lsum = min(mid - L + 1, s[k << 1].rm), rsum = min(R - (mid + 1) + 1, s[k << 1 | 1].lm);
if (lsum > 0 && rsum > 0)//注意要大于0
ans += 1ll * lsum * rsum;
}
return ans;
}
int L[maxn], R[maxn];
int main()
{
int n, m;
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; ++i)
scanf("%d", &a[i]);
build(1, 1, n);
while (m--)
{
int wx;
scanf("%d", &wx);
if (wx == 1)
{
int x, w;
scanf("%d %d", &x, &w);
update(1, x, w);
}
else
{
int L, R;
scanf("%d %d", &L, &R);
printf("%lld\n", query(1, L, R));
}
}
return 0;
}