我们先说明几个结论:
- 一个人的能力变化只会影响自己和上司
- 当变化的下属的能力小于等于上司的中位数 那么这个中位数会向后移动一位 否则不变 (这个画图就能体会到)
所以我们需要用主席树去查询 第mid个能力值 和 第mid+1个能力值 用树状数组去维护差值 枚举每一个点 并记录可以获得的最大差值 最后答案就是初始的能力值之和+最大差值 具体可以看代码解释
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = 1e5+10,M = 1e5;
ll c[N],sum[N*30],ans,mx;
int L[N*30],R[N*30],tot,rt[N],a[N],val[N],mid[N];
int h[N],to[N<<1],nex[N<<1],cur;
//a是初始能力值 val[i]是变化i结点后结点i得到的差值 mid是中值 ans初始能力值之和 mx是最大差值
int siz[N],dfn[N],cnt,n;
void init(){//清空
ans=mx=cur=tot=cnt=0;
memset(h,0,sizeof(h));
}
void add_edge(int x,int y){
to[++cur]=y;nex[cur]=h[x];h[x]=cur;
}
void add(int x,ll val){//树状数组
while(x<=M){
c[x]+=val;
x+=x&-x;
}
}
ll query(int x){
ll ret = 0;
while(x){
ret+=c[x];
x-=x&-x;
}return ret;
}
void update(int &rt,int lasrt,int l,int r,int pos){//主席树
rt=++tot;sum[rt]=sum[lasrt]+1;
if(l==r) return;
L[rt]=L[lasrt],R[rt]=R[lasrt];
int mid = l+r>>1;
if(pos<=mid) update(L[rt],L[lasrt],l,mid,pos);
else {update(R[rt],R[lasrt],mid+1,r,pos);}
}
int Query(int ql,int qr,int l,int r,int k){
if(l==r) return l;
int o = sum[L[qr]]-sum[L[ql]],mid = l+r>>1;
if(k<=o) return Query(L[ql],L[qr],l,mid,k);
else {return Query(R[ql],R[qr],mid+1,r,k-o);}
}
void dfs1(int u){
siz[u]=1,dfn[u]=++cnt;//计算大小和dfs序
update(rt[dfn[u]],rt[dfn[u]-1],1,M,a[u]);
for(int i = h[u]; i; i = nex[i]) dfs1(to[i]),siz[u]+=siz[to[i]];
if(!h[u]){//如果这是一个叶子结点
mid[u]=a[u],val[u]=M-mid[u];//中值就是自己 差值就是1e5-自己
ans+=a[u];
}else{
int k = siz[u]+1>>1;
mid[u]=Query(rt[dfn[u]-1],rt[cnt],1,M,k);//找到中值
val[u]=Query(rt[dfn[u]-1],rt[cnt],1,M,k+1)-mid[u];//差值就是 第mid+1个位置的值-中值
ans+=mid[u];
}
}
void dfs2(int u){
add(mid[u],val[u]);//遍历到一个点时我们把差值放入树状数组的mid[u]位置
mx=max(mx,query(M)-query(a[u]-1));//如果当前点u是产生变化的点
//那么差值就是他的所有上司里面mid[]值 大于等于a[u]的人的val[]值(差值)之和
for(int i = h[u]; i; i = nex[i]) dfs2(to[i]);
add(mid[u],-val[u]);//当我搜完所有儿子就可以撤掉我的差值了
}
int main(){
while(~scanf("%d",&n)){
init();
for(int i = 1; i <= n; i++) scanf("%d",&a[i]);
for(int i = 2; i <= n; i++){
int x;
scanf("%d",&x);
add_edge(x,i);
}
dfs1(1);dfs2(1);
printf("%lld\n",ans+mx);
}
return 0;
}