题意
小凸和小方相约玩密室逃脱,这个密室是一棵有n个节点的完全二叉树,每个节点有一个灯泡。点亮所有灯泡即可逃出密室。每个灯泡有个权值Ai,每条边也有个权值bi。点亮第1个灯泡不需要花费,之后每点亮1个新的灯泡V的花费,等于上一个被点亮的灯泡U到这个点V的距离Du,v,乘以这个点的权值Av。在点灯的过程中,要保证任意时刻所有被点亮的灯泡必须连通,在点亮一个灯泡后必须先点亮其子树所有灯泡才能点亮其他灯泡。请告诉他们,逃出密室的最少花费是多少。
1≤N≤2∗105
1
≤
N
≤
2
∗
10
5
分析
一开始默认了出发点必然是节点1,以为这是一道sb题,打完才发现原来自己看错题了。
假如我们现在已经确定了起点s,那么走过的路径必然是先把s的子树走完,然后走到s的父亲,再走到s的兄弟,如此类推。这样我们要求的实际上就是从某个点出发,遍历完其子树后回到其某个祖先的最小代价。
设g[i,j]表示从i开始,遍历完i的子树后再走到深度为j的祖先时的最小代价。
如果i是叶节点,那么g[i,j]就等于i到深度为j的祖先的路径长度*祖先的权值。如果i只有一个儿子,那就是先走到儿子再从儿子走到祖先。否则的话就枚举先走左儿子还是先走右儿子。
但我们发现先走左儿子再从左儿子走到右儿子的代价并不好求,于是我们要多设一个f[i,j]表示从i开始,遍历完i的子树后,走到其深度为j的祖先的兄弟的最小代价。转移类似。
求完g之后,我们可以枚举起点,然后按上述方式模拟一遍即可。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
const int N=200005;
const LL inf=(LL)1e16;
int n,w[N],dep[N];
LL dis[N],f[N][20],g[N][20];
void pre(int x,int fa)
{
dep[x]=dep[fa]+1;
if (x*2<=n) dis[x*2]+=dis[x],pre(x*2,x);
if (x*2+1<=n) dis[x*2+1]+=dis[x],pre(x*2+1,x);
}
LL get_dis(int x,int y)
{
LL ans=dis[x]+dis[y];
while (x!=y)
{
if (x<y) std::swap(x,y);
x/=2;
}
return ans-dis[x]*2;
}
void dp()
{
for (int i=n;i>=1;i--)
for (int j=dep[i]-1,x=i/2,ls=i;j>=1;j--,ls=x,x/=2)
{
int y=(x*2==ls)?x*2+1:x*2;
if (y>n) {f[i][j]=inf;continue;}
if (i*2>n) f[i][j]=(LL)get_dis(i,y)*w[y];
else if (i*2+1>n) f[i][j]=(LL)(dis[i*2]-dis[i])*w[i*2]+f[i*2][j];
else
{
f[i][j]=(LL)(dis[i*2]-dis[i])*w[i*2]+f[i*2][dep[i]]+f[i*2+1][j];
f[i][j]=std::min(f[i][j],(LL)(dis[i*2+1]-dis[i])*w[i*2+1]+f[i*2+1][dep[i]]+f[i*2][j]);
}
}
for (int i=n;i>=1;i--)
for (int j=dep[i]-1,x=i/2;j>=0;j--,x/=2)
{
if (i*2>n) g[i][j]=(LL)get_dis(i,x)*w[x];
else if (i*2+1>n) g[i][j]=(LL)(dis[i*2]-dis[i])*w[i*2]+g[i*2][j];
else
{
g[i][j]=(LL)(dis[i*2]-dis[i])*w[i*2]+f[i*2][dep[i]]+g[i*2+1][j];
g[i][j]=std::min(g[i][j],(LL)(dis[i*2+1]-dis[i])*w[i*2+1]+f[i*2+1][dep[i]]+g[i*2][j]);
}
}
}
LL solve(int s)
{
LL ans=g[s][dep[s]-1];int ls=s;s/=2;
while (s)
{
if ((ls^1)<=n) ans+=(LL)(dis[ls^1]-dis[s])*w[ls^1]+g[ls^1][dep[s]-1];
else ans+=(LL)(dis[s]-dis[s/2])*w[s/2];
ls=s;s/=2;
}
return ans;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=n;i++) scanf("%d",&w[i]);
for (int i=2;i<=n;i++) scanf("%lld",&dis[i]);
pre(1,0);
dp();
LL ans=inf;
for (int i=1;i<=n;i++)
ans=std::min(ans,solve(i));
printf("%lld",ans);
return 0;
}