本蒟蒻的blog:https://startcraft.cn
ST表概述
ST是解决区间RMQ(区间最值)问题的一种数据结构,它不支持在线修改,预处理\(O(nlogn)\),查询\(O(1)\)
实现
以区间最小值为例子
预处理
\(ans[i][j]\)表示区间\([i,i+2^{j}-1]\)的最小值,同时它可以表示成前半个区间的最小值和后半个区间的最小值
前半个区间是\([i,i+2^{j-1}-1]\), 可以表示成\(ans[i][j-1]\),那么后半个区间是\([i+2^{j-1}, i+2^{j}-1]\), 可以表示成\(ans[i+2^{j-1}][j-1]\)
即
a
n
s
[
i
]
[
j
]
=
m
i
n
(
a
n
s
[
i
]
[
j
−
1
]
,
a
n
s
[
i
+
2
j
−
1
]
[
j
−
1
]
)
ans[i][j]=min(ans[i][j-1],ans[i+2^{j-1}][j-1])
ans[i][j]=min(ans[i][j−1],ans[i+2j−1][j−1])
所以初始化的时间复杂度是\(O(nlogn)\)
实现代码
int en = log2(n);
wfor(i, 0, 31)
{
bit[i] = 1 << i;//表示2^i
}
for(i=1;i<n + 1;i++)//初始化区间长度为1的
{
minnum[i][0] = num[i];
}
for(j=1;j<en + 2;j++)
{
for(i=1;i<n + 1;i++)
{
if (i + bit[j - 1] <= n)
{
minnum[i][j] = min(minnum[i][j - 1], minnum[i + bit[j - 1]][j - 1]);
}
}
}
查询
假设我们要查询的区间为\([l,r]\),那么区间长度为\(len=r-l+1\)
设\(t=\log(len)\)
我们知道\(2^{\log(a)}>a/2\)
所以\(2^{t}>len/2\)
所以区间\([l,l+2^{t}-1]\) 大于等于所求区间的一半,同理从\(r\)往前数\(2^{t}\)的区间也大于等于所求区间长度的一半,那么我们只要求出这两个区间最小值的最小值就是所求区间的最小值
这两个区间我们前面已经预处理出来了,从\(r\)往前数\(2^{t}\) 是\(r-2^{t}+1\).
所以所求区间的最小值为\(min(ans[l][t]),ans[r-2^{t}+1][t]\)
实现代码
int t = log2(len);
int ans = min(minnum[l][t], minnum[r - bit[t] + 1][t]);
例题 POJ-3264
AC代码
#include <iostream>
#include <cmath>
#include <cstdio>
using namespace std;
typedef long long ll;
#define wfor(i,j,k) for(i=j;i<k;++i)
#define mfor(i,j,k) for(i=j;i>=k;--i)
// void read(int &x) {
// char ch = getchar(); x = 0;
// for (; ch < '0' || ch > '9'; ch = getchar());
// for (; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
// }
int bit[32];
void init()
{
int i;
wfor(i, 0, 31)
{
bit[i] = 1 << i;
}
}
const int maxn = 5e4 + 5;
int num[maxn];
int maxnum[maxn][30];
int minnum[maxn][30];
int main()
{
int n, q;
// cin >> n >> q;
scanf("%d%d", &n, &q);
int i;
init();
wfor(i, 1, n + 1)
{
// cin >> num[i];
scanf("%d", &num[i]);
}
int j;
int en = log2(n);
wfor(i, 1, n + 1)
{
maxnum[i][0] = num[i];
minnum[i][0] = num[i];
}
wfor(j, 1, en + 2)
{
wfor(i, 1, n + 1)
{
if (i + bit[j - 1] <= n)
{
maxnum[i][j] = max(maxnum[i][j - 1], maxnum[i + bit[j - 1]][j - 1]);
minnum[i][j] = min(minnum[i][j - 1], minnum[i + bit[j - 1]][j - 1]);
}
}
}
wfor(i, 0, q)
{
int l, r;
// cin >> l >> r;
scanf("%d%d", &l, &r);
int len = r - l + 1;
int t = log2(len);
int x = max(maxnum[l][t], maxnum[r - bit[t] + 1][t]);
int y = min(minnum[l][t], minnum[r - bit[t] + 1][t]);
// cout << x - y << endl;
printf("%d\n", x - y);
}
return 0;
}