ST表
- ST表 (sparse table)即稀疏表
它可以在 O ( n l o g 2 n ) O(nlog_2 n) O(nlog2n)内预处理, O ( 1 ) O(1) O(1)内查询:
-
1. 区间最大值 2. 区间最小值 3. 区间最大公约数 4. 区间最小公倍数
满足吸收率的操作貌似都可以? (吸收率,详见离散数学代数系统的吸收率)
- ST表为离线算法,因此区间给定后,不能进行修改,否则整张表将要重新计算,时间复杂度将会变得非常高。对于带有修改的操作可以使用线段树或树状数组。
那我还学什么ST表?树状数组它不香吗? - ST表依旧被广为使用的原因是其优秀的时间复杂度,以及码量少。
其实树状数组的码量好像更少,但是权当学习倍增思想了。
题目
洛谷-> P3865 【模板】ST表
当然这道题用线段树、树状数组。
ST表的引出
我们以上面这道题为例,详细的引出ST表。
- 最好想的朴素算法
它查询哪个区间,我们直接循环遍历那个区间,求出最大值。
#include <bits/stdc++.h>
using namespace std;
#define int long long
//快读
inline int read(){
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
#define read read() //我觉得写俩括号太难了
const int maxn = 1e5+9;
int a[maxn],n,m;
signed main(){
n=read,m=read;
for(int i=1;i<=n;++i)
a[i] = read;
while(m--){
int l=read,r=read;
int mn = LONG_LONG_MIN;
for(int i=l;i<=r;++i)
mn = max(mn,a[i]);
printf("%lld\n",mn);
}
return 0;
}
很显然时间复杂度是 O ( n m ) O(nm) O(nm) 的,对这题来说,必超时。
我们考虑一下能不能优化一下。考虑到它的区间是不变的,我们可以进行打表,预处理出所有区间的最大值,查询的时候可以直接查出来。
- 打表优化(动态规划)
预处理出所有区间的最大值,查询的时候直接输出。
首先我们定义一个 i 行 j 列的数组 int ans[i][j]
,用来表示区间
[
i
,
j
]
[i,j]
[i,j] 的最大值。
我们需要找一下打表的方法(状态转移方程)
- 很显然当
i
=
=
j
i==j
i==j 时
ans[i][j] = a[j]
- 由上图不难看出当
i
≠
j
i ≠ j
i=j 时,状态转移方程为
ans[i][j] = max(ans[i][j-1] , a[j]);
因此我们可以写出打表代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
//快读
inline int read(){
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
#define read read() //我觉得写俩括号太难了
const int maxn = 1e5+9;
int a[maxn],n,m;
int ans[maxn][maxn];
signed main(){
n=read,m=read;
for(int i=1;i<=n;++i)
a[i] = read;
//预处理
for(int i=1;i<=n;++i)
for(int j=i;j<=n;j++){
if(i==j) ans[i][j] = a[j];
else ans[i][j] = max(ans[i][j-1],a[j]);
}
while(m--){
int l=read,r=read;
printf("%lld\n",ans[l][r]);
}
return 0;
}
对代码进行分析分析(
事实上这个代码在我电脑上根本运行不了,内存太大了),首先是空间复杂度,很显然的 O ( n 2 ) O(n^2) O(n2),其次是时间复杂度,刚开始进行了一次预处理,预处理是 O ( n 2 ) O(n^2) O(n2) 的,然后进行了 m 次询问,所以总时间复杂度为 O ( n 2 + m ) O(n^2+m) O(n2+m) ,这个复杂度比朴素算法要好一些 (因为m比n大的多),但是无论是空间还是时间都还是不能满足该题目。
我们需要继续优化,在上面的预处理中,我们可以发现,每次更新都只是将区间扩大了1个,这必然会导致非常多的重复值,我们能不能一次将区间扩大很多个,同时又能保证每个区间都能被覆盖到。于是乎考虑—倍增!千呼万唤始出来
ST表实现
首先定义一个数组 int ans[i][j]
其表示的意义不再是区间
[
i
,
j
]
[i,j]
[i,j]的最大值,而是借助了倍增思想,每次扩充
2
j
2^j
2j 个数,故表示的是从 i 开始长度为
2
j
2^j
2j 的区间 , 即
[
i
,
i
+
2
j
−
1
]
[i,i+2^j-1]
[i,i+2j−1] 这个区间的最大值。
预处理
现在我们来看一下倍增思想的预处理是什么样的。
- 很显然
ans[i][0]
表示的是区间 [ i , i + 2 0 − 1 ] [i,i+2^0-1] [i,i+20−1] 即 [ i , i ] [i,i] [i,i] , 所以ans[i][0] = a[i]
- 由于 a [ i ] a[i] a[i] 已经被用过了,所以我们可以知道 a n s [ i ] [ j ] ans[i][j] ans[i][j] 的转移方程不会再与 a [ i ] 或 a [ j ] a[i]或a[j] a[i]或a[j] 产生关系
- 我们需要找到两个已经处理完毕的区间,并且这两个小区间能够覆盖住新的更大的区间,很明显我们能够想到
ans[i][j-1]
,通过上面的图我们可以找到另一块更小的区间 a n s [ i + 2 j − 1 ] [ j − 1 ] ans[i+2^{j-1}][j-1] ans[i+2j−1][j−1], 因此可以得到状态转移方程 :ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1])
- 并且我们需要保证每次转移状态时 a n s [ i + 2 j − 1 ] [ j − 1 ] ans[i+2^{j-1}][j-1] ans[i+2j−1][j−1] 已经被更新过,因此我们的外层循环应该是 j ,而内层循环应该是 i,因为 i 更新的速度快。
于是我们的预处理代码为:
void proc(){
for(int i=1;i<=n;++i)
ans[i][0] = read; //ans[i][0] = a[i] 所以我们没有必要再开一个a数组,直接输入即可
for(int j=1;j<=log2(n);++j)
for(int i=1;i+(1<<j)-1<=n;++i)//i+(1<<j)-1是为了防止下次操作越界
ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1]);
}
此时预处理的时间复杂度为 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n)
查询
在上面我们已经预处理出了所有 [ i , i + 2 j − 1 ] [i,i+2^j-1] [i,i+2j−1]的区间。给定 l , r l,r l,r 我们怎么查询 [ l , r ] [l,r] [l,r]的最大值呢?
- 不妨以 [ 1 , 14 ] [1,14] [1,14] 来说明一下,为了查询我们需要将区间拆分成长度为 2 k 2^k 2k的小区间,不难算出 [ 1 , 14 ] = [ 1 , 8 ] ∪ [ 9 , 12 ] ∪ [ 13 , 14 ] [1,14] = [1,8] ∪ [9,12] ∪ [13,14] [1,14]=[1,8]∪[9,12]∪[13,14]
- 因此区间
[
l
,
r
]
[l,r]
[l,r] 的最大值为
max(ans[1][3],ans[9][2],ans[13][1])
- 而更一般的该如何拆分呢,考虑二进制, [ 1 , 14 ] [1,14] [1,14]区间长度为14,14的二进制为 1110,也即 14 = 2 3 + 2 2 + 2 1 14=2^3+2^2+2^1 14=23+22+21 ,这个时候你应该明白了我们的小区间是如何拆分的,事实上任何一个数 n 都能拆成形如 2 a 1 + 2 a 2 + 2 a 3 + . . . + . . . 2 a n 2^{a_1}+2^{a_2}+2^{a_3}+...+...2^{a_n} 2a1+2a2+2a3+...+...2an 的形式的
- 所以我们可以求出任何一个区间长度的形如上述的表示方法,然后求出所有小区间的最大值即为要求区间的最大值。
此时的单次查询时间复杂度为 O ( l o g 2 n ) O(log_2n) O(log2n),总的时间复杂度为 O ( n l o g 2 n + m l o g 2 n ) O(nlog_2n+mlog_2n) O(nlog2n+mlog2n),此时仍不能达到通过题目的程度。我们在开始已经说过ST表查询时间复杂度可以达到 O ( 1 ) O(1) O(1),我们需要继续优化。
事实上,我们真的有必要将区间划分为这么多的小区间来查询吗?
- 对于一段区间,拆分成这样
和拆分成这样
似乎没有什么区别,因为最大值满足吸收率,所以就算出现交集也不会影响到最终结果。所以我们根本没必要将区间划分成这么多份,而仅仅需要划分成两份即可。 - 假设区间 [ l , r ] [l,r] [l,r] 的长度为 s ,那么我们只要找到不大于 s 的最大的 2 k 2^k 2k ,然后将区间划分成为 [ l , l + 2 k − 1 ] , [ r − 2 k + 1 , r ] [l,l+2^k-1],[r-2^k+1,r] [l,l+2k−1],[r−2k+1,r] ,即可。
- 为什么非要找到不大于 s 的最大的 2 k 2^k 2k 呢,因为我们要保证 r − 2 k + 1 ≤ l + 2 k − 1 r-2^k+1≤ l+2^k-1 r−2k+1≤l+2k−1 ,即这俩区间必须能够覆盖住整个区间 [ l , r ] [l,r] [l,r]
- 而要找到不大于 s 的最大的
2
k
2^k
2k ,直接使用
int k = log2(r-l+1)
即可获得。 - 这样无论什么区间,我们只需要返回
max(ans[l][k],ans[r-(1<<k)+1][k])
即可。
查询代码:
int query(int l,int r){
int k = log2(r-l+1);
return max(ans[l][k],ans[r-(1<<k)+1][k]);
}
至此我们已经将ST表写完!读到这里,你应该已经懂得为什么ST表只能处理满足吸收率的运算,因为ST表查询的区间是重叠起来的 ! 因此不能用来查询区间和等问题。
例题总代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
//快读
inline int read(){
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
#define read read() //我觉得写俩括号太难了
const int maxn = 1e5+9;
int a[maxn],n,m;
int ans[maxn][30];
void proc(){
for(int i=1;i<=n;++i)
ans[i][0] = read;
for(int j=1;j<=log2(n);++j)
for(int i=1;i+(1<<j)-1<=n;++i)
ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1]);
}
int query(int l,int r){
int k = log2(r-l+1);
return max(ans[l][k],ans[r-(1<<k)+1][k]);
}
signed main(){
n=read,m=read;
proc();
while(m--){
int l=read,r=read;
printf("%lld\n",query(l,r));
}
return 0;
}
- 但是由于
log2()
的复杂度不明确,所以对于 m 次查询(而且m远远大于n),我们没有必要每次都计算一个 log2(r-l+1) ,我们可以预处理出 1-n 的所有 log2 值,然后直接使用即可。 - 代码如下:
#include <bits/stdc++.h>
using namespace std;
#define int long long
//快读
inline int read(){
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
#define read read() //我觉得写俩括号太难了
const int maxn = 1e5+9;
int a[maxn],n,m;
int ans[maxn][30];
int Log[maxn];
void proc(){
for(int i=1;i<=n;++i)
ans[i][0] = read;
for(int j=1;j<=Log[n];++j)
for(int i=1;i+(1<<j)-1<=n;++i)
ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1]);
}
int query(int l,int r){
return max(ans[l][Log[r-l+1]],ans[r-(1<<Log[r-l+1])+1][Log[r-l+1]]);
}
signed main(){
n=read,m=read;
for(int i=1;i<=n;++i)
Log[i] = log2(i);//预处理出log2
proc();
while(m--){
int l=read,r=read;
printf("%lld\n",query(l,r));
}
return 0;
}
时间大约提升了200ms
- 此外还有另一种预处理 log2() 的方式,利用递推思想,避免调用内置的
log2()
函数:
递推式为:
- 代码如下:
#include <bits/stdc++.h>
using namespace std;
#define int long long
//快读
inline int read(){
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
#define read read() //我觉得写俩括号太难了
const int maxn = 1e5+9;
int a[maxn],n,m;
int ans[maxn][30];
int Log[maxn];
void proc(){
for(int i=1;i<=n;++i)
ans[i][0] = read;
for(int j=1;j<=Log[n];++j)
for(int i=1;i+(1<<j)-1<=n;++i)
ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1]);
}
int query(int l,int r){
return max(ans[l][Log[r-l+1]],ans[r-(1<<Log[r-l+1])+1][Log[r-l+1]]);
}
signed main(){
n=read,m=read;
for(int i=2;i<=n;++i)
Log[i] = Log[i/2]+1;//递推式
proc();
while(m--){
int l=read,r=read;
printf("%lld\n",query(l,r));
}
return 0;
}
时间相对于 log2 的预处理提升了大约40ms,貌似意义不太大。
额外经验
- 求区间最小值,只需把上面的max换成min即可
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define putlen putchar('\n')
//快读
inline int read(){
int X=0; bool flag=1; char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') flag=0; ch=getchar();}
while(ch>='0'&&ch<='9') {X=(X<<1)+(X<<3)+ch-'0'; ch=getchar();}
if(flag) return X;
return ~(X-1);
}
#define read read()
//快输
inline void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9) print(x/10);
putchar(x%10+'0');
}
int st[1000006][40];
int n,m;
int query(int l,int r){
int k = log2(r-l+1);
return min(st[l][k],st[r-(1<<k)+1][k]);
}
signed main(){
n=read,m=read;
for(int i=1;i<=n;++i) st[i][0] = read;
for(int j=1;j<=log2(n);++j)
for(int i=1;i<=n-(1<<j)+1;++i)
st[i][j] = min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
for(int i=1;i<=n-m+1;++i){
print(query(i,m+i-1));
putlen;
}
return 0;
}
P2880 [USACO07JAN]Balanced Lineup G
- 求区间最大值与最小值之差,建两个表就行了,一个维护最大,一个维护最小
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define putlen putchar('\n')
//快读
inline int read(){
int X=0; bool flag=1; char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-') flag=0; ch=getchar();}
while(ch>='0'&&ch<='9') {X=(X<<1)+(X<<3)+ch-'0'; ch=getchar();}
if(flag) return X;
return ~(X-1);
}
#define read read()
//快输
inline void print(int x){
if(x<0){putchar('-');x=-x;}
if(x>9) print(x/10);
putchar(x%10+'0');
}
int stMAX[50004][40];
int stMIN[50004][40];
int n,m;
int query(int l,int r){
int k = log2(r-l+1);
return max(stMAX[l][k],stMAX[r-(1<<k)+1][k]) - min(stMIN[l][k],stMIN[r-(1<<k)+1][k]);
}
signed main(){
n=read,m=read;
for(int i=1;i<=n;++i) stMAX[i][0] = stMIN[i][0] = read;
for(int j=1;j<=log2(n);++j)
for(int i=1;i<=n-(1<<j)+1;++i){
stMIN[i][j] = min(stMIN[i][j-1],stMIN[i+(1<<(j-1))][j-1]);
stMAX[i][j] = max(stMAX[i][j-1],stMAX[i+(1<<(j-1))][j-1]);
}
for(int i=1;i<=m;++i){
int l=read,r=read;
print(query(l,r));
putlen;
}
return 0;
}