[loj2537] Minimax
用线段树合并优化dp做到nlogn
- 代码
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+5;
const int mod=998244353;
typedef long long ll;
inline int add(int a,int b){a+=b;return a>=mod?a-mod:a;}
inline int sub(int a,int b){a-=b;return a<0?a+mod:a;}
inline int mul(int a,int b){return (ll)a*b%mod;}
inline int qpow(int a,int b){int ret=1;while(b){if(b&1)ret=mul(ret,a);b>>=1,a=mul(a,a);}return ret;}
int p[N],fa[N],val[N],ch[N][2],id[N];
bool lve[N];
int tot=0;
typedef pair<int,int> pii;
using namespace std;
pii V[N];
int n;
/*-------------------------------------------*/
int ans=0;
struct segT{
int root[N],lch[N*30],rch[N*30],t[N*30],tag[N*30],cnt;
void init(int id){lch[id]=rch[id]=t[id]=0,tag[id]=1;}
inline void ins(int &x,int l,int r,int pos,int val){
if(!x)x=init(++cnt);int mid=(l+r)>>1;
t[x]=add(t[x],val);
if(l==r) return;
if(pos>mid)ins(rch[x],mid+1,r,pos,val);
else ins(lch[x],l,mid,pos,val);
}
inline void pushdown(int x){
if(!x)return;
if(tag[x]==1)return ;
t[x]=mul(t[x],tag[x]);
if(lch[x])tag[lch[x]]=mul(tag[lch[x]],tag[x]);
if(rch[x])tag[rch[x]]=mul(tag[rch[x]],tag[x]);
tag[x]=1;
}
inline void pushup(int x){
pushdown(x);
if(!lch[x]&&!rch[x])return;
t[x]=0;
if(lch[x])pushdown(lch[x]),t[x]=add(t[x],t[lch[x]]);
if(rch[x])pushdown(rch[x]),t[x]=add(t[x],t[rch[x]]);
}
inline int merge(int x,int y,int ch_max,int ch_min,int sum_x,int sum_y){
if(!x&&!y)return 0;
if(!x) {tag[y]=mul(tag[y],sum_y);return y;}
if(!y) {tag[x]=mul(tag[x],sum_x);return x;}
pushdown(x),pushdown(y);
int y1=mul(t[rch[y]],tag[rch[y]]),y0=mul(t[lch[y]],tag[lch[y]]),x1=mul(t[rch[x]],tag[rch[x]]),x0=mul(t[lch[x]],tag[lch[x]]);
lch[x]=merge(lch[x],lch[y],ch_max,ch_min,add(sum_x,mul(ch_min,y1)),add(sum_y,mul(ch_min,x1)));
rch[x]=merge(rch[x],rch[y],ch_max,ch_min,add(sum_x,mul(ch_max,y0)),add(sum_y,mul(ch_max,x0)));
pushup(x);
return x;
}
inline void print(int x,int l,int r){
if(!x)return ;
pushdown(x);
int mid=(l+r)>>1;
if(l==r){
ans=add(ans,mul(mul(l,V[l].first),mul(t[x],t[x])));
return ;
}
print(lch[x],l,mid);print(rch[x],mid+1,r);
}
}T;
void dfs(int x){
if(!lve[x]) return;
if(ch[x][0])dfs(ch[x][0]); if(ch[x][1])dfs(ch[x][1]);
int p1=mul(val[x],qpow(10000,mod-2));
int p2=sub(1,p1);
if(!ch[x][1])T.root[x]=T.root[ch[x][0]];
else T.root[x]=T.merge(T.root[ch[x][0]],T.root[ch[x][1]],p1,p2,0,0);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&fa[i]);
if(fa[i]){
if(ch[fa[i]][0]==0) ch[fa[i]][0]=i;
else ch[fa[i]][1]=i;
lve[fa[i]]=true;
}
}
for(int i=1;i<=n;i++){
scanf("%d",&val[i]);
if(!lve[i])
V[++tot]=pii(val[i],i);
}
sort(V+1,V+tot+1);
for(int i=1;i<=tot;i++) val[V[i].second]=i;
for(int i=1;i<=n;i++)if(!lve[i])
T.ins(T.root[i],1,tot,val[i],1);
dfs(1);
T.print(T.root[1],1,tot);
printf("%d\n",ans);
}