For a undirected graph with tree characteristics, we can choose any node as the root. The result graph is then a rooted tree. Among all possible rooted trees, those with minimum height are called minimum height trees (MHTs). Given such a graph, write a function to find all the MHTs and return a list of their root labels.
Format
The graph contains n
nodes which are labeled from 0
to n - 1
. You will be given the number n
and a list of undirected edges
(each edge is a pair of labels).
You can assume that no duplicate edges will appear in edges
. Since all edges are undirected, [0, 1]
is the same as [1, 0]
and thus will not appear together in edges
.
Example 1:
Given n = 4
, edges = [[1, 0], [1, 2], [1, 3]]
0 | 1 / \ 2 3
return [1]
Example 2:
Given n = 6
, edges = [[0, 3], [1, 3], [2, 3], [4, 3], [5, 4]]
0 1 2 \ | / 3 | 4 | 5
return [3, 4]
Hint:
- How many MHTs can a graph have at most?
Note:
(1) According to the definition of tree on Wikipedia: “a tree is an undirected graph in which any two vertices are connected by exactlyone path. In other words, any connected graph without simple cycles is a tree.”
(2) The height of a rooted tree is the number of edges on the longest downward path between the root and a leaf.
Credits:
Special thanks to @dietpepsi for adding this problem and creating all test cases.
方法有两种,一种是先计算每个点的degree,然后将degree为1的点放入list或者queue中进行计算,把这些点从neighbours中去除,然后计算接下来degree = 1的点。最后剩下1 - 2个点就是新的root
另外一种是用了类似给许多点,求一个点到其他点距离最短的原理。找到最长的一点leaf to leaf path,然后找到这条path的一个或者两个中点median就可以了。
class Solution(object):
def findMinHeightTrees(self, n, edges):
if n == 1:
return [0]
p = edges
dic = {}
d = [0] * n
q = []
for i,j in p:
dic[i] = dic.get(i,[]) + [j]
dic[j] = dic.get(j,[]) + [i]
d[i] += 1
d[j] += 1
for i in xrange(n):
if d[i] == 1:
q += i,
while n > 2:
n -= len(q)
tq = []
for i in q:
out = dic[i][0]
dic[out].remove(i)
if len(dic[out]) <= 1:
if out not in tq:
tq += out,
#print q,d,dic
q = tq[:]
#print q
return q
#print dic.items()
'''
def check(cur,count):
#print cur,v
if count > m:
return -1
if sum(v) == n:
#print 'res:',res,count
return count
if v[cur] == 1:
return count
v[cur] = 1
#print ' ',cur,v
l = dic.get(cur,[])
#print l
res = 0
for i in l:
h = check(i,count + 1)
#print h
if h == -1:
return 999
res = max(h ,res)
#return True
#print '--------'
return res
#v[cur] = 0
m = 99999
r = []
for i in xrange(n):
v = [0] * n
#print ':::',check(i,0,0)
h = check(i,0)
m = min(m,h)
r += h,
print i,m,h
result = []
for i in xrange(n):
if r[i] == m:
result += i,
return result
'''