题意:给出一棵根节点为1的树,执行m次修改操作,每次修改为a,b,c,表示a节点的子树中,距离a小于等于b的子节点的权值加上c,求m次操作后每个节点的权值
分析:用线段树维护每层节点的权值,然后dfs遍历这颗树,当前节点有操作时,把当前节点的深度到被修改的最大深度都加上c(实际上只有当前节点的子节点才加c),而回朔的时候再将这个区间减c,这样就避免了对非子节点的影响
AC代码(线段树的区间更新):
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=3e5+10;
ll f[maxn],nex[maxn*2],to[maxn*2],cnt,level[maxn],n,m;
ll ans[maxn],num[maxn*4],lazy[maxn*4];
vector<pair<ll,ll>>ve[maxn];
void add(int a,int b)
{
cnt++;
to[cnt]=b;
nex[cnt]=f[a];
f[a]=cnt;
}
void updata(int st,int en,int l,int r,int rt,ll x)
{
//cout<<l<<" "<<r<<endl;
if(st<=l&&en>=r)
{
num[rt]+=x*(r-l+1);
lazy[rt]+=x;
return ;
}
int mid=(l+r)/2;
if(lazy[rt])
{
lazy[rt*2]+=lazy[rt];
lazy[rt*2+1]+=lazy[rt];
num[rt*2]+=lazy[rt]*(mid-l+1);
num[rt*2+1]+=lazy[rt]*(r-mid);
lazy[rt]=0;
}
if(st<=mid)updata(st,en,l,mid,rt*2,x);
if(en>=mid+1)updata(st,en,mid+1,r,rt*2+1,x);
num[rt]=num[rt*2]+num[rt*2+1];
}
ll quer(int x,int l,int r,int rt)
{
int mid=(l+r)/2;
if(l==r)return num[rt];
if(lazy[rt])
{
lazy[rt*2]+=lazy[rt];
lazy[rt*2+1]+=lazy[rt];
num[rt*2]+=lazy[rt]*(mid-l+1);
num[rt*2+1]+=lazy[rt]*(r-mid);
lazy[rt]=0;
}
if(x<=(l+r)/2)return quer(x,l,(l+r)/2,rt*2);
else return quer(x,(l+r)/2+1,r,rt*2+1);
num[rt]=num[rt*2]+num[rt*2+1];
}
void getlevel(int x,int l)
{
level[x]=l;
for(int i=f[x]; i; i=nex[i])
{
int v=to[i];
if(!level[v])getlevel(v,l+1);
}
}
void dfs(int x,int pre)
{
for(int i=0; i<ve[x].size(); i++)
{
updata(level[x],level[x]+ve[x][i].first,1,n,1,ve[x][i].second);
}
ans[x]=quer(level[x],1,n,1);
for(int i=f[x]; i; i=nex[i])
{
int v=to[i];
if(v!=pre)
dfs(v,x);
}
for(int i=0; i<ve[x].size(); i++)
{
updata(level[x],level[x]+ve[x][i].first,1,n,1,-ve[x][i].second);
}
}
int main()
{
ios::sync_with_stdio(false);
cin>>n;
for(int i=1; i<=n-1; i++)
{
int a,b;
cin>>a>>b;
add(a,b);
add(b,a);
}
getlevel(1,1);
cin>>m;
for(int i=1; i<=m; i++)
{
ll a,b,c;
cin>>a>>b>>c;
ve[a].push_back(make_pair(b,c));
}
dfs(1,-1);
for(int i=1; i<=n; i++)
cout<<ans[i]<<" ";
cout<<endl;
return 0;
}
AC代码(树状数组的区间更新):
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn=3e5+10;
ll f[maxn],nex[maxn*2],to[maxn*2],cnt,n,m;
ll ans[maxn],c1[maxn],c2[maxn];
vector<pair<int,ll>>ve[maxn];
void add1(int a,int b)
{
cnt++;
to[cnt]=b;
nex[cnt]=f[a];
f[a]=cnt;
}
void add(int x,ll y)
{
for(int i=x; i<=n; i+=(i&-i))c1[i]+=y,c2[i]+=y*x;
}
ll quer(int x)
{
ll res=0;
for(int i=x; i>0; i-=(i&-i))res+=((x+1)*c1[i]-c2[i]);
return res;
}
void in_add(int l,int r,ll x)
{
add(l,x);
add(r+1,-x);
}
void dfs(int x,int pre,int l)
{
for(int i=0; i<ve[x].size(); i++)
in_add(l,l+ve[x][i].first,ve[x][i].second);
ans[x]=quer(l)-quer(l-1);
for(int i=f[x]; i; i=nex[i])
{
int v=to[i];
if(v!=pre)
dfs(v,x,l+1);
}
for(int i=0; i<ve[x].size(); i++)
in_add(l,l+ve[x][i].first,-ve[x][i].second);
}
int main()
{
ios::sync_with_stdio(false);
cin>>n;
for(int i=1; i<=n-1; i++)
{
int a,b;
cin>>a>>b;
add1(a,b);
add1(b,a);
}
cin>>m;
for(int i=1; i<=m; i++)
{
ll a,b,c;
cin>>a>>b>>c;
ve[a].push_back(make_pair(b,c));
}
dfs(1,-1,1);
for(int i=1; i<=n; i++)
cout<<ans[i]<<" ";
cout<<endl;
return 0;
}