这个题的 n^2 dp 是很显然的 线段树优化dp 也是很显然的
这个题的价值在于增加线段树合并技能熟练度
#include<cmath>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<iomanip>
#include<vector>
#include<string>
#include<bitset>
#include<queue>
#include<map>
#include<set>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned int uint;
inline int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch<='9'&&ch>='0'){x=10*x+ch-'0';ch=getchar();}
return x*f;
}
void print(int x)
{if(x<0)putchar('-'),x=-x;if(x>=10)print(x/10);putchar(x%10+'0');}
const int N=300100,mod=998244353;
inline int qpow(int x,int y)
{
int res(1);
while(y)
{
if(y&1) res=(ll)x*res%mod;
x=(ll)x*x%mod;
y>>=1;
}
return res;
}
int ch[N][2];
int tot;
int V[N],P[N];
struct president_tree{int ls,rs,sum,tag;}tr[N*30];
int root[N],sz;
inline void pushdown(int k)
{
if(tr[k].tag)
{
int tag=tr[k].tag,ls=tr[k].ls,rs=tr[k].rs;
tr[ls].sum=(ll)tr[ls].sum*tag%mod,
tr[rs].sum=(ll)tr[rs].sum*tag%mod,
tr[ls].tag=(ll)tr[ls].tag*tag%mod,
tr[rs].tag=(ll)tr[rs].tag*tag%mod,
tr[k].tag=1;
}
}
void insert(int &k,int l,int r,int x,int val)
{
k=++sz;
tr[k].tag=1;
tr[k].sum+=val;
if(l==r) return ;
int mid=(l+r)>>1;
x<=mid ? insert(tr[k].ls,l,mid,x,val) : insert(tr[k].rs,mid+1,r,x,val);
}
int merger(int x,int y,int sum_x,int sum_y,int u)
{
if(!x)
{
tr[y].sum=(ll)tr[y].sum*sum_x%mod,
tr[y].tag=(ll)tr[y].tag*sum_x%mod;
return y;
}
if(!y)
{
tr[x].sum=(ll)tr[x].sum*sum_y%mod,
tr[x].tag=(ll)tr[x].tag*sum_y%mod;
return x;
}
int val_x[2],val_y[2];
pushdown(x),pushdown(y);
val_x[0]=tr[tr[x].ls].sum,
val_x[1]=tr[tr[x].rs].sum,
val_y[0]=tr[tr[y].ls].sum,
val_y[1]=tr[tr[y].rs].sum;
tr[x].ls=merger(tr[x].ls,tr[y].ls,(sum_x+(ll)val_x[1]*(1+mod-P[u]))%mod,(sum_y+(ll)val_y[1]*(1+mod-P[u]))%mod,u);
tr[x].rs=merger(tr[x].rs,tr[y].rs,(sum_x+(ll)val_x[0]*P[u])%mod,(sum_y+(ll)val_y[0]*P[u])%mod,u);
tr[x].sum=(tr[tr[x].ls].sum+tr[tr[x].rs].sum)%mod;
return x;
}
void dfs(int u)
{
if(!u) return ;
if(!ch[u][0])
{
insert(root[u],1,tot,lower_bound(V+1,V+1+tot,P[u])-V,1);
return ;
}
dfs(ch[u][0]),dfs(ch[u][1]);
if(!ch[u][1]) root[u]=root[ch[u][0]];
else root[u]=merger(root[ch[u][0]],root[ch[u][1]],0,0,u);
}
int ans(0);
void cal(int k,int l,int r)
{
if(l==r)
{
(ans+=(ll)l*V[l]%mod*tr[k].sum%mod*tr[k].sum%mod)%=mod;
return ;
}
pushdown(k);
int mid=(l+r)>>1;
cal(tr[k].ls,l,mid),cal(tr[k].rs,mid+1,r);
}
int main()
{
int n=read();
register int i,x;
for(i=1;i<=n;++i)
x=read(),ch[x][ch[x][0] ? 1 : 0]=i;
int inv_w=qpow(10000,mod-2);
for(i=1;i<=n;++i)
P[i]=read(),
ch[i][0] ? P[i]=(ll)P[i]*inv_w%mod : V[++tot]=P[i];
sort(V+1,V+1+tot);
dfs(1);
cal(root[1],1,tot);
cout<<ans<<endl;
return 0;
}