考虑向上合并 !
考虑一个数, 合并上去是它的概率 ----
它到当前点本身的概率 p * ( 另一棵子树比它小的概率和 * 选大的概率 + 另一棵子树比它大的概率和 * 选小的概率 )
想到了线段树合并, 每个节点表示选当前区间的概率的和, 边往下走边统计和
也就是 , 最后乘到 i 上面去
怎么统计呢 ? 我们记x这颗线段树的贡献为 sx, y这颗的为 sy, 现在合并x,y
如果当前要往左边走,
反之
#include<bits/stdc++.h>
#define N 300050
using namespace std;
int read(){
int cnt = 0, f = 1; char ch = 0;
while(!isdigit(ch)){ ch = getchar(); if(ch == '-') f = -1;}
while(isdigit(ch)) cnt = cnt*10 + (ch-'0'), ch = getchar();
return cnt * f;
}
const int Mod = 998244353, inv = 796898467;
typedef long long ll;
ll add(ll a, ll b){ return (a+b) % Mod;}
ll mul(ll a, ll b){ return (a*b) % Mod;}
int n, m;
int ch[N][2];
ll val[N], b[N], p[N];
int rt[N], tot;
struct Node{ int ls, rs; ll sum, tag;} t[N * 20];
void Pushup(int x){ t[x].sum = add(t[t[x].ls].sum, t[t[x].rs].sum);}
void Insert(int &x, int l, int r, int pos){
if(!x) x = ++tot; t[x].sum = t[x].tag = 1;
if(l == r) return; int mid = (l+r) >> 1;
if(pos <= mid) Insert(t[x].ls, l, mid, pos);
else Insert(t[x].rs, mid+1, r, pos);
}
void Mul(int x, ll v){ t[x].sum = mul(t[x].sum, v); t[x].tag = mul(t[x].tag, v);}
void Pushdown(int x){ if(t[x].tag != 1) Mul(t[x].ls, t[x].tag), Mul(t[x].rs, t[x].tag), t[x].tag = 1; }
int Merge(int x, int y, ll sx, ll sy, ll p){
if(!x){ Mul(y, sx); return y;}
if(!y){ Mul(x, sy); return x;}
Pushdown(x); Pushdown(y);
ll x0 = t[t[x].ls].sum, x1 = t[t[x].rs].sum, y0 = t[t[y].ls].sum, y1 = t[t[y].rs].sum;
t[x].ls = Merge(t[x].ls, t[y].ls, add(sx, mul(x1, add(1, Mod - p))), add(sy, mul(y1, add(1, Mod - p))), p);
t[x].rs = Merge(t[x].rs, t[y].rs, add(sx, mul(x0, p)), add(sy, mul(y0, p)), p);
Pushup(x); return x;
}
int Solve(int u){
if(!ch[u][0]){
Insert(rt[u], 1, m, lower_bound(b+1, b+m+1, val[u]) - b);
return rt[u];
}
else{
int l = Solve(ch[u][0]);
if(!ch[u][1]) return l;
int r = Solve(ch[u][1]);
return Merge(l, r, 0, 0, p[u]);
}
}
ll calc(int x, int l, int r){
if(l == r) return mul(mul(l, b[l]), mul(t[x].sum, t[x].sum));
Pushdown(x); int mid = (l+r) >> 1; ll ans = 0;
ans = add(ans, calc(t[x].ls, l, mid));
ans = add(ans, calc(t[x].rs, mid+1, r));
return ans;
}
int main(){
n = read();
for(int i=1; i<=n; i++){
int x = read(); if(!ch[x][0]) ch[x][0] = i;
else ch[x][1] = i;
}
for(int i=1; i<=n; i++){
if(!ch[i][0]) val[i] = read(), b[++m] = val[i];
else p[i] = read(), p[i] = mul(p[i], inv);
} sort(b+1, b+m+1);
printf("%lld", calc(Solve(1), 1, m));
}