leetcode笔记-union find并查集

        parents = range(n)
        rank = [0] * n

        def find(edge):
            if parents[edge] != edge:
                parents[edge] = find(parents[edge])
            return parents[edge]  
# 

        def find(a):
            if parents[a] != a:
                return find(parents[a])
            return a    


        def union(x, y):
            x, y = find(x), find(y)
            if x != y:
                if rank[x] >= rank[y]:
                    parents[y] = x
                    rank[x] += rank[x] == rank[y]
                else:
                    parents[x] = y
        
        def union(a, b):
            pa, pb = find(a), find(b)
            parent[pb] = pa

1970. Last Day Where You Can Still Cross

class Solution:
    def latestDayToCross(self, n: int, m: int, cells: List[List[int]]) -> int:
        maxDay = 0
        visited = set()  # water has been visited
        p = list(range(n*m))  # for union find, we will union water with 8 directions.
        crossSet = {}  # map from a node to a list of columns, once a parent node has cross all the columns, 
        # it block from top to butoom access

        def find(a):  
            if p[a] != a:
                p[a] = find(p[a])
            return p[a]

        def union(a, b, ja, jb):
            pa = find(a)
            pb = find(b)
            # print("pb", pb)
            c1 = crossSet.get(pa, {ja})
            c2 = crossSet.get(pb, {jb})
            c3 = c1.union(c2)
            crossSet[pa] = c3
            p[pb] = pa
            if len(c3) == m:  # means cross all the columns
                return True

        for x, y in cells:
            x, y = x - 1, y - 1 # change to 0 based
            visited.add((x, y))
            a = x * m + y # every row has m elements
            # print("maxDay", maxDay)
            # print(x, y, "a", a)
            union(a, a, y, y)
            for dx, dy in [(0, 1), (0, -1), (1, 0), (-1, 0), (-1, -1), (-1, 1), (1, -1), (1, 1)]:
                nx, ny = x+dx, y+dy
                if 0 <= nx < n and 0 <= ny < m and (nx, ny) in visited:
                    b = nx * m + ny
                    # print(nx, ny, "b", b)
                    if union(a, b, y, ny):  # union nearby waters in 8 directions
                        return maxDay
            # print(crossSet)
            maxDay += 1
        return 0

1102. Path With Maximum Minimum Value

class Solution:
    def maximumMinimumPath(self, grid: List[List[int]]) -> int:
        # https://leetcode.com/problems/path-with-maximum-minimum-value/solutions/1784266/path-with-maximum-minimum-value/
        # approach 1 BFS + heap (max). variant of Dijkstra's
        '''
        m, n = len(grid), len(grid[0])
        re = grid[0][0]
        h = [(-grid[0][0], 0, 0)]  # value, x, y 
        visited = set([(0, 0)])
        while h:
            val, x, y = heapq.heappop(h)
            re = min(re, -val)
            if x == m-1 and y == n-1:
                return re
            
            for dx, dy in [(-1, 0), [1, 0], [0, -1], [0, 1]]:
                nx = x + dx
                ny = y + dy
                if 0 <= nx < m and 0 <= ny < n and (nx, ny) not in visited:
                    heapq.heappush(h, (-grid[nx][ny], nx, ny))
                    visited.add((nx, ny))
        return re
        '''
        # approach 2  Union Find
        # sort 所有的数字, 从大到小一个个拿, 无所谓他们的位置, 一直拿到一个数字这时候拿了的数字可以把头尾连起来,返回当前这个数字
        m, n = len(grid), len(grid[0])
        parent = list(range(m*n))
        vals = [(r, c) for r in range(m) for c in range(n)]
        vals.sort(key = lambda x: grid[x[0]][x[1]], reverse=True) # from large to small
        visited = set()

        def find(x):
            if x != parent[x]:
                parent[x] = find(parent[x])
            return parent[x]

        def union(a, b):
            pa, pb = find(a), find(b)
            parent[pb] = pa

        for x, y in vals:
            pos = x * n + y
            visited.add(pos)
            for dx, dy in [(-1, 0), [1, 0], [0, -1], [0, 1]]:
                nx = x + dx
                ny = y + dy
                npos = nx*n + ny
                if 0 <= nx < m and 0 <= ny < n and npos in visited:
                    union(pos, npos)
            if find(0) == find(n*m - 1):
                return grid[x][y]

=======================================

765. Couples Holding Hands

  • 2 <= n <= 30
class Solution:
    def minSwapsCouples(self, row: List[int]) -> int:
        '''
        row: 2, 3, 1, 0, 5, 4
        idx: 0, 1, 2, 3, 4, 5

        2 cyclic index groups: 
        group 1: 0 --> 2 --> 1 --> 3 --> 0  (size: 4)
        group 2: 4->5->4 (size: 3)
        make group 1 elements 2 3 1 0 to become 0 1 2 3 need min (size-1) swaps
        make group 2 elements 5 4 to become 4 5 also need min (size - 1) swaps

        first, find the number of groups, assume k cyclic groups
        which is C1, C2 ,,, Ck
        elements of each group is n1, n2 .. nk, total N
        total swaps = (n1 - 1) + (n2 - 1) + ... (nk - 1) = N-k
        '''
        n = len(row) // 2

        def coupleGroup(i):  # id = i will be in couple group i // 2
            return i // 2
        
        # if A, B are not couple but sit togehter, union these 2 couple groups
        # find the # of couple groups k
        # total swaps = n - k
        parent = [i for i in range(n)] # parent of the group
        def find(a):
            if parent[a] != a:
                return find(parent[a])
            return a

        def union(a, b):
            pa, pb = find(a), find(b)
            parent[pb] = pa

        for i in range(n):
            a = row[i*2]
            b = row[i*2+1]
            coupleGroupA = coupleGroup(a)
            coupleGroupB = coupleGroup(b)
            if coupleGroupA != coupleGroupB:
                union(coupleGroupA, coupleGroupB)
        
        dist = set()
        for i in range(n):
            g = find(i)
            dist.add(g)
        
        k = len(dist)
        return n-k

1101. The Earliest Moment When Everyone Become Friends

class Solution:
    def earliestAcq(self, logs: List[List[int]], n: int) -> int:
        logs = sorted(logs)
        parents = [i for i in range(n)]
        self.groups = n
        
        def find(n):
            if parents[n] != n:
                parents[n] = find(parents[n])
            return parents[n]
        
        def union(a, b):
            x, y = find(a), find(b)
            if x != y:
                parents[y] = x
                self.groups -= 1
               
        for t, a, b in logs:
            union(a, b)
            if self.groups == 1:
                return t

        return -1

帮忙做的华为笔试的一道题,具体题目不知道,只有一张图片,推测是找环:

随手写了一下代码: 也可以用拓扑排序做,dfs的时候把path加上,如果找到环就把path加到res里面,思路和union-find类似。

class Solution(object):
    def find_circle(self):
        graph = {1: [2, 3], 4: [1], 3: [4], 2:[4]}
        parents = range(5)
        self.res = []

        def find(node, path):
            if parents[node] != node:
                return find(parents[node], path + [node])
            return node, path + [node]

        def union(parent_node, child_node):
            p1, path1 = find(parent_node, [])
            p2, path2 = find(child_node, [])
            if p1 == p2:  # find circle
                circle = path1 if len(path1) > len(path2) else path2
                self.res.append(sorted(circle))
            parents[child_node] = parent_node

        for parent, childs in graph.items():
            for child in childs:
                union(parent, child)

        return self.res


if __name__ == '__main__':
    s = Solution()
    print s.find_circle()
"""
Google 第一轮:一条直路 分布在坐标y1, y2 之间, 这条路上分布了很多灯(x,y),
灯光照亮范围是半径为r的圆形,照亮的地方无法通过, 问从左往右走,是否能通过这条路
"""
import math
from collections import defaultdict
 
 
class Circle:
    def __init__(self, x, y, r):
        self.x = x
        self.y = y
        self.r = r
     
    def intersect(self, other: 'Circle'):
        if math.pow(self.x - other.x, 2) + math.pow(self.y - other.y, 2) <= math.pow(self.r + other.r, 2):
            return True
        return False
     
    def __repr__(self):
        return f'<{(self.x, self.y)}>'
     
    def __lt__(self, other: 'Circle'):
        return (self.x, self.y, self.r, id(self)) < (other.x, other.y, other.r, id(other))
 
 
class Solution:
    def can_pass(self, y1, y2, circles):
        assert y1 <= y2
        fathers = self._union_find(circles)
        group_dct = defaultdict(set)
        for c, root in fathers.items():
            group_dct[root].add(c)
        print(group_dct)
        print(fathers)
        for circle_set in group_dct.values():
            group_upper_y, group_lower_y = self.get_upper_lower(circle_set)
            print(group_upper_y, group_lower_y)
            if group_upper_y >= y2 and group_lower_y <= y1:
                return False
        return True
     
    def get_upper_lower(self, circles):
        return (
            max(c.y + c.r for c in circles),
            min(c.y - c.r for c in circles),
        )
     
    def _union_find(self, circles):
        fathers = {c: c for c in circles}
        for i in range(len(circles)):
            c1 =  circles[i]
            for j in range(i + 1, len(circles)):
                c2 =  circles[j]
                if c1.intersect(c2):
                    self._merge(fathers, c1, c2)
        return fathers
 
    def _merge(self, fathers, c1, c2):
        root1 = self._find(fathers, c1)
        root2 = self._find(fathers, c2)
        if root1 is not root2:
            if root1 < root2:
                fathers[root2] = root1
                self._find(fathers, c2)
            else:
                fathers[root1] = root2
                self._find(fathers, c1)
 
 
    def _find(self, fathers, c):
        root = c
        while fathers[root] is not root:
            root = fathers[root]
         
        while fathers[c] is not c:
            c, fathers[c] = fathers[c], root
        return root
 
 
import unittest
class Test(unittest.TestCase):
    def test1(self):
        circles = [
            Circle(x, y , r) for x, y, r in [
                (1, 1, 1),
                (1, 2, 1),
                (1, 5, 1),
                (1, 6, 1),
                (1, 3, 1),
            ]
        ]
        sol = Solution()
        self.assertEqual(
            sol.can_pass(1, 2, circles),
            False
        )
 
    def test2(self):
        circles = [
            Circle(x, y , r) for x, y, r in [
                (0, 1, 1),
                (1, 2, 1),
                (3, 0, 1),
                (3, 1, 1),
            ]
        ]
        sol = Solution()
        self.assertEqual(
            sol.can_pass(-1, 3, circles),
            True
        )
unittest.main(verbosity=2)

945. Minimum Increment to Make Array Unique

class Solution:
    def minIncrementForUnique(self, nums: List[int]) -> int:
        
        '''
        Sort the input array.
        Compared with previous number,
        the current number need to be at least prev + 1. O(NlogN)
        '''
        # nums.sort()
        # re = 0
        # for i in range(1, len(nums)):
        #     pre = nums[i-1]
        #     cur = nums[i]
        #     if cur <= pre:
        #         re += pre + 1 - cur
        #         nums[i] = pre + 1
        # return re
        #################################
        
        # O(n)?  union find
        root = {}
        def find(x):
            root[x] = find(root[x] + 1) if x in root else x
            return root[x]
        re = 0
        for n in nums:
            # print(n, find(n))
            re += find(n) - n
            
        return re

Traveling is Fun

Loading...

def connectedCities(n, g, originCities, destinationCities):
    n += 1
    parent = list(range(n))
    rank = [0] * n
    def find(x):
        if x != parent[x]:
            parent[x] = find(parent[x])
        return parent[x]
    def union(x, y):
        x, y = find(x), find(y)
        if x != y:
            if rank[x] >= rank[y]:
                parent[y] = x
                rank[x] += rank[x] == rank[y]
            else:
                parent[x] = y
    for i in range(g + 1, n):
        f = 2
        while i * f < n:
            union(i, i * f)
            f += 1
    return [int(find(s) == find(d)) for s, d in zip(originCities, destinationCities)]

947. Most Stones Removed with Same Row or Column

Loading...

Time Complexity: O(N \log N)O(NlogN), where NN is the length of stones. (If we used union-by-rank, this can be O(N * \alpha(N))O(N∗α(N)), where \alphaα is the Inverse-Ackermann function.)

class DSU:
    def __init__(self, N):
        self.p = range(N)

    def find(self, x):
        if self.p[x] != x:
            self.p[x] = self.find(self.p[x])
        return self.p[x]

    def union(self, x, y):
        xr = self.find(x)
        yr = self.find(y)
        self.p[xr] = yr

        
class Solution(object):
    def removeStones(self, stones):
        N = len(stones)
        dsu = DSU(20000)
        for x, y in stones:
            dsu.union(x, y + 10000)
        
        return N - len({dsu.find(x) for x, y in stones})

305. Number of Islands II

不用Union by rank 也能通过 

class Solution(object):
    def numIslands2(self, m, n, positions):
        """
        :type m: int
        :type n: int
        :type positions: List[List[int]]
        :rtype: List[int]
        """
        parent, rank = {}, {}
        
        def find(x):
            if parent[x] != x:
                parent[x] = find(parent[x])  # path compression
            return parent[x]
        
        def union(x, y):
            p1, p2 = find(x), find(y)
            if p1 == p2:
                return 0
            else:  # union by rank
                if rank[p1] < rank[p2]:
                    parent[p1] = p2
                else:
                    parent[p2] = p1
                    if rank[p1] == rank[p2]:
                        rank[p1] += 1
            return 1
                
        res, count = [], 0
        for i, j in positions:
            x = i, j
            parent[x] = x
            rank[x] = 0
            count += 1
            for y in (i+1, j), (i-1, j), (i, j+1), (i, j-1):
                if y in parent:
                    count -= union(x, y)
            res.append(count)
            
        return res

323. Number of Connected Components in an Undirected Graph

class Solution(object):
    def countComponents(self, n, edges):
        
        parents = range(n)
        
        def find(a):
            if parents[a] != a:
                return find(parents[a])
            return a
        
        def union(b, p1):
            while True:
                p = parents[b]
                parents[b] = p1
                if p != b:
                    b = p
                else:
                    break
            
        for a, b in edges:
            p1, p2 = find(a), find(b)
            if p1 != p2:
                union(b, p1)
        return len({find(x) for x in xrange(n)})
                

721. Accounts Merge

class Solution(object):
    def init(self, accounts):
        self.parents = {}
        self.owner = {}
        
        for acc in accounts:
            owner = acc[0]
            root = acc[1]  # 第一个邮箱为root, root的特点是指向自己
            self.owner[root] = owner
            for email in acc[1:]:
                if email not in self.parents:
                    self.parents[email] = root
                else:
                    self.union(root, self.parents[email])
    
    def find(self, x):
        if self.parents[x] == x:
            return x
        
        self.parents[x] = self.find(self.parents[x])
        return self.parents[x]
    
    def union(self, x, y):
        px = self.find(x)
        py = self.find(y)
        self.parents[py] = px
                
    def accountsMerge(self, accounts):
        """
        :type accounts: List[List[str]]
        :rtype: List[List[str]]
        """
        self.init(accounts)
        for acc in accounts:  
            for email in acc[1:]:
                self.find(email)  # 每次find都会把所有上面的节点指向root
        
        groups = list(set(self.parents.values()))  # values的值是父节点
        res = []
        for j in xrange(len(groups)):
            temp = []
            for key in self.parents:
                if self.parents[key] == groups[j]:
                    temp.append(key)
            temp = sorted(temp)
            temp.insert(0, self.owner[groups[j]])
            res.append(temp)
        return res

下面的TLE?? 

class Solution(object):
    def accountsMerge(self, accounts):
        parent = {}
        owner = {}
        
        def find(email):
            if parent[email] != email:
                parent[email] = find(parent[email])
            return parent[email]            
        
        def union(e1, e2):
            p1 = parent[e1]
            p2 = parent[e2]
            parent[p2] = p1 

        for acc in accounts:
            name = acc[0]
            root = acc[1]
            owner[root] = name
            for e in acc[1:]:
                if e not in parent:
                    parent[e] = root
                else:
                    union(root, parent[e])

        for e in parent:
            find(e)
        
        roots = list(set(parent.values()))
        re = []
        for root in roots:
            temp = []
            for e in parent:
                if parent[e] == root:
                    temp.append(e)
            temp.sort()
            temp.insert(0, owner[root])
            re.append(temp)
        return re

261. Graph Valid Tree

class Solution(object):
    def validTree(self, n, edges):
        """
        :type n: int
        :type edges: List[List[int]]
        :rtype: bool
        """
        parents = range(n)
        
        def find(b):  # 找到root parent
            if parents[b] != b:
                return find(parents[b])
            return b
        
        def union(b, p1):  # 把b和b的所有parent 的parent设为p1
            while True:
                p = parents[b]
                parents[b] = p1
                if p != b:
                    b = p
                else:
                    break
        
        for a, b in edges:  # 如果edge a, b have the same root parent, then circle, else union
            p1 = find(a)
            p2 = find(b)
            
            if p1 == p2:
                return False  # circle 
            else:  
                union(b, p1)  # union
        
        for i in xrange(n):  # check if all the node have the same root parent, if not,there exist more than 1 tree
            if i == 0:
                p = find(i)
            else:
                print i, find(i)
                if find(i) != p:
                    return False
        return True
            
            
            

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值