斜率优化DP
先把方差推一下:
v
m
2
=
m
∑
(
x
i
−
x
ˉ
)
2
=
m
∑
(
x
i
−
∑
x
i
m
)
2
=
m
∑
x
i
2
−
2
(
∑
x
i
)
2
+
(
∑
x
i
)
2
=
m
∑
x
i
2
−
∑
x
i
\begin{aligned} vm^2&=m\sum(x_i-\bar x)^2\\ &=m\sum(x_i-\frac{\sum x_i}m)^2\\ &=m\sum x_i^2-2(\sum x_i)^2+(\sum x_i)^2\\ &=m\sum x_i^2-\sum x_i \end{aligned}
vm2=m∑(xi−xˉ)2=m∑(xi−m∑xi)2=m∑xi2−2(∑xi)2+(∑xi)2=m∑xi2−∑xi
∑
x
i
\sum x_i
∑xi是个定值不用管,我们只要维护
∑
x
i
2
\sum x_i^2
∑xi2就好了。
设 f [ i ] [ j ] f[i][j] f[i][j]表示前 i i i条路走 j j j天的答案,则有 f [ i ] [ j ] = m i n { f [ k ] [ j − 1 ] + ( s [ i ] − s [ k ] ) 2 } f[i][j]=min\{f[k][j-1]+(s[i]-s[k])^2 \} f[i][j]=min{f[k][j−1]+(s[i]−s[k])2},这是 n 3 n^3 n3的。看到有平方这个东西就可以考虑斜率优化。下面是推导过程:
设
x
<
y
x<y
x<y且
x
x
x优于
y
y
y
f
[
x
]
+
(
s
[
i
]
−
s
[
x
]
)
2
<
f
[
y
]
+
(
s
[
i
]
−
s
[
y
]
)
2
f
[
x
]
−
f
[
y
]
+
s
[
x
]
2
−
s
[
y
]
2
<
2
(
s
[
x
]
−
s
[
y
]
)
s
[
i
]
f
[
x
]
−
f
[
y
]
+
s
[
x
]
2
−
s
[
y
]
2
2
(
s
[
x
]
−
s
[
y
]
)
>
s
[
i
]
\begin{aligned} &f[x]+(s[i]-s[x])^2<f[y]+(s[i]-s[y])^2\\ &f[x]-f[y]+s[x]^2-s[y]^2<2(s[x]-s[y])s[i]\\ &\frac{f[x]-f[y]+s[x]^2-s[y]^2}{2(s[x]-s[y])}>s[i] \end{aligned}
f[x]+(s[i]−s[x])2<f[y]+(s[i]−s[y])2f[x]−f[y]+s[x]2−s[y]2<2(s[x]−s[y])s[i]2(s[x]−s[y])f[x]−f[y]+s[x]2−s[y]2>s[i]
单调队列维护一个递增序列即可。
代码:
#include<cctype>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 3005
#define F inline
using namespace std;
typedef long long LL;
int n,m,l,r,q[N]; LL s[N],f[N][N];
F char readc(){
static char buf[100000],*l=buf,*r=buf;
if (l==r) r=(l=buf)+fread(buf,1,100000,stdin);
return l==r?EOF:*l++;
}
F int _read(){
int x=0; char ch=readc();
while (!isdigit(ch)) ch=readc();
while (isdigit(ch)) x=(x<<3)+(x<<1)+(ch^48),ch=readc();
return x;
}
#define calc(x,y) ((f[x][j-1]-f[y][j-1]+s[x]*s[x]-s[y]*s[y])/(s[x]-s[y]<<1ll))
int main(){
n=_read(),m=_read(),l=1;
for (int i=1;i<=n;i++) s[i]=_read()+s[i-1],f[i][1]=s[i]*s[i];
for (int j=2;j<=m;j++,l=1,r=0)
for (int i=1;i<=n;q[++r]=i,i++){
while (l<r&&calc(q[l],q[l+1])<s[i]) l++;
f[i][j]=f[q[l]][j-1]+(s[i]-s[q[l]])*(s[i]-s[q[l]]);
while (l<r&&calc(q[r-1],q[r])>calc(q[r],i)) r--;
}
return printf("%lld\n",f[n][m]*m-s[n]*s[n]),0;
}