区间问题算法整理

贪心 + 排序

1、合并区间

class Solution:
    def merge(self, intervals: List[List[int]]) -> List[List[int]]:
        res = []
        intervals.sort(key=lambda x:x[0])
        temp = [intervals[0][0], intervals[0][1]]
        for a in intervals:
            if temp[1] < a[0]:
                res.append(temp)
                temp = [a[0], a[1]]
            elif temp[1] < a[1]:
                temp[1] = a[1]
        else:
            res.append(temp)
        return res

2、插入区间

class Solution:
    def insert(self, intervals: List[List[int]], newInterval: List[int]) -> List[List[int]]:
        left, right = newInterval
        placed = False
        ans = list()
        for li, ri in intervals:
            if li > right:
                # 在插入区间的右侧且无交集
                if not placed:
                    ans.append([left, right])
                    placed = True
                ans.append([li, ri])
            elif ri < left:
                # 在插入区间的左侧且无交集
                ans.append([li, ri])
            else:
                # 与插入区间有交集,计算它们的并集
                left = min(left, li)
                right = max(right, ri)
        
        if not placed:
            ans.append([left, right])
        return ans

 3、无重叠区间

class Solution:
    def eraseOverlapIntervals(self, intervals: List[List[int]]) -> int:
        intervals.sort(key=lambda x:x[1])
        res, cur = 0, -float('inf')
        for x, y in intervals:
            if x >= cur:
                cur = y
                res += 1
        return len(intervals) - res

4、寻找右区间

class Solution:
    def findRightInterval(self, intervals: List[List[int]]) -> List[int]:
        n = len(intervals)
        starts, ends = list(zip(*intervals))
        starts = sorted(zip(starts, range(n)))
        ends = sorted(zip(ends, range(n)))

        ans, j = [-1] * n, 0
        for end, id in ends:
            while j < n and starts[j][0] < end:
                j += 1
            if j < n:
                ans[id] = starts[j][1]
        return ans

5、删除被覆盖区间

class Solution:
    def removeCoveredIntervals(self, intervals: List[List[int]]) -> int:
        intervals.sort(key=lambda x:(x[0], -x[1]))
        n, res = len(intervals), 0
        rmax = 0
        for i in range(n):
            a, b = intervals[i]
            if b <= rmax: res += 1
            rmax = max(rmax, b)
        return n - res

6、删除区间

class Solution:
    def removeInterval(self, intervals: List[List[int]], rm: List[int]) -> List[List[int]]:
        res = []
        for f, t in intervals:
            if t <= rm[0] or f >= rm[1]:
                res.append([f, t])
                continue
            if f < rm[0]:
                res.append([f, rm[0]])
            if t > rm[1]:
                res.append([rm[1], t])
        return res

7、Finding the Number of Visible Mountains-- 去除完全被覆盖的区间

class Solution:
    def visibleMountains(self, peaks: List[List[int]]) -> int:
        intervals, res = [], 0
        for c, h in peaks:
            intervals.append([c - h, c + h])
        intervals.sort(key=lambda x:(x[0], -x[1]))
        temp = [intervals[0][0], intervals[0][1]]
        repeat = set()
        for a in intervals[1:]:
            if temp[0] == a[0] and temp[1] == a[1]:
                repeat.add((a[0], a[1]))
            if not temp[0] <= a[0] < a[1] <= temp[1]:
                temp = a                
                res += 1 
        else:
            res += 1
        return res - len(repeat)

双指针

1、安排会议日程

class Solution:
    def minAvailableDuration(self, slots1: List[List[int]], slots2: List[List[int]], duration: int) -> List[int]:
        slots1.sort()
        slots2.sort()
        i, j = 0, 0
        while i < len(slots1) and j < len(slots2):
            if slots1[i][1] <= slots2[j][0]:
                i += 1
            elif slots2[j][1] <= slots1[i][0]:
                j += 1
            else:
                start = max(slots1[i][0], slots2[j][0])
                end = min(slots1[i][1], slots2[j][1])
                if end - start >= duration:
                    return [start, start + duration]
                if slots1[i][1] < slots2[j][1]:
                    i += 1
                else:
                    j += 1
        return []

差分数组

1、区间加法

class Solution:
    def getModifiedArray(self, length: int, updates: List[List[int]]) -> List[int]:
        p = [0] * (length + 1)
        res = [0] * length
        for update in updates:
            p[update[0]] += update[2]
            p[update[1] + 1] -= update[2]
        temp = 0
        for i in range(length):
            temp += p[i]
            res[i] += temp
        return res 

2、字母移位 II -- 注意正负数不能混用

class Solution:
    def shiftingLetters(self, s: str, shifts: List[List[int]]) -> str:
        diff1 = [0] * (len(s) + 1)
        diff2 = [0] * (len(s) + 1)
        for f, t, n in shifts:
            if n == 0:
                diff1[f] += 1
                diff1[t + 1] -= 1
            else:
                diff2[f] += 1
                diff2[t + 1] -= 1
        res, cur1, cur2 = "", 0, 0
        for i, c in enumerate(s):
            cur1 += diff1[i]
            cur2 += diff2[i]
            res += chr(ord('a') + (ord(c) - ord('a') - cur1 + cur2 + 26) % 26)
        return res

3、Count Positions on Street With Required Brightness 

class Solution:
    def meetRequirement(self, n: int, lights: List[List[int]], requirement: List[int]) -> int:
        diff = [0] * n
        m_end = [0] * n
        for c, r in lights:
            diff[max(0, c - r)] += 1
            diff[min(n - 1, c + r)] -= 1
            m_end[min(n - 1, c + r)] += 1
        cur, res = 0, 0
        for i, r in enumerate(requirement):
            cur += diff[i]
            if cur + m_end[i] >= r:
                res += 1
        return res

注:连续区间左闭右闭的情况,需要单独设置数据结构,来存储右闭的情况

有序集合

1、我的日程安排表 I

from sortedcontainers import SortedDict as SD

class MyCalendar:

    def __init__(self):
        self.end_start = SD()

    def book(self, start: int, end: int) -> bool:
        ID = self.end_start.bisect_right(start)
        if 0 <= ID < len(self.end_start):
            if self.end_start.values()[ID] < end:
                return False
        self.end_start[end] = start
        return True

2、我的日程安排表 II

from sortedcontainers import SortedDict
class MyCalendarTwo:

    def __init__(self):
        self.counter = SortedDict()

    def book(self, start: int, end: int) -> bool:
        if start not in self.counter:
            self.counter[start] = 0
        if end not in self.counter:
            self.counter[end] = 0

        self.counter[start] += 1
        self.counter[end] -= 1
        acc = 0
        for k in self.counter:
            acc += self.counter[k]
            if acc >= 3:
                self.counter[start] -= 1
                self.counter[end] += 1
                return False
        
        return True

1、两个最好的不重叠活动

class Solution:
    def maxTwoEvents(self, events: List[List[int]]) -> int:
        n, v, res, q = len(events), 0, 0, []
        events.sort()
        for i in range(n):
            while q and q[0][0] < events[i][0]:
                v = max(v, heapq.heappop(q)[1])
            res = max(res, v + events[i][2])
            heapq.heappush(q,(events[i][1],events[i][2]))
        return res

2、将区间分为最少组数

class Solution:
     def minGroups(self, intervals: List[List[int]]) -> int:
        tmp = []
        intervals.sort()
        for x, y in intervals:
            if len(tmp) != 0 and tmp[0] < x:
                heappop(tmp)
            heappush(tmp, y)
        return len(tmp)

3、员工空闲时间

class Solution(object):
    def employeeFreeTime(self, avails):
        ans = []
        pq = [(emp[0].start, ei, 0) for ei, emp in enumerate(avails)]
        heapq.heapify(pq)
        anchor = min(iv.start for emp in avails for iv in emp)
        while pq:
            t, e_id, e_jx = heapq.heappop(pq)
            if anchor < t:
                ans.append(Interval(anchor, t))
            anchor = max(anchor, avails[e_id][e_jx].end)
            if e_jx + 1 < len(avails[e_id]):
                heapq.heappush(pq, (avails[e_id][e_jx+1].start, e_id, e_jx+1))

        return ans

区间并查集

1、每天绘制新区域的数量

class Solution:
    def amountPainted(self, paint: List[List[int]]) -> List[int]:
        line, res = [0] * 50001, [0] * len(paint)
        for i, (start, end) in enumerate(paint):
            while start < end:
                jump = max(start + 1, line[start])
                res[i] += 1 if line[start] == 0 else 0
                line[start] = max(line[start], end)  # compression
                start = jump
        return res

动态规划

1、销售利润最大化

class Solution:
    def maximizeTheProfit(self, n: int, events: List[List[int]]) -> int:
        events.sort(key= lambda x:x[1])
        dp = [0] * n
        maxl, j = 0, 0
        for s, e, v in events:
            while j < e:
                maxl = max(dp[j], maxl)
                dp[j] = maxl
                j += 1        
            dp[e] = max(dp[e], (dp[s - 1] if s > 0 else 0) + v)
        return max(dp)

2、 规划兼职工作

由于数据量比较大,如果用上面的方法会导致超时,所以再利用二分查找

class Solution:
    def jobScheduling(self, startTime: List[int], endTime: List[int], profit: List[int]) -> int:
        jobs = sorted(zip(startTime, endTime, profit), key=lambda x: x[1])
        dp = [[0, 0]]
        for s, e, p in jobs:
            i = bisect.bisect(dp, [s + 1]) - 1
            if dp[i][1] + p > dp[-1][1]:
                dp.append([e, dp[i][1] + p])
        return dp[-1][1]

如果有区间数量的限制,则需要加入一维状态空间用来表示区间数目,例如 -- 

最多可以参加的会议数目 II

树状数组

1、区域和检索 - 数组可修改 

class BitTree:
    def __init__(self, n_: int):
        self.n = n_
        self.tree = [0 for _ in range(n_ + 1)]
        self.merge = lambda a, b: a + b

    def lowbit(self, x: int) -> int:
        return x & (-x)
    
    def add(self, i: int, val: int) -> None:
        i += 1
        while i <= self.n:
            self.tree[i] = self.merge(self.tree[i], val)
            i += self.lowbit(i)
    
    def query(self, i: int) -> int:
        i += 1
        res = 0
        while 1 <= i:
            res = self.merge(res, self.tree[i])
            i -= self.lowbit(i)
        return res


class NumArray:

    def __init__(self, nums: List[int]):
        n = len(nums)
        self.nums = nums
        self.BT = BitTree(n)
        for i, x in enumerate(nums):
            self.BT.add(i, x)

    def update(self, index: int, val: int) -> None:
        old_val = self.nums[index]
        self.BT.add(index, -old_val)
        self.BT.add(index, val)
        self.nums[index] = val

    def sumRange(self, left: int, right: int) -> int:
        res = self.BT.query(right) - self.BT.query(left - 1)
        return res

2、极值树状数组 -- 价格递增的最大利润三元组 II

class BitTree:
    def __init__(self, n_: int):
        self.n = n_       
        self.merge = max
        self.init_val = -inf if self.merge == max else( inf if self.merge == min else 0)  
        self.a = [self.init_val for _ in range(n_ + 1)]
        self.tree = [self.init_val for _ in range(n_ + 1)]
        
    def lowbit(self, x: int) -> int:
        return x & (-x)
    
    def update(self, x: int, val: int) -> None:
        x += 1
        self.a[x] = self.merge(val, self.a[x])
        i = x
        while i <= self.n:
            self.tree[i] = self.merge(self.tree[i], val)
            i += self.lowbit(i)
             
    def query(self, x, y):
        x += 1
        y += 1
        if x > y: return self.init_val  
        res = self.init_val
        while True:
            res = self.merge(res, self.a[y])
            if x == y: break
            y -= 1
            while y - x >= self.lowbit(y):
                res = self.merge(res, self.tree[y])
                y -= self.lowbit(y)
        return res
    
    def query1(self, x):
        x += 1
        res = self.init_val
        while x > 0:
            res = self.merge(self.tree[x], res)
            x -= self.lowbit(x)
        return res
  

class Solution:
    def maxProfit(self, prices: List[int], profits: List[int]) -> int:
        n = len(prices)
        max_p = max(prices)
        left = [-inf] * n
        left_bitTree = BitTree(max_p + 1)
        for i, p in enumerate(prices):
            left[i] = left_bitTree.query1(p - 1)
            left_bitTree.update(p, profits[i])
            
        right_bitTree = BitTree(max_p + 1)
        maxl = -1
        for i in reversed(range(n)):
            p = prices[i]
            right = right_bitTree.query1(max_p - p - 1)
            maxl = max(maxl, left[i] + profits[i] + right)
            right_bitTree.update(max_p - p, profits[i])
        return maxl 

线段树

1、区间极值线段树  --  天际线问题

class Solution:
    def getSkyline(self, buildings: List[List[int]]) -> List[List[int]]:
        pos = set()
        for left, right, _ in buildings:
            pos.add(left)
            pos.add(right)
        lst = sorted(list(pos))
        n = len(lst)
        dct = {x: i for i, x in enumerate(lst)}
        # 离散化更新线段树
        segment = PeakIntervalSegmentTree(n, update_type='other')
        for left, right, height in buildings:
            segment.update(dct[left], dct[right]-1, height)
        # 按照端点进行关键点查询
        pre = -1
        ans = []
        for pos in lst:
            h = segment.query(dct[pos], dct[pos])
            if h != pre:
                ans.append([pos, max(h, 0)])
                pre = h 
        return ans


class PeakIntervalSegmentTree:

    def __init__(self, n, merge=max, update_type='reset'):
        self.init_val = {max: -float('inf'), min: float('inf')}
        self.tree = defaultdict(lambda: self.init_val[merge])
        self.lazy = defaultdict(lambda: self.init_val[merge])
        self.n = n
        self._merge = merge
        self.update_type = update_type

    def query(self, l, r):
        return self._query(l, r, 0, self.n - 1, 1)

    def update(self, l, r, val):
        return self._update(l, r, 0, self.n - 1, val, 1)

    def push_down(self, i):
        if self.lazy[i] != self.init_val[self._merge]:
            self.lazy[2 * i] = self._merge(self.lazy[2 * i], self.lazy[i])
            self.lazy[2 * i + 1] = self._merge(self.lazy[2 * i + 1], self.lazy[i])
            self.tree[2 * i] = self._merge(self.tree[2 * i], self.lazy[i])
            self.tree[2 * i + 1] = self._merge(self.tree[2 * i + 1], self.lazy[i])
            self.lazy[i] = self.init_val[self._merge]

    def _update(self, l, r, s, t, val, i):
        if l <= s and t <= r:
            if self.update_type == 'reset':
                self.tree[i] = self.lazy[i] = val
            else:
                self.tree[i] = self._merge(self.tree[i], val)
                self.lazy[i] = self._merge(self.lazy[i], val)
        else:
            self.push_down(i)
            m = s + (t - s) // 2
            if l <= m:
                self._update(l, r, s, m, val, 2 * i)
            if r > m:
                self._update(l, r, m + 1, t, val, 2 * i + 1)
            self.tree[i] = self._merge(self.tree[2 * i], self.tree[2 * i + 1])

    def _query(self, l, r, s, t, i):
        if l <= s and t <= r:
            return self.tree[i]
        self.push_down(i)
        m = s + (t - s) // 2
        res = self.init_val[self._merge]
        if l <= m:
            cur = self._query(l, r, s, m, 2 * i)
            res = self._merge(res, cur)
        if r > m:
            cur = self._query(l, r, m + 1, t, 2 * i + 1)
            res = self._merge(res, cur)
        return res

2、区间求和线段树 -- 统计区间中的整数数目

class CountIntervals:

    def __init__(self):
        self.s = SumIntervalSegmentTree(10 ** 9 + 1, update_type='reset')

    def add(self, left: int, right: int) -> None:
        self.s.update(left, right, 1)

    def count(self) -> int:
        return self.s.query(0, 10 ** 9)
 
class SumIntervalSegmentTree:
    def __init__(self, n, update_type='add'):
        self.tree = defaultdict(int)
        self.init_val = 0
        self.lazy = defaultdict(lambda: self.init_val)
        self.n = n
        self.update_type = update_type

    def query(self, l, r):
        return self._query(l, r, 0, self.n - 1, 1)

    def update(self, l, r, val):
        return self._update(l, r, 0, self.n - 1, val, 1)

    def push_down(self, i, l, r):
        le = (r + l >> 1) - l + 1
        re = r - l + 1 - le
        if self.update_type == 'add' and self.lazy[i] != self.init_val:
            self.lazy[2 * i] += self.lazy[i]
            self.lazy[2 * i + 1] += self.lazy[i]
            self.tree[2 * i] += self.lazy[i] * le
            self.tree[2 * i + 1] += self.lazy[i] * re
            self.lazy[i] = self.init_val
        elif self.update_type == 'reset' and self.lazy[i] != self.init_val:
            self.lazy[2 * i] = self.lazy[i]
            self.lazy[2 * i + 1] = self.lazy[i]
            self.tree[2 * i] = self.lazy[i] * le
            self.tree[2 * i + 1] = self.lazy[i] * re
            self.lazy[i] = self.init_val

    def _update(self, l, r, s, t, val, i):
        if l <= s and t <= r:
            if self.update_type == 'add':
                self.tree[i] += (t - s + 1) * val
                self.lazy[i] += val
            else:
                self.tree[i] = (t - s + 1) * val
                self.lazy[i] = val
        else:
            self.push_down(i, s, t)
            m = s + (t - s) // 2
            if l <= m:
                self._update(l, r, s, m, val, 2 * i)
            if r > m:
                self._update(l, r, m + 1, t, val, 2 * i + 1)
            self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]

    def _query(self, l, r, s, t, i):
        if l <= s and t <= r:
            return self.tree[i]
        self.push_down(i, s, t)
        m = s + (t - s) // 2
        res = 0
        if l <= m:
            res += self._query(l, r, s, m, 2 * i)
        if r > m:
            res += self._query(l, r, m + 1, t, 2 * i + 1)
        return res

3、二维求和线段树 -- 二维区域和检索 - 可变

class NumMatrix:

    def __init__(self, matrix: List[List[int]]):
        self.segmentTree2D = SegmentTree2D(matrix, merge=SegmentTree2D.add)
        
    def update(self, row: int, col: int, val: int) -> None:
        self.segmentTree2D.update(row, col, val)

    def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
        return self.segmentTree2D.getRegion(row1, col1, row2, col2)


class Node:
    def __init__(self, topIndex, bottomIndex, leftIndex, rightIndex, val):
        self.topIndex = topIndex
        self.bottomIndex = bottomIndex
        self.leftIndex = leftIndex
        self.rightIndex = rightIndex
        self.val = val
        self.leftTop = None
        self.rightTop = None
        self.leftBottom = None
        self.rightBottom = None

class SegmentTree2D:

    add = lambda a, b: a + b
    init_val_map = {max: -float('inf'), min: float('inf'), add: 0}

    def __init__(self, matrix, merge=max):
        self.matrix = matrix
        n, m = len(self.matrix), len(self.matrix[0])
        if m == 0 or n == 0: return
        self.init_val = SegmentTree2D.init_val_map[merge]
        self.merge = merge
        self.root = self._buildTree(0, n-1, 0, m-1)

    def _buildTree(self, topIndex, bottomIndex, leftIndex, rightIndex):
        # 判断上下左右边界是否合法
        if topIndex > bottomIndex or leftIndex > rightIndex:
            return None
        if topIndex == bottomIndex and leftIndex == rightIndex:
            return Node(topIndex, bottomIndex, leftIndex, rightIndex, self.matrix[topIndex][leftIndex])

        root = Node(topIndex, bottomIndex, leftIndex, rightIndex, self.init_val)
        rowMid = (topIndex + bottomIndex) // 2
        colMid = (leftIndex + rightIndex) // 2
        root.leftTop = self._buildTree(topIndex, rowMid, leftIndex, colMid)
        root.rightTop = self._buildTree(topIndex, rowMid, colMid+1, rightIndex)
        root.leftBottom = self._buildTree(rowMid+1, bottomIndex, leftIndex, colMid)
        root.rightBottom = self._buildTree(rowMid+1, bottomIndex, colMid+1, rightIndex)
        # 子node的值被设定后,即可计算出node所代表的矩阵和
        if root.leftTop:
            root.val = self.merge(root.val, root.leftTop.val)
        if root.rightTop:
            root.val = self.merge(root.val, root.rightTop.val)
        if root.leftBottom:
            root.val = self.merge(root.val, root.leftBottom.val)
        if root.rightBottom:
            root.val = self.merge(root.val, root.rightBottom.val)
        return root

    def update(self, row, col, val):
        self._update(self.root, row, col, val)

    def _update(self, root, row, col, val):
        if root.topIndex == root.bottomIndex == row and root.leftIndex == root.rightIndex == col:
            root.val = val
            return
        rowMid = (root.topIndex + root.bottomIndex) // 2
        colMid = (root.leftIndex + root.rightIndex) // 2
        if row <= rowMid and col <= colMid:
            self._update(root.leftTop, row, col, val)
        elif row <= rowMid and col >= colMid+1:
            self._update(root.rightTop, row, col, val)
        elif row >= rowMid+1 and col <= colMid:
            self._update(root.leftBottom, row, col, val)
        else:
            self._update(root.rightBottom, row, col, val)
        # 子node的矩阵和被更新后,node所代表的矩阵和也需要清零,重新计算
        root.val = 0
        if root.leftTop:
            root.val = self.merge(root.val, root.leftTop.val)
        if root.rightTop:
            root.val = self.merge(root.val, root.rightTop.val)
        if root.leftBottom:
            root.val = self.merge(root.val, root.leftBottom.val)
        if root.rightBottom:
            root.val = self.merge(root.val, root.rightBottom.val)

    def getRegion(self, row1, col1, row2, col2):
        return self._getRegion(self.root, row1, col1, row2, col2)

    def _getRegion(self, root, row1, col1, row2, col2):
        if root.topIndex == row1 and root.bottomIndex == row2 and root.leftIndex == col1 and root.rightIndex == col2:
            return root.val
        rowMid = (root.topIndex + root.bottomIndex) // 2
        colMid = (root.leftIndex + root.rightIndex) // 2
        region = self.init_val
        # 注意收缩上下左右边界
        if row1 <= rowMid and col1 <= colMid:
            region = self.merge(region, self._getRegion(root.leftTop, row1, col1, min(rowMid, row2), min(colMid, col2)))
        if row1 <= rowMid and col2 >= colMid+1:
            region = self.merge(region, self._getRegion(root.rightTop, row1, max(colMid + 1, col1), min(rowMid, row2), col2))
        if row2 >= rowMid+1 and col1 <= colMid:
            region = self.merge(region, self._getRegion(root.leftBottom, max(rowMid + 1, row1), col1, row2, min(colMid, col2)))
        if row2 >= rowMid+1 and col2 >= colMid+1:
            region = self.merge(region, self._getRegion(root.rightBottom, max(rowMid + 1, row1), max(colMid + 1, col1), row2, col2))
        return region

4、二维区间极值线段树 -- 矩形面积 II 

class Interval:
    def __init__(self, left=0, right=0, down=0, up=0):
        self.left = left
        self.right = right
        self.down = down
        self.up = up
    
    def __str__(self):
        return f"({self.left}, {self.right}, {self.down}, {self.up})"
    
    def area(self):
        return (self.right - self.left) * (self.up - self.down)

class SegmentNode:
    def __init__(self, interval: Interval = Interval(), value: int = 0, children=[]):
        self.interval = interval
        self.value = value
        self.children = children
    
    def __repr__(self):
        return str(self.value) + '__' + str(self.interval)

    def area(self):
        cur = self.interval
        return (cur.right - cur.left) * (cur.up - cur.down)

    def contains(self, target):
        current = self.interval
        return target.left <= current.left \
            and target.down <= current.down \
            and target.right >= current.right \
            and target.up >= current.up
    
    def twoDSplit(self, midx, midy):
        cur = self.interval
        left, right, up, down = cur.left, cur.right, cur.up, cur.down
        if left > right or down > up: return []
        if (right - left) * (up - down) <= 1: return []
        mx, my = midx, midy
        result = [(l, r, d, u) for l, r in [(left, mx), (mx, right)] for d, u in [(down, my), (my, up)] ]
        def fa(a):
            left, right, down, up = a
            return left <= right and down <= up \
                    and (right - left) * (up - down) and not(left == cur.left and
                    down == cur.down and up == cur.up and right == cur.right)
        return list(filter(fa, result))

    def change(self, target, value):
        current = self.interval
        if target.up < current.down or \
                target.right < current.left or \
                target.down > current.up or \
                target.left > current.right:
            return
        if self.contains(target):
            if len(self.children) == 0:
                self.value = value * self.area()
                return
        else:
            if len(self.children) == 0:
                t = list(filter(lambda a: current.left < a and current.right > a, [target.left, target.right]))
                midx = t[0] if t else current.right
                t = list(filter(lambda a: current.down < a and current.up > a, [target.down, target.up]))
                midy = t[0] if t else current.up
                subinterval = self.twoDSplit(midx, midy)

                if len(subinterval) > 0:
                    def mp(c):
                        it = Interval(c[0], c[1], c[2], c[3])
                        return SegmentNode(it, it.area() * int(self.value > 0))

                    self.children = list(map(mp, subinterval))

        if len(self.children) > 0:
            for child in self.children:
                child.change(target, value)
            self.value = reduce(lambda a, n: a + n.value, self.children, 0) if len(self.children) > 0 else self.value

            if self.value == value * self.area() or self.value == 0:
                self.children = []
                
class Solution:
    def rectangleArea(self, rectangles: List[List[int]]) -> int:
        left, right, down, up = inf, -inf, inf, -inf
        for l, d, r, u in rectangles:
            left = min(left, l)
            right = max(right, r)
            down = min(down, d)
            up = max(up, u)
        s = SegmentNode(Interval(left, right, down, up))
        for l, d, r, u in rectangles:
            s.change(Interval(l, r, d, u), 1)
        return s.value % (10 ** 9 + 7)

 5、更新数组后处理求和查询(区间反转字符)

class Solution:
    def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
        ans = []
        n = len(nums1)
        seg = LazySegmentTree(nums1)
        tot = sum(nums2)
        for t, x, y in queries:
            if t == 1: seg.update(x, y)
            elif t == 2: tot += seg.query(0, n-1) * x
            else: ans.append(tot)
        return ans

class LazySegmentTree:
    def __init__(self, nums):
        self.n = len(nums)
        self.nums = nums
        self.ones = [0] * (4 * self.n)
        self.lazy = [True] * (4 * self.n)
        self._build(0, 0, self.n-1)

    def _build(self, tree_index, l, r):
        if l == r:
            self.ones[tree_index] = self.nums[l]
            return
        left, right = 2 * tree_index + 1, 2 * tree_index + 2
        mid = (l + r) // 2
        self._build(left, l, mid)
        self._build(right, mid+1, r)
        self.ones[tree_index] = self.ones[left] + self.ones[right]

    def update(self, ql, qr):
        self._update(0, 0, self.n-1, ql, qr)

    def _update(self, tree_index, l, r, ql, qr):
        if l == ql and r == qr:
            self.lazy[tree_index] = not self.lazy[tree_index]
            ones = self.ones[tree_index]
            zeros = r - l + 1 - ones
            self.ones[tree_index] = zeros
            return
        left, right = 2 * tree_index + 1, 2 * tree_index + 2
        mid = (l + r) // 2
        if not self.lazy[tree_index]:
            self._update(left, l, mid, l, mid)
            self._update(right, mid+1, r, mid+1, r)
            self.lazy[tree_index] = True
        if qr <= mid: self._update(left, l, mid, ql, qr)
        elif ql > mid: self._update(right, mid+1, r, ql, qr)
        else:
            self._update(left, l, mid, ql, mid)
            self._update(right, mid+1, r, mid+1, qr)
        self.ones[tree_index] = self.ones[left] + self.ones[right]

    def query(self, ql, qr):
        return self._query(0, 0, self.n-1, ql, qr)

    def _query(self, tree_index, l, r, ql, qr):
        if l == ql and r == qr:
            return self.ones[tree_index]
        left, right = 2 * tree_index + 1, 2 * tree_index + 2
        mid = (l + r) // 2
        if not self.lazy[tree_index]:
            self._update(left, l, mid, l, mid)
            self._update(right, mid+1, r, mid+1, r)
            self.lazy[tree_index] = True
        if qr <= mid: return self._query(left, l, mid, ql, qr)
        if ql > mid: return self._query(right, mid+1, r, ql, qr)
        ones1 = self._query(left, l, mid, ql, mid)
        ones2 = self._query(right, mid+1, r, mid+1, qr)
        return ones1 + ones2

 6、线段树摩尔定律(子数组中占绝大多数的元素

class Node:
    def __init__(self, x: int = 0, cnt: int = 0):
        self.x = x
        self.cnt = cnt
    
    def __iadd__(self, that: "Node") -> "Node":
        if self.x == that.x:
            self.cnt += that.cnt
        elif self.cnt >= that.cnt:
            self.cnt -= that.cnt
        else:
            self.x = that.x
            self.cnt = that.cnt - self.cnt
        return self

class MajorityChecker:
    def __init__(self, arr: List[int]):
        self.n = len(arr)
        self.arr = arr
        self.loc = defaultdict(list)

        for i, val in enumerate(arr):
            self.loc[val].append(i)
        
        self.tree = [Node() for _ in range(self.n * 4)]
        self.seg_build(0, 0, self.n - 1)

    def query(self, left: int, right: int, threshold: int) -> int:
        loc_ = self.loc

        ans = Node()
        self.seg_query(0, 0, self.n - 1, left, right, ans)
        pos = loc_[ans.x]
        occ = bisect_right(pos, right) - bisect_left(pos, left)
        return ans.x if occ >= threshold else -1
    
    def seg_build(self, idx: int, l: int, r: int):
        arr_ = self.arr
        tree_ = self.tree

        if l == r:
            tree_[idx] = Node(arr_[l], 1)
            return
        
        mid = (l + r) // 2
        self.seg_build(idx * 2 + 1, l, mid)
        self.seg_build(idx * 2 + 2, mid + 1, r)
        tree_[idx] += tree_[idx * 2 + 1]
        tree_[idx] += tree_[idx * 2 + 2]

    def seg_query(self, idx: int, l: int, r: int, ql: int, qr: int, ans: Node):
        tree_ = self.tree

        if l > qr or r < ql:
            return
        
        if ql <= l and r <= qr:
            ans += tree_[idx]
            return

        mid = (l + r) // 2
        self.seg_query(idx * 2 + 1, l, mid, ql, qr, ans)
        self.seg_query(idx * 2 + 2, mid + 1, r, ql, qr, ans)

 也可以用随机化 + 二分方法解决

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值