裸的树链剖分。。
不会的同学可以戳这。。
我觉得可以直接上代码。。
这次并没有动态开点。。
Code:
#include<cstdio>
#include<cstdlib>
#define ll long long
struct node{int y,next;}a[200010];
struct tree{int l,r,lc,rc;ll c,lz;}tr[200010];
int first[100010],dep[100010],tot[100010],son[100010],top[100010],fa[100010],image[100010],fact[100010],d[100010];
int n,m,len(0);
void ins(int x,int y){a[++len]=(node){y,first[x]};first[x]=len;}
void dfs_1(int x)
{
tot[x]=1;
for(int i=first[x];i;i=a[i].next)
{
int y=a[i].y;
if(y!=fa[x])
{
fa[y]=x;
dep[y]=dep[x]+1;
dfs_1(y);
if(tot[y]>tot[son[x]]) son[x]=y;
tot[x]+=tot[y];
}
}
}
void dfs_2(int x,int tp)
{
top[x]=tp;image[x]=++len;
if(son[x]) dfs_2(son[x],tp);
for(int i=first[x];i;i=a[i].next)
{
int y=a[i].y;
if(y!=fa[x] && y!=son[x]) dfs_2(y,y);
}
fact[x]=len;
}
void bt(int l,int r)
{
len++;int now=len;
tr[now]=(tree){l,r,-1,-1,0,0};
if(l<r)
{
int mid=(l+r)/2;
tr[now].lc=len+1;bt(l,mid);
tr[now].rc=len+1;bt(mid+1,r);
}
}
void pushdown(int now)
{
if(!tr[now].lz) return;
int lc=tr[now].lc,rc=tr[now].rc;
tr[lc].c+=(tr[lc].r-tr[lc].l+1)*tr[now].lz;
tr[rc].c+=(tr[rc].r-tr[rc].l+1)*tr[now].lz;
tr[lc].lz+=tr[now].lz;
tr[rc].lz+=tr[now].lz;
tr[now].lz=0;
}
void change(int now,int l,int r,ll c)
{
if(tr[now].l==l && tr[now].r==r)
{
tr[now].c+=(r-l+1)*c;
tr[now].lz+=c;
return;
}
pushdown(now);
int mid=(tr[now].l+tr[now].r)/2;
int lc=tr[now].lc,rc=tr[now].rc;
if(r<=mid) change(lc,l,r,c);
else if(mid<l) change(rc,l,r,c);
else
{
change(lc,l,mid,c);
change(rc,mid+1,r,c);
}
tr[now].c=tr[lc].c+tr[rc].c;
}
ll findsum(int now,int l,int r)
{
if(tr[now].l==l && tr[now].r==r) return tr[now].c;
pushdown(now);
int mid=(tr[now].l+tr[now].r)/2;
int lc=tr[now].lc,rc=tr[now].rc;
if(r<=mid) return findsum(lc,l,r);
else if(mid<l) return findsum(rc,l,r);
else return findsum(lc,l,mid)+findsum(rc,mid+1,r);
}
void solve(int x)
{
int tx=top[x];
ll ans=0;
while(x!=0)
{
ans+=findsum(1,image[tx],image[x]);
x=fa[tx];tx=top[x];
}
printf("%lld\n",ans);
}
int main()
{
scanf("%d %d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&d[i]);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d %d",&x,&y);
ins(x,y);ins(y,x);
}
dep[1]=1;fa[1]=0;dfs_1(1);
len=0;dfs_2(1,1);
len=0;bt(1,n);
for(int i=1;i<=n;i++) change(1,image[i],image[i],d[i]);
for(int i=1;i<=m;i++)
{
int p,x;ll c;
scanf("%d %d",&p,&x);
if(p==1 || p==2) scanf("%lld",&c);
if(p==1) change(1,image[x],image[x],c);
if(p==2) change(1,image[x],fact[x],c);
if(p==3) solve(x);
}
}