Description
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
Input
第一行包含两个整数 N, M 。表示点数和操作数。接下来一行 N 个整数,表示树中节点的初始权值。接下来 N-1
行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。再接下来 M 行,每行分别表示一次操作。其中
第一个数表示该操作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。
Output
对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
Sample Input
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
Sample Output
6
9
13
9
13
HINT
对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。
树剖,裸题,注意开long long。
#include<iostream>
#include<cstdio>
using namespace std;
const int N=100005;
int n,m,cnt,dcnt,hd[N],mp[N];
long long c[N];
struct edge
{
int to,nxt;
}v[2*N];
struct node
{
int son,fa,sz,w,e,dep,tp;
}tr[N];
struct segtree
{
int l,r;
long long tag,sum;
}st[4*N];
void addedge(int x,int y)
{
v[++cnt].to=y;
v[cnt].nxt=hd[x];
hd[x]=cnt;
}
void dfs1(int u)
{
tr[u].sz=1;
for(int i=hd[u];i;i=v[i].nxt)
if(v[i].to!=tr[u].fa)
{
tr[v[i].to].fa=u;
tr[v[i].to].dep=tr[u].dep+1;
dfs1(v[i].to);
tr[u].sz+=tr[v[i].to].sz;
if(tr[v[i].to].sz>tr[tr[u].son].sz)
tr[u].son=v[i].to;
}
}
void dfs2(int u,int top)
{
tr[u].tp=top;
tr[u].w=++dcnt;
mp[dcnt]=u;
if(tr[u].son)
{
dfs2(tr[u].son,top);
for(int i=hd[u];i;i=v[i].nxt)
if(v[i].to!=tr[u].son&&v[i].to!=tr[u].fa)
dfs2(v[i].to,v[i].to);
}
tr[u].e=dcnt;
}
void pushup(int num)
{
st[num].sum=st[2*num].sum+st[2*num+1].sum;
}
void pushdown(int num)
{
if(st[num].tag)
{
if(st[num].l!=st[num].r)
{
st[2*num].sum+=(st[2*num].r-st[2*num].l+1)*st[num].tag;
st[2*num+1].sum+=(st[2*num+1].r-st[2*num+1].l+1)*st[num].tag;
st[2*num].tag+=st[num].tag;
st[2*num+1].tag+=st[num].tag;
}
st[num].tag=0;
}
}
void build(int num,int l,int r)
{
st[num].l=l,st[num].r=r;
if(l==r)
{
st[num].sum=c[mp[l]];
return ;
}
int mid=(l+r)/2;
build(2*num,l,mid),build(2*num+1,mid+1,r);
pushup(num);
}
void change(int num,int x,int y,long long z)
{
if(st[num].l>y||st[num].r<x)
return ;
if(st[num].l>=x&&st[num].r<=y)
{
st[num].sum+=(st[num].r-st[num].l+1)*z;
st[num].tag+=z;
return ;
}
pushdown(num);
change(2*num,x,y,z),change(2*num+1,x,y,z);
pushup(num);
}
long long query(int num,int x,int y)
{
if(st[num].l>y||st[num].r<x)
return 0;
if(st[num].l>=x&&st[num].r<=y)
return st[num].sum;
pushdown(num);
return query(2*num,x,y)+query(2*num+1,x,y);
}
long long ask(int x,int y)
{
long long res=0;
while(tr[x].tp!=tr[y].tp)
{
if(tr[tr[x].tp].dep>tr[tr[y].tp].dep)
{
res+=query(1,tr[tr[x].tp].w,tr[x].w);
x=tr[tr[x].tp].fa;
}
else
{
res+=query(1,tr[tr[y].tp].w,tr[y].w);
y=tr[tr[y].tp].fa;
}
}
if(tr[x].dep<tr[y].dep)
res+=query(1,tr[x].w,tr[y].w);
else
res+=query(1,tr[y].w,tr[x].w);
return res;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%lld",&c[i]);
for(int i=1;i<=n-1;i++)
{
int x,y;
scanf("%d%d",&x,&y);
addedge(x,y),addedge(y,x);
}
dfs1(1);
dfs2(1,1);
build(1,1,n);
int opt,x;
long long y;
while(m--)
{
scanf("%d",&opt);
switch(opt)
{
case 1:
scanf("%d%lld",&x,&y);
change(1,tr[x].w,tr[x].w,y);
break;
case 2:
scanf("%d%lld",&x,&y);
change(1,tr[x].w,tr[x].e,y);
break;
case 3:
scanf("%d",&x);
printf("%lld\n",ask(1,x));
break;
}
}
return 0;
}