4381 翻转树边(树形dp)

1. 问题描述: 

给定一个 n 个节点的树。节点编号为 1∼n。树中的 n − 1 条边均为单向边。现在,我们需要选取一个节点作为中心点,并希望从中心点出发可以到达其他所有节点。但是,由于树中的边均为单向边,所以在选定中心点后,可能无法从中心点出发到达其他所有节点。为此,我们需要翻转一些边的方向,从而使得所选中心点可以到达其他所有节点。我们希望选定中心点后,所需翻转方向的边的数量尽可能少。请你确定哪些点可以选定为中心点,并输出所需的最少翻转边数量。

输入格式

第一行包含整数 n,接下来 n − 1 行,每行包含两个整数 a,b,表示存在一条从 a 到 b 的单向边。

输出格式

第一行输出一个整数,表示所需的最少翻转边数量。第二行以升序顺序输出所有可选中心点(即所需翻转边数量最少的中心点)的编号。

数据范围

前三个测试点满足 2 ≤ n ≤ 5;
所有测试点满足 2 ≤ n ≤ 2 × 10 ^ 5,1 ≤ a,b ≤ n,a ≠ b;

输入样例1:

3
2 1
2 3

输出样例1:

0
2

输入样例2:

4
1 4
2 4
3 4

输出样例2:

2
1 2 3
来源:https://www.acwing.com/problem/content/description/4384/

2. 思路分析:

由题目可知我们需要将时间复杂度控制在 O(nlogn) 以内,如果直接使用 dfs 求解,对于每一个点需要求解一遍 dfs 那么时间复杂度为 O(n ^ 2) 所以肯定是会超时的;但是我们仔细分析题目可以知道这是经典的两个方向的树形 dp 问题,由于是两个方向的树形 dp,所以对于每一个节点来说可以分为两大类:第一类是节点往上走的代价,第二类是节点往下走的代价,我们可以定义两个数组 down,up,其中 down[i] 表示节点 i 往下走的代价,up[i] 表示节点 i 往上走的代价,当我们求解出 down 和 up 的值之后,枚举每一个节点那么就可以求解出答案,所以关键是求解 down 和 up 的值;对于 down 来说还是比较好求解的,直接 dfs 求解的过程中,由子节点的信息更新父节点的信息,所以关键是 up 的求解,我们可以画图理解一下如何求解 up:可以发现求解 up[u] 的时候可以直接计算出来,up[u] 的值需要加上 fa 往上走的代价 up[fa],fa 往下走的代价减去 u 往下走的代价 down[u],加上 u->fa 的权重 w',减去 fa->u的权重 w,由于计算 up[u] 的时候需要知道父节点 up[fa] 的信息所以需要在递归前更新 up[u] 的信息,当计算好了 up 和 down 之后枚举每一个节点计算出答案即可。

3. 代码如下:

go:

package main

import (
	"bufio"
	"fmt"
	"io"
	"os"
)

// 边数是点数的两倍
const N, M = 200010, N * 2

var (
	idx         int
    // 使用数组模拟邻接表的方式来存储
	h, down, up [N]int
	e, w, ne    [M]int
)

func min(a, b int) int {
	if a < b {
		return a
	}
	return b
}

func add(a int, b int, c int) {
	e[idx] = b
	w[idx] = c
	ne[idx] = h[a]
	h[a] = idx
	idx += 1
}

func dfs_down(u int, from int) {
	for i := h[u]; i != -1; i = ne[i] {
		if i == (from ^ 1) {
			continue
		}
		j := e[i]
		dfs_down(j, i)
		down[u] += down[j] + w[i]
	}
}


// from 传递的是边的编号
func dfs_up(u int, from int) {
	if from != -1 {
		fa := e[from^1]
		up[u] = up[fa] + down[fa] - down[u] - w[from] + w[from^1]
	}
	for i := h[u]; i != -1; i = ne[i] {
		if i == (from ^ 1) {
			continue
		}
		j := e[i]
		dfs_up(j, i)
	}
}

func run(r io.Reader, w io.Writer) {
	in := bufio.NewReader(r)
	out := bufio.NewWriter(w)
	defer out.Flush()
	var (
		n, a, b int
	)
	fmt.Fscan(in, &n)
	for i := 0; i < N; i++ {
		h[i] = -1
	}
	for i := 0; i < n-1; i++ {
		fmt.Fscan(in, &a, &b)
		add(a, b, 0)
		add(b, a, 1)
	}
	dfs_down(1, -1)
	dfs_up(1, -1)
	res := 2 * n
	for i := 1; i <= n; i++ {
		res = min(res, down[i]+up[i])
	}
	fmt.Fprintln(out, res)
	for i := 1; i <= n; i++ {
		if down[i]+up[i] == res {
			fmt.Fprint(out, i, " ")
		}
	}
}

func main() {
	run(os.Stdin, os.Stdout)
}

数组模拟邻接表这样在添加边的时候:0-1,2-3... 是一对,通过边的编号那么就可以通过 w [no] 获取边的权重,而且在递归的递归的时候可以通过判断是否是双向边决定是否往下递归(这样就可以在不使用 python 语言字典的情况下通过边的信息知道边的权重):

# 例如 1->2, 1->3, 1->4:
# 1. 
e[0] = 2    ne[0] = -1    h[1] = 0    idx = 1
e[1] = 3    ne[1] = 0    h[1] = 1    idx = 2
e[2] = 4    ne[2] = 1    h[1] = 2    idx = 3

# 例如 1->2 2->1
e[0] = 2    ne[0] = -1    h[1] = 0    idx = 1
e[1] = 1    ne[1] = -1    h[2] = 1    idx = 2
# e[x]中x为边的编号, e[x]为有向边的终点编号, h[x]x也为点的编号, h[x]为点x对应的边的编号 

python (最后一个数据超时),python 一般对于 10 ^ 5 规模的递归数据都会堆栈溢出或者超时的问题:

import sys
from typing import List


class Solution:
    # 从节点u往下递归求解down列表的值
    def dfs_down(self, u: int, fa: int, g: List[dict], down: List[int]):
        for v, w in g[u].items():
            if v == fa: continue
            # 由子节点更新父节点的信息
            self.dfs_down(v, u, g, down)
            down[u] += down[v] + w
    
    # 求解up列表的时候dfs_up传递的fa是节点编号是因为g在存储边的信息的时候每一个g[i]都是一个字典, 这样就可以通过编号知道边的权重
    def dfs_up(self, u: int, fa: int, g: List[dict], down: List[int], up: List[int]):
        # 先求解父节点然后再求解子节点(求解往上走的时候需要知道边的权重所以在存储数据的时候g中的每一个元素为字典这样可以通过节点编号知道边的权重)
        if fa != -1:
            # 根节点没有父节点所以需要判断是否等于-1
            up[u] = up[fa] + down[fa] - down[u] - g[fa][u] + g[u][fa]
        for v, w in g[u].items():
            if v == fa: continue
            self.dfs_up(v, u, g, up, down)

    def process(self):
        n = int(input())
        # 存储有向边
        g = [dict() for i in range(n + 10)]
        for i in range(n - 1):
            a, b = map(int, input().split())
            # 注意添加的是双向边, a->b 有边说明权重为0, 否则为1
            g[a][b] = 0
            g[b][a] = 1
        down, up = [0] * (n + 10), [0] * (n + 10)
        self.dfs_down(1, -1, g, down)
        self.dfs_up(1, -1, g, down, up)
        res = n * 2
        for i in range(1, n + 1):
            res = min(res, down[i] + up[i])
        print(res)
        for i in range(1, n + 1):
            # 枚举每一个节点判断是否等于res, 如果是res说明满足条件
            if down[i] + up[i] == res:
                print(i, end=" ")


if __name__ == '__main__':
    sys.setrecursionlimit(10 ** 5)
    Solution().process()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值