给出n,m(3000),和一个长为n的数列(和<30000),将其分成连续的m段,使得每段和的方差最小。
输出
方
差
∗
m
2
方差*m^2
方差∗m2(一定是整数)。
利用方差公式, 答 案 = m Σ x i 2 − ( Σ x ) 2 答案 = m\Sigma x_i^2 - (\Sigma x)^2 答案=mΣxi2−(Σx)2,其中 x i x_i xi表示第 i i i段的数字之和。
所以目标是让 Σ x i 2 \Sigma x_i^2 Σxi2最小。
设
d
i
j
d_{ij}
dij表示把前
i
i
i个数字分成
j
j
j段时的最小代价,
d
00
=
0
d_{00}=0
d00=0.
d
i
j
=
m
i
n
{
d
k
,
j
−
1
+
(
s
i
−
s
k
)
2
}
d_{ij}=min\{d_{k,j-1}+(s_i-s_k)^2\}
dij=min{dk,j−1+(si−sk)2},
j
−
1
<
=
k
<
i
j-1<=k<i
j−1<=k<i,
s
i
s_i
si表示前
i
i
i个数的和。
先按 j j j外层循环,内层再按 i i i循环。
整理得, d i = s i 2 + m i n { d k + s k 2 − 2 s i s k } d_i=s_i^2+min\{d_k+s_k^2-2s_is_k\} di=si2+min{dk+sk2−2sisk}
设
a
<
b
<
i
a<b<i
a<b<i且从
b
b
b转移优于从
a
a
a转移,那么:
(
d
b
+
s
b
2
)
−
(
d
a
+
s
a
2
)
s
b
−
s
a
<
2
s
i
\frac{(d_b+s_b^2)-(d_a+s_a^2)}{s_b-s_a}<2s_i
sb−sa(db+sb2)−(da+sa2)<2si
由斜率优化原理, y i = d i + s i 2 , x i = s i , k = 2 s i y_i=d_i+s_i^2,x_i=s_i,k=2s_i yi=di+si2,xi=si,k=2si,维护下凸包,斜率递增。
WA点:
- subx和suby写混
- 分析过程中出错,目标斜率少了常数2
/* LittleFall : Hello! */
#include <bits/stdc++.h>
using namespace std; using ll = long long; inline int read();
const int M = 3016, MOD = 1000000007;
ll sum[M];
ll dp[M][M];
int segs; //做指示变量,因为调用太麻烦
inline ll subx(ll a, ll b)
{
return sum[b] - sum[a];
}
inline ll suby(ll a, ll b)
{
return (dp[b][segs-1]+sum[b]*sum[b]) - (dp[a][segs-1]+sum[a]*sum[a]);
}
inline ll cal(int i, int k)
{
return dp[k][segs-1]+(sum[i]-sum[k])*(sum[i]-sum[k]);
}
int q[M], l, r;
int main(void)
{
#ifdef _LITTLEFALL_
freopen("in.txt","r",stdin);
#endif
int n = read(), m = read();
for(int i=1; i<=n; ++i)
sum[i] = sum[i-1] + read();
for(int i=1; i<=n; ++i)
dp[i][1] = sum[i] * sum[i];
for(segs=2; segs<=m; ++segs)
{
l = r = 0; q[r++] = segs-1;
for(int i=segs; i<=n; ++i)
{
while(r-l>=2 && suby(q[l], q[l+1]) <= subx(q[l],q[l+1])*2*sum[i] ) ++l;
dp[i][segs] = cal(i, q[l]);
while(r-l>=2 &&
suby(q[r-2], q[r-1]) * subx(q[r-1],i) >= subx(q[r-2], q[r-1]) * suby(q[r-1], i)
) --r;
q[r++] = i;
}
}
cout << m*dp[n][m]-sum[n]*sum[n] << "\n";
return 0;
}
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}