仓鼠找 sugar
题目描述
小仓鼠的和他的基(mei)友(zi)sugar住在地下洞穴中,每个节点的编号为 1 1 1~ n n n。地下洞穴是一个树形结构。这一天小仓鼠打算从从他的卧室( a a a)到餐厅( b b b),而他的基友同时要从他的卧室( c c c)到图书馆( d d d)。他们都会走最短路径。现在小仓鼠希望知道,有没有可能在某个地方,可以碰到他的基友?
小仓鼠那么弱,还要天天被 zzq 大爷虐,请你快来救救他吧!
输入格式
第一行两个正整数 n n n 和 q q q,表示这棵树节点的个数和询问的个数。
接下来 n − 1 n-1 n−1 行,每行两个正整数 u u u 和 v v v,表示节点 u u u 到节点 v v v 之间有一条边。
接下来 q q q 行,每行四个正整数 a 、 b 、 c a、b、c a、b、c 和 d d d,表示节点编号,也就是一次询问,其意义如上。
输出格式
对于每个询问,如果有公共点,输出大写字母 “Y”;否则输出 “N”。
样例 #1
样例输入 #1
5 5
2 5
4 2
1 3
1 4
5 1 5 1
2 2 1 4
4 1 3 4
3 1 1 5
3 5 1 4
样例输出 #1
Y
N
Y
Y
Y
提示
20 20 20%的数据 n < = 200 , q < = 200 n<=200,q<=200 n<=200,q<=200
40 40 40%的数据 n < = 2000 , q < = 2000 n<=2000,q<=2000 n<=2000,q<=2000
70 70 70%的数据 n < = 50000 , q < = 50000 n<=50000,q<=50000 n<=50000,q<=50000
100 100 100%的数据 n < = 100000 , q < = 100000 n<=100000,q<=100000 n<=100000,q<=100000
笨蒟蒻的个人题解
先解释一下题意:
就是从树上两个起点分别到两个终点之间走,是否有重合的可能?
OK,咱们翻译一下这是啥意思啊,就是树上两条路径是否有相交。
就拿样例来说吧:
用笔画一下两条路径看看就OK。(不多说了
笨蒟蒻初次看到这题是这样想的,首先找两个起点的最近公共祖先 u u u,然后判断 u u u 是否分别在两条路径上。(肥肠煎蛋
然后就 WA 了hhh。
if((dist(a, b) == dist(a, u) + dist(u, b)) && (dist(c, d) == dist(c, u) + dist(u, d))) cout << "Y" << endl;
else cout << "N" << endl;
就这样想的,明显有问题啊。
然后我注意到了那个第二个 3514 3514 3514 的样例,总觉得这不对啊,这样例只是凑巧,换一些就不成立了。
然后我换了种思路,我这么想啊,就是我先找一条路径的结点最浅的点,然后判断其是否在另一条路径上,似乎没啥问题啊,于是我改成了这样的写法。
if((dist(a, b) == dist(a, u) + dist(u, b)) && (dist(c, d) == dist(c, u) + dist(u, d)))
{
cout << "Y" << endl;
continue;
}
if(depth[a] > depth[b]) swap(a, b);
if(dist(a, u) == dist(a, b) + dist(b, u))
{
if((dist(c, u) == dist(c, b) + dist(b, u)) || dist(d, u) == dist(d, b) + dist(b, u))
{
cout << "Y" << endl;
continue;
}
}
if(depth[c] > depth[d]) swap(c, d);
if(dist(c, u) == dist(c, d) + dist(d, u))
{
if((dist(a, u) == dist(a, d) + dist(d, u)) || dist(b, u) == dist(b, d) + dist(d, u))
{
cout << "Y" << endl;
continue;
}
}
cout << "N" << endl;
果不其然又 WA 了 hhh,然后想了十几分钟,终于知道怎么写了啊。
这题得要考虑一下这种情况。
当时直播写题的时候也是画的这个图。图中两条线我已经标记好了,请问这种如果用上面的方法,就是错误的。所以我们不能这么做。那正确的做法应该是啥呢?我首先想到是先找一条路径的最近公共祖先 x 1 x1 x1,然后判断其是否在另一条路径上,另外那个也是这样操作。
凭啥这样可以?当时我也是突发奇想,不管了,只是抱着试试猜想的态度,可一不小心就 ac 呀。
那现在我来仔细分析一下。首先因为这是一棵树,它只有一个父节点,换句话说,一个结点向上走的路径有且只有一条。
a
−
>
b
a->b
a−>b 有可能有好多条路径,当然你可以往上继续爬啊,但是有这些结点是必须要走的
a
−
>
u
,
u
−
>
b
a->u,u->b
a−>u,u−>b,可毕竟是最短路径啊,所以只有一条。那么这个最近公共祖先就尤为的重要,它是最浅的点,过了这个点就会向下拐而
a
−
>
c
a->c
a−>c 或者
b
−
>
c
b->c
b−>c 也会存在一个点
v
v
v,假如说这个
d
e
p
t
h
[
u
]
>
d
e
p
t
h
[
v
]
depth[u] > depth[v]
depth[u]>depth[v] 那么就一定有解了,而且
u
u
u 的孙子结点一定有
v
v
v,因为它在路径上,正如我之前说的,到了那个点会往下拐。
关键代码如下:
int x1 = lca(a, b), x2 = lca(c, d);
if((dist(c, x2) == dist(c, x1) + dist(x1, x2)) || (dist(d, x2) == dist(d, x1) + dist(x1, x2)))
{
cout << "Y" << endl;
continue;
}
else if((dist(a, x1) == dist(x2, x1) + dist(a, x2)) || (dist(b, x1) == dist(b, x2) + dist(x2, x1)))
{
cout << "Y" << endl;
continue;
}
cout << "N" << endl;
那么怎么找 lca 呢?这里我们可以用倍增,倍增是一种在线的做法,先预处理一下所有结点的深度
d
e
p
t
h
depth
depth,以及到根节点的距离
d
i
s
t
dist
dist,还有每个点跳
2
k
2^k
2k 的祖宗结点记为
f
a
fa
fa,最后求 lca 的时候记得从后往前跳,用的思想是二进制拼凑,最后的两点间距离就是dist[a] + dist[b] - 2 * dist[lca(a, b)]
。时间复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)。
做个小总结吧:
— > —> —> 要询问树上两条路径是否相交:
第一步:分别找两条路径的最近公共祖先结点
x
1
,
x
2
x1, x2
x1,x2;
第二步:分别判断他们各自的一个最近公共祖先结点是否在另一条路径上。
献上笨蒟蒻当时直播写的一个多小时的代码。
AC_Code
#include<iostream>
#include<algorithm>
#include<string>
#include<string.h>
#include<vector>
#include<cmath>
#include<map>
using namespace std;
#define int long long
#define fi first
#define se second
#define inf 0x3f3f3f3f
typedef pair<int, int>PII;
const int N = 1e5 + 10, M = N * 2;
int n, m;
int e[M], ne[M], h[N], idx;
int depth[N], q[N], fa[N][18];
int dis[N];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void bfs(int u)
{
memset(depth, inf, sizeof(depth));
depth[0] = 0, depth[u] = 1, q[0] = 1;
int hh = 0, tt = 0;
while(hh <= tt)
{
int t = q[hh++];
for(int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if(depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
q[++tt] = j;
fa[j][0] = t;
for(int k = 1; k < 18; k++)
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
int lca(int a, int b)
{
if(depth[a] < depth[b]) return lca(b, a);
for(int k = 17; k >= 0; k--)
if(depth[fa[a][k]] >= depth[b]) a = fa[a][k];
if(a == b) return a;
// cout << a << ' ' << b << ',';
for(int k = 17; k >= 0; k--)
if(fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
int dist(int a, int b)
{
return depth[a] + depth[b] - 2 * depth[lca(a, b)];
}
void solve()
{
cin >> n >> m;
memset(h, -1, sizeof(h));
for(int i = 1; i < n; i++)
{
int a, b;
cin >> a >> b;
add(a, b), add(b, a);
}
bfs(1);
while(m--)
{
int a, b, c, d;
cin >> a >> b >> c >> d;
int x1 = lca(a, b), x2 = lca(c, d);
if((dist(c, x2) == dist(c, x1) + dist(x1, x2)) || (dist(d, x2) == dist(d, x1) + dist(x1, x2)))
{
cout << "Y" << endl;
continue;
}
else if((dist(a, x1) == dist(x2, x1) + dist(a, x2)) || (dist(b, x1) == dist(b, x2) + dist(x2, x1)))
{
cout << "Y" << endl;
continue;
}
cout << "N" << endl;
}
}
signed main()
{
int T = 1;
// cin >> T;
while(T--)
solve();
return 0;
}