NOIP2023 T4 天天爱打卡
24 年补题。去年怎么我连蓝题都做不出来? 昨天做这题调了很久,后面看题解发现我的线段树写法和大多数人不一样。他们把区间修改都整上了?为什么我只需要单点修改+区间查询?不过我在合并左右子的信息的时候,需要的写法就不是很简单了。
暑假以来,线段树维护复杂信息写得挺多,一般这种时候都需要在 query()
的时候返回整个线段,重复利用 merge()
函数。
题意简化
题目背景:上来就致敬经典题。
[ 1 , n ] [1, n] [1,n] 的值域, m m m 个闭区间,第 i i i( 1 ≤ i ≤ m 1\le i \le m 1≤i≤m)个区间右端点为 x i x_i xi,长度为 y i y_i yi,有权值 v i v_i vi。
要求选择若干个区间,并且这些区间所覆盖的值中,不能出现连续长度超过 k k k 的情况。即不能存在 1 ≤ x ≤ n − k 1\le x\le n-k 1≤x≤n−k,第 x x x 到第 x + k x+k x+k 个位置都被覆盖。
已知 d d d,设所选的区间权值和为 s s s,这些区间覆盖了 p p p 个位置,求 s − d ⋅ p s-d\cdot p s−d⋅p 的最大值。
t t t 组数据,对于所有测试数据:
1 ≤ t ≤ 10 1\le t\le 10 1≤t≤10, 1 ≤ k ≤ n ≤ 1 0 9 1\le k\le n\le 10^9 1≤k≤n≤109, 1 ≤ m ≤ 1 0 5 1\le m\le 10^5 1≤m≤105, 1 ≤ l i ≤ r i ≤ n 1\le l_i\le r_i\le n 1≤li≤ri≤n, 1 ≤ d , v i ≤ 1 0 9 1\le d,v_i\le 10^9 1≤d,vi≤109。
思路
先看一眼数据范围:
多测要先清空。估计需要
O
(
m
log
m
)
O(m \log m)
O(mlogm) 的算法,
1
≤
d
,
v
i
≤
1
0
9
1\le d,v_i\le 10^9
1≤d,vi≤109,显然算答案的时候需要开 long long
。
没给我保证
y
i
≤
k
y_i \le k
yi≤k,那么
y
i
>
k
y_i > k
yi>k 的情况读入的时候就不要存了,令 --i, --m;
,这个区间显然不能选。
对合法的区间按照右端点进行排序,然后按这个顺序进行处理。在处理的过程中,显然右端点是递增的。
我们维护两个东西:(因为是拆分开来讲的,有不明白随时跳到完整代码那里看)
线段树
将所有的左端点 l i l_i li 进行离散化,在此基础上建立线段树。
记离散化的数组为 hdld[]
(handled),长度为 cnt
。
线段树结构体
struct Segment {
int l, r;
long long sum, max;
} s[MAXM*4];
sum
维护:当前,已处理的 hdld[l] <= l <= hdld[r] 的区间,这些区间的权值之和。
max
维护:当前,不考虑区间右端点位置,选择一个合适的 l (hdld[this->l] <= l <= hdld[this->r]
) 的位置出发,连续一直到 hdld[r]
,这些位置都被覆盖,并且第 l-1
个位置不被覆盖,然后所能够达到的最大值。
为什么这样做?先看看 merge
操作。
merge
Segment merge(Segment lc, Segment rc) {
Segment x;
x.l = lc.l, x.r = rc.r;
x.sum = lc.sum + rc.sum;
// long long lmax = lc.max + rc.sum - d * (hdld[rc.r] - hdld[rc.l] + 1); WA
long long lmax = lc.max + rc.sum - d * (hdld[rc.r] - hdld[lc.r]); // OK
x.max = max(lmax, rc.max);
return x;
}
看到第 6~7 行代码,这个是关键。
如果要从左边的某个位置出发,它在左边的那一部分一定是已经处理好了的。
既然要延伸到右边那一段去,就一定会加上右边的所有权值之和。而右边是要被全部覆盖的,所以还要再减去新增覆盖的那一部分 (hdld[rc.r] - hdld[lc.r])*d
,最后就得到了那一行式子。
注意我注释掉的那一部分:
新增覆盖的部分是 hdld[rc.r]-hdld[lc.r]
而非 hdld[rc.r]-hdld[rc.l]+1
,因为这是离散化后的,hdld[rc.l]
和 hdld[lc.r]
可能不连续。
update
void update(int cur, int p, long long val) {
if (s[cur].l == s[cur].r) {
// sum 是不计损的,max 是计损的
s[cur].sum += val;
// s[cur].max += val; WA
// printf("%lld:", s[cur].max);
if (!s[cur].max) {
int idx = lower_bound(rt+1, rt+tot+1, hdld[p]-1) - rt - 1;
s[cur].max = val + ans[idx] - d;
}
// 取 max 应该是针对不同的点取 max,对同一点还是累计
else s[cur].max += val; // 前面的不能重复加
return;
}
int mid = (s[cur].l + s[cur].r) >> 1;
if (p <= mid) update(cur*2, p, val);
else update(cur*2+1, p, val);
pushup(cur);
}
初始化的时候 s[cur].sum = s[cur].max = 0
。
p
就是要修改的左端点的位置,val
是当前这个区间的权值。需要讨论是不是第一次处理这个位置。
首先,ans[i]
表示截止到 rt[i]
,能取到的最大值,是这个 dp 过程的另一核心,后面会讲。rt
记录相应的右端点。
int idx = lower_bound(rt+1, rt+tot+1, hdld[p]-1) - rt - 1;
查询之前的右端点小于 hdld[p]-1 的最大值。因为是按照右端点从小到大的顺序处理的,左端点一定不大于右端点。所以这个最大值当前是已经确定了的。
如果是第一次处理的话,s[cur].max = val + ans[idx] - d;
,不理解的话向上翻到 Segment::max
的定义。就算只选这里这一个位置,这一个位置的损耗也需要在 max
上面减掉。
如果不是第一次,直接把 val
累加上去就行了。反正 ans[idx]-d
不会再发生改变。
ans数组
前面已经提到,ans[i]
表示截止到 rt[i]
,能取到的最大值。
具体来说,它表示的是:大于 rt[i]
的先不管它,把当前情况和 ans[i-1]
比较,取较大者。
当前情况指的是:从 rt[i]
出发,向左连续不超过 k
个单位长度的区间,断开一个再往前,这里面的最大值。
那么线段树维护的东西就用上了。
请读者思考为什么线段树维护的时候,不需要考虑维护该区间的右端点的问题?
因为我们把右端点先排序了。之前处理过的那些区间,右端点小于等于当前,从它的左端点出发连续到这里,一定会也覆盖掉它的右端点。
注意到我们查询的内容,区间右端点只能到 hdld[maxL]
,而 (hdld[maxL], task[i].r] 这个区间我们必须覆盖它,所以还需要扣除 d * (task[i].r-hdld[maxL])
,才是真正的“当前情况”。
void solve() {
int maxL = 0;
tot = 0;
for (int i = 1; i <= m; ++i) {
update(1, getRk(task[i].l), task[i].v);
if (i < m && task[i].r == task[i+1].r) continue;
maxL = max(maxL, task[i].l); // r 相同,l 从小到大排序,这样维护 maxL 没问题
rt[++tot] = task[i].r; // 记录对应的右端点位置,供 update() 中二分查找使用
ans[tot] = query(1, getRk(task[i].r-k+1), getRk(maxL)).max;
ans[tot] -= d * (task[i].r - maxL);
ans[tot] = max(ans[tot], ans[tot-1]); // 当前和以往的最大值比较
}
}
难点集中在上面三个函数里面,其他的都是线段树+排序离散化的模板,还不会的自己补去。
代码
#include <cstdio>
#include <algorithm>
#define MAXM 100005
using namespace std;
int n, m, k;
long long d;
int hdld[MAXM], cnt;
struct Seq {
int l, r, v;
} task[MAXM];
struct Segment {
int l, r;
long long sum, max;
} s[MAXM*4];
int rt[MAXM];
long long ans[MAXM];
int tot;
// Seq
bool cmp(Seq x, Seq y) {
if (x.r == y.r) return x.l < y.l;
return x.r < y.r;
}
// Segment Tree
Segment merge(Segment lc, Segment rc) {
Segment x;
x.l = lc.l, x.r = rc.r;
x.sum = lc.sum + rc.sum;
// long long lmax = lc.max + rc.sum - d * (hdld[rc.r] - hdld[rc.l] + 1); WA
long long lmax = lc.max + rc.sum - d * (hdld[rc.r] - hdld[lc.r]); // OK
x.max = max(lmax, rc.max);
return x;
}
inline void pushup(int cur) {
s[cur] = merge(s[cur*2], s[cur*2+1]);
}
void build(int cur, int l, int r) {
s[cur].l = l, s[cur].r = r;
if (l == r) {
s[cur].sum = s[cur].max = 0;
return;
}
int mid = (s[cur].l + s[cur].r) >> 1;
build(cur*2, l, mid);
build(cur*2+1, mid+1, r);
pushup(cur);
}
Segment query(int cur, int l, int r) {
if (l <= s[cur].l && r >= s[cur].r) {
return s[cur];
}
int mid = (s[cur].l + s[cur].r) >> 1;
if (r <= mid) return query(cur*2, l, r);
if (l > mid) return query(cur*2+1, l, r);
return merge(query(cur*2, l, r), query(cur*2+1, l, r));
}
void update(int cur, int p, long long val) {
if (s[cur].l == s[cur].r) {
// sum 是不计损的,max 是计损的
s[cur].sum += val;
// s[cur].max += v; WA
// printf("%lld:", s[cur].max);
if (!s[cur].max) {
int idx = lower_bound(rt+1, rt+tot+1, hdld[p]-1) - rt - 1;
s[cur].max = val + ans[idx] - d;
}
// 取 max 应该是针对不同的点取 max,对同一点还是累计
else s[cur].max += val; // 前面的不能重复加
return;
}
int mid = (s[cur].l + s[cur].r) >> 1;
if (p <= mid) update(cur*2, p, val);
else update(cur*2+1, p, val);
pushup(cur);
}
inline int getRk(int x) {
return lower_bound(hdld+1, hdld+cnt+1, x) - hdld;
}
void solve() {
int maxL = 0;
tot = 0;
for (int i = 1; i <= m; ++i) {
update(1, getRk(task[i].l), task[i].v);
if (i < m && task[i].r == task[i+1].r) continue;
maxL = max(maxL, task[i].l);
rt[++tot] = task[i].r;
ans[tot] = query(1, getRk(task[i].r-k+1), getRk(maxL)).max;
ans[tot] -= d * (task[i].r - maxL);
ans[tot] = max(ans[tot], ans[tot-1]);
}
}
int main() {
freopen("run.in", "r", stdin);
freopen("run.out", "w", stdout);
int C, T;
scanf("%d%d", &C, &T);
while (T--) {
scanf("%d%d%d%lld", &n, &m, &k, &d);
for (int i = 1; i <= m; ++i) {
int x, y, v;
scanf("%d%d%d", &x, &y, &v);
if (y > k) {
--i, --m;
continue;
}
task[i].l = x - y + 1, task[i].r = x;
task[i].v = v;
hdld[i] = task[i].l;
}
sort(task+1, task+m+1, cmp);
sort(hdld+1, hdld+m+1);
cnt = unique(hdld+1, hdld+m+1) - hdld - 1;
build(1, 1, cnt);
solve();
printf("%lld\n", ans[tot]);
}
return 0;
}