4034: [HAOI2015]树上操作
Description
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个
操作,分为三种:
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
Input
第一行包含两个整数 N, M 。表示点数和操作数。
接下来一行 N 个整数,表示树中节点的初始权值。
接下来 N-1 行每行三个正整数 fr, to , 表示该树中存在一条边 (fr, to) 。
再接下来 M 行,每行分别表示一次操作。其中第一个数表示该操
作的种类( 1-3 ) ,之后接这个操作的参数( x 或者 x a ) 。
Output
对于每个询问操作,输出该询问的答案。答案之间用换行隔开。
Sample Input
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
Sample Output
6
9
13
HINT
对于 100% 的数据, N,M<=100000 ,且所有输入数据的绝对值都不会超过 10^6 。
思路:
树链剖分之后每个结点下面子树的编号是连续的,所以in[],out[]代表一个区间,用线段树维护这个序列,结合重链top就可以支持点修改,链查询,子树修改,子树查询。
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define LL long long
using namespace std;
const int N = 400010;
int head[N], next[N<<1], to[N<<1], idc, idx;
int in[N], out[N], son[N], fa[N], siz[N], deep[N], top[N];
LL val[N], val1[N<<1], sum[N<<1], flag[N<<1];
void adde(int u, int v){
to[++idc] = v;
next[idc] = head[u];
head[u] = idc;
}
void dfs1(int u, int f){
deep[u] = deep[f] + 1;
fa[u] = f; siz[u] = 1;
for(int k=head[u]; k; k=next[k]){
int v = to[k];
if(v == f) continue;
dfs1(v, u);
siz[u] += siz[v];
if( !son[u] || siz[son[u]] < siz[v]) son[u] = v;
}
}
void dfs2(int u, int tp){
top[u] = tp; out[u] = in[u] = ++idx;
val1[in[u]] = val[u];
if(son[u] == -1) return;
dfs2(son[u], tp);
for(int k=head[u]; k; k=next[k]){
int v = to[k];
if(v==fa[u] || v==son[u]) continue;
dfs2(v, v);
}
out[u] = idx;
}
void build(int l, int r, int pos){
flag[pos] = 0;
if(l == r){
sum[pos] = val1[l];
return ;
}
int mid = (l+r)>>1;
build(l, mid, pos<<1);
build(mid+1, r, pos<<1|1);
sum[pos] = sum[pos<<1] + sum[pos<<1|1];
}
void pushdown(int pos, int l, int r){
if( !flag[pos] ) return ;
int mid = (l+r)>>1;
flag[pos<<1] += flag[pos];
flag[pos<<1|1] += flag[pos];
sum[pos<<1] += (mid-l+1) * flag[pos];
sum[pos<<1|1] += (r-mid) * flag[pos];
flag[pos] = 0;
}
void modify(int pos, int ll, int rr, int l, int r, int val){
if(ll==l && r==rr){
flag[pos] += val;
sum[pos] += (LL)(r-l+1)*val;
return ;
}
pushdown(pos, l, r);
int mid = (l+r)>>1;
if(rr<=mid) modify(pos<<1, ll, rr, l, mid, val);
else if(ll>mid) modify(pos<<1|1, ll, rr, mid+1, r, val);
else{
modify(pos<<1, ll, mid, l, mid, val);
modify(pos<<1|1, mid+1, rr, mid+1, r, val);
}
sum[pos] = sum[pos<<1] + sum[pos<<1|1];
}
LL query(int pos, int ll, int rr, int l, int r){
if(ll==l && r==rr){
return sum[pos];
}
pushdown(pos, l, r);
int mid = (l+r)>>1;
if(rr<=mid) return query(pos<<1, ll, rr, l, mid);
else if(ll>mid) return query(pos<<1|1, ll, rr, mid+1, r);
else return query(pos<<1, ll, mid, l, mid) + query(pos<<1|1, mid+1, rr, mid+1, r);
}
LL solve(int u, int v){
LL sum = 0;
int f1 = top[u], f2 = top[v];
while(f1 != f2){
if(deep[f1] < deep[f2]) swap(f1, f2), swap(u, v);
sum += query(1, in[f1], in[u], 1, idx);
u = fa[f1];
f1 = top[u];
}
if(deep[u] > deep[v]) swap(u, v);
sum += query(1, in[u], in[v], 1, idx);
return sum;
}
int main(){
int n, m;
scanf("%d%d", &n, &m);
memset(head, 0, sizeof(head));
memset(son, -1, sizeof(son));
idc = 0; idx = 0;
for(int i=1; i<=n; i++)
scanf("%lld", &val[i]);
for(int i=1; i<n; i++){
int u, v;
scanf("%d%d", &u, &v);
adde(u, v);
adde(v, u);
}
deep[1] = 0;
dfs1(1, -1);
dfs2(1, 1);
build(1, idx, 1);
while( m-- ){
int opt;
scanf("%d", &opt);
if(opt == 1){
int pos, val;
scanf("%d%d", &pos, &val);
modify(1, in[pos], in[pos], 1, idx, val);
}
if(opt == 2){
int pos, val;
scanf("%d%d", &pos, &val);
modify(1, in[pos], out[pos], 1, idx, val);
}
if(opt == 3){
int pos;
scanf("%d", &pos);
LL ans = solve(1, pos);
printf("%lld\n", ans);
}
}
return 0;
}