给定一棵树,树中包含 n 个节点(编号 1∼n),其中第 i 个节点的权值为 ai。
初始时,1号节点为树的根节点。
现在要对该树进行 m 次操作,操作分为以下 44 种类型:
1 u v k
,修改路径上节点权值,将节点 u 和节点 v 之间路径上的所有节点(包括这两个节点)的权值增加 k。2 u k
,修改子树上节点权值,将以节点 u 为根的子树上的所有节点的权值增加 k。3 u v
,询问路径,询问节点 u 和节点 v 之间路径上的所有节点(包括这两个节点)的权值和。4 u
,询问子树,询问以节点 u 为根的子树上的所有节点的权值和。输入格式
第一行包含一个整数 n,表示节点个数。
第二行包含 n 个整数,其中第 i 个整数表示 ai。
接下来 n−1行,每行包含两个整数 x,y,表示节点 x 和节点 y之间存在一条边。
再一行包含一个整数 m,表示操作次数。
接下来 m 行,每行包含一个操作,格式如题目所述。
输出格式
对于每个操作 3和操作 4,输出一行一个整数表示答案。
数据范围
1≤n,m≤105
0≤ai,k≤105
1≤u,v,x,y≤n输入样例:
5 1 3 7 4 5 1 3 1 4 1 5 2 3 5 1 3 4 3 3 5 4 1 3 5 10 2 3 5 4 1
输出样例:
16 69
难度:中等 时/空限制:1s / 64MB 总通过数:1909 总尝试数:3322 来源:模板题,AcWing 算法标签
通过dfs进行树链剖分,将树变成线性,然后可以通过线段树,树状数组,分块等知识求解
#include <iostream>
#include <cstring>
using namespace std;
constexpr int N=1e5+7,M=N*2;
typedef long long ll;
int h[N],w[N],e[M],ne[M],idx;
int dep[N],sz[N],top[N],fa[N],son[N];
int id[N],nw[N],cnt;
struct node{
int l,r;
ll add,sum;
}tr[N*4];
void add(int a,int b){
e[idx]=b;
ne[idx]=h[a];
h[a]=idx++;
}
void dfs1(int u,int father,int depth){
dep[u]=depth,fa[u]=father,sz[u]=1;
for(int i=h[u];i!=-1;i=ne[i]){
int j=e[i];
if(j!=father){
dfs1(j,u,depth+1);
sz[u]+=sz[j];
if(sz[son[u]]<sz[j]) son[u]=j;
}
}
}
void dfs(int u,int t){
id[u]=++cnt,nw[cnt]=w[u],top[u]=t;
if (!son[u]) return;
dfs(son[u], t);
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[u] || j == son[u]) continue;
dfs(j, j);
}
}
void pushup(int u){
tr[u].sum=tr[u<<1].sum+tr[u<<1|1].sum;
}
void pushdown(int u){
auto &root=tr[u],&left=tr[u<<1],&right=tr[u<<1|1];
if(root.add){
left.add+=root.add;
left.sum+=root.add*(left.r-left.l+1);
right.add+=root.add;
right.sum+=root.add*(right.r-right.l+1);
root.add=0;
}
}
void build(int u,int l,int r){
tr[u]={l,r,0,nw[r]};
if(l==r)return;
int mid=l+r>>1;
build(u<<1,l,mid);
build(u<<1|1,mid+1,r);
pushup(u);
}
void update(int u,int l,int r,int k){
if (l <= tr[u].l && r >= tr[u].r)
{
tr[u].add += k;
tr[u].sum += k * (tr[u].r - tr[u].l + 1);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) update(u << 1, l, r, k);
if (r > mid) update(u << 1 | 1, l, r, k);
pushup(u);
}
ll query(int u,int l,int r){
if (l <= tr[u].l && r >= tr[u].r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
ll res = 0;
if (l <= mid) res += query(u << 1, l, r);
if (r > mid) res += query(u << 1 | 1, l, r);
return res;
}
void update_tree(int u,int k){
update(1,id[u],id[u]+sz[u]-1,k);
}
ll query_tree(int u){
return query(1,id[u],id[u]+sz[u]-1);
}
void update_path(int u,int v,int k){
while (top[u] != top[v]) //向上爬找到相同重链
{
if (dep[top[u]] < dep[top[v]]) swap(u, v);
update(1, id[top[u]], id[u], k); //dfs序原因,上面节点的id必然小于下面节点的id
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
update(1, id[v], id[u], k);
}
ll query_path(int u,int v){
ll res = 0;
while (top[u] != top[v]) //向上爬找到相同重链
{
if (dep[top[u]] < dep[top[v]]) swap(u, v);
res += query(1, id[top[u]], id[u]);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
res += query(1, id[v], id[u]); //在同一重链中,处理剩余区间
return res;
}
int main(){
int n;
scanf("%d",&n);
memset(h,-1,sizeof h);
for(int i=1;i<=n;i++){
scanf("%d",&w[i]);
}
for(int i=1;i<n;i++){
int a,b;
scanf("%d%d",&a,&b);
add(a,b);
add(b,a);
}
dfs1(1,-1,1);
dfs(1,1);
build(1, 1, n);
int m;
scanf("%d",&m);
while(m--){
int op,u,v,k;
scanf("%d",&op);
if(op==1){
scanf("%d%d%d",&u,&v,&k);
update_path(u,v,k);
}
else if(op==2){
scanf("%d%d",&u,&k);
update_tree(u,k);
}
else if(op==3){
scanf("%d%d",&u,&v);
printf("%lld\n",query_path(u,v));
}
else{
scanf("%d",&u);
printf("%lld\n",query_tree(u));
}
}
}