Description
编号为1~n的n个城市,每个城市有两个权值Ai和Bi。
对于两个城市i和j,i可到j当且仅当j>i,而费用为(j-i)*Ai+Bj。
求从城市1到城市n的最小费用。
Input
第一行一个正整数n。
第二行n个正整数,第i个表示Ai。
第三行n个正整数,第i个表示Bi。
Output
一个数,表示最小的费用。
Sample Input
4
2 9 5 4
9 1 2 2
Sample Output
8
Data Constraint
对于20%的数据,1<=n<=100;
对于50%的数据,1<=n<=3000;
对于100%的数据,1<=n<=100000,1<=Ai,Bi<=10^9。
Solution
可以轻易的发现DP方程(由后往前推的)
f[i]=min(f[j]+(j-i)*a[i]+b[j]) (j>i)
但是,若每个i都枚举之前的所有j,时间复杂度为O(n^2),最大的数据会TLE。
所以,要使用斜率优化。
考虑有i<j<k,若j优于k,则有
f[j]+(j-i)*a[i]+b[j]<f[k]+(k-i)*a[i]+b[k]
设g[i]=f[i]+b[i]
则有
g[j]+j*a[i]<g[k]+k*a[i]
(j-k)*a[i]<g[k]-g[j]
(k-j)(-a[i])<g[k]-g[j] (因为k>j,所以k-j>0)
-a[i]<(g[k]-g[j])/(k-j)
(g[k]-g[j])/(k-j)即点(k,g[k])与点(j,g[j])连线的斜率
当满足>-a[i]时,j更优。
于是,我们维护一个下凸壳(上面边的斜率递增),每一次二分找出最优的j,计算i的值
然后把i放入下凸壳,删除队尾至满足凸壳性质。
需要注意斜率是小数而带来的精度问题。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
#define N 100100
#define INF 4611686018427387904
#define LL long long
LL a[N],b[N],f[N];
LL q[N];
int n,tail;
void init()
{
scanf("%d",&n);
for (int i=1;i<=n;++i)
scanf("%lld",&a[i]);
for (int i=1;i<=n;++i)
scanf("%lld",&b[i]);
}
double js(LL x,LL y)
{
double fm,fz;
fm=f[x]+b[x]-f[y]-b[y];
fz=x-y;
double re=fm/fz;
//cout<<(f[x]+b[x]-f[y]-b[y])<<' '<<(x-y)<<' '<<re<<endl;
return re;
}
void push(LL x)
{
while (tail>0 && js(x,q[tail])>js(q[tail],q[tail-1]))
tail--;
tail++;
q[tail]=x;
}
int find(LL x,int l,int r)
{
int ans,mid;
if (l>r) return 0;
while (l<=r)
{
mid=(l+r)/2;
if (js(q[mid],q[mid-1])-x>=0.000 && js(q[mid],q[mid+1])-x<=0.000) return mid;
if (js(q[mid],q[mid-1])-x>=0.000) l=mid+1;
else r=mid-1;
}
return 0;
}
int main()
{
//freopen("1749.in","r",stdin);
//freopen("1749.out","w",stdout);
init();
memset(q,0,sizeof(q));
tail=0;
for (int i=1;i<=n;i++)
f[n]=INF/3;
f[n]=0;
push(n);
for (int i=n-1;i>=1;i--)
{
//if (i==1) cout<<js(q[1],q[2])<<endl;
int j;
j=find(-a[i],2,tail-1);
if (j!=0) j=q[j];
//printf("%d ",j);
if (j==0)
{
//cout<<js(q[1],q[2])<<' '<<js(q[tail],q[tail-1])<<' '<<-a[i]<<endl;
if (js(q[1],q[2])+a[i]<=0.000) j=q[1];
if (js(q[tail],q[tail-1])+a[i]>=0.000) j=q[tail];
}
if (tail==1) j=q[1];
//printf("%d\n",j);
f[i]=f[j]+(j-i)*a[i]+b[j];
push(i);
/*for (int k=1;k<=tail;k++)
{
printf("%lld ",f[q[k]]+b[q[k]]);
printf("%lld\n",q[k]);
}
printf("\n");*/
}
printf("%lld\n",f[1]);
return 0;
}