题意:
一棵有根树,初始根为1,点有权值,有5种操作:
换根,链加,子树加,查询链和,查询子树和
数据范围
节 点 数 , 操 作 数 , 点 权 ≤ 1 e 5 节点数,操作数,点权\le 1e5 节点数,操作数,点权≤1e5
解法
树链剖分,对于后面4个操作都很好维护,然后考虑换根带来的影响,对于链上的操作是没有影响的,主要是子树操作,可以发现,如果根和查询的点x是同一个点,就是直接查询全树的和,如果根在x的子树外,就仍然查询这个子树。比较麻烦的是根在子树内,这时需要用全树的和减去根到x的这段路径上距离查询点最近的一个点p的子树。
这里解释的很到位
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+5;
inline int read(){
char c=getchar();int t=0,f=1;
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
return t*f;
}
int n,a[maxn],fa[maxn];
struct edge{
int v,p;
}e[maxn<<1];
int h[maxn],cnt;
inline void add(int a,int b){
e[++cnt].p=h[a];
e[cnt].v=b;
h[a]=cnt;
e[++cnt].p=h[b];
e[cnt].v=a;
h[b]=cnt;
}
int sz[maxn],id[maxn],dfn,son[maxn],dep[maxn],top[maxn];
struct node{
ll sum,tag;
}t[maxn<<2];
int st[maxn][20];
void dfs1(int u,int fa){
sz[u]=1,dep[u]=dep[fa]+1;st[u][0]=fa;
for(int i=1;st[u][i-1];i++)st[u][i]=st[st[u][i-1]][i-1];
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa)continue;
dfs1(v,u);sz[u]+=sz[v];
if(sz[son[u]]<sz[v]){son[u]=v;}
}
}
int pos[maxn];
void dfs2(int u,int tp,int fa){
id[u]=++dfn;top[u]=tp;pos[dfn]=u;
if(son[u]){
dfs2(son[u],tp,u);
}
for(int i=h[u];i;i=e[i].p){
int v=e[i].v;
if(v==fa||v==son[u])continue;
dfs2(v,v,u);
}
}
#define lc rt<<1
#define rc rt<<1|1
inline void pushup(int rt){
t[rt].sum=t[lc].sum+t[rc].sum;
}
void build(int rt,int l,int r){
if(l==r){
t[rt].sum=a[pos[l]];
return ;
}
int mid=l+r>>1;
build(lc,l,mid);build(rc,mid+1,r);
pushup(rt);
}
int m,rt;
inline void pushdown(int rt,int l,int r){
int mid=l+r>>1;
t[lc].tag+=t[rt].tag;
t[rc].tag+=t[rt].tag;
t[lc].sum+=(mid-l+1)*t[rt].tag;
t[rc].sum+=(r-mid)*t[rt].tag;
t[rt].tag=0;
}
void adt(int rt,int l,int r,int x,int y,int val){
if(x<=l&&r<=y){
t[rt].sum+=(r-l+1)*val;
t[rt].tag+=val;
return ;
}
if(t[rt].tag)pushdown(rt,l,r);
int mid=l+r>>1;
if(x<=mid)adt(lc,l,mid,x,y,val);
if(y>mid)adt(rc,mid+1,r,x,y,val);
pushup(rt);
}
void add(int x,int y,int val){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]){
swap(x,y);
}
adt(1,1,n,id[top[x]],id[x],val);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
adt(1,1,n,id[x],id[y],val);
}
inline int find(int rt,int x){
for(int j=19;j>=0;j--){
if(dep[st[x][j]]>dep[rt]){x=st[x][j];}
}
return x;
}
ll quert(int rt,int l,int r,int x,int y){
if(x<=l&&r<=y){
return t[rt].sum;
}
pushdown(rt,l,r);
int mid=l+r>>1;
ll ans=0;
if(x<=mid)ans=ans+quert(lc,l,mid,x,y);
if(y>mid)ans=ans+quert(rc,mid+1,r,x,y);
return ans;
}
inline ll query(int x,int y){
ll ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]])swap(x,y);
ans=ans+quert(1,1,n,id[top[y]],id[y]);
y=fa[top[y]];
}
if(dep[x]>dep[y])swap(x,y);
ans=ans+quert(1,1,n,id[x],id[y]);
return ans;
}
int main(){
//freopen("139.in","r",stdin);
n=read();
for(int i=1;i<=n;i++){
a[i]=read();
}
for(int i=2;i<=n;i++){
fa[i]=read();
add(i,fa[i]);
}rt=1;
dfs1(1,0);
dfs2(1,1,0);
build(1,1,n);
m=read();
int opt,u,v,k;
while(m--){
opt=read();
//printf("%d\n",opt);
if(opt==1){
u=read();rt=u;
}
if(opt==2){
u=read(),v=read(),k=read();
add(u,v,k);
}
if(opt==3){
u=read(),k=read();
if(u==rt){
adt(1,1,n,1,n,k);
}
else if(id[rt]>=id[u]+sz[u]||id[rt]<id[u]){
adt(1,1,n,id[u],id[u]+sz[u]-1,k);
}
else{
adt(1,1,n,1,n,k);
int pos=find(u,rt);
adt(1,1,n,id[pos],id[pos]+sz[pos]-1,-k);
}
}
if(opt==4){
u=read(),v=read();
printf("%lld\n",query(u,v));
}
if(opt==5){
u=read();
if(u==rt)printf("%lld\n",quert(1,1,n,1,n));
else if(id[rt]>=id[u]+sz[u]||id[rt]<id[u]){
printf("%lld\n",quert(1,1,n,id[u],id[u]+sz[u]-1));
}
else{
ll ans=quert(1,1,n,1,n);
int pos=find(u,rt);
ans=ans-quert(1,1,n,id[pos],id[pos]+sz[pos]-1);
printf("%lld\n",ans);
}
}
}
return 0;
}