学习了国家集训队员周源的这篇论文《浅谈数形结合思想在信息学竞赛中的应用》,很具启发性。受益颇多。
这题的一个简单的O(n^2)的DP是 dp[ i ] = min {dp[ j ] + sum[ i ] - sum[ j ] - ( i - j ) * a[ j+1 ] | i - j >= k } (sum[ i ]为前缀和) 。 直接这么做肯定TLE,重点是如何优化它。
假设对于当前的i 有 j1<j2 , 且j2 不差于 j1 , 则 有:
dp[ j1 ] + sum[ i ] - sum[ j1 ] - ( i - j1 ) *a[ j1+1 ] >= dp[ j2 ] + sum[ i ] - sum[ j2 ] - ( i - j2 ) *a[ j2+1 ]
整理得:
{ ( dp[ j1 ] - sum[ j1 ] + j1 * a[ j1 + 1 ] ) - (dp[ j2 ] - sum[ j2 ] + j2 * a[ j2 + 1 ] ) } / { a[ j1+1] - a[ j2 + 1 ] } <= i
(注意a[ j1+1] - a[ j2 + 1 ] < 0)
记 Y[ j ] = dp[ j ] - sum[ j ] + j * a[ j+1 ] , X[ j ] = a[ j+1 ] , 可画出Y - X 的函数图。
记 rate( a, b ) = (Y[b]-Y[a]) / (X[b] - X[a] ) 即a ,b两点间的斜率。
点j2 不差于j1 在图像上的直观意义是 两点间的斜率 小于等于 i , 即rate(j1 , j2 ) <= i (记住这个结论).
有以下两个结论:
① rate(a , b) <= i 是当前及以后 b点不差于a点的充分条件(i是递增的)
② rate(a , b) > i 只能证明当前 a 优于b ,不能证明以后的情况(这是显而易见的,因为i 递增)
可以使用单调队列(斜率递增)来维护可能是最优的决策点,而剔除不可能最优的点。
首先给个直观的印象:
队列里存放的是斜率递增的点 , 可以知道: 第一个使rate( b , c ) > i 的b 是最佳点,(若不存在rate(b , c) > i , 则b是队尾元素) , 证明:b之前rate(a,b)<=i , 即后者比前者优 ,b之后rate(b,c) > i , 即前者比后者优。最佳决策点b具有这样的性质: [ a, b ,c 相邻 且a<b<c时] rate(a,b) <= i < rate(b , c)
再说说是怎么具体实现的:(两个操作)
队首: a < b是队首的两个元素 ,若rate(a , b)<=i , 可直接删除a 。重复此操作后 ,队首元素将是最佳决策点!
队尾的操作不是很好理解:
a<b<c (c为新加入决策点)是队尾的元素 , 有rate(a , b) > i (否则a已经被删除了) ,即在当前a优于b , 并不能保证以后a也优于b,是否就没法操作了呢?
并非如此,若新加入的c使rate(a,b) >= rate(b , c) 破坏了下凸性(即斜率递增,类似凹函数)时可直接删除b , 这是因为b永远不可能成为最佳决策点 : 对于任意的i , rate(a ,b ) <= i < rate(b , c) 恒不成立!。同时此操作删除了上凸点,维护了斜率递增,可谓精妙!
以上只是简单说明,若想严格证明,可以参考数学归纳法。
当然以上的说明故意忽略了rate(a , b)为无穷大的情况,请读者仔细思考。
参考代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=500050;
typedef long long LL;
LL a[maxn] , dp[maxn] ,sum[maxn];
int que[maxn] , head , tail;
int N , K;
LL dy(int x,int y){
return (dp[y] - sum[y] + y*a[y+1]) - (dp[x] - sum[x] + x*a[x+1]) ;
}
LL dx(int x,int y){
return a[y+1] - a[x+1] ;
}
int readint(){
char c=getchar();
int ans = 0;
while(c<'0' || c>'9') c=getchar();
while(c>='0' && c<='9') ans = ans *10 + c-'0' , c=getchar() ;
return ans;
}
int main()
{
int T;
T = readint();
while(T--){
N = readint() , K = readint();
sum[0] = 0 , dp[0] = 0;
for(int i=1; i<=N; i++) a[i]=readint(), sum[i] = sum[i-1] + a[i];
head = tail = 0 , que[0] = 0;
for(int i=1; i<=N; i++){
while(head < tail && dy(que[head] , que[head+1]) <=i*dx(que[head], que[head+1]) )
head++;
int t = que[head];
dp[i] = dp[t] + sum[i] - sum[t] - a[t+1]*(i-t);
if(i+1 >= (K<<1)){
int z = i-K+1;
while(head < tail){
int x = que[tail-1] , y = que[tail];
if(dy(x,y)*dx(y,z) >= dy(y,z)*dx(x,y)) tail--;
else break;
}
que[++tail] = z;
}
}
printf("%I64d\n",dp[N]);
}
return 0;
}