[python刷题模板] 珂朵莉树 ODT (基于SortedList 20220928弃用,请看新文)
一、 算法&数据结构
1. 描述
- 注意,本篇文章的珂朵莉树实现基于SortedList,assign操作常数较大(2000),因此大数据会TLE。
- 我重新整理了基于跳表的珂朵莉,见新文[python刷题模板] 珂朵莉树 ODT (基于支持随机访问的跳表
- 在网络上搜了很久没怎么找到珂朵莉树的python实现,自己写一个。
- 由于python的有序列表SortedList不是红黑树,所以split要先算begin,再算end。下标代替迭代器,这里和C++实现相反。
- SortedList批量删除使用del tree[begin:end],切片,原理是会用tree[:begin] 连接tree[end:]更新原来的tree。
- 珂朵莉树是建立在数据随机的情况下的一个乱搞算法。
- 当操作中有大量
区间赋值
动作,尤其是大区间
赋值,会把这个区间的值推平
,并且将原本这里边的很多小区间合并
。 - 如果数据随机,可以预见合并区间后,保留下来的区间不会很多。
- 那么查询区间如果不大,只需要处理几个区间即可。
- 底层用有序列表储存,c++set是红黑树,所以珂朵莉树也算树,
- 但python的sortedlist是list套list。。。可能名不副实了。
- 如果有心造数据卡珂朵莉树是可能实现的,数据不保证随机,大范围赋值少,查询多。
- 但这种的可以做特判,因为如果大范围赋值少,那数据范围应该小,没准可以暴力。
2. 复杂度分析
- 在随机数据下,珂朵莉树可达到O(nloglogn)
- split,珂朵莉熟的核心操作,在pos位置切割区间,返回以pos为左端点的区间。O(lgn).
- 之后所有动作都要包含两个split
- 更新assign,O(log2n)
- 查询query, O(lgn)+O©,c是区间内节点数,但c应该小所以速度很客观。
- 所有的query都是暴力操作。
- 找到这个区间所有节点,然后再算。
具体复杂度分析见知乎-珂朵莉树的复杂度分析
3. 常见应用
注意,这类题通常正解是线段树,但珂朵莉在特定情况下吊打正解。
这个特定情况是指:操作中存在大量assign碾平操作。
如果操作中不存在assign,请尽量不要用珂朵莉。(但数据弱的话也没准)
- 区间赋值,区间询问最大最小值。
- 区间赋值,区间询问第K小。
- 区间赋值,区间询问求和
- 区间赋值,区间询问n次方和(一般会有mod)。
这些操作全部暴力处理,因为我们认为:- 在随机数据下,大量区间被合并,询问的区间里不会有太多节点。
4. 常用优化
实现时我做了数次优化,提升不大记一下。
- 这里要用SortedKeyList,因为只能比较左端点。
- 从元组变成结构体,这样就可以直接修改。而且在结构体里实现小于运算,则改用SortedList。
- split时,本应删除原节点,加入两个节点。但实际上插入的左节点和原节点只差了一个右边界,因此可以直接修改。
- split生成时调用节点的解包函数,好写还快一点。
二、 模板代码
1. 区间赋值,区间询问最小值
例题: 715. Range 模块
这题当然可以线段树,套一下板就行。
珂朵莉树实测比线段树快3倍。
当然最佳方案是有序列表合并拆分线段。
class ODTNode:
__slots__ = ['l','r','v']
def __init__(self,l,r,v):
self.l,self.r,self.v = l,r,v
def __lt__(self,other):
return self.l<other.l
def jiebao(self):
return self.l,self.r,self.v
class ODT:
def __init__(self,l,r,v):
from sortedcontainers import SortedList
self.tree = SortedList([ODTNode(l,r,v)])
def split(self,pos):
""" 在pos位置切分,返回左边界l为pos的线段下标
"""
tree = self.tree
p = tree.bisect_left(ODTNode(pos,0,0))
if p != len(tree) and tree[p].l == pos:
return p
p -= 1
l,r,v = tree[p].jiebao()
tree[p].r = pos-1
# tree.pop(p)
# tree.add(ODTNode(l,pos-1,v))
tree.add(ODTNode(pos,r,v))
return p+1
def assign(self,l,r,v):
"""
把[l,r]区域全变成val
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
del tree[begin:end]
# for i in range(begin,end):
# tree.pop(begin)
tree.add(ODTNode(l,r,v))
# 以下操作全是暴力,寄希望于这里边元素不多。
def add_interval(self,l,r,val):
"""区间挨个加
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
for i in range(begin,end):
tree[i].v += val
def query_min(self,l,r):
"""
查找x,y区间的最小值
"""
begin = self.split(l)
end = self.split(r+1)
return min(node.v for node in self.tree[begin:end])
def query_kth(self,l,r,k):
"""查找[x,y]区间第k小的数
"""
begin = self.split(l)
end = self.split(r+1)
vs = [(node.v,node.r-node.l+1) for node in self.tree[begin:end]] # v和v的个数,排序
for v,d in sorted(vs): # 挨个往外丢,缩小k
k -= d
if k <= 0:
return v
def query_sum_of_pow(self,l,r,x):
"""求区间x次方的和,一般还得mod,那就要手写快速幂
"""
s = 0
begin = self.split(l)
end = self.split(r+1)
for node in self.tree[begin:end]:
s += (node.v**x) * (node.r-node.l+1)
return s
def query_cnt_v(self,l,r,v):
"""求区间[l,r]里有多少值等于v"""
s = 0
begin = self.split(l)
end = self.split(r+1)
for node in self.tree[begin:end]:
if node.v == v:
s += node.r-node.l + 1
return s
class RangeModule:
def __init__(self):
self.tree = ODT(1,10**9,0)
def addRange(self, left: int, right: int) -> None:
self.tree.assign(left,right-1,1)
def queryRange(self, left: int, right: int) -> bool:
return 1 == self.tree.query_min(left,right-1)
def removeRange(self, left: int, right: int) -> None:
self.tree.assign(left,right-1,0)
2. 区间赋值,区间查询
链接: 729. 我的日程安排表 I
book是给区间全赋值1,区间操作前检查是否这个区间有非0的值,sum或者max都可以。
线段树
也可以,珂朵莉树
快一点
class ODTNode:
__slots__ = ['l','r','v']
def __init__(self,l,r,v):
self.l,self.r,self.v = l,r,v
def __lt__(self,other):
return self.l<other.l
def jiebao(self):
return self.l,self.r,self.v
class ODT:
def __init__(self,l,r,v):
from sortedcontainers import SortedList
self.tree = SortedList([ODTNode(l,r,v)])
def split(self,pos):
""" 在pos位置切分,返回左边界l为pos的线段下标
"""
tree = self.tree
p = tree.bisect_left(ODTNode(pos,0,0))
if p != len(tree) and tree[p].l == pos:
return p
p -= 1
l,r,v = tree[p].jiebao()
tree[p].r = pos-1
# tree.pop(p)
# tree.add(ODTNode(l,pos-1,v))
tree.add(ODTNode(pos,r,v))
return p+1
def assign(self,l,r,v):
"""
把[l,r]区域全变成val
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
del tree[begin:end]
tree.add(ODTNode(l,r,v))
# 以下操作全是暴力,寄希望于这里边元素不多。
def query_max(self,l,r):
"""
查找x,y区间的最小值
"""
begin = self.split(l)
end = self.split(r+1)
return max(node.v for node in self.tree[begin:end])
class MyCalendar:
def __init__(self):
self.odt = ODT(0,10**9,0)
def book(self, start: int, end: int) -> bool:
if self.odt.query_max(start,end-1) == 1:
return False
self.odt.assign(start,end-1,1)
return True
3. 不存在区间赋值,只有区间加,区间查询
链接: 731. 我的日程安排表 II
这题没有assign本不该用珂朵莉做,但这个数据比较弱,确实能过,而且吊打线段树。
由于是求区间内一个超过1的数,因此可以区间查询时,提前退出。
class ODTNode:
__slots__ = ['l','r','v']
def __init__(self,l,r,v):
self.l,self.r,self.v = l,r,v
def __lt__(self,other):
return self.l<other.l
def jiebao(self):
return self.l,self.r,self.v
class ODT:
def __init__(self,l,r,v):
from sortedcontainers import SortedList
self.tree = SortedList([ODTNode(l,r,v)])
def split(self,pos):
""" 在pos位置切分,返回左边界l为pos的线段下标
"""
tree = self.tree
p = tree.bisect_left(ODTNode(pos,0,0))
if p != len(tree) and tree[p].l == pos:
return p
p -= 1
l,r,v = tree[p].jiebao()
tree[p].r = pos-1
# tree.pop(p)
# tree.add(ODTNode(l,pos-1,v))
tree.add(ODTNode(pos,r,v))
return p+1
def assign(self,l,r,v):
"""
把[l,r]区域全变成val
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
del tree[begin:end]
tree.add(ODTNode(l,r,v))
# 以下操作全是暴力,寄希望于这里边元素不多。
def add_interval(self,l,r,val):
"""区间挨个加
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
for i in range(begin,end):
tree[i].v += val
def query_max(self,l,r):
"""
查找x,y区间的最小值
"""
begin = self.split(l)
end = self.split(r+1)
return max(node.v for node in self.tree[begin:end])
def query_has_greater_than(self,l,r,val):
"""
查找x,y区间是否有大于val的数
"""
begin = self.split(l)
end = self.split(r+1)
return any(node.v>val for node in self.tree[begin:end])
class MyCalendarTwo:
def __init__(self):
self.odt = ODT(0,10**9,0)
def book(self, start: int, end: int) -> bool:
if self.odt.query_has_greater_than(start,end-1,1) :
return False
self.odt.add_interval(start,end-1,1)
return True
4. 不存在区间赋值,只有区间加,区间查询
链接: 732. 我的日程安排表 III
同上题,没有assign,且每次区间询问都是整体询问,我本以为不行的,但是试了一下,数据更友好~
这里也可以优化,每次最大值更新都在插入后询问,因此最大值可以储存下来,每次和更新后的区间比较更新即可。
class ODTNode:
__slots__ = ['l','r','v']
def __init__(self,l,r,v):
self.l,self.r,self.v = l,r,v
def __lt__(self,other):
return self.l<other.l
def jiebao(self):
return self.l,self.r,self.v
class ODT:
def __init__(self,l,r,v):
from sortedcontainers import SortedList
self.tree = SortedList([ODTNode(l,r,v)])
def split(self,pos):
""" 在pos位置切分,返回左边界l为pos的线段下标
"""
tree = self.tree
p = tree.bisect_left(ODTNode(pos,0,0))
if p != len(tree) and tree[p].l == pos:
return p
p -= 1
l,r,v = tree[p].jiebao()
tree[p].r = pos-1
# tree.pop(p)
# tree.add(ODTNode(l,pos-1,v))
tree.add(ODTNode(pos,r,v))
return p+1
def assign(self,l,r,v):
"""
把[l,r]区域全变成val
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
del tree[begin:end]
tree.add(ODTNode(l,r,v))
# 以下操作全是暴力,寄希望于这里边元素不多。
def add_interval(self,l,r,val):
"""区间挨个加
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
m = 0
for i in range(begin,end):
tree[i].v += val
m = max(m,tree[i].v)
return m
def query_max(self,l,r):
"""
查找x,y区间的最大值
"""
begin = self.split(l)
end = self.split(r+1)
return max(node.v for node in self.tree[begin:end])
def query_all_max(self,):
"""
查找x,y区间的最大值
"""
begin = self.split(0)
end = self.split(10**9+1)
return max(node.v for node in self.tree[begin:end])
class MyCalendarThree:
def __init__(self):
self.odt = ODT(0,10**9,0)
self.m = 0
def book(self, start: int, end: int) -> int:
self.m = max(self.m,self.odt.add_interval(start,end-1,1))
return self.m
# return self.odt.query_all_max()
5. 存在区间赋值,整体查询
链接: 699. 掉落的方块
优化维护最大值思路类似上题。
但是存在赋值,因此可以珂朵莉,实测效率和线段树差不多。
class ODTNode:
__slots__ = ['l','r','v']
def __init__(self,l,r,v):
self.l,self.r,self.v = l,r,v
def __lt__(self,other):
return self.l<other.l
def jiebao(self):
return self.l,self.r,self.v
class ODT:
def __init__(self,l,r,v):
from sortedcontainers import SortedList
self.tree = SortedList([ODTNode(l,r,v)])
def split(self,pos):
""" 在pos位置切分,返回左边界l为pos的线段下标
"""
tree = self.tree
p = tree.bisect_left(ODTNode(pos,0,0))
if p != len(tree) and tree[p].l == pos:
return p
p -= 1
l,r,v = tree[p].jiebao()
tree[p].r = pos-1
# tree.pop(p)
# tree.add(ODTNode(l,pos-1,v))
tree.add(ODTNode(pos,r,v))
return p+1
def assign(self,l,r,v):
"""
把[l,r]区域全变成val
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
del tree[begin:end]
tree.add(ODTNode(l,r,v))
# 以下操作全是暴力,寄希望于这里边元素不多。
def add_interval(self,l,r,val):
"""区间挨个加
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
for i in range(begin,end):
tree[i].v += val
def query_max(self,l,r):
"""
查找x,y区间的最大值
"""
begin = self.split(l)
end = self.split(r+1)
return max(node.v for node in self.tree[begin:end])
def query_all_max(self,):
"""
查找x,y区间的最大值
"""
begin = self.split(0)
end = self.split(10**9+1)
return max(node.v for node in self.tree[begin:end])
class Solution:
def fallingSquares(self, positions: List[List[int]]) -> List[int]:
odt = ODT(1,10**9,0)
ans = []
m = 0
for l,d in positions:
h = odt.query_max(l,l+d-1)
odt.assign(l,l+d-1,h+d)
m = max(m,h+d)
ans.append(m)
return ans
6. 单点赋值,合并区间
链接: 352. 将数据流变为多个不相交区间
实际上是线段合并。
class ODTNode:
def __init__(self,l,r,v):
self.l,self.r,self.v = l,r,v
def __lt__(self,other):
return self.l<other.l
def jiebao(self):
return self.l,self.r,self.v
# def __str__(self):
# return str((self.l,self.r,self.v))
# def __repr__(self):
# return str((self.l,self.r,self.v))
class ODT:
def __init__(self,l,r,v):
from sortedcontainers import SortedList
self.tree = SortedList([ODTNode(l,r,v)])
def split(self,pos):
""" 在pos位置切分,返回左边界l为pos的线段下标
"""
tree = self.tree
p = tree.bisect_left(ODTNode(pos,0,0))
if p != len(tree) and tree[p].l == pos:
return p
p -= 1
l,r,v = tree[p].jiebao()
tree[p].r = pos-1
# tree.pop(p)
# tree.add(ODTNode(l,pos-1,v))
tree.add(ODTNode(pos,r,v))
return p+1
def assign(self,l,r,v):
"""
把[l,r]区域全变成val
"""
tree = self.tree
begin = self.split(l)
end = self.split(r+1)
del tree[begin:end]
tree.add(ODTNode(l,r,v))
# 以下操作全是暴力,寄希望于这里边元素不多。
def query_all_intervals(self):
tree = self.tree
lines = []
l = r = -1
for node in tree :
if node.v == 0:
if l != -1:
lines.append([l,r])
l = -1
else:
if l == -1:
l = node.l
r = node.r
# for line in lines:
# self.assign(line[0],line[1],1)
# print(self.tree)
return lines
class SummaryRanges:
def __init__(self):
self.odt = ODT(0,10**4+1,0)
def addNum(self, val: int) -> None:
self.odt.assign(val,val,1)
def getIntervals(self) -> List[List[int]]:
return self.odt.query_all_intervals()
三、其他
- 待补充