题目链接
https://www.lydsy.com/JudgeOnline/problem.php?id=5461
题解
线段树合并,线段树每个区间
[
l
,
r
]
[l,r]
[l,r]代表取到第
l
l
l小到第
r
r
r小的权值的概率,对于每一个节点,线段树由两个端点合并,容易发现在点
u
u
u,对于第
i
i
i小的权值,假设这个权值是由左儿子贡献而来,取到这个权值的概率是
f
u
,
i
=
f
l
s
,
i
(
∑
j
<
i
p
i
f
r
s
,
j
+
∑
j
>
i
(
1
−
p
i
)
f
r
s
,
j
)
f_{u,i}=f_{ls,i}(\sum_{j<i}p_if_{rs,j}+\sum_{j>i}(1-p_i)f_{rs,j})
fu,i=fls,i(j<i∑pifrs,j+j>i∑(1−pi)frs,j)
右儿子同理。
代码
#include <cstdio>
#include <algorithm>
int read()
{
int x=0,f=1;
char ch=getchar();
while((ch<'0')||(ch>'9'))
{
if(ch=='-')
{
f=-f;
}
ch=getchar();
}
while((ch>='0')&&(ch<='9'))
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
const int maxn=300000;
const int mod=998244353;
const int inv=796898467;
struct node
{
node *son[2];
int sum,tag;
};
node tree[maxn<<5];
int cnt;
int clear(node *x)
{
x->son[0]=x->son[1]=NULL;
x->sum=0;
x->tag=1;
return 0;
}
int puttag(node *x,int v)
{
if(x==NULL)
{
return 0;
}
x->tag=1ll*x->tag*v%mod;
x->sum=1ll*x->sum*v%mod;
return 0;
}
int pushdown(node *x)
{
puttag(x->son[0],x->tag);
puttag(x->son[1],x->tag);
x->tag=1;
return 0;
}
int getsum(node *x)
{
return (x==NULL)?0:x->sum;
}
int updata(node *x)
{
x->sum=getsum(x->son[0])+getsum(x->son[1]);
if(x->sum>=mod)
{
x->sum-=mod;
}
return 0;
}
node *merge(node *x,node *y,int l,int r,int p,int xl,int yl,int xr,int yr)
{
if(x==NULL)
{
puttag(y,(1ll*p*xl+1ll*(mod+1-p)*xr)%mod);
return y;
}
else if(y==NULL)
{
puttag(x,(1ll*p*yl+1ll*(mod+1-p)*yr)%mod);
return x;
}
node *now=&tree[++cnt];
clear(now);
pushdown(x);
pushdown(y);
int xlp=xl+getsum(x->son[0]),ylp=yl+getsum(y->son[0]),xrp=xr+getsum(x->son[1]),yrp=yr+getsum(y->son[1]);
if(xlp>=mod)
{
xlp-=mod;
}
if(ylp>=mod)
{
ylp-=mod;
}
if(xrp>=mod)
{
xrp-=mod;
}
if(yrp>=mod)
{
yrp-=mod;
}
now->son[0]=merge(x->son[0],y->son[0],l,r,p,xl,yl,xrp,yrp);
now->son[1]=merge(x->son[1],y->son[1],l,r,p,xlp,ylp,xr,yr);
updata(now);
return now;
}
node *add(node *x,int l,int r,int pos,int v)
{
if(x==NULL)
{
x=&tree[++cnt];
clear(x);
}
if(l==r)
{
x->sum+=v;
return x;
}
pushdown(x);
int mid=(l+r)>>1;
if(pos<=mid)
{
x->son[0]=add(x->son[0],l,mid,pos,v);
}
else
{
x->son[1]=add(x->son[1],mid+1,r,pos,v);
}
updata(x);
return x;
}
int getsum(node *x,int l,int r,int pos)
{
if(x==NULL)
{
return 0;
}
if(l==r)
{
return getsum(x);
}
pushdown(x);
int mid=(l+r)>>1;
if(pos<=mid)
{
return getsum(x->son[0],l,mid,pos);
}
else
{
return getsum(x->son[1],mid+1,r,pos);
}
}
int pre[maxn+10],now[maxn+10],son[maxn+10],tot,p[maxn+10],n,top;
node *root[maxn+10];
int ins(int a,int b)
{
pre[++tot]=now[a];
now[a]=tot;
son[tot]=b;
return 0;
}
int search(int u)
{
if(!now[u])
{
root[u]=add(root[u],1,top,p[u],1);
return 0;
}
node *ls=NULL,*rs=NULL;
for(int i=now[u]; i; i=pre[i])
{
int v=son[i];
search(v);
if(ls==NULL)
{
ls=root[v];
}
else
{
rs=root[v];
}
}
if(rs==NULL)
{
root[u]=ls;
}
else
{
root[u]=merge(ls,rs,1,top,p[u],0,0,0,0);
}
return 0;
}
struct data
{
int id,val;
data(int _id=0,int _val=0):id(_id),val(_val){}
bool operator <(const data &other) const
{
return val<other.val;
}
};
data d[maxn+10];
int main()
{
n=read();
for(int i=1; i<=n; ++i)
{
int f=read();
if(f)
{
ins(f,i);
}
}
for(int i=1; i<=n; ++i)
{
p[i]=read();
if(!now[i])
{
d[++top]=data(i,p[i]);
}
else
{
p[i]=1ll*p[i]*inv%mod;
}
}
std::sort(d+1,d+top+1);
for(int i=1; i<=top; ++i)
{
p[d[i].id]=i;
}
search(1);
int ans=0;
for(int i=1; i<=top; ++i)
{
int di=getsum(root[1],1,top,i);
ans=(ans+1ll*i*d[i].val%mod*di%mod*di)%mod;
}
printf("%d\n",ans);
return 0;
}