方法: LCT
解析:
如果bzoj 2752没做过的话推荐先去做那个题。
这道题的话,前三问太裸了没什么说的。
第四种操作的话根据之前那个2752的题我们只需要考虑对于一个根节点,它的左儿子的期望值已经知道了,右儿子的期望值也已经知道了,怎么维护该节点的期望值。
popoqqq爷讲的非常清晰!好评
如果懒得点链接的话那我就稍微说一下。(就按照大爷说的栗子)
以下如果您做过2752的话应该能看懂。
对于左子树
假设有四个节点那么期望的贡献是什么
1*4*a1+2*3*a2+3*2*a3+4*1*a4
再对于右子树
假设有两个节点那么期望的贡献是什么
1*2*b1+2*1*b2
不妨设根的值为val
那么显然合并之后是什么呢
1*7*a1+2*6*a2+3*5*a3+4*4*a4+5*3*val+6*2*b1+7*1*b2
然后用这个总期望的前4项减掉原来的左子树期望
3*(a1+2*a2+3*a3+4*a4)->3为右子树size+1
后两项减去原来的右子树期望
5*(2*b1+b2);->5为左子树size+1
中间的val呢?
5*3*val->5为左子树size+1,3为右子树size+1
所以我们要维护什么?
第一种:普通的sum
第二种:siz
第三种:1*a1+2*a2+3*a3+4*a4+…+n*an
第四种:n*a1+(n-1)*a2+(n-2)*a3+…+1*an
怎么更新?数学公式啊!
还有一个公式您一定会用到
1*n+2*(n-1)+3*(n-2)+…+n*1=n*(n+1)*(n+2)/6
好题!
代码:
using namespace std;
typedef unsigned long long ll;
ll sum1n[N];
ll sumn1[N];
ll sum[N];
ll siz[N];
int ch[N][2];
ll exp[N];
ll col[N];
ll val[N];
int rt[N];
int fa[N];
int rev[N];
int head[N];
int n,m,cnt;
struct node
{
int to,next;
}edge[N<<1];
ll gcd(ll a,ll b)
{
while(b)
{
ll t=b;
b=a%b;
a=t;
}
return a;
}
void pushup(int x)
{
if(!x)return;
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;
sum[x]=sum[ch[x][0]]+sum[ch[x][1]]+val[x];
sum1n[x]=sum1n[ch[x][0]]+sum1n[ch[x][1]]+(siz[ch[x][0]]+1)*(sum[ch[x][1]]+val[x]);
sumn1[x]=sumn1[ch[x][1]]+sumn1[ch[x][0]]+(siz[ch[x][1]]+1)*(sum[ch[x][0]]+val[x]);
exp[x]=exp[ch[x][0]]+exp[ch[x][1]]+(siz[ch[x][1]]+1)*sum1n[ch[x][0]]+(siz[ch[x][0]]+1)*sumn1[ch[x][1]]+val[x]*(siz[ch[x][0]]+1)*(siz[ch[x][1]]+1);
}
void reverse(int x)
{
swap(ch[x][0],ch[x][1]);
swap(sum1n[x],sumn1[x]);
rev[x]^=1;
}
void pushdown(int x)
{
if(!x)return;
if(rev[x])
{
reverse(ch[x][0]);
reverse(ch[x][1]);
rev[x]=0;
}
if(col[x])
{
if(ch[x][0]!=0)
{
val[ch[x][0]]+=col[x];
col[ch[x][0]]+=col[x];
sum[ch[x][0]]+=col[x]*siz[ch[x][0]];
sum1n[ch[x][0]]+=((1+siz[ch[x][0]])*siz[ch[x][0]]*col[x])/2;
sumn1[ch[x][0]]+=((1+siz[ch[x][0]])*siz[ch[x][0]]*col[x])/2;
exp[ch[x][0]]+=(siz[ch[x][0]]*(siz[ch[x][0]]+1)*(siz[ch[x][0]]+2)*col[x])/6;
}
if(ch[x][1]!=0)
{
val[ch[x][1]]+=col[x];
col[ch[x][1]]+=col[x];
sum[ch[x][1]]+=col[x]*siz[ch[x][1]];
sum1n[ch[x][1]]+=((1+siz[ch[x][1]])*siz[ch[x][1]]*col[x])/2;
sumn1[ch[x][1]]+=((1+siz[ch[x][1]])*siz[ch[x][1]]*col[x])/2;
exp[ch[x][1]]+=(siz[ch[x][1]]*(siz[ch[x][1]]+1)*(siz[ch[x][1]]+2)*col[x])/6;
}
col[x]=0;
}
}
void down(int x)
{
if(!rt[x])down(fa[x]);
pushdown(x);
}
void rotate(int x)
{
int y=fa[x],kind=ch[y][1]==x;
ch[y][kind]=ch[x][!kind];
fa[ch[y][kind]]=y;
ch[x][!kind]=y;
fa[x]=fa[y];
fa[y]=x;
if(rt[y])rt[y]=0,rt[x]=1;
else ch[fa[x]][ch[fa[x]][1]==y]=x;
pushup(y);
}
void splay(int x)
{
down(x);
while(!rt[x])
{
int y=fa[x],z=fa[y];
if(rt[y])rotate(x);
else if((ch[y][1]==x)==(ch[z][1]==y))rotate(y),rotate(x);
else rotate(x),rotate(x);
}
pushup(x);
}
void access(int x)
{
int y=0;
while(x)
{
splay(x);
rt[ch[x][1]]=1,rt[y]=0;
ch[x][1]=y;
pushup(x);
y=x,x=fa[x];
}
}
int find_root(int x)
{
while(fa[x])x=fa[x];
return x;
}
void movetoroot(int x)
{
access(x);
splay(x);
reverse(x);
}
void link(int x,int y)
{
if(find_root(x)==find_root(y))return;
movetoroot(x);movetoroot(y);
fa[x]=y;
}
void cut(int x,int y)
{
if(find_root(x)!=find_root(y))return;
movetoroot(x);
access(y);
splay(y);
if(ch[y][0]==x&&ch[x][1]==0)
{
fa[x]=0;
ch[y][0]=0;
rt[x]=1;
pushup(y);
}
}
void dfs(int now,int fffa)
{
fa[now]=fffa;
for(int i=head[now];i!=-1;i=edge[i].next)
{
int to=edge[i].to;
if(to==fffa)continue;
dfs(to,now);
}
}
void initedge()
{
memset(head,-1,sizeof(head));
cnt=1;
}
void edgeadd(int from,int to)
{
edge[cnt].to=to;
edge[cnt].next=head[from];
head[from]=cnt++;
}
void init()
{
initedge();
for(int i=1;i<=n;i++)
{
scanf("%lld",&val[i]);
sum[i]=sum1n[i]=sumn1[i]=exp[i]=val[i];
siz[i]=rt[i]=1;
}
int x,y;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
edgeadd(x,y);edgeadd(y,x);
}
dfs(1,0);
}
void update(int x,int y,ll z)
{
if(find_root(x)!=find_root(y))return;
movetoroot(x);
access(y);
splay(y);
col[y]+=z;
val[y]+=z;
sum[y]+=z*siz[y];
sum1n[y]+=((1+siz[y])*siz[y]*z)/2;
sumn1[y]+=((1+siz[y])*siz[y]*z)/2;
exp[y]+=(siz[y]*(siz[y]+1)*(siz[y]+2)*z)/6;
pushdown(y);
}
void query(int x,int y)
{
if(find_root(x)!=find_root(y))
{
puts("-1");
return;
}
movetoroot(x);
access(y);
splay(y);
ll ans=exp[y];
ll tmpcnt=(siz[y]*(siz[y]+1))/2;
ll tmpgcd=gcd(ans,tmpcnt);
printf("%llu/%llu\n",ans/tmpgcd,tmpcnt/tmpgcd);
}
int main()
{
scanf("%d%d",&n,&m);
init();
for(int i=1;i<=m;i++)
{
int jd,x,y;
ll z;
scanf("%d%d%d",&jd,&x,&y);
switch(jd)
{
case 1:cut(x,y);break;
case 2:link(x,y);break;
case 3:
scanf("%llu",&z);
update(x,y,z);
break;
case 4:
query(x,y);
break;
}
}
}