public class P928_MinimizeMalwareSpreadII {
private Set<Integer> visit = new HashSet<>();
Map<Integer, Set<Integer>> map = new HashMap<>();
boolean[] record;
public int minMalwareSpread(int[][] graph, int[] initial) {
Arrays.sort(initial);
record = new boolean[graph.length];
for (int i = 0; i < initial.length; i++) {
record[initial[i]] = true;
}
for (int i = 0; i < graph.length; i++) {
map.put(i, new HashSet<>());
}
for (int i = 0; i < graph.length; i++) {
for (int j = 0; j < graph[0].length; j++) {
if (graph[i][j] == 1 && i != j) {
map.get(i).add(j);
map.get(j).add(i);
}
}
}
int max = -1;
int ret = -1;
for (int i = 0; i < initial.length; i++) {
Set<Integer> sons = map.get(initial[i]);
int totalEffect = 0;
visit.clear();
visit.add(initial[i]);
for (Integer e : sons) {
int tmp = dfs(e);
// 遇到有感染的节点,那么整个分支都舍弃
if (tmp < 0) {
continue;
}
// 累加没有感染节点的union的数量
totalEffect += tmp;
}
if (totalEffect > max) {
max = totalEffect;
ret = initial[i];
}
}
return ret;
}
// 计算x群组未被访问的节点总量;如果有感染的返回-1;如果被访问过了返回0;
private int dfs(int x) {
if (!visit.add(x)) {
return 0;
}
if (record[x]) {
return -1;
}
int ret = 1;
Set<Integer> sons = map.get(x);
for (Integer e : sons) {
int tmp = dfs(e);
if (tmp < 0) {
// 关键:可以加快收敛速度,记录间接感染的节点;这句不是必须的,但是可以提高运行速度。
record[x] = true;
return -1;
}
ret += tmp;
}
return ret;
}
}