wqs二分入门题,这题可以直接 dp,且复杂度不会比 wqs 二分高。
问题具有一个性质:限制你建 m 个 post office,但是显然建的越多代价越少,代价和建的 post office 的数量在二维平面上是一个上凸包的形状。
考虑在给每个 post office 加一个代价 x x x ,每建一个 post office 都要付出额外 x x x 的代价,显然 x x x 越大,最优解建的 post office 越少,反之越多,即具有单调性。
可以二分 x x x,然后在不考虑建几个 post office 的限制下进行 dp,求出最小代价 以及对应的建立的 post office 的数量。
考虑如何 dp:首先考虑,在 [l,r] 区间内建一座 post office,建在哪最优。
画一下可以发现是建在中间的点
m
i
d
mid
mid 最优,设刚开始建在
x
x
x(
x
<
m
i
d
x < mid
x<mid),
x
x
x 的左边有 a 个点,右边有 b 个点,且
d
i
s
(
x
,
x
+
1
)
=
d
dis(x, x + 1) = d
dis(x,x+1)=d,如果将post ofiice 右移到下一个点,对答案的贡献改变了:
(
a
+
1
)
∗
d
−
b
∗
d
(a + 1) * d - b * d
(a+1)∗d−b∗d,当
a
+
1
≤
b
a + 1 \leq b
a+1≤b 时都可以移动
x
x
x,当
x
>
m
i
d
x > mid
x>mid 时可以类似的方法证明。
容易列出转移方程:
d
p
[
i
]
=
d
p
[
j
]
+
w
(
j
,
i
)
dp[i] = dp[j] + w(j,i)
dp[i]=dp[j]+w(j,i),其中
w
(
j
,
i
)
w(j,i)
w(j,i) 表示在 [j,i] 建一个 post office 的最小代价。
预处理
w
(
j
,
i
)
w(j,i)
w(j,i):可以发现如果
w
(
j
,
i
−
1
)
w(j,i - 1)
w(j,i−1) 已经求出,
w
(
j
,
i
)
=
w
(
j
,
i
−
1
)
+
a
[
j
]
−
a
[
⌊
i
+
j
2
⌋
]
w(j,i) = w(j,i - 1) + a[j] - a[\lfloor\frac{i+j}{2}\rfloor]
w(j,i)=w(j,i−1)+a[j]−a[⌊2i+j⌋]
复杂度为 n 2 log v n^2\log v n2logv, v v v 较大,可以取所有 a a a 的和。
代码:
#include<iostream>
#include<string.h>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn = 1e3 + 10;
const int inf = 0x3f3f3f3f;
int n,m,a[maxn],sum;
int dp[maxn],w[maxn][maxn],d[maxn];
int solve(int x) {
memset(dp,inf,sizeof dp);
memset(d,0,sizeof d);
dp[0] = 0; d[0] = 0; //在取得最优的情况下尽可能的建post office
for (int i = 1; i <= n; i++) {
for (int j = 0; j < i; j++) {
if (dp[j] + w[j + 1][i] + x < dp[i]) {
dp[i] = dp[j] + w[j + 1][i] + x;
d[i] = d[j] + 1;
} else if (dp[j] + w[j + 1][i] + x == dp[i]) {
if (d[j] + 1 > d[i])
d[i] = d[j] + 1;
}
}
}
return d[n];
}
int main() {
scanf("%d%d",&n,&m);
for (int i = 1; i <= n; i++) {
scanf("%d",&a[i]);
sum += a[i];
}
sort(a + 1,a + n + 1);
for (int i = 1; i <= n; i++)
for (int j = i + 1; j <= n; j++)
w[i][j] = w[i][j - 1] + a[j] - a[i + j >> 1];
memset(dp,0,sizeof dp);
memset(d,0,sizeof d);
int l = 0, r = sum;
while (l < r) {
int mid = l + r >> 1;
if (solve(mid) < m) r = mid;
else l = mid + 1;
}
solve(l - 1);
printf("%d\n",dp[n] - m * (l - 1));
return 0;
}