# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None
class Solution:
def maxProduct(self, root: TreeNode) -> int:
total = 0
mod = 10 ** 9 + 7
res = float('-inf')
# 递归计算树的所有节点之和
def getsum(root):
if not root: return
nonlocal total
total += root.val
getsum(root.left)
getsum(root.right)
getsum(root) # 先计算出所有节点的和total
# 计算每个节点分裂后的最大乘积
def dfs(root):
if not root: return 0
left = dfs(root.left)
right = dfs(root.right)
subsum = root.val + left + right
nonlocal res
res = max(res, subsum * (total - subsum))
return subsum
dfs(root)
return res % mod