题目地址:
https://leetcode.com/problems/range-module/
要求设计一个数据结构,可以做如下操作:
1、跟踪一个区间
[
l
,
r
)
[l,r)
[l,r)里的数;
2、放弃跟踪一个区间
[
l
,
r
)
[l,r)
[l,r)里的数;
3、查询区间
[
l
,
r
)
[l,r)
[l,r),返回是否这个区间里的所有的数都被跟踪了。
同一个数被跟踪多次和跟踪一次是等效的。
法1:线段树。这一题是经典的带懒标记的线段树的应用。我们思考懒标记应当如何打。假设递归到了当前节点,如果当前节点所维护的区间被完全包含了,那么跟踪,可以打标记
1
1
1,不跟踪,可以打标记
−
1
-1
−1,此外我们还需要判断未打标记的情形,那么未打标记,就规定标记是
0
0
0。当然每个节点我们需要开一个变量mark记录当前区间是否被跟踪。这样三个操作如下进行:
1、跟踪区间
[
l
,
r
)
[l,r)
[l,r),等价于跟踪
[
l
,
r
−
1
]
[l,r-1]
[l,r−1],如果当前区间完全包含了,则置mark为true,并打上懒标记
1
1
1并且不下传;否则pushdown将lazy下传,然后递归处理左右孩子,再pushup更新。
2、与1类似;
3、如果当前区间被完全包含了,则直接返回mark,否则pushdown将懒标记下传,然后递归求解左右子树,只有左右子树都mark了才返回true。
注意:这题的查询范围非常大,是 [ 1 , 1 0 9 ] [1,10^9] [1,109],如果用数组写线段树会导致空间消耗太多,但是这题又无法做离散化(因为查询是在线的),所以只能用二叉树的链表写法来做。当当前区间没有被完全覆盖的时候,做split,new出左右孩子并pushdown,然后再做递归。代码如下:
public class RangeModule {
class SegTree {
class Node {
int l, r;
// mark标记[l, r]这个区间是否整个被跟踪
boolean mark;
// lazy取0表示没有懒标记,1表示跟踪的懒标记,-1表示不跟踪的懒标记
int lazy;
Node left, right;
public Node(int l, int r) {
this.l = l;
this.r = r;
}
}
private Node tr;
public SegTree(int l, int r) {
tr = new Node(l, r);
}
private void pushup(Node u) {
u.mark = u.left.mark && u.right.mark;
}
private void pushdown(Node u) {
if (u.lazy != 0) {
u.left.mark = u.right.mark = u.lazy == 1;
u.left.lazy = u.right.lazy = u.lazy;
// 取消当前懒标记
u.lazy = 0;
}
}
public void split(Node u) {
if (u.left != null && u.right != null) {
return;
}
int m = u.l + (u.r - u.l >> 1);
u.left = new Node(u.l, m);
u.right = new Node(m + 1, u.r);
}
public void add(int l, int r) {
add(tr, l, r);
}
private void add(Node u, int l, int r) {
if (l <= u.l && u.r <= r) {
u.mark = true;
u.lazy = 1;
return;
}
split(u);
pushdown(u);
int m = u.l + (u.r - u.l >> 1);
if (l <= m) {
add(u.left, l, r);
}
if (m + 1 <= r) {
add(u.right, l, r);
}
pushup(u);
}
public void remove(int l, int r) {
remove(tr, l, r);
}
private void remove(Node u, int l, int r) {
if (l <= u.l && u.r <= r) {
u.mark = false;
u.lazy = -1;
return;
}
split(u);
pushdown(u);
int m = u.l + (u.r - u.l >> 1);
if (l <= m) {
remove(u.left, l, r);
}
if (m + 1 <= r) {
remove(u.right, l, r);
}
pushup(u);
}
public boolean query(int l, int r) {
return query(tr, l, r);
}
private boolean query(Node u, int l, int r) {
if (l <= u.l && u.r <= r) {
return u.mark;
}
split(u);
pushdown(u);
int m = u.l + (u.r - u.l >> 1);
if (l <= m && !query(u.left, l, r)){
return false;
}
if (m + 1 <= r && !query(u.right, l, r)) {
return false;
}
return true;
}
}
final int N = (int) 1e9;
private SegTree segTree;
public RangeModule() {
segTree = new SegTree(1, N);
}
public void addRange(int left, int right) {
segTree.add(left, right - 1);
}
public boolean queryRange(int left, int right) {
return segTree.query(left, right - 1);
}
public void removeRange(int left, int right) {
segTree.remove(left, right - 1);
}
}
初始化时间复杂度 O ( 1 ) O(1) O(1),其余操作时间复杂度 O ( log r ) O(\log r) O(logr), r r r是查询范围,空间 O ( n ) O(n) O(n), n n n是被查询过的最大范围。
法2:平衡树。用平衡树存储所有被跟踪的区间是比较方便的办法。我们可以维护一个TreeSet专门存已经被跟踪的不相交区间(左闭右开),规定区间之间的比较函数是比较其右端点(右端点相同的话是可以合并的,这样就能保证不存在相同元素)。三个操作如下进行:
1、询问
[
l
,
r
)
[l,r)
[l,r),先找到右端点大于等于
r
r
r的第一个区间,如果不存在则返回false,如果存在但是其左端点大于
l
l
l则也返回false。否则返回true;
2、跟踪
[
l
,
r
)
[l,r)
[l,r),先找到第一个右端点大于等于
l
l
l的区间,然后从该区间向后找后继,直到找到第一个左端点大于
r
r
r的区间为止,中间遍历过的区间(除了那个第一个左端点大于
r
r
r的区间)都要被合并,合并的结果是这些区间左端点的最小值作为左端点,右端点的最大值作为右端点;
3、取消跟踪
[
l
,
r
)
[l,r)
[l,r),先找到第一个右端点大于
l
l
l的区间(右端点等于
l
l
l的区间不会被删,因为这是左闭右开区间),那么这个区间
[
l
′
,
r
′
)
[l',r')
[l′,r′)的左端点如果小于
l
l
l,就会产生一个新区间
[
l
′
,
l
)
[l',l)
[l′,l);接着找这个区间的后继,依次删除(
[
l
′
,
r
′
)
[l',r')
[l′,r′)也要被删),直到找到第一个左端点大于等于
r
r
r的区间(这个区间不用删),那么最后一个不用删的区间如果其右端点
r
′
>
r
r'>r
r′>r,那么就会产生一个新区间
[
r
,
r
‘
)
[r,r‘)
[r,r‘)。综上,只需要将中间遍历的区间该删的删,然后把可能产生的新区间加回去即可。
代码如下:
import java.util.Iterator;
import java.util.TreeSet;
public class RangeModule {
private TreeSet<int[]> treeSet;
public RangeModule() {
treeSet = new TreeSet<>((a, b) -> Integer.compare(a[1], b[1]));
}
public void addRange(int left, int right) {
Iterator<int[]> iter = treeSet.tailSet(new int[]{0, left}, true).iterator();
while (iter.hasNext()) {
int[] temp = iter.next();
if (temp[0] > right) {
break;
}
left = Math.min(left, temp[0]);
right = Math.max(right, temp[1]);
iter.remove();
}
treeSet.add(new int[]{left, right});
}
public boolean queryRange(int left, int right) {
int[] ceiling = treeSet.ceiling(new int[]{0, right});
return ceiling != null && ceiling[0] <= left;
}
public void removeRange(int left, int right) {
Iterator<int[]> iter = treeSet.tailSet(new int[]{0, left}, false).iterator();
int[] prev = null, next = null;
while (iter.hasNext()) {
int[] temp = iter.next();
if (temp[0] >= right) {
break;
}
if (temp[0] < left) {
prev = new int[]{temp[0], left};
}
if (right < temp[1]) {
next = new int[]{right, temp[1]};
}
iter.remove();
}
if (prev != null) {
treeSet.add(prev);
}
if (next != null) {
treeSet.add(next);
}
}
}
初始化时间复杂度 O ( 1 ) O(1) O(1),add和remove时间 O ( n ) O(n) O(n)( n n n是已经跟踪了多少个不相交区间),query时间 O ( log n ) O(\log n) O(logn)。空间 O ( n ) O(n) O(n)。
C++:
class RangeModule {
public:
struct Node {
int lc, rc;
int l, r;
// 当前区间track了没
bool mark;
// 懒标记,0表示无标记,1表示track了,-1表示未track
int tag;
};
vector<Node> tr;
int new_node(int l, int r) {
tr.push_back({0, 0, l, r});
return tr.size() - 1;
}
void pushup(int u) { tr[u].mark = tr[tr[u].lc].mark && tr[tr[u].rc].mark; }
void pushdown(int u) {
Node &lc = tr[tr[u].lc], &rc = tr[tr[u].rc];
if (tr[u].tag) {
lc.mark = rc.mark = tr[u].tag == 1;
lc.tag = rc.tag = tr[u].tag;
tr[u].tag = 0;
}
}
void split(int u) {
if (tr[u].lc) return;
int mid = tr[u].l + tr[u].r >> 1;
tr[u].lc = new_node(tr[u].l, mid);
tr[u].rc = new_node(mid + 1, tr[u].r);
}
void add(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) {
tr[u].mark = true;
tr[u].tag = 1;
return;
}
split(u);
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) add(tr[u].lc, l, r);
if (r > mid) add(tr[u].rc, l, r);
pushup(u);
}
void remove(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) {
tr[u].mark = false;
tr[u].tag = -1;
return;
}
split(u);
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) remove(tr[u].lc, l, r);
if (r > mid) remove(tr[u].rc, l, r);
pushup(u);
}
bool query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) {
return tr[u].mark;
}
split(u);
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid && !query(tr[u].lc, l, r)) return false;
if (r > mid && !query(tr[u].rc, l, r)) return false;
return true;
}
RangeModule() {
tr.push_back({});
new_node(1, (int)1e9);
}
void addRange(int left, int right) { add(1, left, right - 1); }
bool queryRange(int left, int right) { return query(1, left, right - 1); }
void removeRange(int left, int right) { remove(1, left, right - 1); }
};
时空复杂度一样。