E. Count Paths:
题目大意:
思路解析:
根据题目中定义的美丽路径,我们可以发现路径只有两种情况:
当前结点作为起始结点,那我们只需要知道它的子树下有多少个相同颜色的结点,并且相同颜色的结点会被祖先结点阻挡,那我们只需要统计每个子树拥有x颜色的结点有多少个,记录时排除阻挡造成的影响,这样就可以完成对当前结点作为起始结点的答案统计
当前结点中间结点,那么答案只可能出现他的两个子树之间,我们发现上一步我们需要求所以子树当前结点的颜色共有多少个结点,即发现我们需要对这个子树的颜色统计进行合并,并且在合并过程中对此情况下答案进行统计。
这样就可以把这两种情况的答案统计完毕,
代码实现:
import java.io.*;
import java.math.BigInteger;
import java.util.*;
public class Main {
static long inf = (long) 2e18;
public static void main(String[] args) throws IOException {
int t = f.nextInt();
while (t > 0) {
solve();
t--;
}
w.flush();
w.close();
br.close();
}
static long ans = 0;
static Vector<Integer>[] g;
static int[] c;
static HashMap<Integer, Integer>[] cnt;
public static void solve() {
ans = 0;
int n = f.nextInt();
g = new Vector[n];
cnt = new HashMap[n];
for (int i = 0; i < n; i++) {
g[i] = new Vector<>();
cnt[i] = new HashMap<>();
}
c = new int[n];
for (int i = 0; i < n; i++) {
c[i] = f.nextInt();
}
for (int i = 0; i < n - 1; i++) {
int x = f.nextInt() - 1; int y = f.nextInt() - 1;
g[x].add(y); g[y].add(x);
}
dfs(0, -1);
w.println(ans);
}
public static void dfs(int x, int fa){
int bst = -1;
for (int i = 0; i < g[x].size(); i++) {
int y = g[x].get(i);
if (y == fa) continue;
dfs(y, x);
if (bst == -1 || cnt[bst].size() < cnt[y].size()) bst = y;
}
if (bst != -1) cnt[x] = cnt[bst];
for (int i = 0; i < g[x].size(); i++){
int y = g[x].get(i);
if (y == fa) continue;
if (bst == y) continue;
for (Integer a : cnt[y].keySet()) {
if (a != c[x]) ans += (long) cnt[x].getOrDefault(a, 0) * cnt[y].get(a);
cnt[x].put(a, cnt[x].getOrDefault(a, 0) + cnt[y].get(a));
}
}
ans += cnt[x].getOrDefault(c[x], 0);
cnt[x].put(c[x], 1);
}
static PrintWriter w = new PrintWriter(new OutputStreamWriter(System.out));
static Input f = new Input(System.in);
static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
static class Input {
public BufferedReader reader;
public StringTokenizer tokenizer;
public Input(InputStream stream) {
reader = new BufferedReader(new InputStreamReader(stream), 32768);
tokenizer = null;
}
public String next() {
while (tokenizer == null || !tokenizer.hasMoreTokens()) {
try {
tokenizer = new StringTokenizer(reader.readLine());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return tokenizer.nextToken();
}
public String nextLine() {
String str = null;
try {
str = reader.readLine();
} catch (IOException e) {
// TODO 自动生成的 catch 块
e.printStackTrace();
}
return str;
}
public int nextInt() {
return Integer.parseInt(next());
}
public long nextLong() {
return Long.parseLong(next());
}
public Double nextDouble() {
return Double.parseDouble(next());
}
public BigInteger nextBigInteger() {
return new BigInteger(next());
}
}
}