倍增理解
https://blog.csdn.net/jarjingx/article/details/8180560
倍增求LCA
https://blog.csdn.net/wjh2622075127/article/details/81060586
import java.util.*;
public class Main {
static final int N = (int)1e5 + 5;
static int n, m, idx;
static int[] h = new int[N];
static int[] ne = new int[N*2];
static int[] e = new int[N*2];
// 用来保存每个结点的父节点,如果没有父节点则为-1
static int[] from = new int[N*2];
static int[] q = new int[N];
// 用来记录每个结点的深度
static int[] depth = new int[N];
// 这里记录的是向上跳2^j层,能达到的父节点。
static int[][] fa = new int[N][18];
static void add(int a, int b) {
e[idx] = b;
ne[idx] = h[a];
h[a] = idx ++;
}
static void init() {
// System.out.println("INIT BEGIN");
Arrays.fill(h, -1);
Arrays.fill(from, -1);
}
// 循环遍历,求出每个点的深度
static void bfs(int root) {
// System.out.println("BFS BEGIN");
Arrays.fill(depth, 0x3f3f3f3f);
depth[0] = 0;
depth[root] = 1;
int hh = 0;
int tt = 0;
q[hh] = root;
while (hh <= tt) {
int t = q[hh ++];
for (int i = h[t]; i != -1; i = ne[i]) {
int j = e[i];
if (depth[j] > depth[t] + 1) {
depth[j] = depth[t] + 1;
q[++ tt] = j;
fa[j][0] = t;
for (int k = 1; k <= 17; ++ k) {
// 这里是求出j点向上调多少层能够到达的点
fa[j][k] = fa[fa[j][k-1]][k-1];
}
}
}
}
}
static int lca(int a, int b) {
// System.out.println("LCA BEGIN");
if (depth[a] < depth[b]) {
int t = a;
a = b;
b = t;
}
for (int k = 17; k >= 0; -- k) {
if (depth[fa[a][k]] >= depth[b]) {
a = fa[a][k];
}
}
if (a == b) return a;
for (int k = 17; k >= 0; -- k) {
if (depth[fa[a][k]] != depth[fa[b][k]]) {
a = fa[a][k];
b = fa[b][k];
}
}
return fa[a][0];
}
public static void main(String[] args) {
Scanner cin = new Scanner(System.in);
init();
n = cin.nextInt();
m = cin.nextInt();
for (int i = 0; i < n - 1; ++i) {
int a = cin.nextInt();
int b = cin.nextInt();
add(a, b);
from[b] = a;
}
int root = 0;
for (int i = 1; i <= n; ++i)
if (from[i] == -1) {
root = i;
break;
}
bfs(root);
while (m-- > 0) {
int a = cin.nextInt();
int b = cin.nextInt();
int p = lca(a, b);
if (p != a) System.out.println("NO");
else System.out.println("YES");
}
}
}