终于完全靠自己过掉了一道基础的斜率优化DP
DP的思路很直接,一开始看数据范围,还以为不用优化可能卡过去,然后就直接写了三重循环DP,果然超时没商量= =
for(int i=1 ;i<=n ;i++) scanf("%d",&a[i]);
sort(a+1,a+n+1);
for(int i=1 ;i<=n ;i++){
for(int j=1 ;j<=m ;j++){
if(i == 1){
dp[i][j] = 0;
continue;
}
if(j == 1) dp[i][j] = (a[i] - a[1]) * (a[i] - a[1]);
else{
int Min = dp[1][j-1] + (a[i] - a[2])*(a[i] - a[2]);
for(int k=2 ;k<i ;k++){
Min = min(Min,dp[k][j-1] + (a[i] - a[k+1])*(a[i] - a[k+1]));
}
dp[i][j] = Min;
}
}
}
还是要从公式变形入手:
dp[i][j] 表示以i为终点,分成j段时的最优解。
排序后易得:
d p [ i ] [ j ] = d p [ k ] [ j − 1 ] + ( a [ i ] − a [ k + 1 ] ) 2 ( 1 < = k < i ) dp[i][j] = dp[k][j-1] + (a[i] - a[k+1])^2 (1 <= k < i) dp[i][j]=dp[k][j−1]+(a[i]−a[k+1])2(1<=k<i)
假设 1 < = k 2 < = k 1 < = n 1<=k2 <= k1 <= n 1<=k2<=k1<=n
当选择 k1 比 k2 更优时,有:
d
p
[
k
1
]
[
j
−
1
]
+
(
a
[
i
]
−
a
[
k
1
+
1
]
)
2
<
=
d
p
[
k
2
]
[
j
−
1
]
+
(
a
[
i
]
−
a
[
k
2
+
1
]
)
2
dp[k1][j-1] +(a[i] - a[k1+1])^2 <= dp[k2][j-1] + (a[i] - a[k2+1])^2
dp[k1][j−1]+(a[i]−a[k1+1])2<=dp[k2][j−1]+(a[i]−a[k2+1])2
化简得:
a [ i ] > = ( d p [ k 1 ] [ j − 1 ] + a [ k 1 + 1 ] 2 ) − ( d p [ k 2 ] [ j − 1 ] + a [ k 2 + 1 ] 2 ) 2 ∗ ( a [ k 1 + 1 ] − a [ k 2 + 1 ] ) a[i] >= \frac{(dp[k1][j-1] + a[k1+1]^2) - (dp[k2][j-1] + a[k2+1]^2)} {2*(a[k1+1] - a[k2+1])} a[i]>=2∗(a[k1+1]−a[k2+1])(dp[k1][j−1]+a[k1+1]2)−(dp[k2][j−1]+a[k2+1]2)
Y1 = dp[k1][j-1] + a[k1+1]^2
Y2 = dp[k2][j-1] + a[k2+1]^2
X1 = a[k1+1]
X2 = a[k2+1]
则可进行斜率优化。
然后个人感觉斜率优化的写法挺套路的,主要分三部分:
1.从前往后遍历队列,找到一个元素满足上述公式。(通过对比左右队列元素得到)
2.更新答案。
3.将当前点i插入队列前,先从后往前遍历队列,去除多余的点。
那如何判断当前点是否是多余的?
设队列尾部的两个点为p1,p2,根据入队顺序一定有p1 < p2.
而当前点 p3 一定有 p1 < p2 < p3。
当
Y
3
−
Y
2
2
∗
(
X
3
−
X
2
)
<
=
a
[
i
]
<
=
Y
2
−
Y
1
2
∗
(
X
2
−
X
1
)
\frac{Y3 - Y2} {2*(X3 - X2)} <= a[i] < = \frac{Y2 - Y1} {2*(X2 - X1)}
2∗(X3−X2)Y3−Y2<=a[i]<=2∗(X2−X1)Y2−Y1
说明 p3 比 p2 更优且 p1 比 p2 更优。
故队列最尾部的元素p2 可以弹出去。
以上。
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int INF = 1e9 + 7;
const int A = 1e4 + 10;
const int B = 5e3 + 10;
int a[A];
int dp[A][B],que[A];
int head,tail;
int get_up(int i,int j,int len){
return (dp[i][len-1] + a[i+1]*a[i+1]) - (dp[j][len-1] + a[j+1]*a[j+1]);
}
int get_down(int i,int j){
return (a[i+1] - a[j+1]);
}
int main(){
int T,_=1;
scanf("%d",&T);
while(T--){
int n,m;
scanf("%d%d",&n,&m);
for(int i=1 ;i<=n ;i++) scanf("%d",&a[i]);
sort(a+1,a+n+1);
for(int i=1 ;i<=n ;i++) dp[i][1] = (a[i] - a[1])*(a[i] - a[1]);
for(int j=2 ;j<=m ;j++){
head = tail = 0;
que[tail++] = j - 1;
for(int i=j ;i<=n ;i++){
while(head+1 < tail && get_up(que[head+1],que[head],j) <= 2*a[i]*get_down(que[head+1],que[head])) head++;
int k = que[head];
dp[i][j] = dp[k][j-1] + (a[i] - a[k+1])*(a[i] - a[k+1]);
while(head+1 < tail && get_up(i,que[tail-1],j)*get_down(que[tail-1],que[tail-2]) <= get_up(que[tail-1],que[tail-2],j)*get_down(i,que[tail-1])) tail--;
que[tail++] = i;
}
}
printf("Case %d: %d\n",_++,dp[n][m]);
}
return 0;
}