题目
https://loj.ac/problem/2187
思路
好难啊!!!!
首先有一个很显然的性质,对于每个修改,影响到的节点一定是从野子节点往上走的一条路径。
现在问题是怎么维护这条路径的终点。
继续观察性质,发现一个节点改变颜色,当且仅当(sum为儿子为1的个数)
- 0->1 sum=1
- 1->0 sum=2
所以我们的splay需要维护的是第一个sum!=1和第一个sum!=2的位置
有一种思路是二分,但是这样多了一个log,考虑另一种:在splay里维护id[1],id[2]表示最深的sum!=1和sum!=2的点
对于LCTm,如果一个点存在id1,那就吧sum改一下,否则这条路径将走到根,直接改根的sum即可
这道题细节很多:
- 修改时不能把叶子节点放splay里,因为sum=0没意义
- 在把fa[x]旋根之后,一定是修改右子树,而不是修改整个子树,因为左子树的信息并不会改变。
- 在修改了右子树之后,不要忘记对fa[x]进行单点修改。
代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=2e6+77;
int n,m,cnt;
struct node
{
int ch[2],id[3],fa,bj,val,sum;
}tr[N];
void update(int t)
{
tr[t].id[1]=tr[tr[t].ch[1]].id[1];
tr[t].id[2]=tr[tr[t].ch[1]].id[2];
if(!tr[t].id[1])
{
if(tr[t].sum!=1)tr[t].id[1]=t;
else tr[t].id[1]=tr[tr[t].ch[0]].id[1];
}
if(!tr[t].id[2])
{
if(tr[t].sum!=2)tr[t].id[2]=t;
else tr[t].id[2]=tr[tr[t].ch[0]].id[2];
}
}
void add(int t,int x)
{
tr[t].sum+=x;tr[t].val=tr[t].sum>1;
swap(tr[t].id[1],tr[t].id[2]);
tr[t].bj+=x;
}
void down(int t)
{
if(tr[t].bj)
{
if(tr[t].ch[0])add(tr[t].ch[0],tr[t].bj);
if(tr[t].ch[1])add(tr[t].ch[1],tr[t].bj);
tr[t].bj=0;
}
}
bool isroot(int x)
{
return tr[tr[x].fa].ch[0]==x||tr[tr[x].fa].ch[1]==x;
}
int st[N];
void rotate(int x)
{
int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x,w=tr[x].ch[k^1];
if(isroot(y))tr[z].ch[tr[z].ch[1]==y]=x;
tr[x].ch[k^1]=y;tr[y].ch[k]=w;
if(w)tr[w].fa=y;tr[y].fa=x;tr[x].fa=z;
update(y);update(x);
}
void splay(int x)
{
int top=0,y=x;
st[++top]=x;
while(isroot(y))
{
y=tr[y].fa;
st[++top]=y;
}
while(top)down(st[top--]);
int z;
while(isroot(x))
{
y=tr[x].fa,z=tr[y].fa;
if(isroot(y))(tr[z].ch[0]==y)^(tr[y].ch[0]==x)?rotate(x):rotate(y);
rotate(x);
}
update(x);
}
void access(int x)
{
for(int y=0; x; y=x,x=tr[x].fa)
{
splay(x);
tr[x].ch[1]=y;
update(x);
}
}
struct E
{
int to,next;
}e[N<<1];
int ls[N];
void addedge(int u,int v)
{
e[++cnt].to=v; e[cnt].next=ls[u]; ls[u]=cnt;
}
void dfs(int x,int f)
{
tr[x].sum=0;
int v;
for(int i=ls[x]; i; i=e[i].next)
{
v=e[i].to;
if(v==f)continue;
dfs(v,x);
tr[x].sum+=tr[v].val;
}
if(x<=n) tr[x].val=tr[x].sum>1;
}
int main()
{
scanf("%d",&n);
int x;
for(int i=1; i<=n; i++)
{
for(int j=1; j<=3; j++)
{
scanf("%d",&x);
tr[x].fa=i;
addedge(x,i); addedge(i,x);
}
}
for(int i=n+1; i<=n*3+1; i++) scanf("%d",&tr[i].val);
dfs(1,0);
scanf("%d",&m);
int last,w,addtag,ans=tr[1].val;
while(m--)
{
scanf("%d",&last);x=tr[last].fa;
addtag=tr[last].val?-1:1;
access(x);splay(x);
w=tr[x].id[(tr[last].val?2:1)];
if(w)
{
splay(w);
add(tr[w].ch[1],addtag);update(tr[w].ch[1]);
tr[w].sum+=addtag;tr[w].val=tr[w].sum>1;update(w);
}
else ans^=1,add(x,addtag),update(x);
tr[last].val^=1;
printf("%d\n",ans);
}
}