[python刷题模板] 树的直径/换根DP

一、 算法&数据结构

1. 描述

树的直径代表树上最远的两个点的距离。
在某些特定的时候会用到。
  • 求树的直径通常可以
    1. 树形DP:仅用来计算直径的值,一次dfs即可。由于直径一定是某个子树的两个连接根的最长简单路径拼起来(其中一个可能是空),因此令dfs(u)计算u的子树树高(最长路径),计算v时,同时维护v的兄弟节点最大那个,每次都加一下尝试更新答案即可。
    2. 两次遍历(bfs/dfs均可):可以同时计算直径的值且找到两个端点。第一次遍历从任意一点root找最远节点u,第二次从u出发,找最远节点v,u和v就是直径的两端。
    3. 换根DP(大炮打蚊子):进用来计算直径的值,假设直径两端是u,v,那么以u为树根,v一定是最远节点,求最大树高即可。

换根DP是一种树上DP,可以用O(n)的复杂度计算出:分别以每个节点作为整个树的树根,树的某个属性值(如:树高)。
之所以把换根dp和树的直径写在一起,是因为,求树的直径时不论是树形dp还是换根DP,都可以结合着思考更容易吃透。
目前遇到两种换根dp:求max和求sum的形式。从代码量上来说sum更短一些。
  • 记录三个数组down1,down2,up,分别表示:
    • down1[u]: 以u为根的子树,向下最大的值(比如求树高,就是最长简单路径)
    • down2[u]: 以u为根的子树,向下次大的值(比如求树高,就是最长简单路径)
    • up[u]: 以u为根的子树,向上最大的值(比如求树高,就是最长简单路径)
  • 那么当整颗树以u为根时,最大值就是max(down1[u],up[u])。可以想象一下揪着u节点往上提,所有路径其实就是down和up。
  • 第一遍dfs,后根遍历求出down1和down2。
  • 第二遍dfs,先根遍历求出up,注意讨论:如果当前节点v在u的最大路径(down1)下,up[v]应该从应该从u的次大(down2[u])来,或up[u]来,见途中绿色轨迹;否则从最大来。
  • 在这里插入图片描述

  • sum版换根dp的特点是:一对邻居分别作根时,状态的差别只跟这条边有关,因此可以按次序求出每个节点为根时的状态。
    • 想象以0为根时状态已求出,0-1有边,揪住1向上提,让1做根。
    • 发现0的其它子树状态不变,1的其它子树状态也不变,只有0-1这条边变了。
    • 以此类推。
  • max版由于远端的部分可能有影响,所以要记录三个数组的状态。

2. 复杂度分析

  1. O(ln)

3. 常见应用

4. 常用优化

  1. 换根DP的三个数组,可以写成f = [[0,0,0] for _ in range(n)],但这是负优化,还是写三个好使。
  2. 换根DP可以用BFS先列出dp序,然后遍历,省去dfs爆栈的开销。

二、 模板代码

1. 单纯询问树的直径值

例题: 4799. 最远距离
这题是一道裸的树的直径。下边的代码大部分会用这题测试。

  • 如果没特殊需求,建议用两次bfs免去爆栈风险。
def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc


#    4994    ms
def solve2():
    n, m = RI()
    g = [[] for _ in range(n)]
    for _ in range(m):
        u, v = RI()
        g[u - 1].append(v - 1)
        g[v - 1].append(u - 1)
    start = 0
    ans = 0
    for _ in range(2):
        @bootstrap
        def dfs(u, fa, depth=0):
            nonlocal ans, start
            if depth > ans:
                start = u
                ans = depth
            for v in g[u]:
                if v != fa:
                    yield dfs(v, u, depth + 1)
            yield

        dfs(start, -1)

    print(ans)


#   3618    ms
def solve1():
    n, m = RI()
    g = [[] for _ in range(n)]
    for _ in range(m):
        u, v = RI()
        g[u - 1].append(v - 1)
        g[v - 1].append(u - 1)
    start = 0
    ans = 0
    for _ in range(2):
        q = deque([start])
        fas = [-1] * n
        step = 0
        while q:
            step += 1
            for _ in range(len(q)):
                u = q.popleft()
                start = u
                for v in g[u]:
                    fas[v] = u
                    if v != fas[u]:
                        q.append(v)
        ans = max(ans, step)
    print(ans - 1)

2. 求出树的直径两端搞事情

链接: abc267_f - Exactly K Steps
在这里插入图片描述

  • 这题找出一条直径,就是最大支持的距离,如果够就可以。
  • 因此还需要多一次dfs,记录当前路径。
  • 方便的写法就是三次dfs,前两次找直径,第三次算答案,每次都记录最远的端点作为下一次的起始点即可。
    bfs找直径模板

    def get_tree_diameter(g, root=0):  # bfs两次找直径的端点
        """
        求树的直径,g是0-indexed,默认第一次root是从0
        返回某条直径的两个端点u,v,以及直径值d(边数而不是点数)
        简述:求树的直径时,可以通过树形DP做,也可以通过两次遍历找最远的点(bfs或dfs都可以)
            第二次的起始点就是某条直径的端点
        """
        if not g[root]:
            return root, root, 0

        def bfs(start):
            q = deque([(start, -1)])
            step = -1
            while q:
                step += 1
                for _ in range(len(q)):
                    u, fa = q.popleft()
                    for v in g[u]:
                        if v != fa:
                            q.append((v, u))
            return u, step

        x, _ = bfs(root)
        y, d = bfs(x)
        return x, y, d

ac代码

import sys
import heapq
import bisect
import random
import io, os
from bisect import *
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, gcd, inf
from array import *
from functools import lru_cache
from types import GeneratorType

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')

MOD = 10 ** 9 + 7
PROBLEM = """https://atcoder.jp/contests/abc267/tasks/abc267_f

输入 n(≤2e5) 和一棵树的 n-1 条边(节点编号从 1 开始)。
然后输入 q(≤2e5) 和 q 个询问,每个询问输入 u 和 k。
输出到 u 的距离为 k 的任意一个点。如果这个点不存在则输出 -1。
距离指两点最短路上的边的数目。
输入
5
1 2
2 3
3 4
3 5
3
2 2
5 3
3 3
输出
4
1
-1
"""
"""https://atcoder.jp/contests/abc267/submissions/37595672

求出树的任意一条直径,设直径端点为 x 和 y。

从 x 出发 dfs,同时记录 dfs 路径上的点。
如果点 u 的深度 d >= k,那么 dfs 路径上的第 d-k 个点就是答案。

一次 dfs 不一定能满足所有点,再从 y 出发 dfs 一次就能保证所有点都有答案(除了 k 非常大的)。
"""


def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc


# 1629 ms
if __name__ == '__main__':
    n, = RI()
    g = [[] for _ in range(n)]
    for _ in range(n - 1):
        u, v = RI()
        g[u - 1].append(v - 1)
        g[v - 1].append(u - 1)

    q, = RI()
    qs = defaultdict(list)
    for i in range(q):
        u, d = RI()
        qs[u - 1].append([i, d])
    ans = [-1] * q

    leaf = mx = 0
    path = [-1] * n


    @bootstrap
    def dfs(u, fa, d=0):
        path[d] = u
        global leaf, mx
        if d > mx:
            leaf = u
            mx = d
        for i, k in qs[u]:
            if d >= k:
                ans[i] = path[d - k] + 1
        for v in g[u]:
            if v != fa:
                yield dfs(v, u, d + 1)
        yield


    for _ in range(3):
        dfs(leaf, -1)

    print('\n'.join(map(str, ans)))

# # 1423 ms
# if __name__ == '__main__':
#     n, = RI()
#     g = [[] for _ in range(n)]
#     for _ in range(n - 1):
#         u, v = RI()
#         g[u - 1].append(v - 1)
#         g[v - 1].append(u - 1)
#
#
#     def get_tree_diameter(g, root=0):  # bfs两次找直径的端点
#         if not g[root]:
#             return root, root
#
#         def bfs(start):
#             q = deque([(start, -1)])
#             while q:
#                 u, fa = q.popleft()
#                 for v in g[u]:
#                     if v != fa:
#                         q.append((v, u))
#             return u
#
#         x = bfs(root)
#         y = bfs(x)
#         return x, y
#
#
#     x, y = get_tree_diameter(g)
#
#     q, = RI()
#     qs = defaultdict(list)
#     for i in range(q):
#         u, d = RI()
#         qs[u - 1].append([i, d])
#     # print(qs)
#     ans = [-1] * q
#     path = [0] * n  # 当前深度链接到根的路径
#
#
#     @bootstrap
#     def dfs(u, fa, d=0):
#         path[d] = u
#         for i, k in qs[u]:
#             if d >= k:
#                 ans[i] = path[d - k] + 1
#         for v in g[u]:
#             if v != fa:
#                 yield dfs(v, u, d + 1)
#         yield
#
#
#     dfs(x, -1)
#     dfs(y, -1)
#     print('\n'.join(map(str, ans)))

# # 1892ms
# if __name__ == '__main__':
#     n, = RI()
#     g = [[] for _ in range(n)]
#     for _ in range(n - 1):
#         u, v = RI()
#         g[u - 1].append(v - 1)
#         g[v - 1].append(u - 1)
#
#
#     def get_tree_diameter(g, root=0):
#         ans = (0, root, root)
#         if not g[root]:
#             return ans
#
#         dp = {}
#
#         @bootstrap
#         def dfs(u, fa, depth=0):  # 返回树高以及最深的叶子
#             if len(g[u]) == 1 and u != root:  # 没有子节点了,它就是一个端点(叶子),高度1
#                 dp[u] = (1, u)
#                 yield
#
#             hs = []
#             for v in g[u]:
#                 if v != fa:
#                     yield dfs(v, u, depth + 1)
#                     h, o = dp[v]
#                     if len(hs) < 2:
#                         heapq.heappush(hs, (h, o))
#                     else:
#                         heapq.heappushpop(hs, (h, o))
#
#             if len(hs) == 2:
#                 l, r = max((depth, root), hs[0]), hs[1]
#             else:
#                 l, r = (depth, root), hs[0]
#             p = (l[0] + r[0], l[1], r[1])
#             # print(p)
#             nonlocal ans
#             if p > ans:
#                 ans = p
#             dp[u] = (hs[-1][0] + 1, hs[-1][1])
#             yield
#
#         dfs(root, -1)
#         return ans
#
#
#     d, x, y = get_tree_diameter(g)
#
#     q, = RI()
#     qs = defaultdict(list)
#     for i in range(q):
#         u, d = RI()
#         qs[u - 1].append([i, d])
#     # print(qs)
#     ans = [-1] * q
#     path = [0] * n  # 当前深度链接到根的路径
#
#
#     @bootstrap
#     def dfs(u, fa, d=0):
#         path[d] = u
#         for i, k in qs[u]:
#             if d >= k:
#                 ans[i] = path[d - k] + 1
#         for v in g[u]:
#             if v != fa:
#                 yield dfs(v, u, d + 1)
#         yield
#
#
#     dfs(x, -1)
#     dfs(y, -1)
#     print('\n'.join(map(str, ans)))

3. max版换根DP求树的直径(大炮打蚊子,别这么做,只是用来帮助理解换根DP)


#  换根dp   4122    ms
def solve():
    n, m = RI()
    g = [[] for _ in range(n)]
    for _ in range(m):
        u, v = RI()
        g[u - 1].append(v - 1)
        g[v - 1].append(u - 1)

    def get_tree_diameter(g, root=0):  # bfs两次找直径的端点
        """ ms
        求树的直径,g是0-indexed,默认第一次root是从0
        返回直径值(边数而不是点数)
        简述:换根dp,假设直径两端是u,v,那么以u为树根,v一定是最远节点,求最大树高即可。
        """
        if not g[root]:
            return 0
        down1, down2, up = [0] * n, [0] * n, [0] * n  # 初始化向下最大/次大树高、向上树高(其实不是树高,是最远简单路径)
        order = []  # dp序
        fas = [-1] * n  # 记录父节点
        q = deque([root])  # bfs求order
        while q:
            u = q.popleft()
            order.append(u)
            for v in g[u]:
                if v != fas[u]:
                    fas[v] = u
                    q.append(v)

        for u in order[::-1]:  # 第一遍,自底向上求每个子树的最大/次大树高
            for v in g[u]:
                if v == fas[u]:
                    continue
                h = down1[v] + 1  # 高度
                if h > down1[u]:
                    down1[u], down2[u] = h, down1[u]
                elif h > down2[u]:
                    down2[u] = h
        for u in order:
            for v in g[u]:
                if v == fas[u]:
                    continue
                if down1[u] == down1[v] + 1:  # v在u的最大路径上,则往上的路径应该可能从次大走
                    up[v] = max(down2[u], up[u]) + 1
                else:  # 否则一定从最大走
                    up[v] = max(down1[u], up[u]) + 1

        return max(max(x, y) for x, y in zip(down1, up))

    print(get_tree_diameter(g))

4. max版换根dp求特定值(另附小日子模板)

链接: abc222_f - Expensive Expense
在这里插入图片描述

  • 这题由于要附加一个路径上的端点节点值,因此更新答案时多了个,可能是d[u]+w。
  • solve2/3分别是用上述方法三个数组down1/2+up计算的过程。
  • 另附一个从周赛抄来的模板,我尽可能的按照自己的理解把注释写成了中文。
    • 模板需要调整三个地方 e/op/composition方法,具体可以看代码。
    • 注意e通常是0,我目前做的题都是。
    • op目前我做的题都是max。
    • composition是最关键的,知道子树和节点,如何拼接答案。
# Problem: F - Expensive Expense
# Contest: AtCoder - Exawizards Programming Contest 2021(AtCoder Beginner Contest 222)
# URL: https://atcoder.jp/contests/abc222/tasks/abc222_f
# Memory Limit: 1024 MB
# Time Limit: 4000 ms

import sys
import bisect
import random
import io, os
from bisect import *
from collections import *
from contextlib import redirect_stdout
from itertools import *
from math import sqrt, gcd, inf
from array import *
from functools import lru_cache
from types import GeneratorType
from heapq import *

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')

MOD = 10 ** 9 + 7
PROBLEM = """https://atcoder.jp/contests/abc222/tasks/abc222_f

输入 n(2≤n≤2e5) 和一棵树的 n-1 条边(节点编号从 1 开始),每条边输入两个端点和边权。
然后输入 n 个数 d,d[i] 表示点 i 的点权。

定义 f(x,y) = 从 x 到 y 的简单路径的边权之和,再加上 d[y]。
定义 g(x) = max{f(x,i)},这里 i 取遍 1~n 的所有不为 x 的点。
输出 g(1),g(2),...,g(n)。
输入
3
1 2 2
2 3 3
1 2 3
输出
8
6
6
"""

from typing import Callable, Generic, List, TypeVar

T = TypeVar("T")
E = Callable[[int], T]
"""identify element of op, and answer of leaf"""
Op = Callable[[T, T], T]
"""merge value of child node"""
Composition = Callable[[T, int, int, int], T]
"""return value from child node to parent node"""


#    1187    ms
def solve():
    class Rerooting(Generic[T]):
        __slots__ = ("g", "_n", "_decrement", "_root", "_parent", "_order")

        def __init__(self, n: int, decrement: int = 0, edges=None):
            """
            n: 节点个数
            decrement: 节点id可能需要偏移 (1-indexed则-1, 0-indexed则0)
            """
            self.g = g = [[] for _ in range(n)]
            self._n = n
            self._decrement = decrement
            self._root = None  # 一开始的根
            if edges:
                for u, v in edges:
                    u -= decrement
                    v -= decrement
                    g[u].append(v)
                    g[v].append(u)

        def add_edge(self, u: int, v: int):
            """
            无向树加边
            """
            u -= self._decrement
            v -= self._decrement
            self.g[u].append(v)
            self.g[v].append(u)

        def rerooting(
                self, e: E["T"], op: Op["T"], composition: Composition["T"], root=0
        ) -> List["T"]:
            """
            - e: 初始化每个节点的价值
              (root) -> res
              mergeの単位元
              例:求最长路径 e=0

            - op: 两个子树答案如何组合或取舍
              (childRes1,childRes2) -> newRes
              例:求最长路径 return max(childRes1,childRes2)

            - composition: 知道子子树答案和节点值,如何更新子树答案
              (from_res,fa,u,use_fa) -> new_res
              use_fa: 0表示用u更新fa的dp1,1表示用fa更新u的dp2
              例:最长路径return from_res+1

            - root: 可能要设置初始根,默认是0
            <概要> 换根DP模板,用线性时间获取以每个节点为根整颗树的情况。
            注意最终返回的dp[u]代表以u为根时,u的所有子树的最优情况(不包括u节点本身),因此如果要整颗子树情况,还要再额外计算。
            1. 记录dp1,dp2。其中:
                dp1[u] 代表 以u为根的子树,它的孩子子树的最优值,即u节点本身不参与计算。注意,和我们一般定义的f[u]代表以u为根的子树2情况不同。
                dp2[v] 代表 除了v以外,它的兄弟子树的最优值。依然注意,v不参与,同时u也不参与(u是v的父节点)。
                建议画图理解。
            2. dp2[v]的含义后边将进行一次变动,变更为v的兄弟、u的父过来的路径,merge上u节点本身最后得出来的值。即v以父亲为邻居向外延伸的最优值(不含v,但含父)。
            3. 同时dp1[u]的含义更新为目标的含义:以u为根,u的子节点们所在子树的最优情况。
            4. 这样dp1,dp2将分别代表u的向下子树的最优,u除了向下子树以外的最优(一定从父节点来,但父节点可能从兄弟来或祖宗来)
            <步骤>
            1. 先从任意root出发(一般是0),获取bfs层序。这里是为了方便dp,或者直接dfs树形DP其实也是可以的,但可能会爆栈。
            2. 自底向上dp,用自身子树情况更新dp1,除自己外的兄弟子树情况更新dp2。
            3. 自顶向下dp,变更dp2和dp1的含义。这时对于u来说存在三种子树(强烈建议画图观察):
                ① u本身的子树,它们的最优解已经存在于之前的dp1[u]。
                ② u的兄弟子树+fa,它们的最优解=composition(dp2[u],fa,u,use_fa=1)。
                ③ 连接到fa的最优子树+fa,最优解=composition(dp2[fa],fa,u,use_fa=1)。
                    注意这里的dp2含义已变更,由于我们是自顶向下计算,因此dp2[fa]已更新。
                    ②和③可以写一起来更新dp2[u]

            計算量 O(|V|) (Vは頂点数)
            参照 https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e
            """
            # step1
            root -= self._decrement
            assert 0 <= root < self._n
            self._root = root
            g = self.g
            _fas = self._parent = [-1] * self._n  # 记录每个节点的父节点
            _order = self._order = [root]  # bfs记录遍历层序,便于后续dp
            q = deque([root])
            while q:
                u = q.popleft()
                for v in g[u]:
                    if v == _fas[u]:
                        continue
                    _fas[v] = u
                    _order.append(v)
                    q.append(v)

            # step2
            dp1 = [e(i) for i in range(self._n)]  # !子树部分的dp值,假设u是当前子树的根,vs是第一层儿子(它的非父邻居),则dp1[u]=op(dp1(vs))
            dp2 = [e(i) for i in
                   range(
                       self._n)]  # !非子树部分的dp值,假设u是当前子树的根,vs={v1,v2..vi..}是第一层儿子(它的非父邻居),则dp2[vi]=op(dp1(vs-vi)),即他的兄弟们

            for u in _order[::-1]:  # 从下往上拓扑序dp
                res = e(u)
                for v in g[u]:
                    if _fas[u] == v:
                        continue
                    dp2[v] = res
                    res = op(res, composition(dp1[v], u, v, 0))  # op从下往上更新dp1
                # 由于最大可能在后边,因此还得倒序来一遍
                res = e(u)
                for v in g[u][::-1]:
                    if _fas[u] == v:
                        continue
                    dp2[v] = op(res, dp2[v])
                    res = op(res, composition(dp1[v], u, v, 0))
                dp1[u] = res

            # step3 自顶向下计算每个节点作为根时的dp1,dp2的含义变更为:dp2[u]为u的兄弟+父。这样对v来说dp1[u] = op(dp1[fa],dp1[u])

            for u in _order[1:]:
                fa = _fas[u]
                dp2[u] = composition(
                    op(dp2[u], dp2[fa]), fa, u, 1
                )  # op从上往下更新dp2
                dp1[u] = op(dp1[u], dp2[u])

            return dp1

    n, = RI()
    r = Rerooting(n)
    ws = {}
    for _ in range(n - 1):
        u, v, w = RI()
        u -= 1
        v -= 1
        ws[u, v] = w
        ws[v, u] = w
        r.add_edge(u, v)
    d = RILST()

    def e(root: int) -> int:
        # 转移时单个点不管相邻子树的贡献
        # 例:最も遠い点までの距離を求める場合 e=0
        return 0

    def op(child_res1: int, child_res2: int) -> int:
        # 如何组合/取舍两个子树的答案
        # 例:求最长路径 return max(childRes1,childRes2)
        return max(child_res1, child_res2)

    def composition(from_res: int, fa: int, u: int, use_fa: int = 0) -> int:
        # 知道子树的每个子树和节点值,如何更新子树答案;
        # 例子:求最长路径 return from_res+1
        if use_fa == 0:  # cur -> parent
            return max(from_res, d[u]) + ws[u, fa]
        return max(from_res, d[fa]) + ws[fa, u]

    res = r.rerooting(e, op, composition)
    print(*res, sep='\n')



def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc


#    927     ms
def solve2():
    n, = RI()
    g = [[] for _ in range(n)]

    for _ in range(n - 1):
        u, v, w = RI()
        u -= 1
        v -= 1
        g[u].append((v, w))
        g[v].append((u, w))
    d = RILST()
    down1, down2, up = [0] * n, [0] * n, [0] * n

    @bootstrap
    def dfs(u, fa):
        for v, w in g[u]:
            if v == fa:
                continue
            yield dfs(v, u)
            s = max(down1[v], d[v]) + w
            if s > down1[u]:
                down2[u] = down1[u]
                down1[u] = s
            elif s > down2[u]:
                down2[u] = s
        yield

    @bootstrap
    def reroot(u, fa):
        for v, w in g[u]:
            if v != fa:
                if down1[u] == down1[v] + w or down1[u] == d[v] + w:
                    up[v] = max(down2[u] + w, up[u] + w, d[u] + w)
                else:
                    up[v] = max(down1[u] + w, up[u] + w, d[u] + w)
                yield reroot(v, u)
        yield

    dfs(0, -1)
    reroot(0, -1)
    print(*[max(a, b) for a, b in zip(up, down1)], sep='\n')


#     715    ms
def solve3():
    n, = RI()
    g = [[] for _ in range(n)]

    for _ in range(n - 1):
        u, v, w = RI()
        u -= 1
        v -= 1
        g[u].append((v, w))
        g[v].append((u, w))
    d = RILST()
    down1, down2, up = [0] * n, [0] * n, [0] * n
    order = []
    q = deque([0])
    fas = [-1] * n
    while q:
        u = q.popleft()
        order.append(u)
        for v, w in g[u]:
            if v != fas[u]:
                fas[v] = u
                q.append(v)

    for u in order[::-1]:
        for v, w in g[u]:
            if v == fas[u]:
                continue
            s = max(down1[v], d[v]) + w
            if s > down1[u]:
                down2[u] = down1[u]
                down1[u] = s
            elif s > down2[u]:
                down2[u] = s
    for u in order:
        for v, w in g[u]:
            if v != fas[u]:
                if down1[u] == down1[v] + w or down1[u] == d[v] + w:
                    up[v] = max(down2[u] + w, up[u] + w, d[u] + w)
                else:
                    up[v] = max(down1[u] + w, up[u] + w, d[u] + w)

    print(*[max(a, b) for a, b in zip(up, down1)], sep='\n')


if __name__ == '__main__':
    solve()

5. max版换根dp求去掉一个叶子的值。

链接: 2538. 最大价值和与最小价值和的差值

  • 这题是周赛T4,我当时用树形DP写了一大堆两次dfs做出来了。
    手写换根
class Solution:
    def maxOutput(self, n: int, edges: List[List[int]], price: List[int]) -> int:
        g = [[] for _ in range(n)]
        for u,v in edges:
            g[u].append(v)
            g[v].append(u)
            # print(u,v)
        ans = 0

        f = [[0,0,0] for _ in range(n)]  # f[i][0/1/2]代表:i向下走最大路径和,向下走次大路径和,向上走最大路径和;答案一定在向下或向上走的路径中
        def dfs1(u,fa):  # 更新向下走的最大/次大路径和
            f[u][0] = p = price[u]
            for v in g[u]:
                if v != fa:
                    dfs1(v,u)
                    x = f[v][0]+p
                    if f[u][0]<x:
                        f[u][1] = f[u][0]
                        f[u][0] = x
                    elif f[u][1] < x:
                        f[u][1] = x 
        
        def dfs2(u,fa):
            for v in g[u]:
                if v != fa:
                    p = price[v]
                    if f[u][0] == f[v][0] + price[u]:
                        f[v][2] = max(f[u][2],f[u][1]) + p
                    else:
                        f[v][2] = max(f[u][2],f[u][0]) + p 
                    dfs2(v,u)
        dfs1(0,-1)
        dfs2(0,-1)

        return max(max(a-price[i],c-price[i]) for i,(a,_,c) in enumerate(f))

套模板

from typing import List, Tuple, Optional
from collections import defaultdict, Counter, deque

MOD = int(1e9 + 7)
INF = int(1e20)

# 给你一个 n 个节点的无向无根图,节点编号为 0 到 n - 1 。给你一个整数 n 和一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] = [ai, bi] 表示树中节点 ai 和 bi 之间有一条边。

# 每个节点都有一个价值。给你一个整数数组 price ,其中 price[i] 是第 i 个节点的价值。

# 一条路径的 价值和 是这条路径上所有节点的价值之和。

# 你可以选择树中任意一个节点作为根节点 root 。选择 root 为根的 开销 是以 root 为起点的所有路径中,价值和 最大的一条路径与最小的一条路径的差值。

# 请你返回所有节点作为根节点的选择中,最大 的 开销 为多少。


from typing import Callable, Generic, List, TypeVar

T = TypeVar("T")

E = Callable[[int], T]
"""identify element of op, and answer of leaf"""

Op = Callable[[T, T], T]
"""merge value of child node"""

Composition = Callable[[T, int, int, int], T]
"""return value from child node to parent node"""


class Rerooting(Generic[T]):
    __slots__ = ("g", "_n", "_decrement", "_root", "_parent", "_order")

    def __init__(self, n: int, decrement: int = 0, edges=None):
        """
        n: 节点个数
        decrement: 节点id可能需要偏移 (1-indexed则-1, 0-indexed则0)
        """
        self.g = g = [[] for _ in range(n)]
        self._n = n
        self._decrement = decrement
        self._root = None  # 一开始的根
        if edges:
            for u, v in edges:
                u -= decrement
                v -= decrement
                g[u].append(v)
                g[v].append(u)

    def add_edge(self, u: int, v: int):
        """
        无向树加边
        """
        u -= self._decrement
        v -= self._decrement
        self.g[u].append(v)
        self.g[v].append(u)

    def rerooting(
            self, e: E["T"], op: Op["T"], composition: Composition["T"], root=0
    ) -> List["T"]:
        """
        - e: 初始化每个节点的价值
          (root) -> res
          mergeの単位元
          例:求最长路径 e=0

        - op: 两个子树答案如何组合或取舍
          (childRes1,childRes2) -> newRes
          例:求最长路径 return max(childRes1,childRes2)

        - composition: 知道子子树答案和节点值,如何更新子树答案
          (from_res,fa,u,use_fa) -> new_res
          use_fa: 0表示用u更新fa的dp1,1表示用fa更新u的dp2
          例:最长路径return from_res+1

        - root: 可能要设置初始根,默认是0
        <概要> 换根DP,用线性时间获取以每个节点为根整颗树的情况。
        注意最终返回的dp[u]代表以u为根时,u的所有子树的最优情况(不包括u节点本身),因此如果要整颗子树情况,还要再额外计算。
        1. 记录dp1,dp2。其中:
            dp1[u] 代表 以u为根的子树,它的孩子子树的最优值,即u节点本身不参与计算。注意,和我们一般定义的f[u]代表以u为根的子树2情况不同。
            dp2[v] 代表 除了v以外,它的兄弟子树的最优值。依然注意,v不参与,同时u也不参与(u是v的父节点)。
            建议画图理解。
        2. dp2[v]的含义后边将进行一次变动,变更为v的兄弟、u的父过来的路径,merge上u节点本身最后得出来的值。即v以父亲为邻居向外延伸的最优值(不含v,但含父)。
        3. 同时dp1[u]的含义更新为目标的含义:以u为根,u的子节点们所在子树的最优情况。
        4. 这样dp1,dp2将分别代表u的向下子树的最优,u除了向下子树以外的最优(一定从父节点来,但父节点可能从兄弟来或祖宗来)
        <步骤>
        1. 先从任意root出发(一般是0),获取bfs层序。这里是为了方便dp,或者直接dfs树形DP其实也是可以的,但可能会爆栈。
        2. 自底向上dp,用自身子树情况更新dp1,除自己外的兄弟子树情况更新dp2。
        3. 自顶向下dp,变更dp2和dp1的含义。这时对于u来说存在三种子树(强烈建议画图观察):
            ① u本身的子树,它们的最优解已经存在于之前的dp1[u]。
            ② u的兄弟子树+fa,它们的最优解=composition(dp2[u],fa,u,use_fa=1)。
            ③ 连接到fa的最优子树+fa,最优解=composition(dp2[fa],fa,u,use_fa=1)。
                注意这里的dp2含义已变更,由于我们是自顶向下计算,因此dp2[fa]已更新。
                ②和③可以写一起来更新dp2[u]

        計算量 O(|V|) (Vは頂点数)
        参照 https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e
        """
        # step1
        root -= self._decrement
        assert 0 <= root < self._n
        self._root = root
        g = self.g
        _fas = self._parent = [-1] * self._n  # 记录每个节点的父节点
        _order = self._order = [root]  # bfs记录遍历层序,便于后续dp
        q = deque([root])
        while q:
            u = q.popleft()
            for v in g[u]:
                if v == _fas[u]:
                    continue
                _fas[v] = u
                _order.append(v)
                q.append(v)

        # step2
        dp1 = [e(i) for i in range(self._n)]  # !子树部分的dp值,假设u是当前子树的根,vs是第一层儿子(它的非父邻居),则dp1[u]=op(dp1(vs))
        dp2 = [e(i) for i in
               range(self._n)]  # !非子树部分的dp值,假设u是当前子树的根,vs={v1,v2..vi..}是第一层儿子(它的非父邻居),则dp2[vi]=op(dp1(vs-vi)),即他的兄弟们
        for u in _order[::-1]:  # 从下往上拓扑序dp
            res = e(u)
            for v in g[u]:
                if _fas[u] == v:
                    continue
                dp2[v] = res
                res = op(res, composition(dp1[v], u, v, 0))  # op从下往上更新dp1
            # 由于最大可能在后边,因此还得倒序来一遍
            res = e(u)
            for v in g[u][::-1]:
                if _fas[u] == v:
                    continue
                dp2[v] = op(res, dp2[v])
                res = op(res, composition(dp1[v], u, v, 0))
            dp1[u] = res

        # step3 自顶向下计算每个节点作为根时的dp1,dp2的含义变更为:dp2[u]为u的兄弟+父。这样对v来说dp1[u] = op(dp1[fa],dp1[u])
        for u in _order[1:]:  #
            fa = _fas[u]
            dp2[u] = composition(
                op(dp2[u], dp2[fa]), fa, u, 1
            )  # op从上往下更新dp2
            dp1[u] = op(dp1[u], dp2[u])
        return dp1


class Solution:
    def maxOutput(self, n: int, edges: List[List[int]], price: List[int]) -> int:
        def e(root: int) -> int:
            # mergeの単位元
            # 例:最も遠い点までの距離を求める場合 e=0
            return 0

        def op(child_res1: int, child_res2: int) -> int:
            # 如何组合/取舍两个子树的答案
            # 例:求最长路径 return max(childRes1,childRes2)
            return max(child_res1, child_res2)

        def composition(from_res: int, fa: int, u: int, use_fa: int = 0) -> int:
            # 知道子树的每个子树和节点值,如何更新子树答案;
            # 例子:求最长路径 return from_res+1
            if use_fa == 0:  # cur -> parent
                return from_res + price[u]
            return from_res + price[fa]

        R = Rerooting(n, edges=edges)
        # for u, v in edges:
        #     R.add_edge(u, v)
        res = R.rerooting(e, max, composition)
        return max(res)

6. sum版换根dp求每个节点作为根猜对多少边。

链接: 2581. 统计可能的树根数目

  • 周赛T4,由于不会sum版换根翻车。
  • 题意就是求每个节点作为根猜对多少边。
    手写换根
class Solution:
    def rootCount(self, edges: List[List[int]], guesses: List[List[int]], k: int) -> int:
        n = len(edges)+ 1
        g = [[] for _ in range(n)]
        for u,v in edges:
            g[u].append(v)
            g[v].append(u)
        s = set(tuple(x) for x in guesses)
        f = [0]*n
        def dfs(u,fa):
            for v in g[u]:
                if v != fa:
                    if (u,v) in s:
                        f[0] += 1
                    dfs(v,u)
        
        def reroot(u,fa):
            for v in g[u]:
                if v != fa:
                    f[v] = f[u] + int((v,u) in s) - int((u,v) in s)
                    reroot(v,u)
        dfs(0,-1)
        reroot(0,-1)
        return sum(x >= k for x in f)

套模板


from typing import Callable, Generic, List, TypeVar

T = TypeVar("T")
E = Callable[[int], T]
"""identify element of op, and answer of leaf"""
Op = Callable[[T, T], T]
"""merge value of child node"""
Composition = Callable[[T, int, int, int], T]
"""return value from child node to parent node"""



class Rerooting(Generic[T]):
    __slots__ = ("g", "_n", "_decrement", "_root", "_parent", "_order")

    def __init__(self, n: int, decrement: int = 0, edges=None):
        """
        n: 节点个数
        decrement: 节点id可能需要偏移 (1-indexed则-1, 0-indexed则0)
        """
        self.g = g = [[] for _ in range(n)]
        self._n = n
        self._decrement = decrement
        self._root = None  # 一开始的根
        if edges:
            for u, v in edges:
                u -= decrement
                v -= decrement
                g[u].append(v)
                g[v].append(u)

    def add_edge(self, u: int, v: int):
        """
        无向树加边
        """
        u -= self._decrement
        v -= self._decrement
        self.g[u].append(v)
        self.g[v].append(u)

    def rerooting(
            self, e: E["T"], op: Op["T"], composition: Composition["T"], root=0
    ) -> List["T"]:
        """
        - e: 初始化每个节点的价值
          (root) -> res
          mergeの単位元
          例:求最长路径 e=0

        - op: 两个子树答案如何组合或取舍
          (childRes1,childRes2) -> newRes
          例:求最长路径 return max(childRes1,childRes2)

        - composition: 知道子子树答案和节点值,如何更新子树答案
          (from_res,fa,u,use_fa) -> new_res
          use_fa: 0表示用u更新fa的dp1,1表示用fa更新u的dp2
          例:最长路径return from_res+1

        - root: 可能要设置初始根,默认是0
        <概要> 换根DP模板,用线性时间获取以每个节点为根整颗树的情况。
        注意最终返回的dp[u]代表以u为根时,u的所有子树的最优情况(不包括u节点本身),因此如果要整颗子树情况,还要再额外计算。
        1. 记录dp1,dp2。其中:
            dp1[u] 代表 以u为根的子树,它的孩子子树的最优值,即u节点本身不参与计算。注意,和我们一般定义的f[u]代表以u为根的子树2情况不同。
            dp2[v] 代表 除了v以外,它的兄弟子树的最优值。依然注意,v不参与,同时u也不参与(u是v的父节点)。
            建议画图理解。
        2. dp2[v]的含义后边将进行一次变动,变更为v的兄弟、u的父过来的路径,merge上u节点本身最后得出来的值。即v以父亲为邻居向外延伸的最优值(不含v,但含父)。
        3. 同时dp1[u]的含义更新为目标的含义:以u为根,u的子节点们所在子树的最优情况。
        4. 这样dp1,dp2将分别代表u的向下子树的最优,u除了向下子树以外的最优(一定从父节点来,但父节点可能从兄弟来或祖宗来)
        <步骤>
        1. 先从任意root出发(一般是0),获取bfs层序。这里是为了方便dp,或者直接dfs树形DP其实也是可以的,但可能会爆栈。
        2. 自底向上dp,用自身子树情况更新dp1,除自己外的兄弟子树情况更新dp2。
        3. 自顶向下dp,变更dp2和dp1的含义。这时对于u来说存在三种子树(强烈建议画图观察):
            ① u本身的子树,它们的最优解已经存在于之前的dp1[u]。
            ② u的兄弟子树+fa,它们的最优解=composition(dp2[u],fa,u,use_fa=1)。
            ③ 连接到fa的最优子树+fa,最优解=composition(dp2[fa],fa,u,use_fa=1)。
                注意这里的dp2含义已变更,由于我们是自顶向下计算,因此dp2[fa]已更新。
                ②和③可以写一起来更新dp2[u]

        計算量 O(|V|) (Vは頂点数)
        参照 https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e
        """
        # step1
        root -= self._decrement
        assert 0 <= root < self._n
        self._root = root
        g = self.g
        _fas = self._parent = [-1] * self._n  # 记录每个节点的父节点
        _order = self._order = [root]  # bfs记录遍历层序,便于后续dp
        q = deque([root])
        while q:
            u = q.popleft()
            for v in g[u]:
                if v == _fas[u]:
                    continue
                _fas[v] = u
                _order.append(v)
                q.append(v)

        # step2
        dp1 = [e(i) for i in range(self._n)]  # !子树部分的dp值,假设u是当前子树的根,vs是第一层儿子(它的非父邻居),则dp1[u]=op(dp1(vs))
        dp2 = [e(i) for i in
               range(
                   self._n)]  # !非子树部分的dp值,假设u是当前子树的根,vs={v1,v2..vi..}是第一层儿子(它的非父邻居),则dp2[vi]=op(dp1(vs-vi)),即他的兄弟们

        for u in _order[::-1]:  # 从下往上拓扑序dp
            res = e(u)
            for v in g[u]:
                if _fas[u] == v:
                    continue
                dp2[v] = res
                res = op(res, composition(dp1[v], u, v, 0))  # op从下往上更新dp1
            # 由于最大可能在后边,因此还得倒序来一遍
            res = e(u)
            for v in g[u][::-1]:
                if _fas[u] == v:
                    continue
                dp2[v] = op(res, dp2[v])
                res = op(res, composition(dp1[v], u, v, 0))
            dp1[u] = res

        # step3 自顶向下计算每个节点作为根时的dp1,dp2的含义变更为:dp2[u]为u的兄弟+父。这样对v来说dp1[u] = op(dp1[fa],dp1[u])

        for u in _order[1:]:
            fa = _fas[u]
            dp2[u] = composition(
                op(dp2[u], dp2[fa]), fa, u, 1
            )  # op从上往下更新dp2
            dp1[u] = op(dp1[u], dp2[u])

        return dp1
        
class Solution:
    def rootCount(self, edges: List[List[int]], guesses: List[List[int]], k: int) -> int:
        n = len(edges) + 1
        
        s = set(tuple(x) for x in guesses)
        tree = Rerooting(n)
        for u,v in edges:
            tree.add_edge(u,v)
        def e(root: int) -> int:
            # 转移时单个点不管相邻子树的贡献
            # 例:最も遠い点までの距離を求める場合 e=0
            return 0

        def op(child_res1: int, child_res2: int) -> int:
            # 如何组合/取舍两个子树的答案
            # 例:求最长路径 return max(childRes1,childRes2)
            return child_res1 + child_res2

        def composition(from_res: int, fa: int, u: int, use_fa: int = 0) -> int:
            # 知道子树的每个子树和节点值,如何更新子树答案;
            # 例子:求最长路径 return from_res+1
            if use_fa == 0:  # cur -> parent
                return from_res + int((fa,u) in s)
            return from_res + int((u,fa) in s)

        res = tree.rerooting(e, op, composition)
        
            
        return sum(res[i]>=k for i in range(n))

7. sum版换根dp求每个节点作为根猜错多少遍(有多少反边)。

链接: cf219 D. Choosing Capital for Treeland

  • 跟上一题类似,但是问的是反边。
# Problem: D. Choosing Capital for Treeland
# Contest: Codeforces - Codeforces Round 135 (Div. 2)
# URL: https://codeforces.com/problemset/problem/219/D
# Memory Limit: 256 MB
# Time Limit: 3000 ms

import sys
from collections import *
from types import GeneratorType

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')

MOD = 10 ** 9 + 7
PROBLEM = """
https://codeforces.com/problemset/problem/219/D

输入 n(2≤n≤2e5) 和 n-1 条边 v w,表示一条 v->w 的有向边。(节点编号从 1 开始)
保证输入构成一棵树。

定义 f(x) 表示以 x 为根时,要让 x 能够到达任意点,需要反向的边的数量。
输出 min(f(x)),以及所有等于 min(f(x)) 的节点编号(按升序输出)。
输入
3
2 1
2 3
输出
0
2 

输入
4
1 4
2 4
3 4
输出
2
1 2 3 
"""
"""换根DP,类似双周赛T4
先求以0为根时,反边数量
然后求以其它为根时反边数量。
一对邻居分别作根时,状态的差别只跟这条边有关,因此可以按次序求出每个节点为根时的状态。
想象以0为根时状态已求出,0-1有边,揪住1向上提,让1做根。
发现0的其它子树状态不变,1的其它子树状态也不变,只有0-1这条边变了。
以此类推。
"""

def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to

    return wrappedfunc


#     2526  ms
def solve1():
    n, = RI()
    g = [[] for _ in range(n)]
    s = set()
    for _ in range(n - 1):
        u, v = RI()
        u -= 1
        v -= 1
        g[u].append(v)
        g[v].append(u)
        s.add((u, v))
    f = [0] * n

    @bootstrap
    def dfs(u, fa):
        for v in g[u]:
            if v == fa: continue
            if (v, u) in s:
                f[0] += 1
            yield dfs(v, u)
        yield

    @bootstrap
    def reroot(u, fa):
        for v in g[u]:
            if v == fa: continue
            f[v] = f[u] - int((v, u) in s) + int((u, v) in s)
            yield reroot(v, u)
        yield

    dfs(0, -1)
    reroot(0, -1)
    mn = min(f)
    ans = [i + 1 for i, v in enumerate(f) if v == mn]
    print(mn)
    print(*ans)


#   1526    ms
def solve2():
    n, = RI()
    g = [[] for _ in range(n)]
    s = set()
    for _ in range(n - 1):
        u, v = RI()
        u -= 1
        v -= 1
        g[u].append(v)
        g[v].append(u)
        s.add((u, v))
    f = [0] * n
    fas = [-1] * n
    order = []
    q = deque([0])
    while q:
        u = q.popleft()
        order.append(u)
        for v in g[u]:
            if v == fas[u]: continue
            fas[v] = u
            q.append(v)

    for u in order[::-1]:
        for v in g[u]:
            if v == fas[u]: continue
            f[u] += f[v] + int((v, u) in s)
    for u in order:
        for v in g[u]:
            if v == fas[u]: continue
            f[v] = f[u] + int((u, v) in s) - int((v, u) in s)
    # print(f)
    mn = min(f)
    ans = [i + 1 for i, v in enumerate(f) if v == mn]
    print(mn)
    print(*ans)


#    1402   ms
def solve3():
    n, = RI()
    g = [[] for _ in range(n)]
    for _ in range(n - 1):
        u, v = RI()
        u -= 1
        v -= 1
        g[u].append((v, 1))  # 邻居和方向
        g[v].append((u, -1))  # 反边
    f = [0] * n
    fas = [-1] * n
    order = []
    q = deque([0])
    while q:
        u = q.popleft()
        order.append(u)
        for v, _ in g[u]:
            if v == fas[u]: continue
            fas[v] = u
            q.append(v)

    for u in order[::-1]:
        for v, d in g[u]:
            if v == fas[u]: continue
            # f[u] += f[v] + (d < 0)  # 如果是反边则+1
            f[u] += f[v] + ((-d + 1) >> 1)  # 如果是反边则+1 1402
    for u in order:
        for v, d in g[u]:
            if v == fas[u]: continue
            # f[v] = f[u] + (d > 0) - (d < 0)  # uv是正边的话,根从u->v则数量+1,反边则-1
            f[v] = f[u] + d  # uv是正边的话,根从u->v则数量+1,反边则-1
    # print(f)
    mn = min(f)
    ans = [i + 1 for i, v in enumerate(f) if v == mn]
    print(mn)
    print(*ans)


#   1714    ms
def solve():
    from typing import Callable, Generic, List, TypeVar

    T = TypeVar("T")
    E = Callable[[int], T]
    """identify element of op, and answer of leaf"""
    Op = Callable[[T, T], T]
    """merge value of child node"""
    Composition = Callable[[T, int, int, int], T]
    """return value from child node to parent node"""

    class Rerooting(Generic[T]):
        __slots__ = ("g", "_n", "_decrement", "_root", "_parent", "_order")

        def __init__(self, n: int, decrement: int = 0, edges=None):
            """
            n: 节点个数
            decrement: 节点id可能需要偏移 (1-indexed则-1, 0-indexed则0)
            """
            self.g = g = [[] for _ in range(n)]
            self._n = n
            self._decrement = decrement
            self._root = None  # 一开始的根
            if edges:
                for u, v in edges:
                    u -= decrement
                    v -= decrement
                    g[u].append(v)
                    g[v].append(u)

        def add_edge(self, u: int, v: int):
            """
            无向树加边
            """
            u -= self._decrement
            v -= self._decrement
            self.g[u].append(v)
            self.g[v].append(u)

        def rerooting(
                self, e: E["T"], op: Op["T"], composition: Composition["T"], root=0
        ) -> List["T"]:
            """
            - e: 初始化每个节点的价值
              (root) -> res
              mergeの単位元
              例:求最长路径 e=0

            - op: 两个子树答案如何组合或取舍
              (childRes1,childRes2) -> newRes
              例:求最长路径 return max(childRes1,childRes2)

            - composition: 知道子子树答案和节点值,如何更新子树答案
              (from_res,fa,u,use_fa) -> new_res
              use_fa: 0表示用u更新fa的dp1,1表示用fa更新u的dp2
              例:最长路径return from_res+1

            - root: 可能要设置初始根,默认是0
            <概要> 换根DP模板,用线性时间获取以每个节点为根整颗树的情况。
            注意最终返回的dp[u]代表以u为根时,u的所有子树的最优情况(不包括u节点本身),因此如果要整颗子树情况,还要再额外计算。
            1. 记录dp1,dp2。其中:
                dp1[u] 代表 以u为根的子树,它的孩子子树的最优值,即u节点本身不参与计算。注意,和我们一般定义的f[u]代表以u为根的子树2情况不同。
                dp2[v] 代表 除了v以外,它的兄弟子树的最优值。依然注意,v不参与,同时u也不参与(u是v的父节点)。
                建议画图理解。
            2. dp2[v]的含义后边将进行一次变动,变更为v的兄弟、u的父过来的路径,merge上u节点本身最后得出来的值。即v以父亲为邻居向外延伸的最优值(不含v,但含父)。
            3. 同时dp1[u]的含义更新为目标的含义:以u为根,u的子节点们所在子树的最优情况。
            4. 这样dp1,dp2将分别代表u的向下子树的最优,u除了向下子树以外的最优(一定从父节点来,但父节点可能从兄弟来或祖宗来)
            <步骤>
            1. 先从任意root出发(一般是0),获取bfs层序。这里是为了方便dp,或者直接dfs树形DP其实也是可以的,但可能会爆栈。
            2. 自底向上dp,用自身子树情况更新dp1,除自己外的兄弟子树情况更新dp2。
            3. 自顶向下dp,变更dp2和dp1的含义。这时对于u来说存在三种子树(强烈建议画图观察):
                ① u本身的子树,它们的最优解已经存在于之前的dp1[u]。
                ② u的兄弟子树+fa,它们的最优解=composition(dp2[u],fa,u,use_fa=1)。
                ③ 连接到fa的最优子树+fa,最优解=composition(dp2[fa],fa,u,use_fa=1)。
                    注意这里的dp2含义已变更,由于我们是自顶向下计算,因此dp2[fa]已更新。
                    ②和③可以写一起来更新dp2[u]

            計算量 O(|V|) (Vは頂点数)
            参照 https://qiita.com/keymoon/items/2a52f1b0fb7ef67fb89e
            """
            # step1
            root -= self._decrement
            assert 0 <= root < self._n
            self._root = root
            g = self.g
            _fas = self._parent = [-1] * self._n  # 记录每个节点的父节点
            _order = self._order = [root]  # bfs记录遍历层序,便于后续dp
            q = deque([root])
            while q:
                u = q.popleft()
                for v in g[u]:
                    if v == _fas[u]:
                        continue
                    _fas[v] = u
                    _order.append(v)
                    q.append(v)

            # step2
            dp1 = [e(i) for i in range(self._n)]  # !子树部分的dp值,假设u是当前子树的根,vs是第一层儿子(它的非父邻居),则dp1[u]=op(dp1(vs))
            dp2 = [e(i) for i in
                   range(
                       self._n)]  # !非子树部分的dp值,假设u是当前子树的根,vs={v1,v2..vi..}是第一层儿子(它的非父邻居),则dp2[vi]=op(dp1(vs-vi)),即他的兄弟们

            for u in _order[::-1]:  # 从下往上拓扑序dp
                res = e(u)
                for v in g[u]:
                    if _fas[u] == v:
                        continue
                    dp2[v] = res
                    res = op(res, composition(dp1[v], u, v, 0))  # op从下往上更新dp1
                # 由于最大可能在后边,因此还得倒序来一遍
                res = e(u)
                for v in g[u][::-1]:
                    if _fas[u] == v:
                        continue
                    dp2[v] = op(res, dp2[v])
                    res = op(res, composition(dp1[v], u, v, 0))
                dp1[u] = res

            # step3 自顶向下计算每个节点作为根时的dp1,dp2的含义变更为:dp2[u]为u的兄弟+父。这样对v来说dp1[u] = op(dp1[fa],dp1[u])

            for u in _order[1:]:
                fa = _fas[u]
                dp2[u] = composition(
                    op(dp2[u], dp2[fa]), fa, u, 1
                )  # op从上往下更新dp2
                dp1[u] = op(dp1[u], dp2[u])

            return dp1

    n, = RI()
    r = Rerooting(n)
    s = set()
    for _ in range(n - 1):
        u, v = RI()
        u -= 1
        v -= 1
        s.add((u, v))
        r.add_edge(u, v)

    def e(root: int) -> int:
        # 转移时单个点不管相邻子树的贡献
        # 例:最も遠い点までの距離を求める場合 e=0
        return 0

    def op(child_res1: int, child_res2: int) -> int:
        # 如何组合/取舍两个子树的答案
        # 例:求最长路径 return max(childRes1,childRes2)
        return child_res1 + child_res2

    def composition(from_res: int, fa: int, u: int, use_fa: int = 0) -> int:
        # 知道子树的每个子树和节点值,如何更新子树答案;
        # 例子:求最长路径 return from_res+1
        if use_fa == 0:  # cur -> parent 用子节点更新父节点
            return from_res + int((u, fa) in s)  # 计算反边数量
        return from_res + int((fa, u) in s)  # 反过来把子节点当父的话,u->fa才是正边。

    f = r.rerooting(e, op, composition)

    # print(f)
    mn = min(f)
    ans = [i + 1 for i, v in enumerate(f) if v == mn]
    print(mn)
    print(*ans)


if __name__ == '__main__':
    solve()

三、其他

四、更多例题

五、参考链接

  • 6
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值