Union-Find

板子:

非递归写法(基于链表)

public class LinkedDSU {
    public static final int illegal_next=-1;
    private static class Node{
        int equiv;
        int next;
        int length;

        Node(int e,int n,int len){
            equiv = e;
            next = n;
            length = len;
        }
    }

    private final Node[] ns;

    public LinkedDSU(int n){
        ns = new Node[n];
        for (int i = 0; i < ns.length; i++) {
            ns[i] = new Node(i,illegal_next,1);
        }
    }

    public void union(int x,int y){
        // same equiv?
        int ex = find(x);
        int ey = find(y);
        if(ex == ey){
            return;
        }
        // merge small to large
        if(ns[ex].length>ns[ey].length){
            int et = ex;
            ex = ey;
            ey = et;
        }
        // change equiv
        int header = ex;
        while(ns[header].next!=illegal_next){
            ns[header].equiv = ey;
            header = ns[header].next;
        }
        ns[header].equiv = ey;
        // linked list insertion
        ns[header].next = ns[ey].next;
        ns[ey].next = ex;
        // update size
        ns[ey].length += ns[ex].length;
    }

    public int find(int x){
        return ns[x].equiv;
    }
}

递归写法

// LC765 official writeup
public class RecursiveDSU {
    private final int[] f;

    public RecursiveDSU(int n){
        f = new int[n];
        for (int i = 0; i < f.length; i++) {
            f[i] = i;
        }
    }

    public int getf(int x){
        if(f[x]==x){
            return x;
        }
        int newf = getf(f[x]);
        f[x] = newf; // path compression
        return newf;
    }

    public void add(int x,int y){
        int fx = getf(x);
        int fy = getf(y);
        f[fx] = fy;
    }
}

核心想法:如果这个元素的源头不是他自己 说明它被归到别的等价类去了,深搜它的源头(链)

多说无益,看题。

1. LC 2812 找出最安全路径

这题我一开始二分T了。预计算写的暴力,判断写的深搜。这题预计算所有点的安全距离的方式应该是多源BFS。判断连通性的方式是DSU。所以不放在二分题单里,放在DSU。

多源BFS的大致思路:

  1. 选中初始源列表
  2. 对于当前轮次的源列表,遍历,访问每个源的所有可行的邻居(在表里并且没被访问过),放到下一轮访问的源列表中
  3. 这样就类似于一圈一圈的扩散出去

并查集判连通:由于我们想要的是最大安全系数,所以倒着搜各个安全距离对应的点集。如果发现它的邻居的安全距离大于等于它的,那么可以把它的邻居对应的等级类直接归到这个点对应的等级类。由于我们倒着搜答案,因此这个等价类的门槛会越来越低,直至把所有点都囊括进去,那个时候安全系数也就只能为0了。

import java.util.ArrayList;
import java.util.List;

class Solution {
    static int[][] dirs = new int[][]{
            {-1,0},
            {1,0},
            {0,-1},
            {0,1}
    };
    public int maximumSafenessFactor(List<List<Integer>> grid) {
        // 两个关键问题
        // 怎么标记不能走的格子
        // 怎么判断能否从左上角走到右下角
        int n = grid.size();
        ArrayList<int[]> q = new ArrayList<>(); // 所有为1的网格多源bfs
        int[][] dis = new int[n][n];
        // 统计所有1
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if(grid.get(i).get(j)==1){
                    q.add(new int[]{i,j});
                }else{
                    dis[i][j] = -1; // 顺便初始化访问数组
                }
            }
        }

        ArrayList<List<int[]>> groups = new ArrayList<>();
        groups.add(q);
        List<int[]> tmp;

        /**
         * 多源bfs的滚动数组trick
         * 相当于一开始现有一个源的列表(多源)
         * 然后每次根据当前的源,访问所有源的所有邻接的邻居 (当前的源就是tmp列表)
         * 把每个邻居放在下一轮源的列表中
         * 由于用tmp接替了q的位置,所以q就变成了下一轮源的列表,这样while(!q.isEmpty())就会判断下一轮是否还有源可以用
         */
        while(!q.isEmpty()){
            tmp = q; // 滚动数组省空间
            q = new ArrayList<>();
            for (int[] p : tmp) {
                for (int[] d : dirs) {
                    int x = p[0]+d[0];
                    int y = p[1]+d[1];
                    // 下标合法并且未访问过(如果访问过,说明它被赋予了更小的值,就没必要更新了)
                    if( x>=0 && x<n && y>=0 && y<n && dis[x][y] == -1){
                        q.add(new int[]{x,y});
                        // 试想如果当前是最初的那一批源会怎么样
                        // groups.size() == 1
                        // 也就意味着当前访问到的邻居距离最初的源的距离为1(这是因为d[0] + d[1]的abs为1,最多能造成1的位移)
                        dis[x][y] = groups.size();
                    }
                }
            }
            groups.add(q); // 最终会多出来一个空的列表
        }

        // 求最大安全系数 所以倒着搜
        // 并查集
        int[] fa = new int[n*n];
        for (int i = 0; i < n * n; i++) {
            fa[i] = i;
        }

        // 由于多源bfs多加了一个数组,倒着搜的时候要-1
        for(int ans = groups.size()-2;ans>0;ans--){
            // 安全距离为ans的点集合
            List<int[]> g = groups.get(ans);
            for (int[] p : g) {
                int i = p[0];
                int j = p[1];

                for (int[] d : dirs) {
                    int x = i+d[0];
                    int y = j+d[1];

                    // 邻居的安全距离大于要判定的安全距离
                    if(x>=0 && x<n && y>=0 && y<n && dis[x][y] >= ans){
                        // 点(x,y)的类 归入到 (i,j)的等价类
                        fa[find(fa,x*n+y)] = find(fa,i*n+j);
                    }
                }
            }
            // 等级类归类完毕查看起点终点是否连通
            if(find(fa,0)==find(fa,n*n-1)){
                return ans;
            }
        }

        return 0;
    }

    // recursive dsu template
    private int find(int[] fa,int x){
        // 如果这个元素的源头不是他自己 说明它被归到别的等价类去了,深搜它的源头(链)
        if(fa[x]!=x) fa[x] = find(fa,fa[x]);
        return fa[x];
    }
}

2. LC 778 水位上升的泳池中游泳

这题本来在二分题单里的,感觉能用DSU,就写了下。

思路:

  1. 将每个单元格划分为一个等级类
  2. 由于我们要找的是最短的时间,所以对时间正着搜索
  3. 搜索的思路类似于BFS,把当前的每个可以到达的位置存放入队列,随后BFS,BFS的过程中利用DSU连通节点
  4. 需要注意两点。一是如果一个节点的四周的节点并没有被访问完,那么下一轮扩散仍需要使用这个节点。另一个是扩散要持续到不能扩散为止。假设(0,0)扩散到了(1,1),那么(1,1)也需要在本轮完成扩散。
  5. 在实现4时,可以利用两个队列来回倒。在我的代码里,tmp代表了下一轮的节点。q代表当前轮的节点。所以可以看到如果这一轮扩散的某个节点的邻居没有被访问完,就把它放到tmp,下一轮接着扩散。如果一个节点的邻居被扩散到,将置入q,于本轮继续扩散。
  6. 每轮扩散,检查(0,0)和(n-1,n-1)的连通性,若连通返回答案。
import java.util.ArrayDeque;

class Solution {
    static int[][] dirs = new int[][]{
            {-1,0},
            {0,1},
            {1,0},
            {0,-1}
    };
    public int swimInWater(int[][] grid) {
        int n = grid.length;
        int[] fa = new int[n * n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                fa[i*n+j] = i*n+j;
            }
        }

        ArrayDeque<int[]> q = new ArrayDeque<>();
        ArrayDeque<int[]> tmp;
        q.push(new int[]{0,0});
        boolean[] visited = new boolean[n*n];
        visited[0] = true;

        for(int i=0;i<n*n;i++){
            tmp = new ArrayDeque<>();
            while(!q.isEmpty()){
                int[] poll = q.poll();
                int x = poll[0];
                int y = poll[1];

                if(grid[x][y]>i){
                    tmp.offer(poll);
                    break;
                }
                boolean flag = true;
                for (int[] d : dirs) {
                    int nx = x + d[0];
                    int ny = y + d[1];
                    if(legal(nx,ny,n) && !visited[nx*n+ny]){
                        if(grid[nx][ny]<=i){
                            visited[nx*n+ny] = true;
                            fa[nx*n+ny]= find(fa,x*n+y);
                            q.push(new int[]{nx,ny});
                        }
                        flag &= visited[nx*n+ny];
                    }
                }

                if(!flag){
                    tmp.offer(poll);
                }
            }
            q= tmp;
            if(find(fa,0) == find(fa,n*n-1)){
                return i;
            }
        }
        return n*n-1;
    }

    private boolean legal(int nx,int ny,int n){
        return nx>=0 && nx<n && ny>=0 && ny<n;
    }

    private int find(int[] fa,int x){
        if(fa[x]!=x){
            fa[x] = find(fa,fa[fa[x]]);
        }
        return fa[x];
    }
}

3. LC 100199 判断一个数组是否可以变为有序

122双周赛T2。这把状态极差,T2 WA 2发才交上。但是思路其实很简单。二进制数位为1的数目相同代表连通,可以直接对照有序数组查看有序时的索引和当前索引是否连通。注意DSU扩散时只能检查相邻的元素,比如索引2和3,索引3和4,这样索引2和4可以是连通的。但不能直接判断索引2、4是否连通,可以中转但不能直达。就因为这个我WA2发,罚时吃闷了。

import java.util.Arrays;
import java.util.HashMap;

class Solution {
    int[] fa;
    public boolean canSortArray(int[] nums) {
        fa = new int[nums.length];
        for (int i = 0; i < fa.length; i++) {
            fa[i] = i;
        }
				// 哈希表存储每个元素原来的索引位置,由于元素两两不同,所以不怕冲掉
        HashMap<Integer, Integer> m = new HashMap<Integer, Integer>();
        for (int i = 0; i < nums.length; i++) {
            m.put(nums[i],i);
        }
        for(int i = 0;i<nums.length-1;i++){
            if(isConnected(nums[i],nums[i+1])){
                fa[i+1] = find(i);
            }
        }
        Arrays.sort(nums);
        for (int i = 0; i < nums.length; i++) {
            Integer index = m.get(nums[i]);
            if(fa[index]!=fa[i]){
                return false;
            }
        }
        return true;
    }

    private boolean isConnected(int num1,int num2){
        String s1 = Integer.toBinaryString(num1);
        String s2 = Integer.toBinaryString(num2);
        return BCount(s1)==BCount(s2);
    }

    private int BCount(String s){
        char[] ch = s.toCharArray();
        int cnt = 0;
        for (char c : ch) {
            cnt += c-'0';
        }
        return cnt;
    }

    private int find(int x){
        if(fa[x]!=x){
            fa[x] = find(fa[fa[x]]);
        }
        return fa[x];
    }
}

4. LC 100244 带权图里旅途的最小代价

其实这题我也不确定应该放在位运算里还是DSU里,后来想想只是用了位运算一个很基础的性质,所以还是放DSU了。

&的性质是:越&越小。所以把一个连通块中的所有边权与起来一定是整个连通块任意两个节点之间的最小代价。既然都提到连通块了,那么就是并查集。开一个fa数组维护连通块,再开一个cnt数组维护代价就行了。

这里测试用例给拉了,他没说某个节点到自己的代价怎么算,其实默认是0。

import java.util.Arrays;

class Solution {
    int[] fa;
    public int[] minimumCost(int n, int[][] edges, int[][] query) {
        fa = new int[n];
        for (int i = 0; i < n; i++) {
            fa[i] = i;
        }

        int[] cnt = new int[n];
        Arrays.fill(cnt,-1);

        for (int[] edge : edges) {
            int u = edge[0];
            int v = edge[1];
            int w = edge[2];

            int u_fa = find(u);
            int v_fa = find(v);
            fa[v_fa] = u_fa;

            cnt[u_fa] &= cnt[v_fa];
            cnt[u_fa] &= w;
        }

        int[] res = new int[query.length];
        int i = 0;
        for (int[] q : query) {
            int start = q[0];
            int end = q[1];
            if(start==end){
                res[i++] = 0;
                continue;
            }

            int s_fa = find(start);
            int e_fa = find(end);
            if(s_fa!=e_fa){
                res[i] = -1;
            }else{
                res[i] = cnt[s_fa];
            }

            i++;
        }

        return res;
    }

    private int find(int x){
        if(fa[x]!=x){
            fa[x] = find(fa[x]);
        }
        return fa[x];
    }
}

  • 12
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值