题意:
一棵树每个点都有权值,有三种操作
1 从x节点走到y节点,并将路径中的点的权值都取出来
2 将x节点的权值减去
3 将以x为根节点的子树的所有节点的值取出来。
每次操作后查询一次现在取出来的值为多少。
解题思路:
入门树剖+线段树题吧
第一次写树剖,写错了一个地方wa了好几发。
while(ty!=tx)
{
if(deep[tx]>deep[ty]) //比较深度的应该是tx和ty不是x和y
{
update(1, 1, n, tid[top[x]], tid[x], 1);
x=fa[top[x]], tx=top[x];
}
else
{
update(1, 1, n, tid[top[y]], tid[y], 1);
y=fa[top[y]], ty=top[y];
}
}
dfs序可以把树上的节点编号,而树剖可以让一条链上的节点编号都是连续的,然后当我们需要两个点路径的权值的时候,我们就可以通过让两个点所在的链不断上升,直到他们在同一条链上,由于树剖的性质,这个过程最多只需要logn次,这是时间上的优化。
第一种操作就是用树剖实现了,通过树剖找到我们需要更新的连续的编号区间, 然后用线段树维护。具体的我就不讲了,看下卿学姐的视频就清楚了。
第二种操作就是普通的线段树单点更新
第三种的话,由于一颗子树上的dfs序也一定是连续的,所以维护也简单了。
代码:
#include <bits/stdc++.h>
#define ps push_back
#define lson o<<1
#define rson o<<1|1
#define LL long long
using namespace std;
const int maxn=1e5+5;
struct p
{
LL x;
int lazy;
LL sum;
void init()
{
x=0;
lazy=-1;
}
}t[maxn<<4];
int re[maxn];
int son[maxn];
int fa[maxn];
int top[maxn];
int deep[maxn];
int tid[maxn];
int siz[maxn];
vector<int>edg[maxn];
LL val[maxn];
int cnt, n;
void dfs1(int x, int f)
{
int i, j;
fa[x]=f;
son[x]=-1;
siz[x]=1;
for(i=0; i<(int)edg[x].size(); i++)
{
if(edg[x][i]!=f)
{
dfs1(edg[x][i], x);
siz[x]+=siz[edg[x][i]];
if(son[x]==-1||siz[son[x]]<siz[edg[x][i]])
{
son[x]=edg[x][i];
}
}
}
return;
}
void dfs2(int x, int TOP, int de)
{
re[cnt]=x, tid[x]=cnt++, top[x]=TOP, deep[x]=de;
int i;
if(son[x]!=-1)
{
dfs2(son[x], TOP, de+1);
}
for(i=0; i<(int)edg[x].size(); i++)
{
if(edg[x][i]!=fa[x] && edg[x][i]!=son[x])
{
dfs2(edg[x][i], edg[x][i], de+1);
}
}
return;
}
void update(int o, int l, int r, int ll, int rr, int x)
{
if(ll<=l && r<=rr)
{
if(x==0)t[o].x=0;
else t[o].x=t[o].sum,t[o].lazy=x;
return;
}
if(t[o].lazy!=-1)
{
t[lson].x=t[lson].sum;
t[rson].x=t[rson].sum;
t[lson].lazy=t[rson].lazy=t[o].lazy;
t[o].lazy=-1;
}
int mid=(l+r)>>1;
if(ll<=mid)update(lson, l, mid, ll, rr, x);
if(rr>mid)update(rson, mid+1, r, ll, rr, x);
t[o].x=t[lson].x+t[rson].x;
// printf("%d %d %d\n", l, r, t[o].x);
return;
}
void UPD(int x, int y)
{
int ty=top[y], tx=top[x];
while(ty!=tx)
{
// printf("%d %d %d %d\n", x, tx, y, ty);
if(deep[tx]>deep[ty])
{
// printf("l,r %d %d %d\n", top[x], tid[top[x]], tid[x]);
update(1, 1, n, tid[top[x]], tid[x], 1);
x=fa[top[x]], tx=top[x];
}
else
{
update(1, 1, n, tid[top[y]], tid[y], 1);
y=fa[top[y]], ty=top[y];
}
}
// printf("%d %d\n", x, y);
if(deep[x]<deep[y])
{
update(1, 1, n, tid[x], tid[y], 1);
}
else
{
// printf("l,r %d %d\n", tid[y], tid[x]);
update(1, 1, n, tid[y], tid[x], 1);
}
return;
}
void build(int o, int l, int r)
{
t[o].init();
if(l==r)
{
t[o].sum=val[re[l]];
return;
}
int mid=(l+r)>>1;
build(lson, l, mid);
build(rson, mid+1, r);
t[o].sum=t[lson].sum+t[rson].sum;
return;
}
int main()
{
int m, i, j;
cin>>m;
while(m--)
{
scanf("%d", &n);
for(i=1; i<=n; i++)
{
scanf("%lld", &val[i]);
edg[i].clear();
}
int x, y;
for(i=1; i<n; i++)
{
scanf("%d%d", &x, &y);
edg[x].ps(y);
edg[y].ps(x);
}
cnt=1;
dfs1(1, 0);
dfs2(1, 1, 1);
// for(i=1; i<=n; i++)printf("%d %d\n", son[i], tid[i]);printf("\n");
build(1, 1, n);
int q, op;
scanf("%d", &q);
while(q--)
{
scanf("%d", &op);
if(op==1)
{
scanf("%d%d", &x, &y);
UPD(x, y);
printf("%lld\n", t[1].x);
}
else if(op==2)
{
scanf("%d", &x);
update(1, 1, n, tid[x], tid[x], 0);
printf("%lld\n", t[1].x);
}
else
{
scanf("%d", &x);
update(1, 1, n, tid[x], tid[x]+siz[x]-1, 1);
printf("%lld\n", t[1].x);
}
}
}
}