Problem Description
wls有一棵有根树,其中的点从1到n标号,其中1是树根。每次wls可以执行两种操作中的一个:
(1)选定一个点x,将以x为根的子树变成一条按照编号排序的链,其中编号最大的作为新的子树的根(成为原来x的父亲节点的儿子,如果原来x没有父亲节点则新的子树的根也没有父亲节点)。
(2)查询两个点之间的最短路径上经过了多少边。
Input
第一行一个整数t表示数据组数 ( t ≤ 10 ) (t\le10) (t≤10)。
每组数据第一行一个正整数n表示树上的点数 ( 1 ≤ n ≤ 100000 ) (1\le n\le 100000) (1≤n≤100000)。
接下来n-1行每行两个1到n之间的正整数表示一条树边。
接下来一行一个正整数q表示询问的个数 ( 1 ≤ q ≤ 200000 ) (1\le q\le 200000) (1≤q≤200000)。
接下来q行每行表示一个操作。第一种操作格式为 1 x 1\ x 1 x,其中x为指定的树根。第二种操作格式为 2 x y 2\ x\ y 2 x y,表示查询从x到y的路径。
Output
对于每个第二种操作,输出一行一个正整数表示答案。
解题思路:
对于操作1,可以看成是缩点操作,并且因为缩完之后是按编号排序的链,所以对于一个结点u,查询它在这整个子树有多少个比它大的编号,就可以得到它距离链的头部的距离。
所以查询的时候有三种情况:
- x,y都被缩过点,此时分为此时判断一下x,y是否属于同一条链。如果属于同一条链就求它们在链内的距离,否则用lca求两个子树的根的距离,再加上它们分别在链内距离链头的距离。
- 一个被缩过,求lca得到链头和另一个点的距离,然后加上缩过的点在链内的距离。
- 都没被缩过,直接求lca可以得到距离。
ac代码:
#include<bits/stdc++.h>
#define ll long long
#define mid ((l+r)>>1)
using namespace std;
const int maxn = 2e5 + 50;
int T[maxn], lc[maxn*20], rc[maxn*20], sum[maxn*20];//
int fa[maxn], f[maxn][20], dep[maxn];
int dfn[maxn];
int id[maxn], sz[maxn];
int idx;
vector<int> g[maxn];
int fnd(int x)
{
if(x == fa[x]) return x;
return fa[x] = fnd(fa[x]);
}
int n, q;
int lca(int u, int v){
if(dep[u] > dep[v]) swap(u, v);
int d = dep[v] - dep[u];
for(int i = 19; i >= 0; --i) if(d&(1<<i)) v = f[v][i];
if(u == v) return u;
for(int i = 19; i >= 0; --i){
if(f[u][i] == f[v][i]) continue;
u = f[u][i], v = f[v][i];
}
return f[u][0];
}
int dis(int u, int v){
return dep[u] + dep[v] - 2*dep[lca(u, v)];
}
void dfs(int u, int father, int d){//获取子树大小,结点深度,dfs序列,结点dfs的编号,父节点倍增数组
sz[u] = 1; dep[u] = d;
dfn[++idx] = u; id[u] = idx;
f[u][0] = father;
for(int i = 1; i < 20; ++i) f[u][i] = f[f[u][i-1]][i-1];
for(int i = 0; i < g[u].size(); ++i){
int v = g[u][i];
if(v == father) continue;
dfs(v, u, d + 1);
sz[u] += sz[v];
}
return;
}
void init(){
scanf("%d", &n);
for(int i = 1; i <= n; ++i) fa[i] = 0, g[i].clear();
for(int i = 1; i < n; ++i){
int u, v;scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
idx = 0;
dfs(1, 0, 1);
}
void go(int u, int dio){//把x子树的点全缩到kono Dio哒!
if(fa[u]){
fa[fnd(fa[u])] = dio; return;
}
fa[u] = dio;
for(int i = 0; i < g[u].size(); ++i){
int v = g[u][i];
if(v == f[u][0]) continue;
go(v, dio);
}
}
int tot;
void build(int pre, int &cur, int l, int r, int pos)
{
cur = ++tot;
lc[cur] = lc[pre]; rc[cur] = rc[pre]; sum[cur] = sum[pre]+1;
if(l == r) return;
if(pos <= mid)
build(lc[pre], lc[cur], l, mid, pos);
else
build(rc[pre], rc[cur], mid+1, r, pos);
}
int qry(int pre, int cur, int l, int r, int pos){//找区间内比pos大的数有几个
//cout<<"l:"<<l<<" r:"<<r<<" sz:"<<sum[cur]-sum[pre]<<endl;
if(l > pos)
return sum[cur] - sum[pre];
if(pos >= r || l == r)
return 0;
int res = qry(rc[pre], rc[cur], mid+1, r, pos);
if(pos < mid) res += qry(lc[pre], lc[cur], l, mid, pos);
return res;
}
int opt2(int u, int v)
{
if(fa[u] && fa[v]){//两个都被缩过
int ru = fnd(u), rv = fnd(v);
if(ru == rv){
int d1 = qry(T[id[ru]-1], T[id[ru]+sz[ru]-1], 1, n, u);
int d2 = qry(T[id[ru]-1], T[id[ru]+sz[ru]-1], 1, n, v);
return abs(d1 - d2);
}
int d = dis(ru, rv);
int d1 = qry(T[id[ru]-1], T[id[ru]+sz[ru]-1], 1, n, u);
int d2 = qry(T[id[rv]-1], T[id[rv]+sz[rv]-1], 1, n, v);
return d+d1+d2;
}
else if(!fa[u] && !fa[v]){
return dis(u, v);
}
else{
if(!fa[u]) swap(u, v);//u被缩过,v没有被缩过
int ru = fnd(u);
int d = dis(ru, v);
int d1 = qry(T[id[ru]-1], T[id[ru]+sz[ru]-1], 1, n, u);
return d+d1;
}
}
void sol()
{
scanf("%d", &q);
tot = 0;
for(int i = 1; i <= n; ++i)
build(T[i-1], T[i], 1, n, dfn[i]);
while(q--){
int op;
scanf("%d", &op);
if(op == 1){
int x; scanf("%d", &x);
if(fa[x]) continue;//缩过点了
go(x, x);
}
else {
int u, v; scanf("%d%d", &u, &v);
int ans = opt2(u, v);
printf("%d\n", ans);
}
}
}
int main()
{
int T; cin>>T;
while(T--){
init();sol();
}
}
/*
1
9
1 2
1 7
2 4
2 5
7 8
7 6
8 9
6 3
*/