题目链接:
题意:
给定一个以 1 为根节点的树,初始每个节点的权值为 0 。有 m 次操作,每次把以 vi 为祖先且离 vi 的距离小于 di 的所有节点(包括 vi 本身)的权值加上 xi 。问所有操作结束后,每个节点的权值。
思路:
一个点只会影响它的子孙节点(包括自己),且只会被它的祖先节点影响(包括自己)。
所以,我们 dfs 到一个节点 v 时,我们把以它为操作对象的操作完成,那么相当于对于该节点的所有操作都已被完成,那么我们可以直接算出这个节点的最终权值。当 dfs 回溯到该节点时,我们把之前做的操作删除即可,因为关于该节点的操作不能影响不以该节点为祖先的节点。这样一次 dfs 就能算出所有节点的最终答案。
那么,操作怎样进行呢?我们知道:到当前为止的所有操作都会对该节点的子孙产生影响(如果在范围内的话),所以比如说操作为v,d,x,当前节点深度为dep,那么所有深度在 [dep,dep+d] 范围内的节点都要加 x 。因此我们可以用线段树来维护区间和。当前节点的最终权值就是用线段树查找区间 [dep,dep] 的权值和。
Code:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAX = 3e5+10;
const ll mod = 1e9+7;
typedef struct{
int d;
ll x;
}Point;
int n,m;
vector<int>mp[MAX];
vector<Point>op[MAX];
ll sum[MAX<<2],Add[MAX<<2];
ll res[MAX];
/*线段树*/
void Pushup(int root)
{
sum[root]=sum[root<<1]+sum[root<<1|1];
}
void Build(int l,int r,int root)
{
if(l==r){
sum[root]=0;
return;
}
int mid = (l+r)>>1;
Build(l,mid,root<<1);
Build(mid+1,r,root<<1|1);
Pushup(root);
}
void Pushdown(int root,int ln,int rn)
{
if(Add[root]){
Add[root<<1]+=Add[root];
Add[root<<1|1]+=Add[root];
sum[root<<1]+=Add[root]*ln;
sum[root<<1|1]+=Add[root]*rn;
Add[root]=0;
}
}
void Update(int L,int R,ll c,int l,int r,int root)
{
if(L<=l&&r<=R){
sum[root]+=c*(r-l+1);
Add[root]+=c;
return;
}
int mid = (l+r)>>1;
Pushdown(root,mid-l+1,r-mid);
if(L<=mid) Update(L,R,c,l,mid,root<<1);
if(R>mid) Update(L,R,c,mid+1,r,root<<1|1);
Pushup(root);
}
ll Query(int L,int R,int l,int r,int root)
{
if(L<=l&&r<=R){
return sum[root];
}
int mid = (l+r)>>1;
Pushdown(root,mid-l+1,r-mid);
ll ans=0;
if(L<=mid) ans+=Query(L,R,l,mid,root<<1);
if(R>mid) ans+=Query(L,R,mid+1,r,root<<1|1);
return ans;
}
/*线段树*/
void dfs(int root,int fa,int dep)
{
//添加操作
for(int i=0;i<op[root].size();i++){
Point now = op[root][i];
Update(dep,min(dep+now.d,n),now.x,1,n,1);
}
res[root]=Query(dep,dep,1,n,1);
for(int i=0;i<mp[root].size();i++){
int v = mp[root][i];
if(v==fa) continue;
dfs(v,root,dep+1);
}
//删除操作
for(int i=0;i<op[root].size();i++){
Point now = op[root][i];
Update(dep,min(dep+now.d,n),-1*now.x,1,n,1);
}
}
int main()
{
scanf("%d",&n);
for(int i=0;i<n-1;i++){
int x,y;
scanf("%d%d",&x,&y);
mp[x].push_back(y);
mp[y].push_back(x);
}
Build(1,n,1);
scanf("%d",&m);
while(m--){
int v,d;
ll x;
scanf("%d%d%lld",&v,&d,&x);
op[v].push_back(Point{d,x});
}
dfs(1,-1,1);
printf("%lld",res[1]);
for(int i=2;i<=n;i++){
printf(" %lld",res[i]);
}
printf("\n");
return 0;
}