前言
大概是线段树合并类板子题
题意简介
题目大意
给出一棵二叉树,这棵二叉树的每一个叶子节点都有一个权值
w
w
w(每个叶子节点权值互不相同),每一个非叶子节点都有一个概率
p
i
10000
\frac{p_i}{10000}
10000pi取儿子节点的的最大值,而有
1
−
p
i
10000
1-\frac{p_i}{10000}
1−10000pi的概率取儿子节点的最小值,
p
i
p_i
pi为小于
10000
10000
10000的正整数,求
∑
i
=
1
i
∗
v
i
2
∗
s
u
m
i
m
o
d
  
998244353
\sum_{i=1}i*v_i^2*sum_i\mod998244353
i=1∑i∗vi2∗sumimod998244353
v
i
v_i
vi是可能取到的第
i
i
i小的权值,
s
u
m
i
sum_i
sumi是取到第
i
i
i小的权值概率
数据范围
1
≤
n
≤
300000
,
1
≤
w
i
≤
1000000000
,
0
<
p
i
<
10000
1\le n\le300000,1\le w_i\le 1000000000,0<p_i<10000
1≤n≤300000,1≤wi≤1000000000,0<pi<10000
保证所有节点的权值均不同,任意一个节点至多两个子节点
前置知识(线段树合并)
题解
我就不写考场上的心路历程了,游记里都写了
我们发现我们可以用线段树维护每种权值取到的概率
但是我们发现每一个位置的转移并不好做(感觉五五开,只是要
Θ
(
n
l
o
g
2
n
)
\Theta(nlog^2n)
Θ(nlog2n),因为要启发式合并,启发式合并真的比这个难写吗?)
我们考虑更改定义,把
x
x
x节点的定义为取到权值
≥
x
\ge x
≥x的概率
如果我们求出了根节点的线段树,直接遍历一遍整棵树就可以输出答案了
我们考虑怎么求
我这里有个很烦的方法
假如当前我们要求当前节点取到
≥
x
\ge x
≥x的概率
设左子树取到
≥
x
\ge x
≥x的概率为
L
L
L,设右子树取到
≥
x
\ge x
≥x的概率为
R
R
R,当前节点取子树大值的概率为
P
P
P
我们列出式子
A
n
s
=
P
(
L
R
+
L
(
1
−
R
)
+
(
1
−
L
)
R
)
+
(
1
−
P
)
L
R
=
P
(
L
R
+
L
−
L
R
+
R
−
L
R
)
+
L
R
−
P
L
R
=
P
(
L
+
R
−
L
R
)
+
L
R
−
P
L
R
=
P
L
+
P
R
−
2
P
L
R
+
L
R
\begin{aligned} Ans&=P(LR+L(1-R)+(1-L)R)+(1-P)LR\\ &=P(LR+L-LR+R-LR)+LR-PLR\\ &=P(L+R-LR)+LR-PLR\\ &=PL+PR-2PLR+LR \end{aligned}
Ans=P(LR+L(1−R)+(1−L)R)+(1−P)LR=P(LR+L−LR+R−LR)+LR−PLR=P(L+R−LR)+LR−PLR=PL+PR−2PLR+LR
我们发现,如果现在对于一些不同的
L
L
L和一个
R
R
R求结果(如果是不同的
R
R
R和一个
L
L
L本质相同)
我们要把式子化成和
L
L
L有关的
A
n
s
=
L
(
P
+
R
−
2
P
R
)
+
P
R
Ans=L(P+R-2PR)+PR
Ans=L(P+R−2PR)+PR
直接线段树维护区间乘和区间加即可
注意!!!线段树合并的这句话千万不能忘(我因为这个调了很久)
if(x==Null&&y==Null)return Null;
我写的写法真的非常不清真
代码
#include<cstdio>
#include<cctype>
#include<algorithm>
#include<vector>
namespace fast_IO
{
const int IN_LEN=10000000,OUT_LEN=10000000;
char ibuf[IN_LEN],obuf[OUT_LEN],*ih=ibuf+IN_LEN,*oh=obuf,*lastin=ibuf+IN_LEN,*lastout=obuf+OUT_LEN-1;
inline char getchar_(){return (ih==lastin)&&(lastin=(ih=ibuf)+fread(ibuf,1,IN_LEN,stdin),ih==lastin)?EOF:*ih++;}
inline void putchar_(const char x){if(oh==lastout)fwrite(obuf,1,oh-obuf,stdout),oh=obuf;*oh++=x;}
inline void flush(){fwrite(obuf,1,oh-obuf,stdout);}
}
using namespace fast_IO;
#define getchar() getchar_()
#define putchar(x) putchar_((x))
#define rg register
typedef long long LL;
template <typename T> inline T max(const T a,const T b){return a>b?a:b;}
template <typename T> inline T min(const T a,const T b){return a<b?a:b;}
template <typename T> inline void mind(T&a,const T b){a=a<b?a:b;}
template <typename T> inline void maxd(T&a,const T b){a=a>b?a:b;}
template <typename T> inline T abs(const T a){return a>0?a:-a;}
template <typename T> inline void swap(T&a,T&b){T c=a;a=b;b=c;}
template <typename T> inline T gcd(const T a,const T b){if(!b)return a;return gcd(b,a%b);}
template <typename T> inline T lcm(const T a,const T b){return a/gcd(a,b)*b;}
template <typename T> inline T square(const T x){return x*x;};
template <typename T> inline void read(T&x)
{
char cu=getchar();x=0;bool fla=0;
while(!isdigit(cu)){if(cu=='-')fla=1;cu=getchar();}
while(isdigit(cu))x=x*10+cu-'0',cu=getchar();
if(fla)x=-x;
}
template <typename T> inline void printe(const T x)
{
if(x>=10)printe(x/10);
putchar(x%10+'0');
}
template <typename T> inline void print(const T x)
{
if(x<0)putchar('-'),printe(-x);
else printe(x);
}
const int maxn=300001,mod=998244353,INV=796898467;
inline int Md(const int x){return x>=mod?x-mod:x;}
int n,fa[maxn],p[maxn],son[maxn];
std::vector<int>E[maxn];
int lsh[maxn],top;
struct node
{
node*lson,*rson;
int val,mark,add;
}P[maxn*100],*root[maxn],empty,*Null;
int usd;
inline node*newnode()
{
usd++,P[usd].lson=P[usd].rson=Null,P[usd].mark=1;
return &P[usd];
}
int lasx,lasy;
void down(node*x)
{
if(x->mark==1&&x->add==0)return;
if(x->lson!=Null)
{
x->lson->mark=(LL)x->lson->mark*x->mark%mod;
x->lson->val=((LL)x->lson->val*x->mark+x->add)%mod;
x->lson->add=((LL)x->lson->add*x->mark+x->add)%mod;
}
if(x->rson!=Null)
{
x->rson->mark=(LL)x->rson->mark*x->mark%mod;
x->rson->val=((LL)x->rson->val*x->mark+x->add)%mod;
x->rson->add=((LL)x->rson->add*x->mark+x->add)%mod;
}
x->mark=1,x->add=0;
}
int GG;
node*merge(node*x,node*y,const int G)
{
if(x==Null&&y==Null)return Null;
if(x==Null)
{
const int Val=Md(Md(lasx+G)+mod-(LL)2*lasx*G%mod);
const int Add=(LL)lasx*G%mod;
lasy=y->val;
y->mark=(LL)y->mark*Val%mod;
y->add=((LL)y->add*Val+Add)%mod;
y->val=((LL)y->val*Val+Add)%mod;
return y;
}
if(y==Null)
{
const int Val=Md(Md(lasy+G)+mod-(LL)2*lasy*G%mod);
const int Add=(LL)lasy*G%mod;
lasx=x->val;
x->mark=(LL)x->mark*Val%mod;
x->add=((LL)x->add*Val+Add)%mod;
x->val=((LL)x->val*Val+Add)%mod;
return x;
}
down(x),down(y);
x->rson=merge(x->rson,y->rson,G);
x->lson=merge(x->lson,y->lson,G);
if(x->lson!=Null)x->val=x->lson->val;
else x->val=x->rson->val;
return x;
}
void insert(node*rt,const int l,const int r,const int wan)
{
if(l==r){rt->val=1;return;}
const int mid=(l+r)>>1;
if(wan<=mid)rt->lson=newnode(),insert(rt->lson,l,mid,wan);
else rt->rson=newnode(),insert(rt->rson,mid+1,r,wan);
rt->val=1;
}
void dfs(const int u)
{
if(son[u])
{
root[u]=Null;
for(std::vector<int>::iterator Pos=E[u].begin();Pos!=E[u].end();Pos++)
{
dfs(*Pos);GG=u;
node*NXT=root[*Pos];
if(root[u]==Null)root[u]=NXT;
else lasx=lasy=0,root[u]=merge(NXT,root[u],(LL)p[u]*INV%mod);
}
}
else root[u]=newnode(),insert(root[u],1,1000000000,p[u]);
}
int las,tot,ans;
void calc(node*rt,const int l,const int r)
{
if(l==r)
{
const int me=Md(rt->val+mod-las);
las=rt->val;
ans=(ans+(LL)tot*l%mod*me%mod*me)%mod;
tot--;
return;
}
down(rt);
const int mid=(l+r)>>1;
if(rt->rson!=Null)calc(rt->rson,mid+1,r);
if(rt->lson!=Null)calc(rt->lson,l,mid);
}
int main()
{
Null=∅
empty.lson=empty.rson=Null;
read(n);
for(rg int i=1;i<=n;i++)read(fa[i]),son[fa[i]]++,E[fa[i]].push_back(i);
for(rg int i=1;i<=n;i++)read(p[i]),tot+=son[i]==0;
dfs(1);
calc(root[1],1,1000000000);
print(ans);
return flush(),0;
}
总结
经典线段树合并题,做完之后让我对线段树合并的认识大大提高了