说在前面
几百年没写斜率优化
找了一道题练练手,结果用了三个多小时才A….简直弱到不忍直视
题目
题面
Pine开始了从S地到T地的征途。
从S地到T地的路可以划分成N段,相邻两段路的分界点设有休息站。
Pine计划用M天到达T地。除第M天外,每一天晚上Pine都必须在休息站过夜。所以,一段路必须在同一天中走完。
Pine希望每一天走的路长度尽可能相近,所以他希望每一天走的路的长度的方差尽可能小。
帮助Pine求出最小方差是多少。设方差是v,可以证明,v×M^2是一个整数。为了避免精度误差,输出结果时输出v×M^2。
输入输出格式
输入格式:
第一行两个整数N,M,含义如题。
第二行N个数,表示N段路的长度
输出格式:
输出一行一个整数表示答案
规模:1≤N≤3000,保证从 S 到 T 的总路程不超过 30000
解法
首先需要知道方差是什么……
对于一组数据
L1,L2⋯LM
,平均数
x=L1+L2+⋯+LMM
则方差为
v=(x−L1)2+(x−L2)2+⋯+(x−LM)2M
题目要求
vM2
,打开括号发现是这样的:
vM2=M∗[(x−L1)2+(x−L2)2+⋯+(x−LM)2]
,发现每一天走的长度对答案的贡献刚好是
M∗(x−Li)2
于是定义dp[i][j]表示前i天走了j段路,转移:
dp[i][j]=min(dp[i−1][k]+M(x−sum[j]+sum[k])2)
发现把平方打开正好就是斜率优化的式子,然后就可以愉悦的切题了
————–10分钟过后————–
卧槽?为什么转移式子里面有一个
Mx2
啊,这不是个小数吗???说好的保证是个整数的呢???
请忽略上面的过程,再来推导一次
题目要求
vM2
,这个平方比较奇怪,但是不妨先画一下柿子:
vM2=M∗[(x−L1)2+(x−L2)2+⋯+(x−LM)2]
=M∗[(x2−2x∗L1+L21)+(x2−2x∗L2+L22)+⋯+(x2−2x∗LM+L2M)]
=M∗[M∗x2−2x∗(L1+L2+⋯+LM)+(L21+L22+⋯+L2M)]
=sum2−2∗sum2+M∗(L21+L22+⋯+L2M)
=M∗(L21+L22+⋯+L2M)−sum2
愉悦的发现,现在只需要求出 每天走的长度的平方和 的最小值
定义dp[i][j]表示前i天走了j段路,转移:
dp[i][j]=min(dp[i−1][k]+(sum[j]−sum[k])2)
化成斜率式:
(2sum[j]∗sum[k])+(dp[i][j]−sum[j]2)=( dp[i−1][k]+sum[las]2 )
正好就是
k∗x+b=y
的形式,然后就可以上斜率优化了
下面是自带大长度的代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
int N , M , now , pre ;
long long dp[2][3005] , ans , sum[3005] ;
struct Points{
long long x , y ;
long long get_b( int k ){
return y - x * k ;
}
}p[2][3005] ;
bool cmp( Points i , int j , int k ){
long long w1 = ( i.y - p[pre][j].y ) * ( i.x - p[pre][k].x ) ;
long long w2 = ( i.y - p[pre][k].y ) * ( i.x - p[pre][j].x ) ;
return w1 <= w2 ;
}
void solve(){
pre = 1 , now = 0 ;
for( int j = 1 ; j <= N ; j ++ )
dp[now][j] = sum[j] * sum[j] ;
ans = dp[now][N] ;
for( int i = 2 ; i <= M ; i ++ ){
swap( now , pre ) ;
int head = 1 , tail = 0 ;
for( int j = 1 ; j <= N ; j ++ ){
int k = 2 * sum[j] ;
Points tmp = (Points){ sum[j] , dp[pre][j] + sum[j]*sum[j] } ;
while( head < tail && cmp( tmp , tail , tail-1 ) ) tail -- ;
p[pre][++tail] = tmp ;
while( head < tail && p[pre][head].get_b( k ) >= p[pre][head+1].get_b( k ) ) head ++ ;
dp[now][j] = p[pre][head].get_b( k ) + sum[j] * sum[j] ;
}
ans = min( ans , dp[now][N] ) ;
}
printf( "%lld\n" , ans * M - sum[N] * sum[N] ) ;
}
int main(){
scanf( "%d%d" , &N , &M ) ;
for( int i = 1 ; i <= N ; i ++ ){
scanf( "%lld" , &sum[i] ) ;
sum[i] += sum[i-1] ;
}
solve() ;
}