dfs序+线段树+LCA。
难点:更换树的根节点后,如何用线段树更新一子树和查询一子树。
更新时:
如果当前的根节点不在要修改的子树中时,直接成段更新区间即可。
若当前的根节点在要修改的子树中时,有两种情况:
先更新【1~n】整个区间的值
①(root既不是u的祖先也不是v的祖先)时,多修改了一部分,把多修改的那部分该回来,即u,v节点中包含root节点的那颗子树。
②(与①条件相反)时,无须多做改动。
查询时:与更新的情况类似。
#include"bits/stdc++.h"
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
using namespace std;
typedef long long LL;
const int MAXN = 1e5+7;
struct noded {
int l,r;
} p[MAXN];
LL sum[MAXN<<2],lazy[MAXN<<2];
int a[MAXN],val[MAXN],dep[MAXN],f[MAXN][20];
int tt,n,m,root;
vector <int> G[MAXN];
void dfs(int deep, int u, int fa) {
p[u].l = ++tt;
val[tt] = a[u];
dep[u] = deep; f[u][0] = fa;
for(int i = 0; i < 18; i++)
f[u][i+1] = f[f[u][i]][i];
for(int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if(v == fa) continue;
dfs(deep+1,v,u);
}
p[u].r = tt;
}
int lca(int x, int y)
{
int i;
if(dep[x] < dep[y]) swap(x,y);
for(int i = 18; i >= 0; i--)
if(dep[f[x][i]] >= dep[y])
x = f[x][i];
if(x == y) return x;
for(int i = 18; i >= 0; i--)
if(f[x][i] != f[y][i])
x = f[x][i], y = f[y][i];
return f[x][0];
}
int get_son(int x, int y)
{
int i;
if(dep[x] < dep[y]) swap(x,y);
for(int i = 18; i >= 0; i--)
if(dep[f[x][i]] >= dep[y]+1)
x = f[x][i];
return x;
}
inline void pushup(int rt)
{
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}
inline void pushdown(int rt, int l)
{
if(!lazy[rt]) return;
lazy[rt<<1] += lazy[rt];
lazy[rt<<1|1] += lazy[rt];
sum[rt<<1] += lazy[rt]*(l - (l>>1));
sum[rt<<1|1] += lazy[rt]*(l>>1);
lazy[rt] = 0;
}
void build(int l, int r, int rt)
{
if(l == r) {
sum[rt] = val[l];
return;
}
int mid = l+r>>1;
build(lson);
build(rson);
pushup(rt);
}
void updata(int L, int R,LL c,int l, int r, int rt)
{
if(L <= l && r <= R){
sum[rt] += c*(r-l+1);
lazy[rt] += c;
return;
}
pushdown(rt,r-l+1);
int mid = l+r>>1;
if(L <= mid) updata(L,R,c,lson);
if(mid < R) updata(L,R,c,rson);
pushup(rt);
}
LL query(int L, int R, int l, int r, int rt)
{
if(L <= l && r <= R){
return sum[rt];
}
pushdown(rt,r-l+1);
int mid = l+r>>1;
LL ret = 0;
if(L <= mid) ret += query(L,R,lson);
if(mid < R) ret += query(L,R,rson);
return ret;
}
void work2(int u, int v, int c)
{
int fa = lca(u,v);
if(fa == root){
updata(1,n,c,1,n,1);
return;
}
if(root != 1){
if(dep[fa] < dep[root]){
int son = get_son(root,fa);
if(f[son][0] == fa){
updata(1,n,c,1,n,1);
int lcau = lca(root,u), lcav = lca(root,v);
int lcax = dep[lcau] > dep[lcav] ? lcau : lcav;
if(dep[root] > dep[lcax]){
son = get_son(root,lcax);
updata(p[son].l,p[son].r,-c,1,n,1);
}
}
else updata(p[fa].l,p[fa].r,c,1,n,1);
}
else updata(p[fa].l,p[fa].r,c,1,n,1);
}
else updata(p[fa].l,p[fa].r,c,1,n,1);
}
LL work3(int u)
{
if(u == root){
return query(1,n,1,n,1);
}
if(root != 1){
if(dep[u] < dep[root]){
int son = get_son(root,u);
if(f[son][0] == u){
return query(1,n,1,n,1)-query(p[son].l,p[son].r,1,n,1);
}
else return query(p[u].l,p[u].r,1,n,1);
}
else return query(p[u].l,p[u].r,1,n,1);
}
else return query(p[u].l,p[u].r,1,n,1);
}
int main() {
#ifdef __LOCAL__
freopen("input.txt","r",stdin);
#endif // __LOCAL__
root = 1;
scanf("%d%d",&n,&m);
for(int i = 1; i <= n; i++)
scanf("%d",&a[i]);
for(int i = 1; i < n; i++) {
int u,v;
scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,1,0);
build(1,n,1);
for(int i = 1; i <= m; i++){
int op,u,v,c;
scanf("%d",&op);
if(op == 1) scanf("%d",&root);
if(op == 2){
scanf("%d%d%d",&u,&v,&c);
work2(u,v,c);
}
if(op == 3){
scanf("%d",&u);
printf("%I64d\n",work3(u));
}
}
return 0;
}