题目:点击打开链接
关键:首先是dpi,j表示第i个点选第j小的概率。转移用前缀和优化。
然后发现转移时只涉及到子树内的值,于是想到线段树合并。
合并x,y时,若都有则递归,否则直接乘相应系数。
递归时直接算前、后缀和,因为当x或y为空时不需计算
合并的时候相当于对值域分治,所以把贡献分别乘一下,再加起来就好
O(nlogn)考虑每个点只会被合并logn次,因为每次sz至少翻倍
注意:min,max不要写反
用宏定义写线段树,简化代码,少开变量,加速
#include<bits/stdc++.h>
using namespace std;
#define maxn 300020
#define ls(x) sgt[x].ls //用宏定义减少码量,不用开新变量。简单清晰!
#define rs(x) sgt[x].rs
typedef long long ll;
const ll mod = 998244353;
struct node{
int ls,rs;
ll sum,mul;
}sgt[maxn * 20];
struct node2{
int d,id;
bool operator < (node2 a)const{
return d < a.d;
}
}dt[maxn];
int tot,rt[maxn];
int fa[maxn],son[maxn][2],n,val[maxn],m,id[maxn];
ll p[maxn],ans;
inline ll power(ll x,int y){
ll res = 1;
while ( y ){
if ( y & 1 ) res = res * x % mod;
y >>= 1;
x = x * x % mod;
}
return res;
}
inline void mul(int x,ll d){
if ( !x ) return;
sgt[x].sum = sgt[x].sum * d % mod;
sgt[x].mul = sgt[x].mul * d % mod;
}
inline void pushdown(int x){
if ( sgt[x].mul != 1 ){
mul(ls(x),sgt[x].mul);
mul(rs(x),sgt[x].mul);
sgt[x].mul = 1;
}
}
inline void update(int x){
sgt[x].sum = (sgt[ls(x)].sum + sgt[rs(x)].sum) % mod;
}
void insert(int &x,int l,int r,int id,int d){
if ( !x ) x = ++tot , sgt[x].mul = 1;
if ( l == r ){
sgt[x].sum = d;
return;
}
pushdown(x);
int mid = (l + r) >> 1;
if ( id <= mid ) insert(sgt[x].ls,l,mid,id,d);
else insert(sgt[x].rs,mid + 1,r,id,d);
update(x);
}
int Merge(int x,int y,ll sumx,ll sumy,ll pmin,ll pmax){
if ( !x ){ mul(y,sumx); return y; }
if ( !y ){ mul(x,sumy); return x; }
pushdown(x) , pushdown(y);
ll x0 = sgt[ls(x)].sum , x1 = sgt[rs(x)].sum , y0 = sgt[ls(y)].sum , y1 = sgt[rs(y)].sum;
//++tot , sgt[tot].mul = 1; 不用新建节点,默认合并到x上即可
ls(x) = Merge(ls(x),ls(y),(sumx + x1 * pmin) % mod,(sumy + y1 * pmin) % mod,pmin,pmax);
rs(x) = Merge(rs(x),rs(y),(sumx + x0 * pmax) % mod,(sumy + y0 * pmax) % mod,pmin,pmax);
update(x);
return x;
}
void calc(int x,int l,int r){
if ( l == r ){
ans = (ans + sgt[x].sum * sgt[x].sum % mod * dt[l].d % mod * l) % mod;
return;
}
pushdown(x);
int mid = (l + r) >> 1;
calc(ls(x),l,mid);
calc(rs(x),mid + 1,r);
}
void dfs(int x){
if ( !son[x][0] ) insert(rt[x],1,m,id[x],1);
else if ( son[x][1] ){
dfs(son[x][0]) , dfs(son[x][1]);
rt[x] = Merge(rt[son[x][0]],rt[son[x][1]],0,0,(1 - p[x] + mod),p[x]);
}
else{
dfs(son[x][0]);
rt[x] = rt[son[x][0]];
}
}
int main(){
freopen("input.txt","r",stdin);
scanf("%d",&n);
for (int i = 1 ; i <= n ; i++){
scanf("%d",&fa[i]);
if ( fa[i] ){
if ( son[fa[i]][0] ) son[fa[i]][1] = i;
else son[fa[i]][0] = i;
}
}
for (int i = 1 ; i <= n ; i++){
scanf("%d",&val[i]);
if ( son[i][0] ) p[i] = power(10000,mod - 2) * val[i] % mod;
else dt[++m] = (node2){val[i],i};
}
sort(dt + 1,dt + m + 1);
for (int i = 1 ; i <= m ; i++) id[dt[i].id] = i;
dfs(1);
calc(rt[1],1,m);
cout<<ans<<endl;
return 0;
}