codeforces 线段树题单
step1
1.A. Segment Tree for the Sum
题意:
给你一个长度为 n n n数组 a a a,执行 m m m个操作:
分为两类:
-
- 把数组下标第 i i i个数的值修改为 v v v
-
- 询问数组下标为 l l l~ r r r的区间和
数据范围
1 ⩽ n , m ⩽ 1 0 5 1 \leqslant n, m \leqslant 10^5 1⩽n,m⩽105
0 ⩽ l < r < n 0 \leqslant l < r < n 0⩽l<r<n
0 ⩽ a [ i ] , v ⩽ 1 0 9 0 \leqslant a[i], v \leqslant 10^9 0⩽a[i],v⩽109
思路:
线段树维护区间和即可。
时间复杂度:
O ( n log n ) O(n\log{n}) O(nlogn)
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
template <typename T> void read(T &t) {
t=0; char ch=getchar(); int f=1;
while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
template <typename T> void write(T t) {
if (t<0) { putchar('-'); write(-t); return; }
if (t>9) write(t/10);
putchar('0'+t%10);
}
template <typename T> void writeln(T t) { write(t); puts(""); }
const int N = 100010;
struct Node{
LL sum;
}seg[N * 4];
int n, m;
int a[N];
void pushup(int u){
seg[u].sum = seg[u << 1].sum + seg[u << 1 | 1].sum;
}
void build(int u, int l, int r){
if(l == r){
seg[u].sum = a[l];
return;
}
int mid = l + r >> 1; build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); pushup(u);
}
void update(int u, int l, int r, int id, int val){
if(l == r && r == id){
seg[u].sum = val;
return;
}
int mid = l + r >> 1;
if(id <= mid) update(u << 1, l, mid, id, val);
else update(u << 1 | 1, mid + 1, r, id, val);
pushup(u);
}
LL query(int u, int l, int r, int ql, int qr){
if(l == ql && r == qr) return seg[u].sum;
int mid = l + r >> 1;
if(qr <= mid) return query(u << 1, l, mid, ql, qr);
else if(ql > mid) return query(u << 1 | 1, mid + 1, r, ql, qr);
else return query(u << 1, l, mid, ql, mid) + query(u << 1 | 1, mid + 1, r, mid + 1, qr);
}
int main(){
read(n), read(m);for(int i = 1; i <= n; i ++ ) read(a[i]);
build(1, 1, n);
while(m -- ){
int op, x, y;read(op), read(x), read(y);++x;
if(op == 1) update(1, 1, n, x, y);
else printf("%lld\n",query(1, 1, n, x, y));
}
return 0;
}
B. Segment Tree for the Minimum
题意:
给你一个长度为 n n n数组 a a a,执行 m m m个操作:
分为两类:
-
- 把数组下标第 i i i个数的值修改为 v v v
-
- 询问数组下标为 l l l~ r r r的区间最小值
数据范围
1 ⩽ n , m ⩽ 1 0 5 1 \leqslant n, m \leqslant 10^5 1⩽n,m⩽105
0 ⩽ l < r < n 0 \leqslant l < r < n 0⩽l<r<n
0 ⩽ a [ i ] , v ⩽ 1 0 9 0 \leqslant a[i], v \leqslant 10^9 0⩽a[i],v⩽109
思路:
线段树维护区间最小值即可。
时间复杂度:
O ( n log n ) O(n\log{n}) O(nlogn)
代码
#include <bits/stdc++.h>
using namespace std;;
template <typename T> void read(T &t) {
t=0; char ch=getchar(); int f=1;
while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
template <typename T> void write(T t) {
if (t<0) { putchar('-'); write(-t); return; }
if (t>9) write(t/10);
putchar('0'+t%10);
}
template <typename T> void writeln(T t) { write(t); puts(""); }
const int N = 100010
struct Node{
int minval;
}seg[N * 4];
int n, m;
int a[N];
void pushup(int u){
seg[u].minval = min(seg[u << 1].minval, seg[u << 1 | 1].minval);
}
void build(int u, int l, int r){
if(l == r){
seg[u].minval = a[l];
return;
}
int mid = l + r >> 1; build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); pushup(u);
}
void update(int u, int l, int r, int id, int val){
if(l == r && r == id){
seg[u].minval = val;
return;
}
int mid = l + r >> 1;
if(id <= mid) update(u << 1, l, mid, id, val);
else update(u << 1 | 1, mid + 1, r, id, val);
pushup(u);
}
int query(int u, int l, int r, int ql, int qr){
if(l == ql && r == qr) return seg[u].minval;
int mid = l + r >> 1;
if(qr <= mid) return query(u << 1, l, mid, ql, qr);
else if(ql > mid) return query(u << 1 | 1, mid + 1, r, ql, qr);
else return min(query(u << 1, l, mid, ql, mid) , query(u << 1 | 1, mid + 1, r, mid + 1, qr));
}
int main(){
read(n), read(m);for(int i = 1; i <= n; i ++ ) read(a[i]);
build(1, 1, n);
while(m -- ){
int op, x, y;read(op), read(x), read(y);++x;
if(op == 1) update(1, 1, n, x, y);
else writeln(query(1, 1, n, x, y));
}
return 0;
}
C. Number of Minimums on a Segment
题意:
给你一个长度为 n n n数组 a a a,执行 m m m个操作:
分为两类:
-
- 把数组下标第 i i i个数的值修改为 v v v
-
- 询问数组下标为 l l l~ r r r的区间最小值,并询问最小值出现次数
数据范围
1 ⩽ n , m ⩽ 1 0 5 1 \leqslant n, m \leqslant 10^5 1⩽n,m⩽105
0 ⩽ l < r < n 0 \leqslant l < r < n 0⩽l<r<n
0 ⩽ a [ i ] , v ⩽ 1 0 9 0 \leqslant a[i], v \leqslant 10^9 0⩽a[i],v⩽109
思路:
线段树维护区间最小值和最小值出现次数。
时间复杂度:
O ( n log n ) O(n\log{n}) O(nlogn)
部分代码
可以先重载下+
号,后面可以直接用,方便呀
struct info{
int minval;
int cnt;
};
info operator + (const info &x, const info &y){
info a = {2000000000, 0};//错误地方
a.minval = min(x.minval, y.minval);
if(a.minval == x.minval) a.cnt += x.cnt;
if(a.minval == y.minval) a.cnt += y.cnt;
return a;
}
易错点
1.直接在函数定义一个 info a
,a
里面的元素是随机的值。所以要先把一开始的a
里面的minval
初始化为无穷大。
2.update
以后记得要 pushup。
完整代码
#include <bits/stdc++.h>
using namespace std;
template <typename T> void read(T &t) {
t=0; char ch=getchar(); int f=1;
while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
template <typename T> void write(T t) {
if (t<0) { putchar('-'); write(-t); return; }
if (t>9) write(t/10);
putchar('0'+t%10);
}
template <typename T> void writeln(T t) { write(t); puts(""); }
const int N = 100010;
struct info{
int minval;
int cnt;
};
info operator + (const info &x, const info &y){
info a = {2000000000, 0};//错误地方
a.minval = min(x.minval, y.minval);
if(a.minval == x.minval) a.cnt += x.cnt;
if(a.minval == y.minval) a.cnt += y.cnt;
return a;
}
struct Node{
info val;
}seg[N * 4];
int n, m;
int a[N];
void pushup(int u){
seg[u].val = seg[u << 1].val + seg[u << 1 | 1].val;
}
void build(int u, int l, int r){
if(l == r){
seg[u].val = (info){a[l], 1};
return;
}
int mid = l + r >> 1; build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r); pushup(u);
}
void update(int u, int l, int r, int id, int val){
if(l == r && r == id){
seg[u].val= {val, 1};
return;
}
int mid = l + r >> 1;
if(id <= mid) update(u << 1, l, mid, id, val);
else update(u << 1 | 1, mid + 1, r, id, val);
pushup(u);
}
info query(int u, int l, int r, int ql, int qr){
if(l == ql && r == qr) return seg[u].val;
int mid = l + r >> 1;
if(qr <= mid) return query(u << 1, l, mid, ql, qr);
else if(ql > mid) return query(u << 1 | 1, mid + 1, r, ql, qr);
else return query(u << 1, l, mid, ql, mid) + query(u << 1 | 1, mid + 1, r, mid + 1, qr);
}
int main(){
read(n), read(m);for(int i = 1; i <= n; i ++ ) read(a[i]);
build(1, 1, n);
while(m -- ){
int op, x, y;read(op), read(x), read(y);++x;
if(op == 1) update(1, 1, n, x, y);
else{
info ans = query(1, 1, n, x, y);
printf("%d %d\n",ans.minval, ans.cnt);
}
}
return 0;
}
step2
A. Segment with the Maximum Sum
题意:
给你一个长度为 n n n数组 a a a,执行 m m m个操作:
每次操作把数组下标第
i
i
i个数的值修改为
v
v
v,
并在每次操作之前和之后输出数组最大连续子段和。
数据范围
1 ⩽ n , m ⩽ 1 0 5 1 \leqslant n, m \leqslant 10^5 1⩽n,m⩽105
0 ⩽ a [ i ] , v ⩽ 1 0 9 0 \leqslant a[i], v \leqslant 10^9 0⩽a[i],v⩽109
思路:
线段树需要维护出最大连续字段和,
需要维护的信息:区间和sum
,区间左边最大连续子段和lsum
,区间右边最大连续子段和rsum
,区间最大连续子段和maxsum
。
更新方式
info operator + (const info &A, const info &B){
info a = {0, 0, 0, 0};
a.sum = A.sum + B.sum;
a.lsum = max(A.sum + B.lsum, A.lsum);
a.rsum = max(A.rsum + B.sum, B.rsum);
a.maxsum = max({A.maxsum, B.maxsum, A.rsum + B.lsum});
return a;
};
时间复杂度:
O ( n log n ) O(n\log{n}) O(nlogn)
代码
#include <bits/stdc++.h>
#define debug(a) cout<<#a<<"="<<a<<endl;
using namespace std;
typedef long long LL;
template <typename T> void read(T &t) {
t=0; char ch=getchar(); int f=1;
while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
template <typename T> void write(T t) {
if (t<0) { putchar('-'); write(-t); return; }
if (t>9) write(t/10);
putchar('0'+t%10);
}
template <typename T> void writeln(T t) { write(t); puts(""); }
const int N = 100010;
struct info{
LL lsum, rsum, maxsum, sum;
};
info operator + (const info &A, const info &B){
info a = {0, 0, 0, 0};
a.sum = A.sum + B.sum;
a.lsum = max(A.sum + B.lsum, A.lsum);
a.rsum = max(A.rsum + B.sum, B.rsum);
a.maxsum = max({A.maxsum, B.maxsum, A.rsum + B.lsum});
return a;
};
struct Node{
info val;
}seg[N * 4];
int n, m;
LL a[N];
void pushup(int u){
seg[u].val = seg[u << 1].val + seg[u << 1 | 1].val;
}
void build(int u,int l,int r){
if(l == r){
seg[u].val = {a[l], a[l], a[l], a[l]};
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int id, int value){
if(l == r && r == id){
seg[u].val = {value, value, value, value};
return;
}
int mid = l + r >> 1;
if(id <= mid) update(u << 1, l, mid, id, value);
else update(u << 1 | 1, mid + 1, r, id, value);
pushup(u);
}
info query(int u, int l, int r, int ql, int qr){
if(l == ql && r == qr){
return seg[u].val;
}
int mid = l + r >> 1;
if(qr <= mid) return query(u << 1, l, mid, ql, qr);
else if(ql > mid) return query(u << 1 | 1, mid + 1, r, ql, qr);
else{
return query(u << 1, l, mid, ql, mid) + query(u << 1 | 1, mid + 1, r, mid + 1, qr);
}
}
int main(){
read(n);read(m); for(int i = 1; i <= n; i ++ ) read(a[i]);
build(1, 1, n);
writeln(max(query(1, 1, n, 1, n).maxsum, 0ll));
while(m -- ){
int x, y;
read(x), read(y); ++ x;
update(1, 1, n, x, y);
writeln(max(query(1, 1, n, 1, n).maxsum, 0ll));
}
return 0;
}
B. K-th one
题意:
给你一个长度为 n n n二进制数组 a a a,执行 m m m个操作:
分为两类:
-
- 把数组下标第 i i i个数的值翻转
-
- 询问数组中第 x x x个 1 1 1的下标
数据范围
1 ⩽ n , m ⩽ 1 0 5 1 \leqslant n, m \leqslant 10^5 1⩽n,m⩽105
a [ i ] = 1 或者 a [ i ] = 0 a[i] = 1 或者 a[i] = 0 a[i]=1或者a[i]=0
思路:
线段树维护区间和,然后二分出区间和等于x
的位置
时间复杂度:
O
(
n
∗
log
n
∗
log
n
)
O(n*\log{n}*\log{n})
O(n∗logn∗logn)
代码:
#include <bits/stdc++.h>
#define debug(x) cout << #x << " = " << x << endl
using namespace std;
template <typename T> void read(T &t) {
t=0; char ch=getchar(); int f=1;
while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
template <typename T> void write(T t) {
if (t<0) { putchar('-'); write(-t); return; }
if (t>9) write(t/10);
putchar('0'+t%10);
}
template <typename T> void writeln(T t) { write(t); puts(""); }
const int N = 100010;
struct Node{
int sum;
}seg[N * 4];
int a[N];
void pushup(int u){
seg[u].sum = seg[u << 1].sum + seg[u << 1 | 1].sum;
}
void build(int u, int l, int r){
if(l == r){
seg[u].sum = a[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int id){
if(l == r && l == id){
if(seg[u].sum == 1) seg[u].sum = 0;
else seg[u].sum = 1;
return;
}
int mid = l + r >> 1;
if(id <= mid) update(u << 1, l, mid, id);
else update(u << 1 | 1, mid + 1, r, id);
pushup(u);
}
int query(int u, int l, int r, int ql, int qr){//两个log写法
if(l == ql && r == qr){
return seg[u].sum;
}
int mid = l + r >> 1;
if(qr <= mid) return query(u << 1, l, mid, ql, qr);
else if(ql > mid) return query(u << 1 | 1, mid + 1, r, ql, qr);
else{
return query(u << 1, l, mid, ql, mid) + query(u << 1 | 1, mid + 1, r, mid + 1, qr);
}
}
int n, m;
int main(){
read(n);read(m);
for(int i = 0; i <= n - 1; i ++ ) read(a[i]);
build(1, 0, n - 1);
while(m -- ){
int op, x; read(op);read(x);
if(op == 1) update(1, 0, n - 1, x);
else{
++x;
int l = 0, r = n - 1;
while(l < r){
int mid = l + r >> 1;
if(query(1, 0, n - 1, 0, mid) >= x) r = mid;
else l = mid + 1;
}
writeln(l);
}
}
return 0;
}
优化 :
可以直接在线段树上二分:
先找到1~i
区间和等于x
的区间范围,再不断减小区间的范围,直到找到答案
时间复杂度 :
O ( n log n ) O(n\log{n}) O(nlogn)
#include <bits/stdc++.h>
#define debug(x) cout << #x << " = " << x << endl
using namespace std;
template <typename T> void read(T &t) {
t=0; char ch=getchar(); int f=1;
while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
template <typename T> void write(T t) {
if (t<0) { putchar('-'); write(-t); return; }
if (t>9) write(t/10);
putchar('0'+t%10);
}
template <typename T> void writeln(T t) { write(t); puts(""); }
const int N = 100010;
struct Node{
int sum;
}seg[N * 4];
int a[N];
void pushup(int u){
seg[u].sum = seg[u << 1].sum + seg[u << 1 | 1].sum;
}
void build(int u, int l, int r){
if(l == r){
seg[u].sum = a[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int id){
if(l == r && l == id){
if(seg[u].sum == 1) seg[u].sum = 0;
else seg[u].sum = 1;
return;
}
int mid = l + r >> 1;
if(id <= mid) update(u << 1, l, mid, id);
else update(u << 1 | 1, mid + 1, r, id);
pushup(u);
}
int search(int u, int l, int r, int x){//一个log写法
if(seg[u].sum == x){
if(l == r) return l;
int mid = l + r >> 1;
if(seg[u << 1 | 1].sum == 0) return search(u << 1, l, mid, x);
else return search(u << 1 | 1, mid + 1, r, x - seg[u << 1].sum);
}
int mid = l + r >> 1;
if(seg[u].sum > x){
if(seg[u << 1].sum < x) return search(u << 1 | 1, mid + 1, r, x - seg[u << 1].sum);
return search(u << 1, l, mid, x);
}
}
int n, m;
int main(){
read(n);read(m);
for(int i = 0; i <= n - 1; i ++ ) read(a[i]);
build(1, 0, n - 1);
while(m -- ){
int op, x; read(op);read(x);
if(op == 1) update(1, 0, n - 1, x);
else{
++x;
int l = 0, r = n - 1;
printf("%d\n",search(1, 0, n - 1, x));
}
}
return 0;
}
C. First element at least X
题意:
给你一个长度为 n n n数组 a a a,执行 m m m个操作:
分为两类:
-
- 把数组下标第 i i i个数的值修改为 v v v
-
- 询问数组中第 1 1 1个大于等于 x x x的下标
数据范围
1 ⩽ n , m ⩽ 1 0 5 1 \leqslant n, m \leqslant 10^5 1⩽n,m⩽105
1 ⩽ a [ i ] , x ⩽ 1 0 5 1 \leqslant a[i], x \leqslant 10^5 1⩽a[i],x⩽105
思路:
线段树维护区间最大值,然后二分出区间和等于x
的位置
时间复杂度
O
(
n
log
n
)
O(n\log{n})
O(nlogn)
代码
#include <bits/stdc++.h>
using namespace std;
template <typename T> void read(T &t) {
t=0; char ch=getchar(); int f=1;
while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
template <typename T> void write(T t) {
if (t<0) { putchar('-'); write(-t); return; }
if (t>9) write(t/10);
putchar('0'+t%10);
}
template <typename T> void writeln(T t) { write(t); puts(""); }
const int N = 100010;
struct Node{
int maxn;
}seg[N * 4];
int n, m;
int a[N];
void pushup(int u){
seg[u].maxn = max(seg[u << 1].maxn, seg[u << 1 | 1].maxn);
}
void build(int u, int l, int r){
if(l == r){
seg[u].maxn = a[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int id, int x){
if(l == r && id == l){
seg[u].maxn = x;
return;
}
int mid = l + r >> 1;
if(id <= mid) update(u << 1, l, mid, id, x);
else update(u << 1 | 1, mid + 1, r, id, x);
pushup(u);
}
int search(int u, int l, int r, int x){
if(seg[u].maxn >= x){
if(l == r) return l;
int mid = l + r >> 1;
if(seg[u << 1].maxn >= x) return search(u << 1, l, mid, x);
else return search(u << 1 | 1, mid + 1, r, x);
}
return -1;
}
int main(){
read(n);read(m);for(int i = 0; i <= n - 1; i ++ ) read(a[i]);
build(1, 0, n - 1);
while(m -- ){
int op; read(op);
if(op == 1){
int id, x;read(id);read(x);
update(1, 0, n - 1, id, x);
}else{
int x;read(x);
writeln(search(1, 0, n - 1, x));
}
}
return 0;
}
D. First element at least X - 2
题意:
给你一个长度为 n n n数组 a a a,执行 m m m个操作:
分为两类:
-
- 把数组下标第 i i i个数的值修改为 v v v
-
- 询问数组中下标大于等于 l l l中,第 1 1 1个大于等于 x x x的下标
数据范围
1 ⩽ n , m , l ⩽ 1 0 5 1 \leqslant n, m, l \leqslant 10^5 1⩽n,m,l⩽105
1 ⩽ a [ i ] , x ⩽ 1 0 5 1 \leqslant a[i], x \leqslant 10^5 1⩽a[i],x⩽105
思路:
线段树维护区间最大值,然后先二分找到指定区间,再在指定区间内二分出区间和等于x
的位置
时间复杂度
O ( n log n ) O(n\log{n}) O(nlogn)
代码
#include <bits/stdc++.h>
#define debug(x) cout << #x << " = " << x << endl
using namespace std;
template <typename T> void read(T &t) {
t=0; char ch=getchar(); int f=1;
while (ch<'0'||ch>'9') { if (ch=='-') f=-1; ch=getchar(); }
do { (t*=10)+=ch-'0'; ch=getchar(); } while ('0'<=ch&&ch<='9'); t*=f;
}
template <typename T> void write(T t) {
if (t<0) { putchar('-'); write(-t); return; }
if (t>9) write(t/10);
putchar('0'+t%10);
}
template <typename T> void writeln(T t) { write(t); puts(""); }
const int N = 100010;
struct Node{
int maxn;
}seg[N * 4];
int n, m;
int a[N];
void pushup(int u){
seg[u].maxn = max(seg[u << 1].maxn, seg[u << 1 | 1].maxn);
}
void build(int u, int l, int r){
if(l == r){
seg[u].maxn = a[l];
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int id, int x){
if(l == r && id == l){
seg[u].maxn = x;
return;
}
int mid = l + r >> 1;
if(id <= mid) update(u << 1, l, mid, id, x);
else update(u << 1 | 1, mid + 1, r, id, x);
pushup(u);
}
int search(int u, int l, int r, int ql, int x){
if(l >= ql){
if(seg[u].maxn < x) return -1;
if(l == r) return l;
int mid = l + r >> 1;
if(seg[u << 1].maxn >= x) return search(u << 1, l, mid, ql, x);
return search(u << 1 | 1, mid + 1, r, ql, x);
}
int mid = l + r >> 1;
int pos = -1;
if(ql <= mid) pos = search(u << 1, l, mid, ql, x);
if(pos == -1) return search(u << 1 | 1, mid + 1, r, ql, x);
}
int main(){
read(n);read(m);for(int i = 0; i <= n - 1; i ++ ) read(a[i]);
build(1, 0, n - 1);
while(m -- ){
int op; read(op);
if(op == 1){
int id, x;read(id);read(x);
update(1, 0, n - 1, id, x);
}else{
int l, x;read(x);read(l);
writeln(search(1, 0, n - 1, l, x));
}
}
return 0;
}