斜率DP以前就听说过但是一直都搁着没学,趁着还有时间赶紧补了吧。这东西乍看上去公式一堆推来推去很是可怕,但是其实没有那么难学,核心思想就是把DP的式子转换成斜率的形式之后利用凸壳去维护 更新当前DP值的最优点。
这篇博客写的比较简洁,没有那么多公式可能会比较好理解
然后就开始刷题巩固知识点吧。
题目链接
题目大意:
一共n个点,每个点有权值ai,画m个分界线把他们分成m+1段,每一段[l,r]的价值是 ∑ i = l r ∑ j = l + 1 r a i ∗ a j \sum_{i=l}^r\sum_{j=l+1}^ra_i*a_j ∑i=lr∑j=l+1rai∗aj,求能得到的最小价值是多少。
解题思路:
设
s
u
m
[
i
]
=
∑
i
=
1
i
a
[
i
]
,
g
[
i
]
=
∑
i
=
1
i
a
i
2
sum[i]=\sum_{i=1}^ia[i],g[i]=\sum_{i=1}^ia_i^2
sum[i]=∑i=1ia[i],g[i]=∑i=1iai2,则有
∑
i
=
l
r
∑
j
=
l
+
1
r
a
i
a
j
=
(
s
u
m
[
r
]
−
s
u
m
[
l
−
1
]
)
2
−
(
g
[
r
]
−
g
[
l
−
1
]
)
2
\sum_{i=l}^r\sum_{j=l+1}^ra_ia_j=\frac{(sum[r]-sum[l-1])^2-(g[r]-g[l-1])}{2}
i=l∑rj=l+1∑raiaj=2(sum[r]−sum[l−1])2−(g[r]−g[l−1])
设
d
p
[
t
]
[
i
]
dp[t][i]
dp[t][i]为到i为止,前面画了t个分界线的最小价值,那么显然:
d
p
[
t
]
[
i
]
=
m
i
n
(
d
p
[
t
−
1
]
[
j
]
+
(
s
u
m
[
i
]
−
s
u
m
[
j
]
)
2
−
(
g
[
i
]
−
g
[
j
]
)
2
dp[t][i]=min(dp[t-1][j]+\frac{(sum[i]-sum[j])^2-(g[i]-g[j])}{2}
dp[t][i]=min(dp[t−1][j]+2(sum[i]−sum[j])2−(g[i]−g[j])
这样就可以j比k更优的条件为:
d
p
[
t
−
1
]
[
j
]
+
(
s
u
m
[
i
]
−
s
u
m
[
j
]
)
2
−
(
g
[
i
]
−
g
[
j
]
)
2
≤
d
p
[
t
−
1
]
[
k
]
+
(
s
u
m
[
i
]
−
s
u
m
[
k
]
)
2
−
(
g
[
i
]
−
g
[
k
]
)
2
dp[t-1][j]+\frac{(sum[i]-sum[j])^2-(g[i]-g[j])}{2}\le dp[t-1][k]+\frac{(sum[i]-sum[k])^2-(g[i]-g[k])}{2}
dp[t−1][j]+2(sum[i]−sum[j])2−(g[i]−g[j])≤dp[t−1][k]+2(sum[i]−sum[k])2−(g[i]−g[k])
化简得到:
f
(
t
−
1
,
j
)
−
f
(
t
−
1
,
k
)
s
u
m
[
j
]
−
s
u
m
[
k
]
≤
2
s
u
m
[
i
]
\frac{f(t-1,j)-f(t-1,k)}{sum[j]-sum[k]}\le2sum[i]
sum[j]−sum[k]f(t−1,j)−f(t−1,k)≤2sum[i]
其中
f
(
t
,
j
)
=
2
d
p
[
t
]
[
j
]
+
s
u
m
[
j
]
2
+
g
[
j
]
f(t,j)=2dp[t][j]+sum[j]^2+g[j]
f(t,j)=2dp[t][j]+sum[j]2+g[j]
然后就是套路维护下凸包了。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e3 + 50;
int n, m;
ll dp[2][maxn];
ll sum[maxn], g[maxn];
int q[2][maxn], head[2], tail[2];
ll a[maxn];
ll up(int i, int j, int t){//i > j
return 2*dp[t][i] + sum[i]*sum[i]+g[i] - (2*dp[t][j] + sum[j]*sum[j]+g[j]);
}
ll down(int i, int j){
return sum[i] - sum[j];
}
void init(){
for(int i = 1; i <= n; ++i) scanf("%lld", &a[i]), sum[i] = sum[i-1]+a[i], g[i] = g[i-1]+a[i]*a[i];
}
void sol(){
int cur = 0, pre = 1;
head[0] = tail[0] = 0;
q[0][tail[0]++] = 0;
for(int t = 0; t <= m; ++t){
swap(cur, pre);
head[cur] = tail[cur] = 0;
q[cur][tail[cur]++] = 0;
for(int i = 1; i <= n; ++i){
int p1 = q[pre][head[pre]], p2 = q[pre][head[pre]+1];
while(tail[pre] - head[pre] > 1 && up(p2, p1, pre) <= 2*sum[i]*down(p2, p1)){//注意别写反
head[pre]++; p1 = q[pre][head[pre]], p2 = q[pre][head[pre]+1];
}
dp[cur][i] = dp[pre][p1] + ( (sum[i] - sum[p1])*(sum[i] - sum[p1]) - (g[i] - g[p1]))/2;
//这里dp[pre][p1]写成了dp[p1][pre]导致RE
if(tail[cur] - head[cur] > 1) p1 = q[cur][tail[cur]-1], p2 = q[cur][tail[cur]-2];//注意别越界
while(tail[cur] - head[cur] > 1 && up(i,p1,cur)*down(p1, p2) <= up(p1,p2,cur)*down(i,p1)) {
tail[cur]--;
if(tail[cur] - head[cur] > 1) p1 = q[cur][tail[cur]-1], p2 = q[cur][tail[cur]-2];//注意别越界
}
q[cur][tail[cur]++] = i;
}
}
cout<<dp[cur][n]<<endl;
}
int main()
{
// freopen("1.in", "r", stdin);
while(cin>>n>>m){
if(n == 0 && m == 0) break;
init();
sol();
}
}