题面
解法
这道题竟然是贪心……
- 显然,我们可以把那些一定为 0 0 的先去掉,然后我们考虑剩下的那些位置
- 把每一个原来的区间稍作改动一下,保证新的区间的两端满足它可以不为
- 然后我们可以去掉一些包含其他区间的大区间,因为小区间确定了自然大区间的条件也被满足了
- 然后我们考虑一下剩下的区间,这些区间一定满足两两不包含,且左端点和右端点均单调递增
- 因为总共就 k k 个为的位置,那么我们可以先求出最少可以用多少个 1 1 来满足所有的要求,然后剩下的就可以随意放了
- 设表示前 i i 个区间至少需要有多少个使得这些区间的要求被满足, g[i] g [ i ] 表示 i i 以后(包括)的区间至少需要有多少个 1 1 使得这些区间的要求被满足,这个显然可以用贪心解决。求解的时候一定是把 1 1 放在区间的右端点上,反之
- 然后对于每一个区间分别考虑。如果区间长度为 1 1 ,那么显然这个位置必须取。否则,在最优的情况下一定是取区间的右端点,假设为 x x 。然后我们考虑取而不取 x x 是否可行,因为一定比 x−2 x − 2 等位置要优,如果取 x−1 x − 1 不可行,那么 x x 这个位置一定会被取
- 二分最近的且不能覆盖到 x−1 x − 1 的两个区间 l,r l , r ,如果 f[l]+g[r]+1>K f [ l ] + g [ r ] + 1 > K ,那么说明 x x 这个位置使一定不能不取的,所以即为一个答案
- 时间复杂度: O(mlogm) O ( m log m )
代码
#include <bits/stdc++.h>
#define N 100010
using namespace std;
template <typename node> void chkmax(node &x, node y) {x = max(x, y);}
template <typename node> void chkmin(node &x, node y) {x = min(x, y);}
template <typename node> void read(node &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
struct Node {
int l, r, x;
bool operator < (const Node &a) const {
if (l == a.l) return r > a.r;
return l < a.l;
}
} a[N], b[N], c[N];
int f[N], g[N], s[N], id[N], pre[N], nxt[N], num[N];
int main() {
int n, k, m; read(n), read(k), read(m);
for (int i = 1; i <= m; i++) {
read(a[i].l), read(a[i].r), read(a[i].x);
if (!a[i].x) s[a[i].l]++, s[a[i].r + 1]--;
}
for (int i = 1; i <= n; i++) s[i] += s[i - 1];
int tot = 0, q = 0;
for (int i = 1; i <= n; i++)
if (!s[i]) pre[i] = nxt[i] = id[i] = ++tot, num[tot] = i;
if (tot == k) {
for (int i = 1; i <= tot; i++) cout << num[i] << "\n";
return 0;
}
for (int i = 1; i <= n; i++)
if (!pre[i]) pre[i] = pre[i - 1];
for (int i = n; i; i--)
if (!nxt[i]) nxt[i] = nxt[i + 1];
for (int i = 1; i <= m; i++) {
if (!a[i].x) continue;
int l = nxt[a[i].l], r = pre[a[i].r];
if (l <= r) b[++q] = (Node) {l, r, 1};
}
sort(b + 1, b + q + 1); int cnt = 0;
for (int i = 1; i <= q; i++) {
while (cnt && c[cnt].r >= b[i].r) cnt--;
c[++cnt] = b[i];
}
int mx = 0;
for (int i = 1; i <= cnt; i++)
if (mx < c[i].l) f[i] = f[i - 1] + 1, mx = c[i].r;
else f[i] = f[i - 1];
mx = INT_MAX;
for (int i = cnt; i; i--)
if (mx > c[i].r) g[i] = g[i + 1] + 1, mx = c[i].l;
else g[i] = g[i + 1];
bool flag = false;
for (int i = 1; i <= cnt; i++) {
if (f[i] != f[i - 1] + 1) continue;
if (c[i].l == c[i].r) {
cout << num[c[i].l] << "\n";
flag = true; continue;
}
int ll = 0, rr = cnt + 1, l = 1, r = i - 1;
while (l <= r) {
int mid = (l + r) >> 1;
if (c[mid].r < c[i].r - 1) ll = mid, l = mid + 1;
else r = mid - 1;
}
l = i + 1, r = cnt;
while (l <= r) {
int mid = (l + r) >> 1;
if (c[mid].l > c[i].r - 1) rr = mid, r = mid - 1;
else l = mid + 1;
}
if (f[ll] + g[rr] + 1 > k) cout << num[c[i].r] << "\n", flag = true;
}
if (!flag) cout << "-1\n";
return 0;
}