Description
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
Input
第一行包含两个整数 N, M 。表示点数和操作数。
接下来一行 N 个整数,表示树中节点的初始权值。
接下来 N-1 行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。
再接下来 M 行,每行分别表示一次操作。其中第一个数表示该操
作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。
Output
对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
Sample Input
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
Sample Output
6
9
13
HINT
对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不
会超过 10^6 。
题解
树链剖分。。。
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#define ll long long
int n,m,tim,tot;
int fa[100005],son[100005];
int pos[100005],re[100005],v[100005];
int top[100005],size[100005];
int ret[200005],Next[200005],Head[200005];
ll lazy[400005],sum[400005];
using namespace std;
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void ins(int x,int y){
tot++;
ret[tot]=y;
Next[tot]=Head[x];
Head[x]=tot;
}
void init(){
n=read();m=read();
for (int i=1;i<=n;i++){
v[i]=read();
}
for (int i=1;i<n;i++){
int x,y;
x=read();y=read();
ins(x,y);
ins(y,x);
}
}
void dfs1(int u){
size[u]=1;
for (int i=Head[u];i!=0;i=Next[i]){
int v=ret[i];
if (v!=fa[u]){
fa[v]=u;
dfs1(v);
size[u]+=size[v];
if (son[u]==0||size[son[u]]<size[v]){
son[u]=v;
}
}
}
}
void dfs2(int u,int chain){
tim++;
pos[u]=tim;
top[u]=chain;
re[u]=tim;
if (son[u]!=0){
dfs2(son[u],chain);
re[u]=max(re[u],re[son[u]]);
}
for (int i=Head[u];i!=0;i=Next[i]){
int v=ret[i];
if (v!=fa[u]&&v!=son[u]){
dfs2(v,v);
re[u]=max(re[u],re[v]);
}
}
}
void pushdown(int k,int l,int r){
if (l==r) return;
int mid=(l+r)/2;
ll t=lazy[k];lazy[k]=0;
lazy[k<<1]+=t;lazy[k<<1|1]+=t;
sum[k<<1]+=t*(mid-l+1);
sum[k<<1|1]+=t*(r-mid);
}
void change(int k,int l,int r,int x,int y,ll a){
if (lazy[k]) pushdown(k,l,r);
if (l==x&&r==y){
lazy[k]+=a;
sum[k]+=(r-l+1)*a;
return;
}
int mid=(l+r)/2;
if (y<=mid) change(k*2,l,mid,x,y,a);
if (x>mid) change(k*2+1,mid+1,r,x,y,a);
if (x<=mid&&y>mid){
change(k*2,l,mid,x,mid,a);
change(k*2+1,mid+1,r,mid+1,y,a);
}
sum[k]=sum[k*2+1]+sum[k*2];
}
ll query(int k,int l,int r,int x,int y){
if (lazy[k]) pushdown(k,l,r);
if (l==x&&r==y){
return sum[k];
}
int mid=(l+r)/2;
if (y<=mid) return query(k*2,l,mid,x,y);
if (x>mid) return query(k*2+1,mid+1,r,x,y);
ll ans=query(k*2,l,mid,x,mid)+query(k*2+1,mid+1,r,mid+1,y);
return ans;
}
ll solveask(int x){
ll ans=0;
while (top[x]!=0){
ans+=query(1,1,n,pos[top[x]],pos[x]);
x=fa[top[x]];
}
return ans;
}
void solve(){
for (int i=1;i<=n;i++){
change(1,1,n,pos[i],pos[i],v[i]);
}
int opt,x,y;
for (int i=1;i<=m;i++){
opt=read();x=read();
if (opt==1){
y=read();
change(1,1,n,pos[x],pos[x],y);
}
if (opt==2){
y=read();
change(1,1,n,pos[x],re[x],y);
}
if (opt==3){
printf("%lld\n",solveask(x));
}
}
}
int main(){
init();
dfs1(1);
dfs2(1,1);
solve();
}