【思路要点】
- 我们先来考虑这个问题在序列上的形式。
- 我们要将序列分成 k k k 段,使得每一段所有数到其中位数的距离之和最小。
- 由于代价函数 w w w 满足四边形不等式 w ( i , k ) + w ( j , l ) ≤ w ( i , l ) + w ( j , k ) ( i ≤ j ≤ k ≤ l ) w(i,k)+w(j,l)≤w(i,l)+w(j,k)\ (i≤j≤k≤l) w(i,k)+w(j,l)≤w(i,l)+w(j,k) (i≤j≤k≤l) ,因此该 D P DP DP 的决策点满足决策单调性。
- 那么利用决策调性进行分治,这个问题在序列上的形式就有一种简单的 O ( N K L o g N ) O(NKLogN) O(NKLogN) 的做法。
- 另外,
打表注意到令 o p t ( x ) opt(x) opt(x) 为 k = x k=x k=x 时的最优解,那么 o p t ( x ) opt(x) opt(x) 是一个关于 x x x 的下凸函数,因此,通过二分斜率凸优化,我们可以在 O ( N L o g N L o g V ) O(NLogNLogV) O(NLogNLogV) 的时间内解决这个问题在序列上的形式。- 这里涉及到了一个凸优化输出方案的小技巧,设当前二分出的斜率 x x x 的最优解分段至多分成 k 1 k_1 k1 段,至少分成 k 2 k_2 k2 段,且 k 1 < k < k 2 k_1<k<k_2 k1<k<k2 ,那么若存在一组位置 i ≤ j ≤ k ≤ l i≤j≤k≤l i≤j≤k≤l ,其中 i , l i,l i,l 为最少分段中相邻的两个断点, j , k j,k j,k 为最多分段中相邻的两个断点,由 w ( i , k ) + w ( j , l ) ≤ w ( i , l ) + w ( j , k ) ( i ≤ j ≤ k ≤ l ) w(i,k)+w(j,l)≤w(i,l)+w(j,k)\ (i≤j≤k≤l) w(i,k)+w(j,l)≤w(i,l)+w(j,k) (i≤j≤k≤l) ,并且原有的两个分段均为最优解,我们取最少分段中 i i i 以及其之前的断点、取最多分段中 k k k 以及其之后的断点,形成的解一定也是一个最优解。可以证明,我们一定通过这种方式可以找到一种调整的方式将分的段数调整至 k k k 。
- 考虑环上的问题,假设全局最优解为 ( p 0 , p 1 , p 2 , . . . , p k ) (p_0,p_1,p_2,...,p_k) (p0,p1,p2,...,pk) 那么显然 ( p 1 , p 2 , p 3 , . . . , p k , p 0 ) (p_1,p_2,p_3,...,p_k,p_0) (p1,p2,p3,...,pk,p0) 也是全局最优解,由决策单调性,对于任意 p 0 ≤ q 0 ≤ p 1 p_0≤q_0≤p_1 p0≤q0≤p1 ,以 q 0 q_0 q0 开头的最优解 ( q 0 , q 1 , q 2 , . . . , q k ) (q_0,q_1,q_2,...,q_k) (q0,q1,q2,...,qk) 一定满足 p 0 ≤ q 0 ≤ p 1 , p 1 ≤ q 1 ≤ p 2 , . . . , p k ≤ q k ≤ p 0 p_0≤q_0≤p_1,p_1≤q_1≤p_2,...,p_k≤q_k≤p_0 p0≤q0≤p1,p1≤q1≤p2,...,pk≤qk≤p0 。
- 不妨令 p 0 ≤ 0 ≤ p 1 p_0≤0≤p_1 p0≤0≤p1 用上述凸优化的方式求出以 0 0 0 开头的最优解 ( 0 , q 1 , q 2 , . . . , q k ) (0,q_1,q_2,...,q_k) (0,q1,q2,...,qk) ,那么 0 − q 1 , q 1 − q 2 , . . . , q k − 1 − q k 0-q_1,q_1-q_2,...,q_{k-1}-q_k 0−q1,q1−q2,...,qk−1−qk 中的每一段都会有一个最优解上的断点,选择其中最小的一段,其长度必定在 O ( N k ) O(\frac{N}{k}) O(kN) 内。不妨令选择了 0 − q 1 0-q_1 0−q1 ,我们需要求出以其中每一个点开始的最优解,这些最优解中一定包含了全局最优解。
- 注意到我们已经确定了每一个决策点的范围,用最开始提到的分治做法求一个点开始的最优解是 O ( N L o g N ) O(NLogN) O(NLogN) 的。
- 假设我们当前需要求 [ l , r ] [l,r] [l,r] 中每一个点开始的最优解,我们可以先求出 m i d = l + r 2 mid=\frac{l+r}{2} mid=2l+r 开始的最优解,并且,我们将进一步确定 [ l , m i d − 1 ] [l,mid-1] [l,mid−1] 和 [ m i d + 1 , r ] [mid+1,r] [mid+1,r] 中的决策点的范围,如此递归处理,时间复杂度为 O ( N L o g 2 N ) O(NLog^2N) O(NLog2N) 。
- 时间复杂度 O ( N L o g 2 N + N L o g N L o g V ) O(NLog^2N+NLogNLogV) O(NLog2N+NLogNLogV) 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 4e5 + 5; const long long INF = 1e18; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } struct info {ll val; int cnt; }; bool operator < (info a, info b) { if (a.val == b.val) return a.cnt < b.cnt; else return a.val < b.val; } bool operator > (info a, info b) { if (a.val == b.val) return a.cnt > b.cnt; else return a.val < b.val; } bool operator <= (info a, info b) { if (a.val == b.val) return a.cnt <= b.cnt; else return a.val < b.val; } bool operator >= (info a, info b) { if (a.val == b.val) return a.cnt >= b.cnt; else return a.val < b.val; } info operator + (info a, ll val) { a.val += val; a.cnt += 1; return a; } struct range { int pos; int l, r; }; int n, k, pos[MAXN], home; ll ans, anspos[MAXN]; ll a[MAXN], sum[MAXN], l; ll weight(int l, int r) { int mid = (l + r + 1) / 2; return (mid - l) * a[mid] - (sum[mid] - sum[l]) + (sum[r] - sum[mid]) - (r - mid) * a[mid]; } info least[MAXN], most[MAXN]; int pathl[MAXN], pathm[MAXN]; void work(int from, ll cost) { least[from] = most[from] = (info) {0, 0}; static range qleast[MAXN], qmost[MAXN]; int lleast = 1, rleast = 1, lmost = 1, rmost = 1; qleast[1] = qmost[1] = (range) {from, from + 1, from + n}; for (int i = from + 1; i <= from + n; i++) { pathl[i] = qleast[lleast].pos, pathm[i] = qmost[lmost].pos; most[i] = most[qmost[lmost].pos] + (weight(qmost[lmost].pos, i) + cost); least[i] = least[qleast[lleast].pos] + (weight(qleast[lleast].pos, i) + cost); assert(most[i].val == least[i].val); if (i == qmost[lmost].r) lmost++; else qmost[lmost].l++; if (i == qleast[lleast].r) lleast++; else qleast[lleast].l++; if (i == from + n) break; while (lmost <= rmost && most[i] + weight(i, qmost[rmost].l) >= most[qmost[rmost].pos] + weight(qmost[rmost].pos, qmost[rmost].l)) { rmost--; if (rmost >= lmost) qmost[rmost].r = qmost[rmost + 1].r; } if (rmost < lmost) qmost[++rmost] = (range) {i, i + 1, from + n}; else if (most[i] + weight(i, qmost[rmost].r) >= most[qmost[rmost].pos] + weight(qmost[rmost].pos, qmost[rmost].r)) { int l = qmost[rmost].l, r = qmost[rmost].r; while (l < r) { int mid = (l + r) / 2; if (most[i] + weight(i, mid) >= most[qmost[rmost].pos] + weight(qmost[rmost].pos, mid)) r = mid; else l = mid + 1; } qmost[rmost].r = l - 1; qmost[++rmost] = (range) {i, l, from + n}; } while (lleast <= rleast && least[i] + weight(i, qleast[rleast].l) <= least[qleast[rleast].pos] + weight(qleast[rleast].pos, qleast[rleast].l)) { rleast--; if (rleast >= lleast) qleast[rleast].r = qleast[rleast + 1].r; } if (rleast < lleast) qleast[++rleast] = (range) {i, i + 1, from + n}; else if (least[i] + weight(i, qleast[rleast].r) <= least[qleast[rleast].pos] + weight(qleast[rleast].pos, qleast[rleast].r)) { int l = qleast[rleast].l, r = qleast[rleast].r; while (l < r) { int mid = (l + r) / 2; if (least[i] + weight(i, mid) <= least[qleast[rleast].pos] + weight(qleast[rleast].pos, mid)) r = mid; else l = mid + 1; } qleast[rleast].r = l - 1; qleast[++rleast] = (range) {i, l, from + n}; } } } ll calc(int from) { ll l = 0, r = INF; while (l <= r) { ll mid = (l + r) / 2; work(from, mid); if (least[from + n].cnt <= k && most[from + n].cnt >= k) { static int posl[MAXN], posm[MAXN]; for (int i = least[from + n].cnt, pos = from + n; i >= 0; i--) posl[i] = pos, pos = pathl[pos]; for (int i = most[from + n].cnt, pos = from + n; i >= 0; i--) posm[i] = pos, pos = pathm[pos]; for (int i = 0; i < least[from + n].cnt; i++) { int tmp = most[from + n].cnt - k + i; pos[i] = posl[i]; if (posl[i] <= posm[tmp] && posl[i + 1] >= posm[tmp + 1]) { int now = i; for (int j = tmp + 1; j <= most[from + n].cnt; j++) pos[++now] = posm[j]; return least[from + n].val - k * mid; } } assert(false); } if (least[from + n].cnt > k) l = mid + 1; else r = mid - 1; } return -1; } map <int, ll> dp[MAXN]; map <int, int> path[MAXN]; vector <int> rangel, ranger, dppos; void conquer(int layer, int l, int r, int ql, int qr) { if (l > r) return; int mid = (l + r) / 2; dp[mid][layer] = INF; for (int i = ql; i <= qr; i++) { ll tmp = dp[i][layer - 1] + weight(i, mid); if (tmp < dp[mid][layer]) dp[mid][layer] = tmp, path[mid][layer] = i; } assert(dp[mid][layer] != INF); conquer(layer, l, mid - 1, ql, path[mid][layer]); conquer(layer, mid + 1, r, path[mid][layer], qr); } ll getdp(int from) { dp[from][0] = 0; path[from][0] = from; int lastl = from, lastr = from; for (int i = 1; i <= k; i++) { conquer(i, rangel[i], ranger[i], lastl, lastr); lastl = rangel[i], lastr = ranger[i]; } dppos.resize(k + 1); for (int i = k, pos = from + n; i >= 0; i--) dppos[i] = pos, pos = path[pos][i]; return dp[from + n][k]; } void divide(int l, int r) { if (l > r) return; vector <int> bakl = rangel, bakr = ranger; int mid = (l + r) / 2; ll tmp = getdp(mid); vector <int> bak = dppos; if (tmp < ans) ans = tmp, home = mid; rangel = bakl, ranger = bak; divide(l, mid - 1); rangel = bak, ranger = bakr; divide(mid + 1, r); rangel = bakl, ranger = bakr; } int main() { read(n), read(k), read(l); for (int i = 1; i <= n; i++) read(a[i]), a[i + n] = a[i] + l; for (int i = 1; i <= 2 * n; i++) sum[i] = sum[i - 1] + a[i]; ans = calc(0), home = 0; int Min = 1; for (int i = 1; i <= k; i++) if (pos[i] - pos[i - 1] < pos[Min] - pos[Min - 1]) Min = i; for (int i = Min; i <= k; i++) { rangel.push_back(pos[i - 1]); ranger.push_back(pos[i]); } for (int i = 1; i <= Min; i++) { rangel.push_back(pos[i - 1] + n); ranger.push_back(pos[i] + n); } divide(rangel[0], ranger[0]); writeln(ans = calc(home)); for (int j = 1; j <= k; j++) anspos[j] = a[(pos[j] + pos[j - 1] + 1) / 2] % l; sort(anspos + 1, anspos + k + 1); for (int i = 1; i <= k; i++) printf("%lld ", anspos[i]); return 0; }