花了差不多两个钟才AC这道题,所以写了下题解
这道题很明显,是求在基环树上求lca。
对于基环树的题,我们都是先将环去掉,然后在去掉环以后每一棵树上面搜索一次,查询的时候再将环插进去就好了。
特别地,要注意题目的限定条件。
参考代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
template<typename T>inline void read(T &x) {
x = 0; int f = 0; char s = getchar();
while (!isdigit(s)) f |= s=='-', s = getchar();
while ( isdigit(s)) x = (x<<3) + (x<<1) + (s-48), s = getchar();
x = f ? -x : x;
}
int ss, buf[31];
template<typename T>inline void print(T x) {
if (x < 0) x = -x, putchar('-');
do { buf[++ss] = int(x%10); x /= 10; }while(x);
while (ss) putchar(buf[ss--]+'0');
}
template<typename T>inline void write(T x) { print(x); putchar(' '); }
template<typename T>inline void writeln(T x) { print(x); puts(""); }
const int N = 5e5 + 10;
int n, q;
bool vis[N], bk[N];
int d[N], fa[N][23];
int tp, sta[N];
int cnt, id[N], ys[N], size[N];
//cnt环的个数,id所在环的编号,ys环上节点的编号,size[环编号]=该环节点数
void Dfs(int u) {
vis[u] = bk[u] = 1; sta[++tp] = u;
int v = fa[u][0];
if (bk[v]) {
int j = tp, z = 0; cnt++;
while(sta[j] != v) {
id[sta[j]] = cnt;
ys[sta[j]] = ++z;
j--;
}
id[sta[j]] = cnt;
ys[sta[j]] = ++z;
size[cnt] = z;
}
if (!vis[v] && !bk[v]) Dfs(v);
bk[u] = 0; sta[tp--] = 0;
}
int anc[N];
void Dfs2(int u) {
bk[u] = 1;
if (id[u]) {
anc[u] = u;
d[u] = 0;
fa[u][0] = 0;
return;
}
if (!bk[fa[u][0]]) Dfs2(fa[u][0]);
anc[u] = anc[fa[u][0]];
d[u] = d[fa[u][0]] + 1;
for (int i = 1; i <= 20; i++)
fa[u][i] = fa[fa[u][i-1]][i-1];
}
int Lca(int x, int y) {
if (d[x] < d[y]) swap(x, y);
for (int i = 20; i >= 0; i--)
if (d[x] - d[y] >= (1<<i))
x = fa[x][i];
if (x == y) return x;
for (int i = 20; i >= 0; i--)
if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
int main() {
cin >> n >> q;
for (int u = 1, v; u <= n; u++)
read(v), fa[u][0] = v;
for (int i = 1; i <= n; i++)
if (!vis[i]) Dfs(i);
for (int i = 1; i <= n; i++) Dfs2(i);
int x, y, s1, s2, p1, p2, ans1, ans2;
while (q--) {
read(x), read(y);
if (id[anc[x]] != id[anc[y]]) puts("-1 -1");
else if (anc[x] == anc[y]) {
int lca = Lca(x,y);
s1 = d[x] - d[lca], s2 = d[y] - d[lca];
write(s1); writeln(s2);
}
else {
s1 = d[x], s2 = d[y];
x = anc[x], y = anc[y];
bool bk = 0;
if (ys[x] < ys[y]) swap(x, y), swap(s1, s2), bk = 1;
p1 = ys[x] - ys[y] , p2 = size[id[y]] - p1;
if (bk) swap(p1, p2), swap(s1, s2);
if (max(s1+p1,s2) > max(s1,s2+p2))
ans1 = s1, ans2 = s2+p2;
else if(max(s1+p1,s2) == max(s1,s2+p2)) {
if (min(s1+p1,s2) < min(s1,s2+p2))
ans1 = s1+p1, ans2 = s2;
else if (min(s1+p1,s2) == min(s1,s2+p2)){
if (s1 + p1 >= s2) ans1 = s1+p1, ans2 = s2;
else ans1 = s1, ans2 = s2+p2;
}
else
ans1 = s1, ans2 = s2+p2;
}
else
ans1 = s1+p1, ans2 = s2;
write(ans1), writeln(ans2);
}
}
return 0;
}