题目描述
传送门
题目大意:给定一棵树,边的颜色为黑或白,初始时全部为白色。维护两个操作:
1.查询u到根路径上的第一条黑色边的标号。
2.将u到v 路径上的所有边的颜色设为黑色。
Notice:这棵树的根节点为1
题解
先将所有操作正着进行一遍,将所有的黑边相邻的点按照关系合并,就是一个集合中的代表元素一定是深度最小的点。
然后找出所有自始至终都是白色的边,以及每条边变黑的时间。将白边用并查集合并
倒着做所有的操作,对于染黑操作如果我们撤销相当于染白,将是所有在当前操作中变黑的边的两端用并查集合并,可以直接遍历路径,做法与合并黑边时类似。(也可以按照每条边变黑的时间排序,然后直接合并每一条边,这样就不用遍历路径了)
对于每次的查询操作直接找集合的代表元素,代表元素与其父节点之间的边就是答案。
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 1000003
using namespace std;
int tot,point[N],nxt[N*2],v[N*2],c[N*2];
int fa[N],belong[N],size[N],son[N],q[N];
int n,m,mark[N],pos[N],sz,deep[N],f[N],pd[N],pd1[N],ans[N];
struct data{
int opt,x,y;
}e[N],p[N];
int read()
{
char ch = getchar();
for ( ; ch > '9' || ch < '0'; ch = getchar());
int tmp = 0;
for ( ; '0' <= ch && ch <= '9'; ch = getchar())
tmp = tmp * 10 + int(ch) - 48;
return tmp;
}
void add(int x,int y,int num)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=num;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; c[tot]=num;
}
void dfs(int x,int f)
{
deep[x]=deep[f]+1;
for (int i=point[x];i;i=nxt[i]) {
if (v[i]==f) continue;
int t=c[i];
e[t].x=x; e[t].y=v[i];
fa[v[i]]=x;
dfs(v[i],x);
mark[v[i]]=c[i];
}
}
int find(int x)
{
if (f[x]==x) return x;
f[x]=find(f[x]);
return f[x];
}
void change(int x,int y,int opt)
{
x=find(x); y=find(y);
while (x!=y) {
if (deep[x]<deep[y]) swap(x,y);
if (!pd[x]) f[x]=f[fa[x]],pd[x]=opt;
x=f[x];
}
}
void solve(int x,int y,int opt)
{
x=find(x); y=find(y);
while (x!=y) {
if (deep[x]<deep[y]) swap(x,y);
if (pd[x]==opt) f[x]=f[fa[x]];
x=fa[x];
}
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1;i<n;i++) {
int x,y; x=read(); y=read();
add(x,y,i);
}
dfs(1,0);
for (int i=1;i<=n;i++) f[i]=i;
for (int i=1;i<=m;i++) {
scanf("%d%d",&p[i].opt,&p[i].x);
if (p[i].opt==2) scanf("%d",&p[i].y),change(p[i].x,p[i].y,i);
}
for (int i=1;i<=n;i++) f[i]=i;
for (int i=1;i<=n;i++) pd1[i]=pd[i];
for (int i=2;i<=n;i++)
if (!pd1[i]) {
int t=mark[i];
int r1=find(e[t].x); int r2=find(e[t].y);
f[r2]=r1;
}
int cnt=0;
for (int i=m;i>=1;i--) {
if (p[i].opt==1) {
int r1=find(p[i].x);
ans[++cnt]=mark[r1];
}
else solve(p[i].x,p[i].y,i);
}
for (int i=cnt;i>=1;i--) printf("%d\n",ans[i]);
}