[python刷题模板] 稀疏表(ST)
一、 算法&数据结构
1. 描述
稀疏表(SparseTable)通常用来解决区间查询问题,如RMQ(Range Max/min Query)。
区间问题通常可以用线段树来解决,但Python线段树由于空间问题经常TLE,因此掌握一个st模板还是很有必要的。
st优势:查询快(O(1)),常数小。
st劣势:
1. 初始化耗时大O(nlgn),且空间大。
2. 查询过程不支持更新。
3. 应用场景小,仅支持可重复贡献问题(max/min/gcd/or_/and_ 等)的区间查询。
- 涉及区间问题,我们通常的手段能想到:
- 线段树 - 最全,最强大,支持RUPQ\PURQ\RURQ、支持几乎所有merge操作、支持在线修改;但常数很大。空间要开4*n
- 树状数组 - 支持RUPQ和PURQ、支持在线修改、常数很小;但常规操作仅支持sum\sub,难理解,易编码(套板);骚操作可以max。
- 前缀和 - 区间和问题首先应该想到。 初始化O(n),查询O(1),常数小, 不支持修改。
- 稀疏表(ST) - 也就是本文。仅支持可重复贡献的区间查询问题,优势是常数小;不支持离线。
2. 复杂度分析
- 时间复杂度:
- 查询query, O(1)
- 预处理 O(nlgn)
- 更新update,st不支持修改,因此必须一次性获得全部原数据,询问过程不允许修改st。
- 空间复杂度:O(nlogn)
3. 常见应用
- 单点更新,区间求极值(最入门)
- 单点更新,区间求和(稍复杂)
- 区间更新,单点或区间求值,如果卡常数需要用到lazytag
4. 常用优化
- log和pow的计算可以预处理出结果,重复使用效率较高。
二、 模板代码
1. 询问区间或和+二分,对每个点找它到右边特定的或值的位置。
例题: 2411. 按位或最大的最小子数组长度
用st维护
class SparseTable:
def __init__(self, data: list, func=or_):
# 稀疏表,O(nlgn)预处理,O(1)查询区间最值/或和/gcd/lcm
# 下标从0开始
self.func = func
self.st = st = [list(data)]
i, N = 1, len(st[0])
while 2 * i <= N+1:
pre = st[-1]
st.append([func(pre[j], pre[j + i]) for j in range(N - 2 * i + 1)])
i <<= 1
def query(self, begin: int, end: int): # 查询闭区间[begin, end]的最大值
lg = (end - begin+1).bit_length() - 1
return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])
class Solution:
def smallestSubarrays(self, nums: List[int]) -> List[int]:
st = SparseTable(nums)
n = len(nums)
# print(st.query(0,n-1))
def calc(i):
mx = st.query(i, n - 1)
j = bisect_left(range(n), mx, lo=i, key=lambda j: st.query(i, j))
# print(mx, i, j)
return j - i + 1
return [calc(i) for i in range(n)]
2. 区间询问gcd+二分
链接: 题目 2744: 蓝桥杯2022年第十三届决赛真题-最大公约数(Python组)
- 这题思路是找1,因为1肯定是不用变的。
- 且1可以用一步操作让相邻的一个数变成1。
- 每一步也最多只能增加一个1。
- 当原数组含c(c>1)个1的时候,显然答案是n-c。
- 当原数组不含1的时候,我们gcd(arr),如果整个数组gcd起来都不是1,那么必没有答案;反之必有答案。
- 那么我们只需要找到一段gcd起来是1,且这段长度(k)最小,首先用k-1的代价在这段生成1个1;答案就是k+n-1。
- 遍历每个点作为区间的右端点,它左边每段的gcd满足递增,因此可以二分。这样就nlogn了。
- 我们只需要降低单次区间查询gcd的复杂度即可。查询O1的st就很合适。
- 这个版本的py二分bisect不支持key参数,因此我自己实现了my_bisect_left和right。
- 查询每个点作为左端点也是可以的,但是二分时要取负数。
import io
import os
import sys
from collections import deque
from functools import reduce, lru_cache
from math import gcd, inf
from operator import or_
if sys.hexversion == 50923504:
sys.stdin = open('input.txt')
else:
input = sys.stdin.buffer.readline
MOD = 10 ** 9 + 7
class IntervalTree:
def __init__(self, size, nums=None):
self.size = size
self.nums = nums
self.interval_tree = [0] * (size * 4)
if nums:
self.build_tree(1, 1, size)
def build_tree(self, p, l, r):
interval_tree = self.interval_tree
nums = self.nums
if l == r:
interval_tree[p] = nums[l - 1]
return
mid = (l + r) // 2
self.build_tree(p * 2, l, mid)
self.build_tree(p * 2 + 1, mid + 1, r)
interval_tree[p] = gcd(interval_tree[p * 2], interval_tree[p * 2 + 1])
@lru_cache()
def gcd_interval(self, p, l, r, x, y):
if y < l or r < x:
return 1
interval_tree = self.interval_tree
if x <= l and r <= y:
return interval_tree[p]
mid = (l + r) // 2
a = b = None
if x <= mid:
a = self.gcd_interval(p * 2, l, mid, x, y)
if mid < y:
b = self.gcd_interval(p * 2 + 1, mid + 1, r, x, y)
if not a:
return b
if not b:
return a
return gcd(a, b)
class SparseTable:
def __init__(self, data: list, func=or_):
# 稀疏表,O(nlgn)预处理,O(1)查询区间最值/或和/gcd
# 下标从0开始
self.func = func
self.st = st = [list(data)]
i, N = 1, len(st[0])
while 2 * i <= N+1:
pre = st[-1]
st.append([func(pre[j], pre[j + i]) for j in range(N - 2 * i + 1)])
i <<= 1
def query(self, begin: int, end: int): # 查询闭区间[begin, end]的最大值
lg = (end - begin+1).bit_length() - 1
# print(lg)
return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])
def my_bisect_left(a, x, lo=None, hi=None, key=None):
"""
由于3.10才能用key参数,因此自己实现一个。
:param a: 需要二分的数据
:param x: 查找的值
:param lo: 左边界
:param hi: 右边界(闭区间)
:param key: 数组a中的值会依次执行key方法,
:return: 第一个大于等于x的下标位置
"""
if not lo:
lo = 0
if not hi:
hi = len(a) - 1
else:
hi = min(hi, len(a) - 1)
size = hi - lo + 1
if not key:
key = lambda _x: _x
while size:
half = size >> 1
mid = lo + half
if key(a[mid]) < x:
lo = mid + 1
size = size - half - 1
else:
size = half
return lo
def my_bisect_right(a, x, lo=None, hi=None, key=None):
if not lo:
lo = 0
if not hi:
hi = len(a) - 1
else:
hi = min(hi, len(a) - 1)
size = hi - lo + 1
if not key:
key = lambda _x: _x
while size:
half = size >> 1
mid = lo + half
if key(a[mid]) > x:
size = half
else:
lo = mid + 1
size = size - half - 1
return lo
def solve(n, a):
if reduce(gcd, a) != 1:
return print(-1)
c = a.count(1)
if c:
return print(n - c)
# tree = IntervalTree(n, a)
st = SparseTable(a, gcd)
ans = inf
# 每个点找右边相邻区间
# for i in range(n):
# def query(k):
# return -st.query(i, k)
#
# pos = my_bisect_left(range(n), -1, lo=i + 1, hi=n-1, key=query)
# # print(pos,i)
# if pos < n and query(pos) == -1:
# ans = min(ans, pos - i)
# 每个点找左边相邻区间
for i in range(n):
def query(k):
return st.query(k, i)
pos = my_bisect_right(range(i), 1, lo=0, key=query)
# print(pos,i)
if pos and query(pos-1) == 1:
ans = min(ans, i-pos+1)
print(ans + n - 1)
if __name__ == '__main__':
n = int(input())
a = list(map(int, input().split()))
solve(n, a)
3. st模板题,区间max
链接: P3865 【模板】ST 表
- 注意,这份代码并不能通过洛谷的判题,只得了80分,报MLE,我看了一遍,没有一份py代码AC。
import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from operator import or_
if sys.hexversion == 50923504:
sys.stdin = open('cfinput.txt')
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
MOD = 10 ** 9 + 7
"""https://www.luogu.com.cn/problem/P3865
st求区间max模板题
"""
class SparseTable:
def __init__(self, data: list, func=or_):
# 稀疏表,O(nlgn)预处理,O(1)查询区间最值/或和/gcd
# 下标从0开始
self.func = func
self.st = st = [list(data)]
i, N = 1, len(st[0])
while 2 * i <= N:
pre = st[-1]
st.append([func(pre[j], pre[j + i]) for j in range(N - 2 * i + 1)])
i <<= 1
def query(self, begin: int, end: int): # 查询闭区间[begin, end]的最大值
lg = (end - begin+1).bit_length() - 1
return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])
if __name__ == '__main__':
n, m = RI()
st = SparseTable(RILST(), max)
for _ in range(m):
l, r = RILST()
print(st.query(l - 1, r - 1))