题目大意:
给你
n(n<=3∗105)
n
(
n
<=
3
∗
10
5
)
个点的一棵树,初始根为
1
1
,支持种操作。
1.
1.
把根换成
x
x
。
把
x
x
,两点的
lca
l
c
a
的子数每个点权值
+x
+
x
。
3.
3.
询问以
x
x
为根的子树权值和。
分析:
主要是换根操作。先以为根跑
dfs
d
f
s
,然后对两点
lca
l
c
a
分类讨论。
如果
x
x
和都为
root
r
o
o
t
的儿子,且
lca
l
c
a
不为
root
r
o
o
t
,那么给他们的子树直接加。
如果一个是
root
r
o
o
t
的儿子,另一个不是,或者
lca
l
c
a
为
root
r
o
o
t
,给整棵树加权值。
如果两个都不是
root
r
o
o
t
儿子的,且
lca
l
c
a
不是
root
r
o
o
t
的父亲,直接给子树加。
如果两个都不是
root
r
o
o
t
儿子的,
lca
l
c
a
是
root
r
o
o
t
的父亲,此时找到
x
x
和和
y
y
与的
lca
l
c
a
,设深度大的那个
lca
l
c
a
为
d
d
,那么这棵树除了中包含
root
r
o
o
t
的子树,其他的都加权值。线段树维护即可。
代码:
#include <iostream>
#include <cmath>
#include <cstdio>
#define LL long long
const int maxn=3e5+7;
using namespace std;
struct edge{
int y,next;
}g[maxn*2];
struct node{
LL lazy;
LL sum;
}t[maxn*4];
int n,test,x,y,cnt,root,op;
int ls[maxn],dfn[maxn],last[maxn],dep[maxn];
LL a[maxn],k;
int f[maxn][20];
void add(int x,int y)
{
g[++cnt]=(edge){y,ls[x]};
ls[x]=cnt;
}
void dfs(int x,int fa)
{
dfn[x]=++cnt;
last[x]=dfn[x];
f[x][0]=fa;
for (int i=ls[x];i>0;i=g[i].next)
{
int y=g[i].y;
if (y==fa) continue;
dep[y]=dep[x]+1;
dfs(y,x);
last[x]=max(last[x],last[y]);
}
}
void clean(int p,int l,int r)
{
if (t[p].lazy)
{
int mid=(l+r)/2;
t[p*2].lazy+=t[p].lazy;
t[p*2].sum+=(LL)(mid-l+1)*t[p].lazy;
t[p*2+1].lazy+=t[p].lazy;
t[p*2+1].sum+=(LL)(r-mid)*t[p].lazy;
t[p].lazy=0;
}
}
void ins(int p,int l,int r,int x,int y,LL k)
{
if ((l==x) && (r==y))
{
t[p].sum+=(LL)(r-l+1)*k;
t[p].lazy+=k;
return;
}
int mid=(l+r)/2;
clean(p,l,r);
if (y<=mid) ins(p*2,l,mid,x,y,k);
else if (x>mid) ins(p*2+1,mid+1,r,x,y,k);
else
{
ins(p*2,l,mid,x,mid,k);
ins(p*2+1,mid+1,r,mid+1,y,k);
}
t[p].sum=t[p*2].sum+t[p*2+1].sum;
}
LL getsum(int p,int l,int r,int x,int y)
{
if ((l==x) && (r==y)) return t[p].sum;
int mid=(l+r)/2;
clean(p,l,r);
if (y<=mid) return getsum(p*2,l,mid,x,y);
else if (x>mid) return getsum(p*2+1,mid+1,r,x,y);
else
{
return getsum(p*2,l,mid,x,mid)+getsum(p*2+1,mid+1,r,mid+1,y);
}
}
bool isfa(int x,int y)
{
return ((dfn[x]<=dfn[y]) && (last[x]>=last[y]));
}
int up(int x,int d)
{
int k=19,t=1<<19;
while (d)
{
if (d>=t) x=f[x][k],d-=t;
t/=2; k--;
}
return x;
}
int lca(int x,int y)
{
if (dep[x]>dep[y]) swap(x,y);
int d=dep[y]-dep[x];
y=up(y,d);
if (x==y) return x;
int k=19;
while (k>=0)
{
if (f[x][k]!=f[y][k])
{
x=f[x][k];
y=f[y][k];
}
k--;
}
return f[x][0];
}
int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
scanf("%d%d",&n,&test);
for (int i=1;i<=n;i++) scanf("%lld",&a[i]);
for (int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
cnt=0;
dfs(1,0);
for (int j=1;j<20;j++)
{
for (int i=1;i<=n;i++)
{
f[i][j]=f[f[i][j-1]][j-1];
}
}
root=1;
for (int i=1;i<=n;i++) ins(1,1,n,dfn[i],dfn[i],a[i]);
for (int i=1;i<=test;i++)
{
scanf("%d",&op);
if (op==1) scanf("%d",&root);
if (op==2)
{
scanf("%d%d%lld",&x,&y,&k);
if (isfa(root,x) && isfa(root,y))
{
int d=lca(x,y);
if (d==root) ins(1,1,n,1,n,k);
else ins(1,1,n,dfn[d],last[d],k);
}
else
{
if (isfa(root,x) || isfa(root,y)) ins(1,1,n,1,n,k);
else
{
int d=lca(x,y);
if (!isfa(d,root))
{
ins(1,1,n,dfn[d],last[d],k);
}
else
{
int d1=lca(x,root);
int d2=lca(y,root);
if (dep[d1]<dep[d2]) swap(d1,d2);
int c=up(root,dep[root]-dep[d1]-1);
if (dfn[c]-1>=1) ins(1,1,n,1,dfn[c]-1,k);
if (last[c]+1<=n) ins(1,1,n,last[c]+1,n,k);
}
}
}
}
if (op==3)
{
scanf("%d",&x);
if (isfa(root,x))
{
if (x==root) printf("%lld\n",getsum(1,1,n,1,n));
else printf("%lld\n",getsum(1,1,n,dfn[x],last[x]));
}
else
{
if (!isfa(x,root)) printf("%lld\n",getsum(1,1,n,dfn[x],last[x]));
else
{
int d=lca(x,root);
int c=up(root,dep[root]-dep[d]-1);
LL ans=0;
if (dfn[c]-1>=1) ans+=getsum(1,1,n,1,dfn[c]-1);
if (last[c]+1<=n) ans+=getsum(1,1,n,last[c]+1,n);
printf("%lld\n",ans);
}
}
}
}
}