[python刷题模板] 稀疏表(ST)

一、 算法&数据结构

1. 描述

稀疏表(SparseTable)通常用来解决区间查询问题,如RMQ(Range Max/min Query)。
区间问题通常可以用线段树来解决,但Python线段树由于空间问题经常TLE,因此掌握一个st模板还是很有必要的。
st优势:查询快(O(1)),常数小。
st劣势: 		
				1. 初始化耗时大O(nlgn),且空间大。
				2. 查询过程不支持更新。
				3. 应用场景小,仅支持可重复贡献问题(max/min/gcd/or_/and_ 等)的区间查询。
  • 涉及区间问题,我们通常的手段能想到:
    1. 线段树 - 最全,最强大,支持RUPQ\PURQ\RURQ、支持几乎所有merge操作、支持在线修改;但常数很大。空间要开4*n
    2. 树状数组 - 支持RUPQ和PURQ、支持在线修改、常数很小;但常规操作仅支持sum\sub,难理解,易编码(套板);骚操作可以max。
    3. 前缀和 - 区间和问题首先应该想到。 初始化O(n),查询O(1),常数小, 不支持修改。
    4. 稀疏表(ST) - 也就是本文。仅支持可重复贡献的区间查询问题,优势是常数小;不支持离线。

2. 复杂度分析

  • 时间复杂度:
  1. 查询query, O(1)
  2. 预处理 O(nlgn)
  3. 更新update,st不支持修改,因此必须一次性获得全部原数据,询问过程不允许修改st。
  • 空间复杂度:O(nlogn)

3. 常见应用

  1. 单点更新,区间求极值(最入门)
  2. 单点更新,区间求和(稍复杂)
  3. 区间更新,单点或区间求值,如果卡常数需要用到lazytag

4. 常用优化

  1. log和pow的计算可以预处理出结果,重复使用效率较高。

二、 模板代码

1. 询问区间或和+二分,对每个点找它到右边特定的或值的位置。

例题: 2411. 按位或最大的最小子数组长度
用st维护

class SparseTable:
    def __init__(self, data: list, func=or_):
        # 稀疏表,O(nlgn)预处理,O(1)查询区间最值/或和/gcd/lcm
        # 下标从0开始
        self.func = func
        self.st = st = [list(data)]
        i, N = 1, len(st[0])
        while 2 * i <= N+1:
            pre = st[-1]
            st.append([func(pre[j], pre[j + i]) for j in range(N - 2 * i + 1)])
            i <<= 1

    def query(self, begin: int, end: int):  # 查询闭区间[begin, end]的最大值
        lg = (end - begin+1).bit_length() - 1
        return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])

class Solution:
    def smallestSubarrays(self, nums: List[int]) -> List[int]:
        st = SparseTable(nums)
        n = len(nums)
        # print(st.query(0,n-1))
        def calc(i):
            mx = st.query(i, n - 1)
            j = bisect_left(range(n), mx, lo=i, key=lambda j: st.query(i, j))
            # print(mx, i, j)
            return j - i + 1
        
        return [calc(i) for i in range(n)]

2. 区间询问gcd+二分

链接: 题目 2744: 蓝桥杯2022年第十三届决赛真题-最大公约数(Python组)

在这里插入图片描述

  • 这题思路是找1,因为1肯定是不用变的。
  • 且1可以用一步操作让相邻的一个数变成1。
  • 每一步也最多只能增加一个1。
  • 当原数组含c(c>1)个1的时候,显然答案是n-c。
  • 当原数组不含1的时候,我们gcd(arr),如果整个数组gcd起来都不是1,那么必没有答案;反之必有答案。
  • 那么我们只需要找到一段gcd起来是1,且这段长度(k)最小,首先用k-1的代价在这段生成1个1;答案就是k+n-1。
  • 遍历每个点作为区间的右端点,它左边每段的gcd满足递增,因此可以二分。这样就nlogn了。
  • 我们只需要降低单次区间查询gcd的复杂度即可。查询O1的st就很合适。
  • 这个版本的py二分bisect不支持key参数,因此我自己实现了my_bisect_left和right。

  • 查询每个点作为左端点也是可以的,但是二分时要取负数。
import io
import os
import sys
from collections import deque
from functools import reduce, lru_cache
from math import gcd, inf
from operator import or_

if sys.hexversion == 50923504:
    sys.stdin = open('input.txt')
else:
    input = sys.stdin.buffer.readline
MOD = 10 ** 9 + 7


class IntervalTree:
    def __init__(self, size, nums=None):
        self.size = size
        self.nums = nums
        self.interval_tree = [0] * (size * 4)
        if nums:
            self.build_tree(1, 1, size)

    def build_tree(self, p, l, r):
        interval_tree = self.interval_tree
        nums = self.nums
        if l == r:
            interval_tree[p] = nums[l - 1]
            return
        mid = (l + r) // 2
        self.build_tree(p * 2, l, mid)
        self.build_tree(p * 2 + 1, mid + 1, r)
        interval_tree[p] = gcd(interval_tree[p * 2], interval_tree[p * 2 + 1])

    @lru_cache()
    def gcd_interval(self, p, l, r, x, y):
        if y < l or r < x:
            return 1
        interval_tree = self.interval_tree
        if x <= l and r <= y:
            return interval_tree[p]
        mid = (l + r) // 2
        a = b = None
        if x <= mid:
            a = self.gcd_interval(p * 2, l, mid, x, y)
        if mid < y:
            b = self.gcd_interval(p * 2 + 1, mid + 1, r, x, y)
        if not a:
            return b
        if not b:
            return a
        return gcd(a, b)


class SparseTable:
    def __init__(self, data: list, func=or_):
        # 稀疏表,O(nlgn)预处理,O(1)查询区间最值/或和/gcd
        # 下标从0开始
        self.func = func
        self.st = st = [list(data)]
        i, N = 1, len(st[0])
        while 2 * i <= N+1:
            pre = st[-1]
            st.append([func(pre[j], pre[j + i]) for j in range(N - 2 * i + 1)])
            i <<= 1

    def query(self, begin: int, end: int):  # 查询闭区间[begin, end]的最大值
        lg = (end - begin+1).bit_length() - 1
        # print(lg)
        return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])




def my_bisect_left(a, x, lo=None, hi=None, key=None):
    """
    由于3.10才能用key参数,因此自己实现一个。
    :param a: 需要二分的数据
    :param x: 查找的值
    :param lo: 左边界
    :param hi: 右边界(闭区间)
    :param key: 数组a中的值会依次执行key方法,
    :return: 第一个大于等于x的下标位置
    """
    if not lo:
        lo = 0
    if not hi:
        hi = len(a) - 1
    else:
        hi = min(hi, len(a) - 1)
    size = hi - lo + 1

    if not key:
        key = lambda _x: _x
    while size:
        half = size >> 1
        mid = lo + half
        if key(a[mid]) < x:
            lo = mid + 1
            size = size - half - 1
        else:
            size = half
    return lo

def my_bisect_right(a, x, lo=None, hi=None, key=None):
    if not lo:
        lo = 0
    if not hi:
        hi = len(a) - 1
    else:
        hi = min(hi, len(a) - 1)
    size = hi - lo + 1

    if not key:
        key = lambda _x: _x
    while size:
        half = size >> 1
        mid = lo + half
        if key(a[mid]) > x:
            size = half
        else:
            lo = mid + 1
            size = size - half - 1
    return lo
def solve(n, a):
    if reduce(gcd, a) != 1:
        return print(-1)
    c = a.count(1)
    if c:
        return print(n - c)
    # tree = IntervalTree(n, a)
    st = SparseTable(a, gcd)
    ans = inf

    # 每个点找右边相邻区间
    # for i in range(n):
        # def query(k):
        #     return -st.query(i, k)
        #
        # pos = my_bisect_left(range(n), -1, lo=i + 1, hi=n-1, key=query)
        # # print(pos,i)
        # if pos < n and query(pos) == -1:
        #     ans = min(ans, pos - i)

    # 每个点找左边相邻区间
    for i in range(n):
        def query(k):
            return st.query(k, i)

        pos = my_bisect_right(range(i), 1, lo=0, key=query)
        # print(pos,i)
        if pos and query(pos-1) == 1:
            ans = min(ans, i-pos+1)

    print(ans + n - 1)

if __name__ == '__main__':
    n = int(input())
    a = list(map(int, input().split()))
    solve(n, a)


3. st模板题,区间max

链接: P3865 【模板】ST 表

  • 注意,这份代码并不能通过洛谷的判题,只得了80分,报MLE,我看了一遍,没有一份py代码AC。
import sys
from collections import *
from itertools import *
from math import *
from array import *
from functools import lru_cache
import heapq
import bisect
import random
import io, os
from operator import or_

if sys.hexversion == 50923504:
    sys.stdin = open('cfinput.txt')

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())

MOD = 10 ** 9 + 7
"""https://www.luogu.com.cn/problem/P3865
st求区间max模板题
"""


class SparseTable:
    def __init__(self, data: list, func=or_):
        # 稀疏表,O(nlgn)预处理,O(1)查询区间最值/或和/gcd
        # 下标从0开始
        self.func = func
        self.st = st = [list(data)]
        i, N = 1, len(st[0])
        while 2 * i <= N:
            pre = st[-1]
            st.append([func(pre[j], pre[j + i]) for j in range(N - 2 * i + 1)])
            i <<= 1

    def query(self, begin: int, end: int):  # 查询闭区间[begin, end]的最大值
        lg = (end - begin+1).bit_length() - 1
        return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])



if __name__ == '__main__':
    n, m = RI()
    st = SparseTable(RILST(), max)

    for _ in range(m):
        l, r = RILST()
        print(st.query(l - 1, r - 1))



三、其他

四、更多例题

五、参考链接

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值