线段树-python模板和例题

线段树

线段是处理区间问题的高级数据结构,同时也是分治和递归思想的一种实践。
常见的是求区间最小值,最大值,区间和。单点修改,区间修改。

题目列表

Box of Box

给定一组盒子,盒子的长宽高分别为h, w, d。判断是否存在一个盒子可以完全装下另一个盒子。即存在i和j, hi<hj,wi<wj,di<dj。数据范围1<=h<=1e9。
思路,根据数据范围和只需要判断大小关系,所以可以离散化。由于每个盒子可以旋转,直接对h,w,d排序,这一点不会证明。排序并且分组。把相同高度的盒子分为一组。然后从左往右按组,一组一组地遍历w和d。对于w,我们需要查询[1, w)区间内的最小值是否小于d。如果小于d,则Yes。这一组遍历完成后,把该组的w和h,更新到区间内。
固定一维,维护一维,查询一维。解决三维问题。

Mex and Update

给定一个数组a,数组长度1e5,0<=a[i]<=1e9,求数组的Mex。Mex是数组中未出现的最小自然数。操作一:修改a[i]的值为x。操作二:查询数组的Mex。
显然,数组Mex的值不会大于数组长度n。因此,我们需要维护构建数组b,统计0-n-1的出现次数,找到一个出现零的位置。
转化线段树,就是求最大的右边界,使得区间[left, right]中的最小值大于0。线段树可以在log(n)的时间复杂度查询区间最小值。左边界left=0,因为单调性的原因,我们可以通过二分法来确定右边界,因此算法的总体复杂度是nlognlogn。

代码细节:
初始化线段树的几个参数。

  • n 表示数组的长度
  • op 表示区间的某种操作,类似max, min, sum
  • e 幺元。线段树维护值的幺元函数。对于任意x,op(x, e) = x
  • mapping 父结点的懒标记更新子结点的值的函数
  • composition 父结点的懒标记更新(结合)子结点的懒标记的函数
  • id 懒标记的幺元函数。
    类中的几个函数:
  • 单点修改 set(i, x)
  • 查询最大右边界 max_right(left, func)。第一个参数是左边界,第二个参数是一个函数,该函数完成对区间值的布尔判断。
  • 区间查询prod,左闭右开[L, R)。即op(L, L+1, , R-1)的结果。

线段树的实现细节:懒标记。空间复杂度是4n还是2n。从下到上更新 VS 从上到下更新。

class LazySegmentTree:
    """
    Reference
    https://github.com/atcoder/ac-library/blob/master/atcoder/lazysegtree.hpp
    https://github.com/atcoder/ac-library/blob/master/document_en/lazysegtree.md
    https://github.com/atcoder/ac-library/blob/master/document_ja/lazysegtree.md
    https://leetcode.cn/circle/discuss/4rJDBt/

    """

    def __init__(self, n, op, e, mapping, composition, id):
        self.n = n
        self.op = op
        self.e = e
        self.mapping = mapping
        self.composition = composition
        self.id = id
        self.log = (n - 1).bit_length()
        self.size = 1 << self.log
        self.d = [e] * (2 * self.size)
        self.lz = [id] * self.size

    def update(self, k):
        self.d[k] = self.op(self.d[2 * k], self.d[2 * k + 1])

    def all_apply(self, k, f):
        self.d[k] = self.mapping(f, self.d[k])
        if k < self.size:
            self.lz[k] = self.composition(f, self.lz[k])

    def push(self, k):
        self.all_apply(2 * k, self.lz[k])
        self.all_apply(2 * k + 1, self.lz[k])
        self.lz[k] = self.id

    def build(self, v):
        assert len(v) <= self.n
        for i in range(len(v)):
            self.d[self.size + i] = v[i]
        for i in range(self.size - 1, 0, -1):
            self.update(i)

    def set(self, p, x):
        assert 0 <= p < self.n
        p += self.size
        for i in range(self.log, 0, -1):
            self.push(p >> i)
        self.d[p] = x
        for i in range(1, self.log + 1):
            self.update(p >> i)

    def get(self, p):
        assert 0 <= p < self.n
        p += self.size
        for i in range(self.log, 0, -1):
            self.push(p >> i)
        return self.d[p]

    def prod(self, l, r):
        assert 0 <= l <= r <= self.n
        if l == r:
            return self.e
        l += self.size
        r += self.size
        for i in range(self.log, 0, -1):
            if (l >> i) << i != l:
                self.push(l >> i)
            if (r >> i) << i != r:
                self.push((r - 1) >> i)
        sml = smr = self.e
        while l < r:
            if l & 1:
                sml = self.op(sml, self.d[l])
                l += 1
            if r & 1:
                r -= 1
                smr = self.op(self.d[r], smr)
            l >>= 1
            r >>= 1
        return self.op(sml, smr)

    def all_prod(self):
        return self.d[1]

    def apply(self, l, r, f):
        assert 0 <= l <= r <= self.n
        if l == r:
            return
        l += self.size
        r += self.size
        for i in range(self.log, 0, -1):
            if (l >> i) << i != l:
                self.push(l >> i)
            if (r >> i) << i != r:
                self.push((r - 1) >> i)
        l2 = l
        r2 = r
        while l < r:
            if l & 1:
                self.all_apply(l, f)
                l += 1
            if r & 1:
                r -= 1
                self.all_apply(r, f)
            l >>= 1
            r >>= 1
        l = l2
        r = r2
        for i in range(1, self.log + 1):
            if (l >> i) << i != l:
                self.update(l >> i)
            if (r >> i) << i != r:
                self.update((r - 1) >> i)

    def max_right(self, l, g):
        assert 0 <= l <= self.n
        # assert g(self.e)
        if l == self.n:
            return self.n
        l += self.size
        for i in range(self.log, 0, -1):
            self.push(l >> i)
        sm = self.e
        while True:
            while l % 2 == 0:
                l >>= 1
            if not g(self.op(sm, self.d[l])):
                while l < self.size:
                    self.push(l)
                    l = 2 * l
                    if g(self.op(sm, self.d[l])):
                        sm = self.op(sm, self.d[l])
                        l += 1
                return l - self.size
            sm = self.op(sm, self.d[l])
            l += 1
            if (l & -l) == l:
                return self.n

    def min_left(self, r, g):
        assert 0 <= r <= self.n
        assert g(self.e)
        if r == 0:
            return 0
        r += self.size
        for i in range(self.log, 0, -1):
            self.push((r - 1) >> i)
        sm = self.e
        while True:
            r -= 1
            while r > 1 and r % 2:
                r >>= 1
            if not g(self.op(self.d[r], sm)):
                while r < self.size:
                    self.push(r)
                    r = 2 * r + 1
                    if g(self.op(self.d[r], sm)):
                        sm = self.op(self.d[r], sm)
                        r -= 1
                return r + 1 - self.size
            sm = self.op(self.d[r], sm)
            if (r & -r) == r:
                return 0


import sys
from collections import Counter
from heapq import heappop, heappush

sys.stdin = open('./../input.txt', 'r')
I = lambda: int(input())
MI = lambda: map(int, input().split())
GMI = lambda: map(lambda x: int(x) - 1, input().split())
LI = lambda: list(MI())
LGMI = lambda: list(GMI())
mod = 1000000007
mod2 = 998244353

n, q = MI()
a = LI()

N = n + 1
lst = LazySegmentTree(N, min, N, min, min, N)
b = [0] * N
for i, x in enumerate(a):
    if x < n:
        b[x] += 1

# print(b)
for i, x in enumerate(b):
    lst.set(i, x)

# for i in range(N):
#     print(lst.get(i), end=' ')
# print()

for _ in range(q):
    i, x = MI()
    i -= 1
    if a[i] < n:
        lst.set(a[i], lst.get(a[i]) - 1)
    a[i] = x
    if a[i] < n:
        lst.set(a[i], lst.get(a[i]) + 1)

    # for i in range(N):
    #     print(lst.get(i), end=' ')
    # print()

    j = lst.max_right(0, lambda x: x > 0)
    print(j)

小白逛公园

给定一个数组a,-1000<a[i]<1000。操作一:修改a[i]值为x。操作二:查询区间[a,b]的最大连续子数组。

这一题算是,对线段树操作的高阶应用。
思路分析:
对于每个区间,我们需要维护区间最大值子数组,区间前缀和最大数组,区间后缀和最大子数组,区间和。
当单点修改的时,我们需要从下往上依次更新该点所在的区间的四个值。
当区间查询时,依次查询区间里四个值的最大值即可。
pypy3的代码是参考是洛谷大佬qishui的(侵删)。

# Problem: P4513 小白逛公园
# Contest: Luogu
# URL: https://www.luogu.com.cn/problem/P4513
# Memory Limit: 128 MB
# Time Limit: 1000 ms

import sys
from math import inf

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')
print = lambda d: sys.stdout.write(
    str(d) + "\n")  # 打开可以快写,但是无法使用print(*ans,sep=' ')这种语法,需要print(' '.join(map(str, p))),确实会快。

PROBLEM = """
"""


class IntervalTree:
    """区间加,区间求和"""

    def __init__(self, size, a=None):
        self.size = size
        self.ms = [0 for _ in range(size * 4)]
        self.s = [0 for _ in range(size * 4)]
        self.ls = [0 for _ in range(size * 4)]
        self.rs = [0 for _ in range(size * 4)]
        if a:
            self.a = a
            self.build(1, 1, size)

    def update_by_son(self, p, l, r):
        ms, s, ls, rs = self.ms, self.s, self.ls, self.rs
        x, y = p << 1, p << 1 | 1
        ms[p] = max(ms[x], ms[y], ls[y] + rs[x])
        s[p] = s[x] + s[y]
        ls[p] = max(s[x] + ls[y], ls[x])
        rs[p] = max(s[y] + rs[x], rs[y])

    def build(self, p, l, r):
        ms, s, ls, rs = self.ms, self.s, self.ls, self.rs
        if l == r:
            ms[p] = s[p] = ls[p] = rs[p] = self.a[l - 1]
            return
        mid = (l + r) >> 1
        self.build(p << 1, l, mid)
        self.build(p << 1 | 1, mid + 1, r)
        self.update_by_son(p, l, r)

    def update_point(self, p, l, r, x, val):
        """
        把x位置变成val
        """
        if x < l or r < x:
            return
        ms, s, ls, rs = self.ms, self.s, self.ls, self.rs
        if l == r:
            ms[p] = s[p] = ls[p] = rs[p] = val
            return

        mid = (l + r) // 2

        if x <= mid:
            self.update_point(p << 1, l, mid, x, val)
        if mid < x:
            self.update_point(p << 1 | 1, mid + 1, r, x, val)
        self.update_by_son(p, l, r)

    def query_interval(self, p, l, r, x, y):
        """
        查找x,y区间的最大子段和        """

        # if y < l or r < x:
        #     return -inf, 0, -inf, -inf
        ms, s, ls, rs = self.ms, self.s, self.ls, self.rs
        if x <= l and r <= y:
            return ms[p], s[p], ls[p], rs[p]

        mid = (l + r) >> 1
        a1, a2, a3, a4 = -inf, 0, -inf, -inf
        b1, b2, b3, b4 = -inf, 0, -inf, -inf
        # if x <= mid < y:
        #     a1, a2, a3, a4 = self.query_interval(p << 1, l, mid, x, y)
        #     b1, b2, b3, b4 = self.query_interval(p << 1 | 1, mid + 1, r, x, y)
        #     return max(a1, b1, b3 + a4), a2 + b2, max(a3, a2 + b3), max(b4, b2 + a4)

        if x <= mid:
            a1, a2, a3, a4 = self.query_interval(p << 1, l, mid, x, y)
        if mid < y:
            b1, b2, b3, b4 = self.query_interval(p << 1 | 1, mid + 1, r, x, y)
        return max(a1, b1, b3 + a4), a2 + b2, max(a3, a2 + b3), max(b4, b2 + a4)


class ZKW:
    # n = 1
    # size = 1
    # log = 2
    # d = [0]
    # op = None
    # e = 10 ** 15
    """自低向上非递归写法线段树,0_indexed
    tmx = ZKW(pre, max, -2 ** 61)
    """

    # __slots__ = ('n', 'op', 'e', 'log', 'size', 'd')

    def __init__(self, V):
        """
        V: 原数组
        OP: 操作:max,min,sum
        E: 每个元素默认值
        """
        self.n = len(V)
        self.log = (self.n - 1).bit_length()
        self.size = 1 << self.log
        self.ms = [0 for i in range(2 * self.size)]
        self.s = [0 for i in range(2 * self.size)]
        self.ls = [0 for i in range(2 * self.size)]
        self.rs = [0 for i in range(2 * self.size)]
        # ms = self.ms = [0] * (2 * self.size)
        # s = self.s = [0] * (2 * self.size)
        # ls = self.ls = [0] * (2 * self.size)
        # rs = self.rs = [0] * (2 * self.size)
        for i in range(self.n):
            self.ms[self.size + i] = self.s[self.size + i] = self.ls[self.size + i] = self.rs[self.size + i] = V[i]
        for i in range(self.size - 1, 0, -1):
            self.update(i)

    def set(self, p, x):
        # assert 0 <= p and p < self.n
        update = self.update
        p += self.size
        ms, s, ls, rs = self.ms, self.s, self.ls, self.rs
        ms[p] = s[p] = ls[p] = rs[p] = x
        for i in range(1, self.log + 1):
            update(p >> i)

    def get(self, p):
        # assert 0 <= p and p < self.n
        return self.s[p + self.size]

    def query(self, l, r):  # [l,r)左闭右开
        # assert 0 <= l and l <= r and r <= self.n
        ms, s, ls, rs = self.ms, self.s, self.ls, self.rs
        l += self.size
        r += self.size

        a1, a2, a3, a4 = -inf, 0, -inf, -inf
        b1, b2, b3, b4 = -inf, 0, -inf, -inf
        while l < r:
            if l & 1:
                a1, a2, a3, a4 = max(a1, ms[l], ls[l] + a4), a2 + s[l], max(a3, a2 + ls[l]), max(rs[l], s[l] + a4)
                l += 1
            if r & 1:
                b1, b2, b3, b4 = max(ms[r - 1], b1, b3 + rs[r - 1]), s[r - 1] + b2, max(ls[r - 1], s[r - 1] + b3), max(
                    b4, b2 + rs[r - 1])
                r -= 1
            l >>= 1
            r >>= 1
        return max(a1, b1, b3 + a4)

    def update(self, p):
        ms, s, ls, rs = self.ms, self.s, self.ls, self.rs
        # self.d[k] = self.op(self.d[2 * k], self.d[2 * k + 1])
        x, y = p << 1, p << 1 | 1
        ms[p] = max(ms[x], ms[y], ls[y] + rs[x])
        s[p] = s[x] + s[y]
        ls[p] = max(s[x] + ls[y], ls[x])
        rs[p] = max(s[y] + rs[x], rs[y])

    def __str__(self):
        return str([self.get(i) for i in range(self.n)])


#    必须开快写 979  ms
def solve():
    n, m = RI()
    a = []
    for _ in range(n):
        v, = RI()
        a.append(v)

    tree = ZKW(a)
    for _ in range(m):
        t, a, b = RI()
        if t == 1:
            if a > b:
                a, b = b, a
            print(tree.query(a - 1, b))
        else:
            tree.set(a - 1, b)


#     TLE+MLE  ms
def solve1():
    n, m = RI()
    a = []
    for _ in range(n):
        v, = RI()
        a.append(v)

    tree = IntervalTree(n, a)
    for _ in range(m):
        t, a, b = RI()
        if t == 1:
            if a > b:
                a, b = b, a
            print(tree.query_interval(1, 1, n, a, b)[0])
        else:
            tree.update_point(1, 1, n, a, b)


solve()

代码模版

线段树是维护区间查询的高级数据结构,可以支持单点修改,区间修改,学会线段树,也是入门树形DP的一种手段。
直接上板子。板子是求区间和。

N = 10**5+5
a = [0] * n
# 树
tree = [0] * (N*4)
# 懒修改标记
lazy = [0] * (N*4)

def pushup(p):
	tree[p] = tree[p*2] + tree[p*2]

def build(s, e, p):
	if s == e:
		tree[p] = a[s]
	m = s + (e-s>>1)
	build(s, m, p*2)
	build(m+1, e, p*2)
	pushup(p)

def update(l, r, c, s, e, p):
	if s == e:
		tree[p] = c
		return
	m = s + (e-s>>1)
	if l<=m:
		update(l, r, c, s, m+1, p)
	if r>m:
		update(l, r, c, m, e, p)
	pushup(p)

def query(l, r, s, e, p):
	if l<=s and e<=r:
		return tree[p]
	m = s + (e-s>>1)
	sum = 0
	if l<=m:
		sum += query(l, r, s, m+1, p)
	if r>m:
		sum += query(l, r,  m, e, p)
	return sum

Atcoder线段树板子 常数较大 刷题时遇到卡常 需优化。

class LazySegmentTree:
    """
    Reference
    https://github.com/atcoder/ac-library/blob/master/atcoder/lazysegtree.hpp
    https://github.com/atcoder/ac-library/blob/master/document_en/lazysegtree.md
    https://github.com/atcoder/ac-library/blob/master/document_ja/lazysegtree.md
    https://leetcode.cn/circle/discuss/4rJDBt/

    """

    def __init__(self, n, op, e, mapping, composition, id):
        self.n = n
        self.op = op
        self.e = e
        self.mapping = mapping
        self.composition = composition
        self.id = id
        self.log = (n - 1).bit_length()
        self.size = 1 << self.log
        self.d = [e] * (2 * self.size)
        self.lz = [id] * (self.size)

    def update(self, k):
        self.d[k] = self.op(self.d[2 * k], self.d[2 * k + 1])

    def all_apply(self, k, f):
        self.d[k] = self.mapping(f, self.d[k])
        if k < self.size:
            self.lz[k] = self.composition(f, self.lz[k])

    def push(self, k):
        self.all_apply(2 * k, self.lz[k])
        self.all_apply(2 * k + 1, self.lz[k])
        self.lz[k] = self.id

    def build(self, v):
        assert len(v) <= self.n
        for i in range(len(v)):
            self.d[self.size + i] = v[i]
        for i in range(self.size - 1, 0, -1):
            self.update(i)

    def set(self, p, x):
        assert 0 <= p < self.n
        p += self.size
        for i in range(self.log, 0, -1):
            self.push(p >> i)
        self.d[p] = x
        for i in range(1, self.log + 1):
            self.update(p >> i)

    def get(self, p):
        assert 0 <= p < self.n
        p += self.size
        for i in range(self.log, 0, -1):
            self.push(p >> i)
        return self.d[p]

    def prod(self, l, r):
        assert 0 <= l <= r <= self.n
        if l == r:
            return self.e
        l += self.size
        r += self.size
        for i in range(self.log, 0, -1):
            if (l >> i) << i != l:
                self.push(l >> i)
            if (r >> i) << i != r:
                self.push((r - 1) >> i)
        sml = smr = self.e
        while l < r:
            if l & 1:
                sml = self.op(sml, self.d[l])
                l += 1
            if r & 1:
                r -= 1
                smr = self.op(self.d[r], smr)
            l >>= 1
            r >>= 1
        return self.op(sml, smr)

    def all_prod(self):
        return self.d[1]

    def apply(self, l, r, f):
        assert 0 <= l <= r <= self.n
        if l == r:
            return
        l += self.size
        r += self.size
        for i in range(self.log, 0, -1):
            if (l >> i) << i != l:
                self.push(l >> i)
            if (r >> i) << i != r:
                self.push((r - 1) >> i)
        l2 = l
        r2 = r
        while l < r:
            if l & 1:
                self.all_apply(l, f)
                l += 1
            if r & 1:
                r -= 1
                self.all_apply(r, f)
            l >>= 1
            r >>= 1
        l = l2
        r = r2
        for i in range(1, self.log + 1):
            if (l >> i) << i != l:
                self.update(l >> i)
            if (r >> i) << i != r:
                self.update((r - 1) >> i)

    def max_right(self, l, g):
        assert 0 <= l <= self.n
        assert g(self.e)
        if l == self.n:
            return self.n
        l += self.size
        for i in range(self.log, 0, -1):
            self.push(l >> i)
        sm = self.e
        while True:
            while l % 2 == 0:
                l >>= 1
            if not g(self.op(sm, self.d[l])):
                while l < self.size:
                    self.push(l)
                    l = 2 * l
                    if g(self.op(sm, self.d[l])):
                        sm = self.op(sm, self.d[l])
                        l += 1
                return l - self.size
            sm = self.op(sm, self.d[l])
            l += 1
            if (l & -l) == l:
                return self.n

    def min_left(self, r, g):
        assert 0 <= r <= self.n
        assert g(self.e)
        if r == 0:
            return 0
        r += self.size
        for i in range(self.log, 0, -1):
            self.push((r - 1) >> i)
        sm = self.e
        while True:
            r -= 1
            while r > 1 and r % 2:
                r >>= 1
            if not g(self.op(self.d[r], sm)):
                while r < self.size:
                    self.push(r)
                    r = 2 * r + 1
                    if g(self.op(self.d[r], sm)):
                        sm = self.op(self.d[r], sm)
                        r -= 1
                return r + 1 - self.size
            sm = self.op(self.d[r], sm)
            if (r & -r) == r:
                return 0
CLICK ME
Hello world!

动态开点线段树(待补充)

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值