wls有一棵有根树,其中的点从1到n标号,其中1是树根。每次wls可以执行两种操作中的一个:
(1)选定一个点x,将以x为根的子树变成一条按照编号排序的链,其中编号最大的作为新的子树的根(成为原来x的父亲节点的儿子,如果原来x没有父亲节点则新的子树的根也没有父亲节点)。
(2)查询两个点之间的最短路径上经过了多少边。
对每一个点都建一个线段树。
对于操作1
将x和其子树所有点进行合并。
对于操作2的查询
如果两个点都没有被拉成链,则直接求ans=dis(u)+dis(v)-dis(lca(u,v))*2。
如果两个点在同一个链上,则通过线段树求出在链上2点的距离(即U,V间有多少个数)。
如果不再同一个链上,则ans=dis(U的链头)+dis(V的链头)-lca(U的链头,V的链头)*2+U到链头的距离(即大于U的点的个数)+V到链头的距离(即大于V的点的个数)
#include<map>
#include<stack>
#include<queue>
#include<cstdio>
#include<algorithm>
#include<vector>
#include <assert.h>
#include<cstring>
#include<cmath>
#include<iostream>
#include<string>
#include<bitset>
using namespace std;
typedef long long ll;
#define mid ((l+r)>>1)
const int N = 200020;
int rt[N * 40], sum[N * 40], ls[N * 40], rs[N * 40], fa[N][20], depth[N], flag[N];
int fb[N];
vector<int>e[N];
int n;
int tot;
int newNode() {
tot++;
ls[tot] = rs[tot] = sum[tot] = 0;
return tot;
}
int build(int &p, int l, int r, int x) {
p = newNode();
if (l == r) {
sum[p] = 1;
return p;
}
if (x <= mid)
build(ls[p], l, mid, x);
else
build(rs[p], mid + 1, r, x);
sum[p] = 1;
}
void dfs(int p, int f, int dep) {
flag[p] = 0;
build(rt[p], 1, n, p);
fa[p][0] = f;
depth[p] = dep;
for (int i = 0; i < e[p].size(); i++) {
int v = e[p][i];
if (v != f) {
dfs(v, p, dep + 1);
}
}
}
int Union(int u, int v, int l, int r) {
if (u == 0 || v == 0)return u + v;
int p = newNode();
if (l == r) {
sum[p] = sum[u] + sum[v];
return p;
}
ls[p] = Union(ls[u], ls[v], l, mid);
rs[p] = Union(rs[u], rs[v], mid + 1, r);
sum[p] = sum[rs[p]] + sum[ls[p]];
return p;
}
int find(int x) {
if (x == fb[x])return x;
return fb[x] = find(fb[x]);
}
void dfs1(int p, int fg) {
if (flag[p]) {
fb[p] = fg;
return;
}
fb[p] = fg;
flag[p] = 1;
for (int i = 0; i < e[p].size(); i++) {
int v = e[p][i];
if (v != fa[p][0]) {
dfs1(v, fg);
rt[p] = Union(rt[p], rt[v], 1, n);
}
}
}
int getlca(int x, int y) {
if (depth[x] < depth[y]) {
swap(x, y);
}
for (int i = 17; i >= 0; i--) {
if ((1 << i) <= depth[x] - depth[y]) {
x = fa[x][i];
}
}
if (x == y)return x;
for (int i = 17; i >= 0; i--) {
if (fa[x][i] != fa[y][i]) {
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
int query(int p, int l, int r, int x, int y) {
//if (p == 0)return 0;
if (l == x && y == r) {
return sum[p];
}
if (y <= mid) {
return query(ls[p], l, mid, x, y);
}
else if (x > mid) {
return query(rs[p], mid + 1, r, x, y);
}
else {
return query(ls[p], l, mid, x, mid) + query(rs[p], mid + 1, r, mid + 1, y);
}
}
int main()
{
int u, v;
int t;
scanf("%d", &t);
while (t--) {
tot = 0;
scanf("%d", &n);
for (int i = 1; i < n; i++) {
scanf("%d%d", &u, &v);
e[u].push_back(v);
e[v].push_back(u);
}
for (int i = 1; i <= n; i++)fb[i] = i;
memset(fa, 0, sizeof(fa));
dfs(1, 0, 0);
for (int i = 1; i <= 17; i++) {
for (int j = 1; j <= n; j++) {
fa[j][i] = fa[fa[j][i - 1]][i - 1];
}
}
int q, f;
scanf("%d", &q);
for (int i = 0; i < q; i++) {
scanf("%d", &f);
if (f == 1) {
scanf("%d", &u);
if (!flag[u])dfs1(u, u);
}
else {
scanf("%d%d", &u, &v);
int x = find(u);
int y = find(v);
int ans = 0;
if (x == y) {
if (u < v)swap(u, v);
ans = query(rt[x], 1, n, v, u) - 1;
}
else {
int lca = getlca(x, y);
ans = depth[x] + depth[y] - depth[lca] * 2;
ans += sum[rt[x]] - query(rt[x], 1, n, 1, u) + sum[rt[y]] - query(rt[y], 1, n, 1, v);
}
printf("%d\n", ans);
}
}
for (int i = 1; i <= n; i++)e[i].clear();
}
return 0;
}