这个条件给的有点诡异:对于任意的 a p j = p k a_{p_j}=p_k apj=pk,都有 k < j k<j k<j。
那么对于某个 a x = y a_x=y ax=y,意思就是 y y y 在 p p p 中的位置小于 x x x 在 p p p 中的位置。
那么如果我们连边 ( a x , x ) (a_x,x) (ax,x),就是要求图中没有环,是一棵树,而且父亲在 p p p 中的位置要小于儿子在 p p p 中的位置。
再看一下要求:按 p p p 的顺序把 i × w p i i\times w_{p_i} i×wpi 加起来得到总权值,找到一种 p p p 的顺序使得总权值最大。
是个树上全序问题(可参考 AGC023F),现在有若干个连通块,我们要给这些连通块一种顺序使得总权值最大。
仍然考虑最优顺序下,相邻的两个连通块 a a a 和 b b b,它们的大小分别为 s z a , s z b sz_a,sz_b sza,szb,它们各自的 w w w 的和分别为 W a , W b W_a,W_b Wa,Wb,它们各自的总权值分别为 V a , V b V_a,V_b Va,Vb。设 a a a 之前的总权值为 V V V, a a a 之前的点的总数为 S S S。
交换前: a n s 1 = V + S ⋅ W a + V a + ( S + s z a ) ⋅ W b + V b ans_1=V+S\cdot W_a+V_a+(S+sz_a)\cdot W_b+V_b ans1=V+S⋅Wa+Va+(S+sza)⋅Wb+Vb。
交换后: a n s 2 = V + S ⋅ W b + V b + ( S + s z b ) ⋅ W a + V a ans_2=V+S\cdot W_{b}+V_b+(S+sz_b)\cdot W_a+V_a ans2=V+S⋅Wb+Vb+(S+szb)⋅Wa+Va。
强制令 a n s 1 ≥ a n s 2 ans_1\geq ans_2 ans1≥ans2,得到 W a s z a ≤ W b s z b \dfrac{W_a}{sz_a}\leq \dfrac{W_b}{sz_b} szaWa≤szbWb,于是我们直接按 W a s z a \dfrac{W_a}{sz_a} szaWa 从小到大排序即可。
#include<bits/stdc++.h>
#define N 500010
#define ll long long
using namespace std;
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=(x<<1)+(x<<3)+(ch^'0');
ch=getchar();
}
return x*f;
}
int n,fa[N],rt[N];
int sz[N];
ll w[N],v[N];
bool vis[N];
bool del[N];
vector<int>e[N];
void dfs(int u)
{
if(vis[u])
{
puts("-1");
exit(0);
}
vis[u]=1;
for(int v:e[u]) dfs(v);
}
int find(int x)
{
return x==rt[x]?x:(rt[x]=find(rt[x]));
}
struct cmp
{
inline bool operator()(int a,int b) const//这里加inline会大大地优化时间,不然过不去
{
if(w[a]*sz[b]==w[b]*sz[a]) return a<b;
return w[a]*sz[b]<w[b]*sz[a];
}
};
set<int,cmp>s;
int main()
{
n=read();
for(int i=1;i<=n;i++)
{
rt[i]=i;
fa[i]=read();
if(fa[i]) e[fa[i]].push_back(i);
}
for(int i=1;i<=n;i++)
if(!fa[i]) dfs(i);
for(int i=1;i<=n;i++)
{
if(!vis[i])
{
puts("-1");
return 0;
}
}
for(int i=1;i<=n;i++)
v[i]=w[i]=read(),sz[i]=1;
for(int i=1;i<=n;i++)
s.insert(i);
del[0]=1;
ll V=0,S=0;
while(!s.empty())
{
int u=(*s.begin());
s.erase(s.begin());
int f=find(fa[u]);
if(del[f])
{
del[u]=1;
V+=S*w[u]+v[u];
S+=sz[u];
continue;
}
s.erase(f);
rt[u]=f;
v[f]+=sz[f]*w[u]+v[u];
w[f]+=w[u];
sz[f]+=sz[u];
s.insert(f);
}
printf("%lld\n",V);
return 0;
}