题意
给你
n
n
个整数和
w[i]
w
[
i
]
,然后一个排列
a[p[i]]
a
[
p
[
i
]
]
合法当且仅当
p[k]=a[p[j]],k<j
p
[
k
]
=
a
[
p
[
j
]
]
,
k
<
j
问所有的合法排列中,
w[p[1]]+w[p[2]]+...+w[p[n]]
w
[
p
[
1
]
]
+
w
[
p
[
2
]
]
+
.
.
.
+
w
[
p
[
n
]
]
n≤5×105,0≤a[i]≤n,1≤w[i]≤109 n ≤ 5 × 10 5 , 0 ≤ a [ i ] ≤ n , 1 ≤ w [ i ] ≤ 10 9
分析
首先连边(a[i],i)表示a[i]在i的前面,这样就是一棵树
现在的问题就变成了给这个树编一个dfs序,然后对应dfs序的点的值乘上dfs序,使得总和最大
我们可以先手玩一下,假设现在有两种方案可以选,一个是选x,一个是选ab(当然这个b一定是粘着a的,也就是说,选了a下一个就一定选b)
那么要不就是xab,要不就是abx,而如果选xab的话,贡献当且仅当
x的贡献1<ab贡献2
x
的
贡
献
1
<
a
b
贡
献
2
这样才会使贡献最大
我们把这些平均值从小到大排序,然后按全局最小值向父亲合并就可以了
合并的操作就是在原来的序列上加上这一段,每一个连通块代表的就是一段
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ldb;
const ll N = 1000010;
inline ll read()
{
char ch=getchar(); ll p=0; ll f=1;
while(ch<'0' || ch>'9'){if(ch=='-') f=-1; ch=getchar();}
while(ch >='0' && ch<='9'){p=p*10+ch-'0'; ch=getchar();}
return p*f;
}
ll n , a[N] , w[N];
ll fa[N];
ll Find(ll x){return (fa[x] == x) ? fa[x] : fa[x] = Find(fa[x]);}
struct node
{
ll sum , x , id;
node(){}
node(ll _sum,ll _x,ll _id){sum = _sum; x = _x; id = _id;}
friend bool operator < (const node &x,const node &y)
{
if((ldb) x.sum / (ldb) x.x != (ldb) y.sum / (ldb) y.x) return (ldb) x.sum / (ldb) x.x < (ldb) y.sum / (ldb) y.x;
return x.id < y.id;
}
};
set <node> s;
ll sum[N],x[N],f[N];
int main()
{
n = read();
for(ll i=1;i<=n;i++) a[i] = read();
for(ll i=1;i<=n;i++) w[i] = read();
for(ll i=0;i<=n;i++) fa[i] = i;
for(ll i=1;i<=n;i++)
{
ll xx = Find(a[i]); ll yy = Find(i);
if(xx!=yy) fa[xx] = yy;
else{printf("-1\n"); return 0;}
}
ll ans = 0;
fa[0] = 0; for(ll i=1;i<=n;i++) f[i] = a[i],fa[i] = i;
// s.insert(node(0,1,0)); x[0] = 1;
for(ll i=1;i<=n;i++) s.insert(node(w[i],1,i)),sum[i] = w[i] , x[i] = 1 , ans += w[i];
while(s.size())
{
node xx = *s.begin(); s.erase(s.begin()); ll y = Find(f[xx.id]);
if(y)
{
// printf("%lld\n",*s.find(node(sum[y],x[y],y)).id);
s.erase(s.find(node(sum[y],x[y],y)));
ans += xx.sum * x[y]; sum[y] += xx.sum; x[y] += xx.x; fa[xx.id] = y;
s.insert(node(sum[y] , x[y] , y));
// printf("%lld %lld %d\n",xx.id,y,s.size());
}
else
{
ans += xx.sum * x[y]; sum[y] += xx.sum; x[y] += xx.x; fa[xx.id] = y;
}
}
return printf("%lld\n",ans),0;
}