[python刷题模板] 线段树-树形结构(非数组实现,用于动态开点)

@[TOC]([python刷题模板] 线段树-树形结构(非数组实现,用于动态开点) )

一、 算法&数据结构

1. 描述

线段树通常用来维护区间询问,通过二分的形式对数组分块,维护一个个小块上的属性(极值、求和等),用以实现O(logn)时间的查询和更新。
  • 我之前写过一遍线段树,那里用的都是数组结构,涉及到动态开点时,就得用映射表。参见[python刷题模板] 线段树
  • 这里留存一份树形结构实现的。
  • 树形实现时,需要初始化一个root作为根节点,储存总线段信息。
  • ~~慢慢施工。~~施工完毕。
  • 目前经我测试,树形写法和字典写法执行用时不相上下。

2. 复杂度分析

  1. 查询query, O(log2n)
  2. 更新update,O(log2n)

3. 常见应用

  1. 单点更新,区间求极值(最入门)
  2. 单点更新,区间求和(稍复杂)
  3. 区间更新,单点或区间求值,如果卡常数需要用到lazytag

4. 常用优化

  1. 设置lazytag,用于区间更新,判断全包含时,不再向下递归,一般卡常数可以搞,每次update和query都需要give_lazy_to_son
  2. 离散化,因为线段树维护的是整数,如果题目给的是实数(浮点、复数、过大的数),那么可以把数据离散化,毕竟数组长度一般不会太大。

二、 模板代码

1. 区间更新,区间询问最大值(IUIQ)

例题: 699. 掉落的方块
用线段树维护x轴上每个线段(区间)的最大高度

class ITreeNode:
    __slots__ = ['val','l','r','lazy']
    def __init__(self,l=None,r=None,v=0,lazy=0):
        self.val,self.l,self.r,self.lazy = v,l,r,lazy
        
class IntervalTree:
    def __init__(self):        
        self.root = ITreeNode()        

    def give_lay_to_son(self,node):
        if not node.l:
            node.l = ITreeNode()
        if not node.r:
            node.r = ITreeNode()
        if node.lazy == 0:
            return
        node.l.val = node.l.lazy = node.lazy
        node.r.val = node.r.lazy = node.lazy
        node.lazy = 0
        
    def update_from_son(self,node):
        node.val = max(node.l.val, node.r.val)
        
    def update(self,node,l,r,x,y,val):
        """
        把[x,y]区域全变成val
        """   
        if x <= l and r<=y:
            node.val= val
            node.lazy = val
            return
        self.give_lay_to_son(node)
        mid = (l+r)//2
        if x <= mid:
            self.update(node.l,l,mid,x,y,val)
        if mid < y:
            self.update(node.r,mid+1,r,x,y,val)
        self.update_from_son(node)
    
    def query(self,node,l,r,x,y):
        """
        查找x,y区间的最大值
        """        
        if x<=l and r<=y:
            return node.val
        self.give_lay_to_son(node)
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s = max(s,self.query(node.l,l,mid,x,y))
        if mid < y:
            s = max(s,self.query(node.r,mid+1,r,x,y))
        return s

class Solution:
    def fallingSquares(self, positions: List[List[int]]) -> List[int]:
        n = len(positions)
        hashes = [left for left,_ in positions] + [left+side for left,side in positions] 
        hashes = sorted(list(set(hashes)))
        # 用线段树维护x轴区间最大值,记录每个点的高度:比如[1,2]这个方块,会使线段[1,2]闭区间这个线段上的每个高度都变成2
        # 落下一个新方块时,查询它的底边所在线段[x,y]的最大高度h,这个方块会落在这个高度h,把新高度h+side插入线段树[x,y]的部分
        # 每次插入结束,树根存的高度就是当前最大高度
        # 由于数据范围大1 <= lefti <= 108,需要散列化
        # 散列化的值有left和right(线段短点)
        # print(hashes)
        tree_size = len(hashes)
        tree = IntervalTree()
        heights = []
        for left,d in positions:
            right = left + d 
            l = bisect_left(hashes,left)
            r = bisect_left(hashes,right)
            h = tree.query(tree.root,1,tree_size,l+1,r)
            tree.update(tree.root,1,tree_size,l+1,r,h+d)
            heights.append(tree.root.val)
        return heights

2. 矩形面积并

链接: 850. 矩形面积 II

线段树经典案例,涉及离散化扫描线,细节很多,非常难写

class IntervalTreeNode:
    __slots__ = 'l','r','length','cover'
    def __init__(self, l=None,r=None,length=0,cover=0):
        self.length = length
        self.cover = cover
        self.l = l
        self.r = r
class IntervalTree:
    def __init__(self,ys):
        self.root = IntervalTreeNode()
        self.ys = ys
    def update_from_son(self,node,l,r):
        if node.cover > 0:
            node.length = self.ys[r-1]-self.ys[l-1]
        else:
            if l+1 ==r:
                node.length = 0
            else:
                # if not node.l:
                #     node.l = IntervalTreeNode()
                # if not node.r:
                #     node.r = IntervalTreeNode()
                node.length = (node.l.length if node.l else 0 )+ (node.r.length if node.r else 0 )

    def insert(self,node,l,r,x,y,cover):        
        if y < l or r < x:
            return    
        if x<=l and r<=y:
            node.cover += cover
            self.update_from_son(node,l,r)
            return
        mid = (l+r)//2
        if x < mid:
            if not node.l:
                node.l = IntervalTreeNode()
            self.insert(node.l,l,mid,x,y,cover)
        if y > mid:
            if not node.r:
                node.r = IntervalTreeNode()
            self.insert(node.r,mid,r,x,y,cover)
        self.update_from_son(node,l,r)
            
class LineY:
    __slots__ = 'x','y1','y2','cover'
    def __init__(self,x,y1,y2,cover):
        self.x = x
        self.y1 = y1
        self.y2 = y2
        self.cover = cover
        
class Solution:
    def rectangleArea(self, rectangles: List[List[int]]) -> int:
        lines = []  # 所有竖线线段
        ys = set()  # 离散化
        for x1,y1,x2,y2 in rectangles:
            lines.append(LineY(x1,y1,y2,1))
            lines.append(LineY(x2,y1,y2,-1))
            ys.add(y1)
            ys.add(y2)
        lines.sort(key=lambda x:x.x)
        line_count = len(lines)
        ys = list(ys)
        ys.sort()
        size = len(ys)
        interval_tree = IntervalTree(ys=ys)

        ans = 0
        mod = int(1e9+7)
        for i in range(0,line_count):
            line = lines[i]
            # print(line.x,line.y1,line.y2,line.cover)
            y1 = bisect_left(ys,line.y1) 
            y2 = bisect_left(ys,line.y2)
            # if y1==y2:
            #     continue
            if i >0:
                ans += (line.x-lines[i-1].x) * interval_tree.root.length
                ans %= mod
            interval_tree.insert(interval_tree.root,1,size,y1+1,y2+1,line.cover)
        
        return ans

3.单点更新,区间求和

链接: 307. 区域和检索 - 数组可修改

线段树经典案例,比区间求极值麻烦一点点

class ITreeNode:
    __slots__ = ['val','l','r']
    def __init__(self,l=None,r=None,v=0):
        self.val,self.l,self.r = v,l,r
class IntervalTree:
    def __init__(self, nums):
        self.root = ITreeNode() 
        self.nums = nums 

    def build_tree(self,node,l,r):
        nums = self.nums
        if not node:
            node = ITreeNode()
        if l == r:            
            node.val = nums[l-1]
            return node
        mid = (l+r)//2
        node.l = self.build_tree(node.l,l,mid)        
        node.r = self.build_tree(node.r,mid+1,r)
        node.val = node.l.val + node.r.val
        return node
    
    def add_point(self,node,l,r,index,add):        
        if index < l or r < index:
            return 
        node.val += add
        if l == r:
            return
        mid = (l+r)//2
        if index <= mid:
            self.add_point(node.l,l,mid,index,add)
        else:
            self.add_point(node.r,mid+1,r,index,add)
    
    def sum_interval(self,node,l,r,x,y):        
        if y < l or r < x:
            return 0
        if x<=l and r<=y:
            return node.val
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s += self.sum_interval(node.l,l,mid,x,y)
        if mid < y:
            s += self.sum_interval(node.r,mid+1,r,x,y)
        return s

class NumArray:

    def __init__(self, nums: List[int]):
        self.size = len(nums)
        self.nums = nums
        self.interval_tree = IntervalTree( nums)
        self.interval_tree.build_tree(self.interval_tree.root,1,self.size)
        
    def update(self, index: int, val: int) -> None:         
        add = val - self.nums[index]
        self.nums[index] = val
        self.interval_tree.add_point(self.interval_tree.root,1,self.size,index+1,add)
        
    def sumRange(self, left: int, right: int) -> int:        
        return self.interval_tree.sum_interval(self.interval_tree.root,1,self.size,left+1,right+1)        

4.单点更新,区间求和

链接: 327. 区间和的个数

线段树这题麻烦一点,求区间内数字数量,每个数初始化为0,插入时候+1,计数转化为求和
参考链接: [LeetCode解题报告]327. 区间和的个数

class ITreeNode:
    __slots__ = ['val','l','r']
    def __init__(self,l=None,r=None,v=0):
        self.val,self.l,self.r = v,l,r
class IntervalTree:
    def __init__(self):
        self.root = ITreeNode()
    def update_from_son(self,node):
        node.val = (node.l.val if node.l else 0 )  + (node.r.val if node.r else 0)
        return node
    def insert(self,node,l,r,index):
        if index < l or r < index:
            return None        
        if not node:
            node = ITreeNode()
        if l == r:
            node.val += 1
            return node
        mid = (l+r)//2
        if index <= mid:
            node.l = self.insert(node.l,l,mid,index)
        else:
            node.r = self.insert(node.r,mid+1,r,index)        
        
        return self.update_from_son(node) 
    
    def query(self,node,l,r,x,y):
        if not node or y < l or r < x :
            return 0
        if x<=l and r<=y:
            return node.val
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s += self.query(node.l,l,mid,x,y) 
        if mid < y:
            s += self.query(node.r,mid+1,r,x,y) 
        return s

class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:         
        s = list(accumulate(nums,initial=0))
        hashes = s + [ x-lower for x in s] + [ x-upper for x in s]
        hashes = sorted(list(set(hashes)))
        # 生成前缀和,问题转化为,对于每个j,找左边的i,判断 s[j]-upper<=s[i]<=s[j]-lower,统计这些i的数量
        # 把所有前缀和数组中的数字插入线段树,并对这些数字划分区间,线段树维护当前区间数字数量,
        # 所以需要对这些数字都散列化
        tree_size = len(hashes)
        tree = IntervalTree()
        cnt = 0
        for i in s:
            x = bisect_left(hashes,i-upper)
            y = bisect_left(hashes,i-lower)
            j = bisect_left(hashes,i)
            c = tree.query(tree.root,1,tree_size, x+1,y+1)
            # print(x,y,j,c)
            cnt += c
            tree.insert(tree.root,1,tree_size,j+1)

        return cnt

5.区间更新,区间查询,无法离散化,动态开点。

链接: 732. 我的日程安排表 III

线段树这题由于只能在线做,不能做离散化,因此需要用字典维护线段端点,实现动态开点。
参考链接: [LeetCode解题报告] 732. 我的日程安排表 III


class ITreeNode:
    __slots__ = ['val','l','r','lazy']
    def __init__(self,l=None,r=None,v=0,lazy=0):
        self.val,self.l,self.r,self.lazy = v,l,r,lazy
class IntervalTree:
    def __init__(self):
        self.root = ITreeNode()    

    def give_lay_to_son(self,node):        
        if node.lazy == 0:
            return
        if not node.l:
            node.l = ITreeNode()
        node.l.val += node.lazy
        node.l.lazy += node.lazy
        if not node.r:
            node.r = ITreeNode()
        node.r.val += node.lazy
        node.r.lazy += node.lazy
        node.lazy = 0

    def update_from_son(self,node):
        node.val = max(node.l.val if node.l else 0, node.r.val if node.r else 0)
        return node

    def add(self,node,l,r,x,y,val):
        """
        把[x,y]区域全+val
        """
        if not node:
            node = ITreeNode()
        if r < x or y < l:  # 这里不加就会TLE
            return  node
        
        if x <= l and r<=y:
            node.val += val
            node.lazy += val
            return node
        self.give_lay_to_son(node)  #这题由于永远不会询问子区间,所以其实可以不向下give,直接在return的时候+lazy,会快一点。
        mid = (l+r)//2
        if x <= mid:
            node.l = self.add(node.l,l,mid,x,y,val)
        if mid < y:
            node.r = self.add(node.r,mid+1,r,x,y,val)
        return self.update_from_son(node) 
    
    def query(self,node,l,r,x,y):
        """
        查找x,y区间的最大值
        """        
        if not node:
            return 0
        if x<=l and r<=y:
            return node.val
        self.give_lay_to_son(node)
        mid = (l+r)//2
        s = 0
        if x <= mid:
            s = max(s,self.query(node.l,l,mid,x,y))
        if mid < y:
            s = max(s,self.query(node.r,mid+1,r,x,y))
        return s

class MyCalendarThree:

    def __init__(self):
        self.tree = IntervalTree()
   
    def book(self, start: int, end: int) -> int:
        self.tree.add(self.tree.root,1,10**9+1,start,end-1,1)
        return self.tree.query(self.tree.root,1,10**9+1,1,10**9+1)

6.单点更新,区间查询最小


class ITreeNode:
    __slots__ = ['val', 'l', 'r']

    def __init__(self, l=None, r=None, v=0):
        self.val, self.l, self.r = v, l, r


class IntervalTree:
    def __init__(self):
        self.root = ITreeNode()

    def update_from_son(self, node):
        node.val = min(node.l.val if node.l else inf, node.r.val if node.r else inf)
        return node

    def update_point(self, node, l, r, index, val):
        if index < l or r < index:
            return None
        if not node:
            node = ITreeNode(v=inf)
        if l == r:
            node.val = val
            return node
        mid = (l + r) // 2
        if index <= mid:
            node.l = self.update_point(node.l, l, mid, index, val)
        else:
            node.r = self.update_point(node.r, mid + 1, r, index, val)

        return self.update_from_son(node)

    def query(self, node, l, r, x, y):
        if not node or y < l or r < x:
            return inf
        if x <= l and r <= y:
            return node.val
        mid = (l + r) // 2
        s = inf
        if x <= mid:
            s = min(s, self.query(node.l, l, mid, x, y))
        if mid < y:
            s = min(s, self.query(node.r, mid + 1, r, x, y))
        return s

三、其他

  1. 如果还是卡常数,有些区间问题可以转化为树状数组,常数小,代码短,不过真的很难理解,还是线段树好写。遇到就套板吧。
  2. 有的区间更新题,珂朵莉也是可以过得。
  3. 我的测试来看,树形写法和map写法执行用时差不多,可以说是不相上下。

四、更多例题

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值