题解:
对w求前缀和s,可以得到一个
O
(
n
2
)
O(n^2)
O(n2)的dp:令f[i]为桥到i结束时花的最小代价,则
f
[
i
]
=
m
i
n
{
f
[
j
]
+
(
h
[
i
]
−
h
[
j
]
)
2
+
(
s
[
i
−
1
]
−
s
[
j
]
)
}
f[i]=min\{ f[j]+(h[i]-h[j])^2+(s[i-1]-s[j]) \}
f[i]=min{f[j]+(h[i]−h[j])2+(s[i−1]−s[j])}。
这个时间复杂度显然不够,考虑斜率优化。
经过一通化简,令
g
[
i
]
=
f
[
i
]
+
h
[
i
]
2
−
s
[
i
]
g[i]=f[i]+h[i]^2-s[i]
g[i]=f[i]+h[i]2−s[i],若
j
1
j_1
j1比
j
2
j_2
j2优,得
2
h
[
i
]
(
h
[
j
1
]
−
h
[
j
2
]
)
>
g
[
j
1
]
−
g
[
j
2
]
2h[i](h[j_1]-h[j_2])>g[j_1]-g[j_2]
2h[i](h[j1]−h[j2])>g[j1]−g[j2]。
h没有单调性,考虑cdq分治。
至今还是不熟悉cdq分治优化dp和斜率优化的套路(承认吧我就没做几道这种题)
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define maxn 100005
#define D 20
#define LL long long
#define INF 1e16
using namespace std;
int n,tp,num[D][maxn];
LL f[maxn];
struct node { int i,h; LL s; } a[maxn],tmp[maxn];
bool cmpi(node p,node q) { return p.i<q.i; }
struct Point
{
LL x,y,i;
Point(LL x=0,LL y=0,LL i=0):x(x),y(y),i(i){};
}stk[maxn];
typedef Point Vector;
Vector operator - (Vector a,Vector b) { return Vector(a.x-b.x,a.y-b.y,0); }
LL Cross(Vector a,Vector b) { return a.x*b.y-a.y*b.x; }
double T(Point a,Point b)
{
if(a.x==b.x) return a.y<b.y?INF:-INF;
return 1.0*(a.y-b.y)/(a.x-b.x);
}
void Stk(int l,int r)
{
tp=0;
for(int i=l;i<=r;i++)
{
Point now=Point(a[i].h,f[a[i].i]+1ll*a[i].h*a[i].h-a[i].s,a[i].i);
while(tp>1&&Cross(stk[tp]-stk[tp-1],now-stk[tp])<=0) tp--;
stk[++tp]=now;
}
}
void cdq(int k,int l,int r)
{
if(l==r) return;
int mid=(l+r)>>1;
cdq(k+1,l,mid);
for(int i=l;i<=r;i++) a[i]=tmp[num[k+1][i]];
Stk(l,mid);
int fr=1;
for(int i=mid+1;i<=r;i++)
{
while(fr<tp&&2.0*a[i].h>T(stk[fr],stk[fr+1])) fr++;
int j=stk[fr].i,id=a[i].i;
f[id]=min(f[id],f[j]+1ll*(tmp[j].h-tmp[id].h)*(tmp[j].h-tmp[id].h)+tmp[id-1].s-tmp[j].s);
}
for(int i=mid+1;i<=r;i++) a[i]=tmp[i];
cdq(k+1,mid+1,r);
}
void Merge(int k,int l,int r)
{
if(l==r) { num[k][l]=l; return; }
int mid=(l+r)>>1;
Merge(k+1,l,mid); Merge(k+1,mid+1,r);
for(int i=l,j=mid+1,cnt=l;cnt<=r;cnt++)
if(i<=mid&&(j>r||tmp[num[k+1][i]].h<=tmp[num[k+1][j]].h)) num[k][cnt]=num[k+1][i++];
else num[k][cnt]=num[k+1][j++];
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) { scanf("%d",&a[i].h); a[i].i=i; }
for(int i=1;i<=n;i++) { scanf("%lld",&a[i].s); a[i].s+=a[i-1].s; }
memset(f,0x3f,sizeof(f)); f[1]=0;
memcpy(tmp,a,sizeof(a));
Merge(1,1,n); cdq(1,1,n);
printf("%lld\n",f[n]);
}