这里讲一下斜率优化,其实也是给自己复习一下。
我们从一道例题开始:
BZOJ1597: [Usaco2008 Mar]土地购买
Description
农夫John准备扩大他的农场,他正在考虑N (1 <= N <= 50,000) 块长方形的土地. 每块土地的长宽满足(1 <= 宽 <= 1,000,000; 1 <= 长 <= 1,000,000). 每块土地的价格是它的面积,但FJ可以同时购买多快土地. 这些土地的价格是它们最大的长乘以它们最大的宽, 但是土地的长宽不能交换. 如果FJ买一块3x5的地和一块5x3的地,则他需要付5x5=25. FJ希望买下所有的土地,但是他发现分组来买这些土地可以节省经费. 他需要你帮助他找到最小的经费.
Input
第1行: 一个数: N
第2..N+1行: 第i+1行包含两个数,分别为第i块土地的长和宽
Output
- 第一行: 最小的可行费用.
Sample Input
4
100 1
15 15
20 5
1 100
输入解释:
共有4块土地.
Sample Output
500
HINT
FJ分3组买这些土地: 第一组:100x1, 第二组1x100, 第三组20x5 和 15x15 plot. 每组的价格分别为100,100,300, 总共500.
Source
Gold
这题我们先想一想基本思路,我们可以先按长排序(从大到小,若长相等则按宽从大到小),然后我们可以扫一遍宽(因为长已经从大到小了),如果i+1的宽小于i的宽,那么就不用把i放入数组了,因为i可以买下i+1。所以我们将一个问题转换成了将数组分成连续的多段(因为每段的两端点的宽和长即为最长和最宽),然后直接DP就行。
令f[i]为前i个土地全买下的最小价格,x[i]为长,y[i]为宽。
则我们可以得到DP转移方程
f[i]=min{f[j−1]+x[j]∗y[i]}(1<=j<=i)
这里的DP方程可能于代码中的不同(
代码中可能为f[i]=min{f[j]+x[j+1]∗y[i]}(0<=j<=i)
),但是不影响学习斜率优化。
但是我们会发现,直接枚举两层循环是会超时的,于是我们要想个办法,这时我可以用斜率优化。
设 j< k< i
当 k比j更优秀时,我们可以发现
所以我们可以得到,当j,k满足这个样子时,我们可以将j舍去。
接着我们又可以发现
令g(j,k)=
f[k−1]−f[j−1]x[j]−x[k]
(我们上面已经知道,当g(j,k)<=y[i]时,k比j优)
设j< k< i< i’
Case1:当
g(j,k)>=g(k,i)
时
y[i′]
有三种可能
1.
y[i′]>=g(j,k)>=g(k,i)
时,k比j优,i比k优,所以i最优。
2.
g(j,k)>=y[i′]>=g(k,i)
时,j比k优,i比k优,所以i,j都比k优。
3.
g(j,k)>=g(k,i)>=y[i′]
时,k比i优,j比k优,所以j最优。
所以通过分类讨论我们可以得到,无论如何k都可以被舍去。
Case2:当
g(i,j)<g(k,i)
时 我们可以进行同样的分类讨论
但我们发现当
g(j,k)<=y[i′]<g(k,i)
时 k是最优的,所以我们不能将k舍去。
这里的g(j,k)是不是很像求斜率的公式 k=
y−y′x−x′
所以我们可以用一个单调队列来维护这个过程。
具体维护可以参考代码(由于除可能会有精度误差,故代码中装换成乘来比较)。
#include<algorithm>
#include<cstdio>
#include<iostream>
using namespace std;
struct node
{
long long x,y;
}a[50010],b[50010];
int n;
long long q[100000],f[100000];
bool cmp(node a,node b)
{
if (a.x>b.x) return true;
if (a.x==b.x&&a.y>b.y) return true;
return false;
}
bool check(int i,int j,int k)
{
long long x1=a[i].y*(a[k+1].x-a[j+1].x);
long long y1=f[j]-f[k];
if (x1<=y1) return true;else return false;
}
bool check1(int k,int j,int i)
{
long long x1=(f[k]-f[j])*(a[i+1].x-a[j+1].x);
long long y1=(f[j]-f[i])*(a[j+1].x-a[k+1].x);
if (x1>=y1) return true;else return false;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=n;i++) scanf("%d%d",&a[i].x,&a[i].y);
sort(a+1,a+n+1,cmp);
int tot=1;
b[1]=a[1];
for (int i=1;i<=n;i++)
if (a[i].y>b[tot].y) tot++,b[tot]=a[i];
for (int i=1;i<=tot;i++) a[i]=b[i];
int head=1,tail=2;
q[1]=0;f[0]=0;
for (int i=1;i<=tot;i++)
{
while (head<tail-1&&check(i,q[head],q[head+1])) head++;
f[i]=f[q[head]]+a[q[head]+1].x*a[i].y;
while (head<tail-1&&check1(q[tail-2],q[tail-1],i)) tail--;
q[tail]=i;tail++;
}
printf("%lld\n",f[tot]);
return 0;
}