给你这棵「无向树」,请你测算并返回它的「直径」:这棵树上最长简单路径的 边数。
我们用一个由所有「边」组成的数组 edges 来表示一棵无向树,其中 edges[i] = [u, v] 表示节点 u 和 v 之间的双向边。
树上的节点都已经用 {0, 1, ..., edges.length} 中的数做了标记,每个节点上的标记都是独一无二的。
示例 1:
输入:edges = [[0,1],[0,2]]
输出:2
解释:
这棵树上最长的路径是 1 - 0 - 2,边数为 2。
示例 2:
输入:edges = [[0,1],[1,2],[2,3],[1,4],[4,5]]
输出:4
解释:
这棵树上最长的路径是 3 - 2 - 1 - 4 - 5,边数为 4。
提示:
0 <= edges.length < 10^4
edges[i][0] != edges[i][1]
0 <= edges[i][j] <= edges.length
edges 会形成一棵无向树
解题思路:
两次DFS
先任选一个起点,BFS/DFS找到最长路的终点,再从终点进行第二次BFS/DFS找到最长路即为树的直径。
原理:设起点为u,第一次BFS/DFS找到的终点v一定是树的直径的一个端点。
证明:1)如果u是直径上的点,则v就是直径的终点(因为如果v不是的话,则必定存在另一个点w使得u到w的距离更长,则与BFS/DFS找到了v矛盾。
2) 如果u不是直径上的点,则u到v必然与树的直径相交(反证),那么交点到v必然就是直径的后半段了,所以v一定是直径的一个端点,所以从v进行BFS/DFS得到的一定是直径长度。
题解 | #树的直径#c++/python3/java(1)贪心--2次dfs(自上而下)_牛客博客
Python代码(LeetCode):
class Solution:
def treeDiameter(self, edges: List[List[int]]) -> int:
def dfs(i, dist):
visited[i] = True
res_i, res_dist = i, dist
for j in table[i]:
if not visited[j]:
temp_i, temp_dist = dfs(j, dist + 1)
if temp_dist > res_dist:
res_dist = temp_dist
res_i = temp_i
visited[i] = False
return res_i, res_dist
# 1 init neighbor table
if not edges:
return 0
table = defaultdict(list)
for i in range(len(edges)):
x = edges[i][0]
y = edges[i][1]
table[x].append(y)
table[y].append(x)
# 2 search
visited = [False] * len(table)
node1, _ = dfs(0, 0)
node2, res = dfs(node1, 0)
return res
Python代码(NC):
# class Interval:
# def __init__(self, a=0, b=0):
# self.start = a
# self.end = b
#
# 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
#
# 树的直径
# @param n int整型 树的节点个数
# @param Tree_edge Interval类一维数组 树的边
# @param Edge_value int整型一维数组 边的权值
# @return int整型
#
from collections import defaultdict
class Solution:
def solve(self , n: int, Tree_edge: List[Interval], Edge_value: List[int]) -> int:
# write code here
def dfs(i, dist):
visited[i] = True
res_i, res_dist = i, dist
for j, val in table[i]:
if not visited[j]:
temp_i, temp_dist = dfs(j, dist + val)
if temp_dist > res_dist:
res_dist = temp_dist
res_i = temp_i
visited[i] = False
return res_i, res_dist
# 1 init neighbor table
table = defaultdict(list)
for i in range(n - 1):
x = Tree_edge[i].start
y = Tree_edge[i].end
val = Edge_value[i]
table[x].append((y, val))
table[y].append((x, val))
# 2 search
visited = [False] * len(table)
node1, _ = dfs(0, 0)
node2, res = dfs(node1, 0)
return res