题解转自:ITNXD
链接:https://www.acwing.com/solution/content/18676/
1. 题目来源
美团 2019 笔试题
2. 题目说明
3. 题目解析
思路如下:
- 按行进行区间合并
- 按列进行区间合并
- 判断行列的重叠部分减去多加的
存储结构:
对于行和列我们要存储三个值,分别为区间左右或上下端点以及一个标识表示那一行或那一列。
- 行或列的标识:行或列相同的哪一个数字
- 左右端点:不相同的一组中较小值和较大值
对于排序:
可以重载小于号,也可以直接在外面写一个 cmp
函数传入 sort
进行比较!
优先按照行货列的标号从下到大,然后就是按照左端点和右端点!
区间合并:
-
保证在同一行
k == seg.k
- 区间无法合并
ed < seg.l
,则进行上一个区间的保存,同时更新左右端点 - 可以合并,则进行合并,更新右端点
- 区间无法合并
-
不在同一行,直接将上一个区间保存,同时更新新的区间的左右端点以及行或列的标识k
-
记得最后一个区间的保存,for循环无法处理最后一个区间的保存
-
最后将保存的容器还原到原始的
segs
,通过引用传回去!
注意: 对于每次合并当然都得判断起始点是否是-2e9
计数:
每次保存就是一个新区间,将区间长度累加一下即可,cnt += ed - st + 1
去重:
即只要是横着的和竖着的有相交就去减掉一个重合点,判断条件(画个图就知道了):row.k >= col.l && row.k <= col.r && row.r >= col.k && row.l <= col.k
时间复杂度:
- 区间合并: O ( n ) O(n) O(n)
- 快排: O ( l o g n ) O(logn) O(logn)
- 去重: O ( n 2 ) O(n^2) O(n2)
- 总复杂度为: O ( n 2 ) O(n^2) O(n2)
参见代码如下:
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long LL;
struct Node {
int k, l, r;
bool operator<(const Node& w) const { // 按行列、左端点、右端点优先级排序
if (k != w.k) return k < w.k;
else if (l!= w.l) return l < w.l;
else return r < w.r;
}
};
int n;
LL cnt;
vector<Node> cols, rows;
void merge(vector<Node> &segs) {
sort(segs.begin(), segs.end());
vector<Node> res;
int st = -2e9, ed = -2e9, k = -2e9;
for (auto seg:segs) {
if (seg.k == k) { // 保证在同一行/列
if (ed < seg.l) { // 无法区间合并,即发现新区间
if (st != -2e9) { // 若不是初始区间
res.push_back({k, st, ed});
cnt += ed - st + 1;
}
st = seg.l, ed = seg.r; // 无法合并,更新区间左右端点
} else {
ed = max(ed, seg.r); // 合并区间,更新右端点
}
} else { // 不在同一行/列,直接就是个新区间,不能进行区间合并
if (st != -2e9) { // 若不为初始区间
res.push_back({k, st, ed});
cnt += ed - st + 1;
}
k = seg.k, st = seg.l, ed = seg.r; // 更新成这个新区间的左右端点和行/列k值
}
}
// 最后一个区间一个区间for循环无法处理,自行保存,且需要防止为空区间输入
if (st != -2e9) res.push_back({k, st, ed}), cnt += ed - st + 1;
segs = res;
}
int main() {
cin >> n;
for (int i = 0; i < n; ++i) {
int x1, y1, x2, y2;
cin >> x1 >> y1 >> x2 >> y2;
if (x1 == x2) cols.push_back({x1, min(y1, y2), max(y1, y2)});
else rows.push_back({y1, min(x1, x2), max(x1, x2)});
}
merge(cols), merge(rows);
for (auto col:cols)
for (auto row:rows)
if (row.k >= col.l && row.k <= col.r && row.r >= col.k && row.l <= col.k) cnt --;
cout << cnt << endl;
return 0;
}