题目链接: Non-Decreasing Dilemma
大致题意
给定长度为 n n n的序列 a a a, 有两种操作:
1 pos c
修改
a
p
o
s
=
c
a_{pos} = c
apos=c.
2 l r
询问区间存在多少个连续非递减子序列.
解题思路
思维
我们先只考虑查询操作: 对于一段长度为 l e n len len的连续非递减子序列而言, 其中会有 ∑ i = 1 l e n i \displaystyle \sum_{i = 1}^{len}i i=1∑leni个符合要求的序列.
因此我们只需统计 [ l , r ] [l, r] [l,r]中每个最长的非递减子序列的长度即可.
线段树 由于题目带修改操作, 我们考虑用线段树维护区间信息.
考虑到对于两个区间 l e f t , r i g h t left, right left,right进行合并时, 当 l e f t left left的右端点值小于等于 r i g h t right right的左端点值时, 此时才会产生更长的连续子序列. 因此树中需要维护 l n u m , r n u m lnum, rnum lnum,rnum, 表示当前区间左(右)起的连续子序列长度. 在合并区间时, 减去原先的贡献, 再加上合并后的贡献即可.
AC代码
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1; i <= (n); ++i)
using namespace std;
typedef long long ll;
const int N = 2E5 + 10;
ll sum[N];
void init(int n = N - 5) { rep(i, n) sum[i] = sum[i - 1] + i; }
int w[N];
struct node {
int l, r;
ll sum; int lnum, rnum;
}t[N << 2];
void pushup(node& p, node& l, node& r) {
p.l = l.l, p.r = r.r;
p.sum = l.sum + r.sum;
if (w[l.r] <= w[r.l]) {
p.sum -= sum[l.rnum] + sum[r.lnum];
p.sum += sum[l.rnum + r.lnum];
int llen = l.r - l.l + 1, rlen = r.r - r.l + 1;
p.lnum = l.lnum == llen ? llen + r.lnum : l.lnum;
p.rnum = r.rnum == rlen ? rlen + l.rnum : r.rnum;
}
else p.lnum = l.lnum, p.rnum = r.rnum;
}
void pushup(int x) { pushup(t[x], t[x << 1], t[x << 1 | 1]); }
void build(int l, int r, int x = 1) {
t[x] = { l, r, 1, 1, 1 };
if (l == r) return;
int mid = l + r >> 1;
build(l, mid, x << 1), build(mid + 1, r, x << 1 | 1);
pushup(x);
}
void modify(int a, int x = 1) {
if (t[x].l == t[x].r) return;
int mid = t[x].l + t[x].r >> 1;
modify(a, x << 1 | (a > mid));
pushup(x);
}
auto ask(int l, int r, int x = 1) {
if (l <= t[x].l and r >= t[x].r) return t[x];
int mid = t[x].l + t[x].r >> 1;
if (r <= mid) return ask(l, r, x << 1);
if (l > mid) return ask(l, r, x << 1 | 1);
node left = ask(l, r, x << 1), right = ask(l, r, x << 1 | 1);
node res = { 0, 0, 0, 0, 0 };
pushup(res, left, right);
return res;
}
int main()
{
init();
int n, m; cin >> n >> m;
rep(i, n) scanf("%d", &w[i]);
build(1, n);
rep(i, m) {
int tp; scanf("%d", &tp);
if (tp == 1) {
int a, c; scanf("%d %d", &a, &c);
w[a] = c;
modify(a);
}
else {
int l, r; scanf("%d %d", &l, &r);
printf("%lld\n", ask(l, r).sum);
}
}
return 0;
}