r
题意:把n分成任意段(每段中连续),每一段代价(∑ ci ) +m 求总的最小代价
i=l
作为提醒自己的经验题,即使再水也要记住坑点(其实是自己被坑的地方- -)
斜率优化裸题- -
首先定义dp[i]表示把前i个处理好的最小代价
dp[i]=min{dp[j]+m+(sum[i]-sum[j])^2}
一看这个转移就是n^2的,我们来搞成O(n)的。
h[i]=m+sum[i]^2,g[j]=sum[j]^2
->> dp[i]=dp[j]+h[i]+g[j]-2sum[i]sum[j]
考虑对于一个j1<j2<i,如果j2更优
dp[j1]+g[j1]-2sum[i]sum[j1]>dp[j2]+g[j2]-2sum[i]sum[j2]
令y=dp[j]+g[j]
(y[j1]-y[j2])/(sum[j1]-sum[j2])<2sum[i] (sum[j1]<sum[j2],递增的)
然后就搞成了斜率
令左边为T(j1,j2),T(j1,j2)<2sum[i]时,j2更优
对于x<y,T(x,y)<2sum[i]<2sum[i+1]<````,也就是T(x,y)永远满足条件,y一定最优,x可以删去
我们要维护的就是T(x,y)>2sum[i]这一部分
我们考虑对于T(j1,j2),T(j2,j3)应该怎么维护:
对于T(j1,j2)>2sum[i]时,j1更优,j2不是最优;
对于T(j1,j2)<2sum[i]时,如果T(j2,j3)<T(j1,j2)<2sum[i],那么j2比j1优,j3比j2优,j2不是最优,
所以T(j1,j2)>T(j2,j3)的时候,j2一定不是最优的。
所以我们要维护T(j1,j2)<T(j2,j3)
综上,
我们要维护2sum[i]<T(j1,j2)<T(j2,j3)<T(j3,j4)<T(j4,j5)<T(j5,j6)<`````
此时j1最优。
就维护一个T>sum[i]的序列并且递增
好了,接下来说明坑点:
以前写Toy那道题,我直接double求斜率,然而这一次不适用,我才发现直接求double斜率会挂精度!!!!
所以我们先看成维护斜率,在判断的时候把分母乘到右边去,用乘法代替除法,大概估算用long long不会超!
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<queue>
#include<cmath>
#define ll long long
using namespace std;
const ll maxn=500000+20;
ll dp[maxn];//dp[i] means min cost of [1,i]
/*dp[i]=min{dp[j]+(sum[i]-sum[j])^2+M}
dp[i]=dp[j]+sum[i]^2+sum[j]^2-2sum[i]sum[j]+M
h[i]=sum[i]^2+M;
g[j]=sum[j]^2
dp[i]=dp[j]+h[i]+g[j]-2sum[i]sum[j]
对于j1<j2 j2更优
dp[j1]+g[j1]-2sum[i]sum[j1]>dp[j2]+g[j2]-2sum[i]sum[j2]
y[j]=dp[j]+g[j]
(y[j1]-y[j2]) /(sum[j1]-sum[j2])<2sum[i]
T(j1,j2)表示(j1<j2)的斜率
T(j1,j2)<2sum[i]表示j2更优 sum[i]递增
即T(j1,j2)>2sum[i]表示j1更优
ifx<y T(x,y)<2sum[i]<2sum[i+1]<``` y一定更优,x舍去
所以维护T(x,y)>2sum[i]
考虑T(j1,j2)>2sum[i],j1更优
sum[i]增大,要满足这个条件,T()也应该递增
T(j1,j2)>2sum[i],j1更优
T(j1,j2)<2sum[i]时,如果T(j2,j3)<T(j1,j2)<2sum[i],一定有j3更优
所以斜率递减时j2不会成为最优解
2sum[i]<T(j1,j2)<T(j2,j3)<```
*/
ll n,m;
ll q[maxn];
ll sum[maxn];
ll head,tail;
ll up(ll x,ll y)
{
return dp[x]-dp[y]+sum[x]*sum[x]-sum[y]*sum[y];
}
ll down(ll x,ll y)
{
return sum[x]-sum[y];
}
int main()
{
while(scanf("%I64d%I64d",&n,&m)!=EOF)
{
head=tail=0;
dp[0]=0;
q[0]=0;
sum[0]=0;
for(ll i=1;i<=n;i++)
{
ll x;
scanf("%I64d",&x);
sum[i]=sum[i-1]+x;
}
for(ll i=1;i<=n;i++)
{
while(head<tail&&up(q[head],q[head+1])>=2*sum[i]*down(q[head],q[head+1]))head++;
ll j=q[head];
dp[i]=dp[j]+m+(sum[i]-sum[j])*(sum[i]-sum[j]);
while(head<tail&&up(q[tail-1],q[tail])*down(q[tail],i)>=up(q[tail],i)*down(q[tail-1],q[tail]))tail--;
q[++tail]=i;
}
printf("%I64d\n",dp[n]);
}
return 0;
}