【简概】
ST 算法在 RMQ(区间最值)中用来求得一个区间的最值,但却不能维护最值。
也就是说,过程中不能改变区间中的某个元素的值。(不能修改,快速查询)
O(nlogn) 的预处理和 O(1) 的查询对于需要大量询问的场景是非常适用的。
比如有如下长度为 10 的数组 :
1 3 2 4 9 5 6 7 8 0
我们要查询 [1, 7] 之间的最大值,如果采用朴素的线性查找,复杂度O(n),
而 ST 算法却只需要 O(1)的时间复杂度,因为 ST 算法预处理了一个 dp 数组。
我们 用 dp[i][j] 表示从 i 开始的 2^j 个数的最值,表示 dp[i][j] “管辖” index=i 开始的 2^j 个数字,
那么很显然,任何一段区间都能被两个 dp 元素管辖到。
比如上面说的 [1, 7],就能被dp[1][2] 和 dp[4][2]管辖到,max(dp[1][2], dp[4][2])也就是[1, 7] 的最值了。
如何得出是 dp[1][2] 和 dp[4][2] 这两个元素?
让dp[1][n](2^n <= 区间个数)中的n尽可能大就得到了第一个元素,
可以推得第二个元素,两个元素的管辖范围大小是一样的。
这样我们只需预处理一个 dp 数组就可以了,而这个预处理是一个动态规划的过程,转移方程为:
dp[i][j] = max(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);
而 dp 数组的预处理和 RMQ 的求解过程正好是个逆过程。
【作用】用来求解给定区间RMQ的最值,以最小值为例。
【举例】给出一数组A[0~5] = {5,4,6,10,1,12},则区间[2,5]之间的最值为1。
【方法】ST算法分成两部分:离线预处理 (nlogn)和 在线查询(O(1))。
虽然还可以使用线段树、树状链表等求解区间最值,但是ST算法要比它们更快,而且适用于在线查询。
(1)离线预处理:运用DP思想,用于求解区间最值,并保存到一个二维数组中。
ST算法使用DP思想求解区间最值,貌似属于区间动态规划,
不过区间在增加时,使用倍增的思想,每次增加2^i个长度。
使用F[i,j]表示以i为起点,区间长度为2^j的区间最值,此时区间为[i,i + 2^j - 1]。
比如,F[0,2]表示区间[0,3]的最小值,即等于4,F[2,2]表示区间[2,5]的最小值,即等于1。
在求解F[i,j]时,ST算法是先对长度为2^j的区间[i,i + 2^j - 1]分成两等份,每份长度均为2^(j - 1)。
之后在分别求解这两个区间的最值F[i,j - 1]和F[i + 2^(j - 1),j - 1]。
最后在结合这两个区间的最值,求出整个区间的最值。
特殊情况,当j = 0时,区间长度等于0,即区间中只有一个元素,此时F[i,0]应等于每一个元素的值。
举例:要求解F[1,2]的值,即求解区间[1,4] = {4,6,10,1}的最小值。此时需要把这个区间分成两个等长的区间,
即为[1,2]和[3,4],之后分别求解这两个区间的最小值。此时这两个区间最小值分别对应着F[1,1] 和 F[3,1]的值。
状态转移方程是 F[i,j] = min(F[i,j - 1],F[i + 2^(j - 1),j - 1])
初始状态为:F[i,0] = A[i]。
在根据状态转移方程递推时,是对每一元素,先求区间长度为1的区间最值,
之后再求区间长度为2的区间最值,之后再求区间长度为4的区间最值....,
最后,对每一个元素,在求解区间长度为log2^n的区间最值后,算法结束,其中n表示元素个数。
即:先求F[0][1],F[1][1],F[2][1],F[3][1],F[n][1],再求F[0][2],F[1][2],F[2][2],F[3][2],F[m][2] 。
【简便理解】
(2)在线处理:已知待查询的区间[x,y],求解其最值。
在预处理期间,每一个状态对应的区间长度都为2^i。
由于给出的待查询区间长度不一定恰好为2^i,因此我们应对待查询的区间进行处理。
这里我们把待查询的区间分成两个小区间,这两个小区间满足两个条件:
(1)这两个小区间要能覆盖整个区间。
(2)为了利用预处理的结果,要求小区间长度相等且都为2^i。注意两个小区间可能重叠。
如:待查询的区间为[3,11],先尽量等分两个区间,则先设置为[3,7]和[8,11]。
之后再扩大这两个区间,让其长度都等于为2^i。刚划分的两个区间长度分别为5和4,
之后继续增加区间长度,直到其成为2^i。此时满足两个条件的最小区间长度为8,此时i = 3。
在程序计算求解区间长度时,并没有那么麻烦,我们可以直接得到i,
即等于直接对区间长度取以2为底的对数。这里,对于区间[3,11],
其分解的区间长度为int(log(11 - 3 + 1)) = 3,这里log是以2为底的。
根据上述思想,可以把待查询区间[x,y]分成两个小区间 [x,x + 2^i - 1] 和 [y - 2^i + 1,y] ,
其又分别对应着F[x,i]和F[y - 2^i + 1,i],此时为了求解整个区间的最小值,
我们只需求这两个值得最小值即可,此时复杂度是O(1)。
#include <iostream>
#include <math.h>
using namespace std;
/*方程
F[i,j]: 区间[i,i + 2^j - 1]的最小值,此时区间长度为2^j
F[i,j] = min(F[i,j - 1],F[i + 2^(j - 1),j - 1])
F[i,0] = nArr[i]; */
int F[1000000][20];//待比较元素的个数最大为1百万
void SparseTable(int nArr[],int nLen) { //初始化
for (int i = 0;i < nLen;i++)
F[i][0] = nArr[i]; //递推
int nLog = int(log(double(nLen))/log(2.0));
for (int j = 1;j <= nLog;j++)
for (int i = 0;i < nLen;i++)
if ((i + (1 << j) - 1) < nLen)//区间的端点不能超过数组最后一位的下标
F[i][j] = min(F[i][j - 1],F[i + (1 << (j - 1))][j - 1]);
}
int RMQ(int nArr[],int nLen,int nStart,int nEnd) {
int nLog = (int)(log(double(nEnd - nStart + 1)/log(2.0)));
return min(F[nStart][nLog],F[nEnd - (1 << nLog) + 1][nLog]);
}
int main() {
int nArr[6] = {5,4,6,10,1,12};
SparseTable(nArr,6);
cout<<RMQ(nArr,6,0,5)<<endl; //查询具体区间
cout<<RMQ(nArr,6,1,3)<<endl;
cout<<RMQ(nArr,6,2,5)<<endl;
cout<<RMQ(nArr,6,2,2)<<endl;
system("pause");
return 1;
}
【例题1】数列区间的最大值
- m个询问,每次询问区间x到y的最大数。
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=1e6+5,logn=20; //10^6+5
int log[maxn],f[maxn][logn+5],a[maxn];
int n,m,x,y;
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
log[0]=-1; //这样才能使log[1]=0
for(int i=1;i<=n;i++) //预处理出长度1~n的log值
f[i][0]=a[i],log[i]=log[i>>1]+1;
for(int j=1;j<=logn;j++=)
for(int i=1;i+(1<<j)-1<=n;i++) //注意区间边界不能超过n
f[i][j]=max(f[i][j-1],f[i+(1<<j-1)][j-1]);
//注意,加减乘除运算符优先级高于<<
while(m--){
scanf("%d%d",&x,&y);
int s=log[y-x+1]; //log2(y-x-1)向下取整得到的值
printf("%d\n",max(f[x][s],f[y-(1<<s)+1][s]));
}
return 0;
}
【例题2】滑动窗口最值
- 共n个数,把每连续k个数中最大最小值求出来。
【分析】可以用单调队列做(bfs维护队列首尾),此处采用倍增法。
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=1e5*10,logn=17; //10^6+5
int log[maxn],f[maxn][logn+5],z[maxn][logn+5],a[maxn];
int n,k;
int main(){
scanf("%d%d",&n,&k); log[0]=-1; //这样才能使log[1]=0
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<=n;i++){ //预处理出长度1~n的log值
f[i][0]=z[i][0]=a[i];
log[i]=log[i>>1]+1;
}
for(int j=1;j<=logn;j++=)
for(int i=1;i+(1<<j)-1<=n;i++) //注意区间边界不能超过n
f[i][j]=max(f[i][j-1],f[i+(1<<j-1)][j-1]);
z[i][j]=min(z[i][j-1],z[i+(1<<j-1)][j-1]);
//注意,加减乘除运算符优先级高于<<
n=n-k+1; int x,y,s;
for(int i=1;i<=n;i++){
x=i; y=i+k-1; s=log[y-x+1];
printf("%d %d\n",max(f[x][s],f[y-(1<<s)+1][s],
min(z[x][s],z[y-(1<<s)+1][s]);
}
return 0;
}
——时间划过风的轨迹,那个少年,还在等你。