线段树
线段树是一棵满二叉树,所以我们用一维数组存整棵树(编号为x的节点,父节点为x >> 1
,左儿子为x << 1
,右儿子为x << 1 | 1
),n个点最多会有4n-1个位置,我们一般开4n的空间。
线段树一般分为五个函数:pushup()
(用子节点更新父节点)、pushdown()
(用父节点的懒标记更新子节点)、build()
(创建线段树)、update()
(区间修改)、query()
(求解)
typedef long long LL;
struct Node
{
int l, r;
LL add; //当前区间的所有儿子加上add
// TODO: 需要维护的信息
} tr[N * 4];
void pushup(int u)
{
// TODO: 利用左右儿子信息维护当前节点的信息
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add)
{
left.add += root.add, left.sum += (left.r - left.l + 1) * root.add;
right.add += root.add, right.sum += (right.r - right.l + 1) * root.add;
root.add = 0;
}
// TODO: 将懒标记下传
}
void build(int u, int l, int r) //将区间[l,r]初始化为线段树
{
if (l == r)
tr[u] = {l, r};
else
{
tr[u] = {l, r};
int mid = l + ((r - l) >> 1);
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void update(int u, int l, int r, int d)
{
if (tr[u].l >= l && tr[u].r <= r)
{
// TODO: 修改区间
}
else
{
pushdown(u);
int mid = tr[u].l + ((tr[u].r - tr[u].l) >> 1);
if (l <= mid)
update(u << 1, l, r, d);
if (r > mid)
update(u << 1 | 1, l, r, d);
pushup(u);
}
}
int query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
return; // TODO 需要补充返回值
}
else
{
pushdown(u);
int mid = tr[u].l + ((tr[u].r - tr[u].l) >> 1);
int res = 0;
if (l <= mid)
res = query(u << 1, l, r);
if (r > mid)
res += query(u << 1 | 1, l, r);
return res;
}
}
扫描线
#include <bits/stdc++.h>
using namespace std;
const int N = 100010;
int n;
struct Segment //线段
{
double x, y1, y2;
int k;
bool operator<(const Segment &t) const
{
return x < t.x;
}
} seg[N * 2];
struct Node
{
int l, r;
int cnt; //当前区间整个被覆盖的次数
double len; //不考虑祖先节点cnt的前提下cnt>0的区间总长
} tr[N * 8];
vector<double> ys; //离散化
int find(double y)
{
return lower_bound(ys.begin(), ys.end(), y) - ys.begin();
}
void pushup(int u)
{
if (tr[u].cnt)
tr[u].len = ys[tr[u].r + 1] - ys[tr[u].l];
else if (tr[u].l != tr[u].r)
{
tr[u].len = tr[u << 1].len + tr[u << 1 | 1].len;
}
else
tr[u].len = 0;
}
void build(int u, int l, int r)
{
tr[u] = {l, r, 0, 0};
if (l != r)
{
int mid = l + ((r-l) >> 1);
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
}
void modify(int u, int l, int r, int k)
{
if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].cnt += k;
pushup(u);
}
else
{
int mid = tr[u].l + ((tr[u].r-tr[u].l) >> 1);
if (l <= mid)
modify(u << 1, l, r, k);
if (r > mid)
modify(u << 1 | 1, l, r, k);
pushup(u);
}
}
int main()
{
ys.clear();
for (int i = 0, j = 0; i < n; i++)
{
double x1, y1, x2, y2;
scanf("%lf%lf%lf%lf", &x1, &y1, &x2, &y2);
seg[j++] = {x1, y1, y2, 1};
seg[j++] = {x2, y1, y2, -1};
ys.push_back(y1), ys.push_back(y2);
}
sort(ys.begin(), ys.end());
ys.erase(unique(ys.begin(), ys.end()), ys.end()); //判重
build(1, 0, ys.size() - 2);
sort(seg, seg + n * 2);
double res = 0;
for (int i = 0; i < n * 2; i++)
{
if (i > 0)
res += tr[1].len * (seg[i].x - seg[i - 1].x);
modify(1, find(seg[i].y1), find(seg[i].y2) - 1, seg[i].k);
}
return 0;
}