题意:
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个操作,分为三种:操作 1 :把某个节点 x 的点权增加 a 。操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
一道没有什么思维容量的树链剖分,比洛谷的模板还要简单,唯一值得注意的是结果会爆int!
然而当我发现之后改代码时一开始只把一部分int改成了long long,然后发现交上去还是WA,于是索性把除了主函数前面的int之外的所有int都改成了long long,反正是不会爆内存的。
下面附上代码。
#include <bits/stdc++.h>
using namespace std;
struct node
{
long long next,to;
}a[200010];
struct tree
{
long long l,r;
long long sum,tag;
}tr[400005];
long long n,m,b[200010],head[200010],cnt;
long long tot[200010],dep[200010],fa[200010],son[200010];
long long top[200010],ys[200010],yss[200010],z,end[200010];
void ins(long long from,long long to)
{
a[++cnt].next=head[from];
a[cnt].to=to;
head[from]=cnt;
}
void dfs1(long long x)
{
tot[x]=1;
for(long long i=head[x];i;i=a[i].next)
{
long long y=a[i].to;
if(y!=fa[x])
{
fa[y]=x;
dep[y]=dep[x]+1;
dfs1(y);
tot[x]+=tot[y];
if(tot[y]>tot[son[x]])
son[x]=y;
}
}
}
void dfs2(long long x,long long tp)
{
top[x]=tp;
ys[x]=++z;
yss[z]=x;
if(son[x])
{
dfs2(son[x],tp);
for(long long i=head[x];i;i=a[i].next)
{
long long y=a[i].to;
if(y!=fa[x]&&y!=son[x])
dfs2(y,y);
}
}
end[x]=z;
}
void update(long long num)
{
tr[num].sum=tr[num<<1].sum+tr[num<<1|1].sum;
}
void build(long long num,long long l,long long r)
{
tr[num].l=l;
tr[num].r=r;
if(l==r)
{
tr[num].sum=b[yss[l]];
return;
}
long long mid=(l+r)>>1;
build(num<<1,l,mid);
build(num<<1|1,mid+1,r);
update(num);
}
void pushdown(long long num)
{
if(tr[num].tag)
{
long long x=tr[num].tag;
tr[num<<1].tag+=x;
tr[num<<1|1].tag+=x;
tr[num<<1].sum+=(tr[num<<1].r-tr[num<<1].l+1)*x;
tr[num<<1|1].sum+=(tr[num<<1|1].r-tr[num<<1|1].l+1)*x;
tr[num].tag=0;
}
}
void change(long long num,long long l,long long r,long long x)
{
if(tr[num].l>r||tr[num].r<l)
return;
if(tr[num].l>=l&&tr[num].r<=r)
{
tr[num].tag+=x;
tr[num].sum+=(tr[num].r-tr[num].l+1)*x;
return;
}
pushdown(num);
change(num<<1,l,r,x);
change(num<<1|1,l,r,x);
update(num);
}
long long query(long long num,long long l,long long r)
{
if(r<tr[num].l||l>tr[num].r)
return 0;
if(tr[num].l>=l&&r>=tr[num].r)
return tr[num].sum;
pushdown(num);
return query(num<<1,l,r)+query(num<<1|1,l,r);
}
long long find(long long x,long long y)
{
long long f1=top[x],f2=top[y];
long long ans=0;
while(f1!=f2)
{
if(dep[x]<dep[y])
{
swap(x,y);
swap(f1,f2);
}
ans+=query(1,ys[f1],ys[x]);
x=fa[f1];
f1=top[x];
}
if(dep[x]<dep[y])
ans+=query(1,ys[x],ys[y]);
else
ans+=query(1,ys[y],ys[x]);
return ans;
}
int main()
{
scanf("%lld%lld",&n,&m);
for(long long i=1;i<=n;i++)
scanf("%lld",&b[i]);
for(long long i=1;i<=n-1;i++)
{
long long x,y;
scanf("%lld%lld",&x,&y);
ins(x,y);
ins(y,x);
}
dfs1(1);
dfs2(1,1);
build(1,1,n);
for(long long i=1;i<=m;i++)
{
long long x,y,z;
scanf("%lld",&x);
if(x==1)
{
scanf("%lld%lld",&y,&z);
change(1,ys[y],ys[y],z);
}
if(x==2)
{
scanf("%lld%lld",&y,&z);
change(1,ys[y],end[y],z);
}
if(x==3)
{
scanf("%lld",&y);
printf("%lld\n",find(y,1));
}
}
return 0;
}