一、题目
二、解法
考虑用
lct
\text{ lct }
lct 维护这个期望,我们先把柿子展开(分子考虑每个点的贡献,分母用等差数列求和):
e
x
p
=
a
1
×
1
×
n
+
a
2
×
2
×
(
n
−
1
)
+
.
.
.
.
+
a
n
×
n
×
1
n
×
(
n
+
1
)
2
exp=\frac{a_1\times1\times n+a_2\times2\times(n-1)+....+a_n\times n\times 1}{\frac{n\times(n+1)}{2}}
exp=2n×(n+1)a1×1×n+a2×2×(n−1)+....+an×n×1我们维护两个值,
l
s
=
a
1
+
2
a
2
+
.
.
.
+
n
a
n
,
r
s
=
n
a
1
+
(
n
−
1
)
a
2
+
.
.
.
+
a
n
ls=a_1+2a_2+...+na_n,rs=na_1+(n-1)a_2+...+a_n
ls=a1+2a2+...+nan,rs=na1+(n−1)a2+...+an,那么根节点期望也可以维护:
e
x
p
x
=
e
x
p
l
+
e
x
p
r
+
v
a
l
x
×
(
s
i
z
l
+
1
)
×
(
s
i
z
r
+
1
)
+
l
s
l
×
(
s
i
z
r
+
1
)
+
r
s
r
×
(
s
i
z
l
+
1
)
exp_x=exp_l+exp_r+val_x\times(siz_l+1)\times(siz_r+1)+ls_l\times(siz_r+1)+rs_r\times(siz_l+1)
expx=expl+expr+valx×(sizl+1)×(sizr+1)+lsl×(sizr+1)+rsr×(sizl+1)
l
s
,
r
s
ls,rs
ls,rs也比较容易维护:
l
s
=
l
s
l
+
l
s
r
+
s
u
m
r
×
(
s
i
z
l
+
1
)
+
v
a
l
x
×
(
s
i
z
l
+
1
)
ls=ls_l+ls_r+sum_r\times(siz_l+1)+val_x\times(siz_l+1)
ls=lsl+lsr+sumr×(sizl+1)+valx×(sizl+1)
r
s
=
r
s
l
+
r
s
r
+
s
u
m
l
×
(
s
i
z
r
+
1
)
+
v
a
l
x
×
(
s
i
z
r
+
1
)
rs=rs_l+rs_r+sum_l\times(siz_r+1)+val_x\times(siz_r+1)
rs=rsl+rsr+suml×(sizr+1)+valx×(sizr+1)那么我们就完成了一半的工作(
p
u
s
h
_
u
p
push\_up
push_up)
考虑修改,
v
a
l
,
s
u
m
,
l
a
z
y
val,sum,lazy
val,sum,lazy都很好维护,
l
s
ls
ls的维护也不难,用等差数列(
r
s
rs
rs也同理):
l
s
=
l
s
+
p
×
s
i
z
x
(
s
i
z
x
+
1
)
2
ls=ls+p\times\frac{siz_x(siz_x+1)}{2}
ls=ls+p×2sizx(sizx+1)
e
x
p
exp
exp就有点**了,我们设
n
=
s
i
z
x
n=siz_x
n=sizx,一开始有这个柿子:
e
x
p
=
e
x
p
+
p
×
(
1
×
n
+
2
×
(
n
−
1
)
+
.
.
.
+
n
×
1
)
exp=exp+p\times(1\times n+2\times (n-1)+...+n\times1)
exp=exp+p×(1×n+2×(n−1)+...+n×1)问题变成了后面柿子的化简,网上大多是推导,我讲一种无脑方法,就是用拉格朗日插值(我的博客,看公式就行了),先盲猜它通项是一个三次多项式,我们选取四个点插它:
(
0
,
0
)
,
(
1
,
1
)
,
(
2
,
4
)
,
(
3
,
10
)
(0,0),(1,1),(2,4),(3,10)
(0,0),(1,1),(2,4),(3,10),然后带入公式里化简,过程略,给出结果(因式分解后):
n
(
n
+
1
)
(
n
+
2
)
6
\frac{n(n+1)(n+2)}{6}
6n(n+1)(n+2)然后
e
x
p
exp
exp也可以维护了,
lct
\text{lct}
lct 打标记即可解决这个问题,总结一些坑点:
- f l i p flip flip时,不仅左右儿子要交换, l s , r s ls,rs ls,rs也要交换。
- 操作三跳过是要把第三个数个读进来(虽然用不到)。
#include <cstdio>
#include <iostream>
#define lc ch[x][0]
#define rc ch[x][1]
#define int long long
using namespace std;
const int M = 50005;
int read()
{
int x=0,flag=1;char c;
while((c=getchar())<'0' || c>'9') if(c=='-') flag=-1;
while(c>='0' && c<='9') x=(x<<3)+(x<<1)+(c^48),c=getchar();
return x*flag;
}
int n,m,par[M],siz[M],ch[M][2],val[M],exp[M];
int fl[M],st[M],ls[M],rs[M],la[M],sum[M];
int nrt(int x)
{
return ch[par[x]][0]==x || ch[par[x]][1]==x;
}
int chk(int x)
{
return ch[par[x]][1]==x;
}
void push_up(int x)
{
if(!x) return ;
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;
sum[x]=sum[ch[x][0]]+sum[ch[x][1]]+val[x];
ls[x]=ls[lc]+ls[rc]+sum[rc]*(siz[lc]+1)+val[x]*(siz[lc]+1);
rs[x]=rs[lc]+rs[rc]+sum[lc]*(siz[rc]+1)+val[x]*(siz[rc]+1);
exp[x]=exp[lc]+exp[rc]+val[x]*(siz[lc]+1)*(siz[rc]+1)+ls[lc]*(siz[rc]+1)+rs[rc]*(siz[lc]+1);
}
void flip(int x)
{
if(!x) return ;
swap(ch[x][0],ch[x][1]);
swap(ls[x],rs[x]);
fl[x]^=1;
}
void add(int x,int p)
{
if(!x) return ;
val[x]+=p;la[x]+=p;
sum[x]+=p*siz[x];
ls[x]+=p*siz[x]*(siz[x]+1)/2;
rs[x]+=p*siz[x]*(siz[x]+1)/2;
exp[x]+=p*siz[x]*(siz[x]+1)*(siz[x]+2)/6;
}
void push_down(int x)
{
if(!x) return ;
if(fl[x])
{
flip(ch[x][0]);flip(ch[x][1]);
fl[x]=0;
}
if(la[x])
{
add(ch[x][0],la[x]);
add(ch[x][1],la[x]);
la[x]=0;
}
}
void rotate(int x)
{
int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1];
ch[y][k]=w;par[w]=y;
if(nrt(y)) ch[z][chk(y)]=x;par[x]=z;
ch[x][k^1]=y;par[y]=x;
push_up(y);push_up(x);
}
void splay(int x)
{
int y=x,z=0;
st[++z]=y;
while(nrt(y)) st[++z]=y=par[y];
while(z) push_down(st[z--]);
while(nrt(x))
{
int y=par[x],z=par[y];
if(nrt(y))
{
if(chk(y)==chk(x)) rotate(y);
else rotate(x);
}
rotate(x);
}
}
void access(int x)
{
for(int y=0;x;x=par[y=x])
splay(x),ch[x][1]=y,push_up(x);
}
void makeroot(int x)
{
access(x);splay(x);
flip(x);
}
int findroot(int x)
{
access(x);splay(x);
while(ch[x][0]) push_down(x),x=ch[x][0];
splay(x);
return x;
}
void split(int x,int y)
{
makeroot(x);
access(y);splay(y);
}
void link(int x,int y)
{
makeroot(x);
if(findroot(y)!=x) par[x]=y;
}
void cut(int x,int y)
{
makeroot(x);
if(findroot(y)==x && par[y]==x && !ch[y][0])
{
par[y]=ch[x][1]=0;
push_up(x);
}
}
int gcd(int a,int b)
{
return !b?a:gcd(b,a%b);
}
signed main()
{
n=read();m=read();
for(int i=1;i<=n;i++)
{
val[i]=read();
push_up(i);
}
for(int i=1;i<n;i++)
link(read(),read());
while(m--)
{
int op=read(),u=read(),v=read();
if(op==1)
cut(u,v);
if(op==2)
link(u,v);
if(op==3)
{
if(findroot(u)^findroot(v))
{
read();//一定要读。。
continue ;
}
split(u,v);
add(v,read());
}
if(op==4)
{
if(findroot(u)^findroot(v))
{
puts("-1");
continue ;
}
split(u,v);
int a=exp[v],b=siz[v]*(siz[v]+1)/2,t=gcd(a,b);
a/=t;b/=t;
printf("%lld/%lld\n",a,b);
}
}
}