我要成为线段树领域大神
题目大意
有 n n n 个数,按照一下条件分组:
- 每组至少有 k k k 个数。
- 每组的极差不超过 d d d 。
问这 n n n 个数是否都可以恰好被分入一个组中。
分析
首先有一个很简单的操作:这 n n n 个数可以排序以后进行分组。
随后我们可以想到一个很简单的 D P DP DP :记 d p i dp_i dpi 表示前 i i i 个数是否可以被恰好分到一个组里。则有
d p i = { 1 ( i − j ≥ k , a [ i ] − a [ j + 1 ] ≤ d , d p [ j ] = 1 ) 0 o t h e r w i s e dp_i=\left\{ \begin{aligned} 1 & \quad (i-j \ge k,a[i]-a[j+1] \le d,dp[j]=1)\\ 0 & \quad otherwise\\ \end{aligned} \right . dpi={10(i−j≥k,a[i]−a[j+1]≤d,dp[j]=1)otherwise
初始化则设
d
p
0
=
1
dp_0=1
dp0=1 。
于是我们便得到了一个
O
(
n
2
)
O(n^2)
O(n2) 的写法。
dp[0] = 1;
for (ll i = 1;i <= n;i++) {
ll j = i - k;
for (;j >= 0;j--) {
if (a[i] - a[j + 1] > d) {
break;
}
if (dp[j]) {
dp[i] = 1;
break;
}
}
}
然而很明显, n 2 n^2 n2 冲不过 n ≤ 5 × 1 0 5 n \le 5 \times 10^5 n≤5×105 。所以我们考虑怎么优化这个过程。
优化
我们观察 d p dp dp 数组的转移过程,可以发现一个很明显的贪心:对于每一个 d p i dp_i dpi ,转移的 d p j dp_j dpj 一定是 所有满足条件的 j j j 中的最大的 j j j 。
所以可以我们考虑线段树,对于每一个 i i i 直接求出 $[1,i-k] $ 中满足条件的最大的 j j j ,再判断 a [ i ] − a [ j ] ≤ d a[i]-a[j] \le d a[i]−a[j]≤d 并更新答案以及线段树即可。
Code
喜提 O ( n l o g n ) O(nlogn) O(nlogn) 最劣解
#include<bits/stdc++.h>
#define ll long long
#define ma 620000
using namespace std;
const ll inf = 1e18;
const ll mod = 998244353;
//---------------------------------------------
ll read() {
ll x = 0, f = 1;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == '-') f = -1;
ch = getchar();
}
while (isdigit(ch)) {
x = x * 10 + ch - '0';
ch = getchar();
}
return x * f;
}
//---------------------------------------------
ll n, k, d;
ll a[ma];
ll dp[ma];
//---------------------------------------------
struct seg {
ll maxn;
}t[ma << 2];
#define ls p<<1
#define rs p<<1|1
void pushup(ll p) {
t[p].maxn = max(t[ls].maxn, t[rs].maxn);
}
void update(ll p, ll l, ll r, ll x, ll k) {
if (l == r) {
t[p].maxn = k;
return;
}
ll mid = (l + r) >> 1;
if (x <= mid) update(ls, l, mid, x, k);
else update(rs, mid + 1, r, x, k);
pushup(p);
}
ll ask(ll p, ll l, ll r, ll x, ll y) {
if (x <= l && r <= y) return t[p].maxn;
ll mid = (l + r) >> 1;
ll ans = 0;
if (x <= mid) ans = max(ans, ask(ls, l, mid, x, y));
if (y > mid) ans = max(ans, ask(rs, mid + 1, r, x, y));
return ans;
}
//---------------------------------------------
int main() {
// freopen("debug.in", "r", stdin);
// freopen("debug.out", "w", stdout);
n = read(), k = read(), d = read();
for (ll i = 1;i <= n;i++) a[i] = read();
sort(a + 1, a + 1 + n);
dp[0] = 1;// 初始化
for (ll i = 1;i <= n;i++) {
ll pos = ask(1, 1, n, 1, max(i - k, 1ll));// 求满足条件的j
if (a[i] - a[pos + 1] <= d && i - pos >= k) {// 判断是否满足条件
dp[i] = 1;
update(1, 1, n, i, i);// 更新
}
}
cout << (dp[n] ? "YES" : "NO") << endl;
return 0;
}