《算法竞赛·快冲300题》将于2024年出版,是《算法竞赛》的辅助练习册。
所有题目放在自建的OJ New Online Judge。
用C/C++、Java、Python三种语言给出代码,以中低档题为主,适合入门、进阶。
“ 01树” ,链接: http://oj.ecustacm.cn/problem.php?id=1715
题目描述
【题目描述】 现在给你一个n个节点的树,而且每个节点有一个权值为0或者1。
现在有m次询问,每次询问输入两个节点x和y,以及一个权值k。
请你判断x和y的路径中是否存在权值为k的点。(包括x和y本身)
【输入格式】 输入第一行为两个正整数n和m,均为不超过10^5次方的正整数。
第二行是一个长度为n的01字符串,表示从节点1到节点n的权值。
接下来n-1行,每行两个数字u和v,表示节点u和v之间存在边。
接下来m行,每行输入三个数字x,y,k。其中x,y不相同,k为0或者1。 。
【输出格式】 对于每一次询问,如果x和y的路径中包含权值为k的点,输出Yes,否则输出No 。
【输入样例】
5 5
11010
1 2
2 3
2 4
1 5
1 4 1
1 4 0
1 3 0
1 3 1
5 5 1
【输出样例】
Yes
No
Yes
Yes
No
题解
本题简单的做法是先建树,然后每次查询用DFS搜索路径。任意两点之间有且只有一条路径,做一次DFS能找到这条路径,计算量O(n)。一共做m次查询,总复杂度O(mn),超时。
不过,本题特殊在于每个点的权值是0或1,查询也是查有没有等于0或1的点。查询一条路径时,如果能确定所有点都是1,或所有点都是0,或有0有1,那么就得到了答案。
把所有点按0和1分成多个子集,其中一些连通的1是一个子集,一些连通的0是一个子集。最后把整棵树分成很多权值为1的子集、权值为0的子集。权值为0的子集和权值为1的子集相邻。
对一个查询“x,y,k”:
(1)如果{x,y}属于一个子集,它们必然连通,且权值相同,权值为0或1。
(2)如果{x,y}不属于一个子集,它们要么是相邻的两个不同权值的子集,要么它们之间的路径穿过了一个不同权值的子集,两种情况下的路径上有1也有0。
以上讨论的实际上是并查集的操作。下面用带路径压缩的并查集编码,一次查询约为O(1),m次查询的总复杂度约为O(m)。。
【笔记】 。
C++代码
#include<bits/stdc++.h>
using namespace std;
char str[100010];
int s[100010]; //并查集
int find_set(int x){ //查询并查集,返回x的根
if(x != s[x]) s[x] = find_set(s[x]); //路径压缩
return s[x];
}
void merge_set(int x, int y){ //合并
x = find_set(x); y = find_set(y);
if(x != y) s[x] = s[y]; //把x合并到y上,y的根成为x的根
}
int main(){
int n, m;
scanf("%d %d",&n,&m);
scanf("%s",str+1);
for(int i = 1; i <= n; i++) s[i] = i; //并查集初始化
for(int i = 1; i < n; i++){
int u, v; scanf("%d %d",&u,&v);
if(str[u] == str[v]) merge_set(u,v); //合并
}
for(int i = 1; i <= m; i++){
int x, y; char k; scanf("%d %d %c",&x,&y,&k);
if(find_set(x) == find_set(y) && str[x] != k) //属于同一个子集,且权值不等于k
puts("No"); //比cout快
else //其他情况,既有0也有1
puts("Yes"); //比cout快
}
return 0;
}
Java代码
import java.util.Scanner;
public class Main {
static char[] str = new char[100010];
static int[] s = new int[100010];
static int findSet(int x) {
if (x != s[x]) s[x] = findSet(s[x]);
return s[x];
}
static void mergeSet(int x, int y) {
x = findSet(x);
y = findSet(y);
if (x != y) s[x] = s[y];
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int m = sc.nextInt();
String strInput = sc.next();
strInput.getChars(0, strInput.length(), str, 1);
for (int i = 1; i <= n; i++) s[i] = i;
for (int i = 1; i < n; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
if (str[u] == str[v]) mergeSet(u, v);
}
for (int i = 1; i <= m; i++) {
int x = sc.nextInt();
int y = sc.nextInt();
char k = sc.next().charAt(0);
if (findSet(x) == findSet(y) && str[x] != k) System.out.println("No");
else System.out.println("Yes");
}
}
}
Python代码
import sys
sys.setrecursionlimit(1000000) #注意要扩栈
str = [0] * 100010
s = [0] * 100010
def find_set(x):
if x != s[x]: s[x] = find_set(s[x])
return s[x]
def merge_set(x, y):
x = find_set(x)
y = find_set(y)
if x != y: s[x] = s[y]
n, m = map(int, input().split())
str[1:n+1] = input()
for i in range(1, n+1): s[i] = i
for i in range(n-1):
u, v = map(int, input().split())
if str[u] == str[v]: merge_set(u, v)
for i in range(m):
x, y, k = input().split()
x = int(x)
y = int(y)
if find_set(x) == find_set(y) and str[x] != k: print("No")
else: print("Yes")