题意: 给定一个含有n个节点的树(相邻两点间的距离为1),以及q个询问。
每一个询问给出5个整数:x,y,a,b,k。指如果像树中的x节点和y节点之间加一条边,问是否存在一条路径从a到b长度为k的路径?(注:每一次询问添加的边不会互相干扰,即只在该次询问有效)
思路:
- 首先我们看a与b之间的连通性,有三种有意义的最短路径L:
- a -> b,即不需要x与y之间的边a和b也能连通,距离记录为dist_ab。
- a -> x -> y -> b,记作dist_axyb。
- a -> y -> x -> b,记作dist_ayxb。
- 现在考虑最短路径与L的关系,显然若 L>k 是不合法的。那么对于 L <= k,我们可以考虑在某个(些)点反复走过,即找到一个z,使得 k == L+z*2;简而言之即L与k同奇偶性便是合法的。
- 而对于求(a,b)间的最短距离利用朴素的倍增LCA算法即可。在此感谢大佬博客1带给我的启发,尤其大佬博客2的讲解特别详细。
代码实现:
#include<bits/stdc++.h>
#define endl '\n'
#define null NULL
#define ll long long
#define int long long
#define pii pair<int, int>
#define lowbit(x) (x &(-x))
#define ls(x) x<<1
#define rs(x) (x<<1+1)
#define me(ar) memset(ar, 0, sizeof ar)
#define mem(ar,num) memset(ar, num, sizeof ar)
#define rp(i, n) for(int i = 0, i < n; i ++)
#define rep(i, a, n) for(int i = a; i <= n; i ++)
#define pre(i, n, a) for(int i = n; i >= a; i --)
#define IOS ios::sync_with_stdio(0); cin.tie(0);cout.tie(0);
const int way[4][2] = {{1, 0}, {-1, 0}, {0, 1}, {0, -1}};
using namespace std;
const int inf = 0x7fffffff;
const double PI = acos(-1.0);
const double eps = 1e-6;
const ll mod = 1e9 + 7;
const int N = 3e5 + 5;
int n, q;
int d[N], fa[N][22];
vector<int> g[N];
void dfs(int rt, int pre){
d[rt] = d[pre] + 1;
fa[rt][0] = pre;
for(int i = 1; i < 20; i ++) fa[rt][i] = fa[fa[rt][i-1]][i-1];
for(int i = 0; i < g[rt].size(); i ++){
if(g[rt][i] == pre) continue;
dfs(g[rt][i], rt);
}
}
int LCA(int x, int y){
if(d[x] < d[y]) swap(x, y);
for(int i = 19; ~i; i --)
if(d[x]-(1<<i) >= d[y]) x = fa[x][i];
if(x == y) return x;
for(int i = 19; ~i; i --){
if(fa[x][i] != fa[y][i]){
x = fa[x][i];
y = fa[y][i];
}
}
return fa[x][0];
}
int dist(int a, int b){
int c = LCA(a, b);
return d[a] + d[b] - 2ll*d[c];
}
signed main()
{
IOS;
cin >> n;
d[0] = -1;
for(int i = 1; i < n; i ++){
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
cin >> q;
while(q --){
int x, y, a, b, k;
cin >> x >> y >> a >> b >> k;
int dist_ab = dist(a, b);
int dist_axyb = dist(a, x) + 1 + dist(y, b);
int dist_ayxb = dist(a, y) + 1 + dist(x, b);
int ans = inf;
if(dist_ab%2 == k%2) ans = min(ans, dist_ab);
if(dist_axyb%2 == k%2) ans = min(ans, dist_axyb);
if(dist_ayxb%2 == k%2) ans = min(ans, dist_ayxb);
cout << (ans<=k ? "YES" : "NO") << endl;
}
return 0;
}