九、基础算法精讲:动态规划二

一、状态机DP

1.1 买卖股票的最佳时机 II(不限制交易次数)

Leetcode 122

解法一:记忆化搜索

class Solution:
    def maxProfit(self, prices: List[int]) -> int:
        n = len(prices)

        @cache
        def dfs(i: int, hold: bool)->int:
            if i < 0:
                return -inf if hold else 0
            if hold:
                return max(dfs(i - 1, True), dfs(i - 1, False) - prices[i])
            return max(dfs(i - 1, False), dfs(i - 1, True) + prices[i])
        
        return dfs(n - 1, False)
class Solution {
public:
    int maxProfit(vector<int>& prices) {
        int n = prices.size(), cache[n][2];
        memset(cache, -1, sizeof cache);
        function<int(int, bool)> dfs = [&] (int i, bool hold)->int {
            if (i < 0) return hold ? INT_MIN : 0;
            int &res = cache[i][hold];
            if (res != -1) return res;
            if (hold) {
                res = max(dfs(i - 1, true), dfs(i - 1, false) - prices[i]);
                return res;
            }
            res = max(dfs(i - 1, false), dfs(i - 1, true) + prices[i]);
            return res;
        };
        return dfs(n - 1, false);
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( n ) O(n) O(n)

解法二:递推

class Solution:
    def maxProfit(self, prices: List[int]) -> int:
        n = len(prices)
        f = [[0] * 2 for _ in range(n + 1)]
        f[0][1] = -inf
        for i, p in enumerate(prices):
            f[i + 1][0] = max(f[i][0], f[i][1] + p)
            f[i + 1][1] = max(f[i][1], f[i][0] - p)
        return f[n][0]
class Solution {
public:
    int maxProfit(vector<int>& prices) {
        int n = prices.size(), f[n + 1][2];
        memset(f, 0, sizeof f);
        f[0][1] = INT_MIN;
        for (int i = 0; i < n; i ++ ) {
            f[i + 1][0] = max(f[i][0], f[i][1] + prices[i]);
            f[i + 1][1] = max(f[i][1], f[i][0] - prices[i]);
        }
        return f[n][0];
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( n ) O(n) O(n)

解法三:两个变量进行空间优化

class Solution:
    def maxProfit(self, prices: List[int]) -> int:
        n = len(prices)
        f0, f1 = 0, -inf
        for p in prices:
            new_f0 = max(f0, f1 + p)
            f1 = max(f1, f0 - p)
            f0 = new_f0
        return f0
class Solution {
public:
    int maxProfit(vector<int>& prices) {
        int n = prices.size();
        int f0 = 0, f1 = INT_MIN;
        for (int p: prices) {
            int new_f0 = max(f0, f1 + p);
            f1 = max(f1, f0 - p);
            f0 = new_f0;
        }
        return f0;
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( 1 ) O(1) O(1)

1.2 买卖股票的最佳时机含冷冻期

Leetcode 309

解法一:记忆化搜索

class Solution:
    def maxProfit(self, prices: List[int]) -> int:
        n = len(prices)

        @cache
        def dfs(i: int, hold: bool)->int:
            if i < 0:
                return -inf if hold else 0
            if hold:
                return max(dfs(i - 1, True), dfs(i - 2, False) - prices[i])
            return max(dfs(i - 1, False), dfs(i - 1, True) + prices[i])
        
        return dfs(n - 1, False)
class Solution {
public:
    int maxProfit(vector<int>& prices) {
        int n = prices.size(), cache[n][2];
        memset(cache, -1, sizeof cache);
        function<int(int, bool)> dfs = [&] (int i, bool hold)->int {
            if (i < 0) return hold ? INT_MIN : 0;
            int &res = cache[i][hold];
            if (res != -1) return res;
            if (hold) {
                res = max(dfs(i - 1, true), dfs(i - 2, false) - prices[i]);
                return res;
            }
            res = max(dfs(i - 1, false), dfs(i - 1, true) + prices[i]);
            return res;
        };
        return dfs(n - 1, false);
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( n ) O(n) O(n)

解法二:递推

class Solution:
    def maxProfit(self, prices: List[int]) -> int:
        n = len(prices)
        f = [[0] * 2 for _ in range(n + 2)]
        f[1][1] = -inf
        for i, p in enumerate(prices):
            f[i + 2][0] = max(f[i + 1][0], f[i + 1][1] + p)
            f[i + 2][1] = max(f[i + 1][1], f[i][0] - p)
        return f[-1][0]
class Solution {
public:
    int maxProfit(vector<int>& prices) {
        int n = prices.size(), f[n + 2][2];
        memset(f, 0, sizeof f);
        f[1][1] = INT_MIN;
        for (int i = 0; i < n; i ++ ) {
            f[i + 2][0] = max(f[i + 1][0], f[i + 1][1] + prices[i]);
            f[i + 2][1] = max(f[i + 1][1], f[i][0] - prices[i]);
        }
        return f[n + 1][0];
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( n ) O(n) O(n)

解法三:空间优化

class Solution:
    def maxProfit(self, prices: List[int]) -> int:
        pre0, f0, f1 = 0, 0, -inf
        for p in prices:
            pre0, f0, f1 = f0, max(f0, f1 + p), max(f1, pre0 - p)
        return f0
class Solution {
public:
    int maxProfit(vector<int>& prices) {
        int pre0 = 0, f0 = 0, f1 = INT_MIN;
        for (int p: prices) {
            int new_f0 = max(f0, f1 + p);
            f1 = max(f1, pre0 - p);
            pre0 = f0;
            f0 = new_f0;
        }
        return f0;
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( 1 ) O(1) O(1)

1.3 买卖股票的最佳时机 IV

Leetcode 188

解法一:记忆化搜索

class Solution:
    def maxProfit(self, k: int, prices: List[int]) -> int:
        n = len(prices)

        @cache
        def dfs(i: int, j: int, hold: bool)->int:
            if j < 0: return -inf
            if i < 0: return -inf if hold else 0
            if hold: return max(dfs(i - 1, j, True), dfs(i - 1, j, False) - prices[i])
            return max(dfs(i - 1, j, False), dfs(i - 1, j - 1, True) + prices[i])
        
        return dfs(n - 1, k, False)
class Solution {
public:
    int maxProfit(int k, vector<int>& prices) {
        int n = prices.size(), cache[n][k + 1][2];
        memset(cache, -1, sizeof cache);
        function<int(int, int, bool)> dfs = [&] (int i, int j, bool hold) -> int{
            if (j < 0) return INT_MIN;
            if (i < 0) return hold ? INT_MIN / 2: 0;
            int &res = cache[i][j][hold];
            if (res != -1) return res;
            if (hold) {
                res = max(dfs(i - 1, j, true), dfs(i - 1, j, false) - prices[i]);
                return res;
            }
            res = max(dfs(i - 1, j, false), dfs(i - 1, j - 1, true) + prices[i]);
            return res;
        };
        return dfs(n - 1, k, false);
    }
};
  • 时间复杂度: O ( n k ) O(nk) O(nk),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( n k ) O(nk) O(nk)

解法二:递推

class Solution:
    def maxProfit(self, k: int, prices: List[int]) -> int:
        n = len(prices)
        f = [[[-inf] * 2 for _ in range(k + 2)] for _ in range(n + 1)]
        for j in range(1, k + 2):
            f[0][j][0] = 0
        for i, p in enumerate(prices):
            for j in range(1, k + 2):
                f[i + 1][j][0] = max(f[i][j][0], f[i][j][1] + p)
                f[i + 1][j][1] = max(f[i][j][1], f[i][j - 1][0] - p)
        return f[-1][-1][0]
class Solution {
public:
    int maxProfit(int k, vector<int>& prices) {
        int n = prices.size(), f[n + 1][k + 2][2];
        memset(f, -0x3f, sizeof f);
        for (int j = 1; j <= k + 1; j ++ ) f[0][j][0] = 0;
        for (int i = 0; i < n; i ++ )
            for (int j = 1; j <= k + 1; j ++ ) {
                f[i + 1][j][0] = max(f[i][j][0], f[i][j][1] + prices[i]);
                f[i + 1][j][1] = max(f[i][j][1], f[i][j - 1][0] - prices[i]);
            }
        return f[n][k + 1][0];
    }
};
  • 时间复杂度: O ( n k ) O(nk) O(nk),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( n k ) O(nk) O(nk)

解法三:空间复杂度优化

class Solution:
    def maxProfit(self, k: int, prices: List[int]) -> int:
        n = len(prices)
        f = [[-inf] * 2 for _ in range(k + 2)]
        for j in range(1, k + 2):
            f[j][0] = 0
        for p in prices:
            for j in range(k + 1, 0, -1):
                f[j][1] = max(f[j][1], f[j - 1][0] - p)
                f[j][0] = max(f[j][0], f[j][1] + p)
        return f[-1][0]
class Solution {
public:
    int maxProfit(int k, vector<int>& prices) {
        int n = prices.size(), f[k + 2][2];
        memset(f, -0x3f, sizeof f);
        for (int j = 1; j <= k + 1; j ++ ) f[j][0] = 0;
        for (int p: prices)
            for (int j = 1; j <= k + 1; j ++ ) {
                f[j][0] = max(f[j][0], f[j][1] + p);
                f[j][1] = max(f[j][1], f[j - 1][0] - p);
            }
        return f[k + 1][0];
    }
};
  • 时间复杂度: O ( n k ) O(nk) O(nk),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( k ) O(k) O(k)

1.4 买卖股票的最佳时机含手续费

Leetcode 714

解法一:记忆化搜索

class Solution:
    def maxProfit(self, prices: List[int], fee: int) -> int:
        n = len(prices)

        @cache
        def dfs(i: int, hold: bool) -> int:
            if i < 0: return -inf if hold else 0
            if hold: 
                return max(dfs(i - 1, True), dfs(i - 1, False) - prices[i])
            return max(dfs(i - 1, False), dfs(i - 1, True) + prices[i] - fee)
        
        return dfs(n - 1, False)
class Solution {
public:
    int maxProfit(vector<int>& prices, int fee) {
        int n = prices.size(), cache[n][2];
        memset(cache, -1, sizeof cache);
        function<int(int, bool)> dfs = [&] (int i, bool hold) -> int {
            if (i < 0) return hold ? INT_MIN / 2 : 0;
            int &res = cache[i][hold];
            if (res != -1) return res;
            if (hold) {
                res = max(dfs(i - 1, true), dfs(i - 1, false) - prices[i]);
                return res;
            }
            res = max(dfs(i - 1, false), dfs(i - 1, true) + prices[i] - fee);
            return res;
        };
        return dfs(n - 1, false);
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( n ) O(n) O(n)

解法二:递推

class Solution:
    def maxProfit(self, prices: List[int], fee: int) -> int:
        n = len(prices)
        f = [[0] * 2 for _ in range(n + 1)]
        f[0][1] = -inf
        for i, p in enumerate(prices):
            f[i + 1][0] = max(f[i][0], f[i][1] + p - fee)
            f[i + 1][1] = max(f[i][1], f[i][0] - p)
        return f[n][0]
class Solution {
public:
    int maxProfit(vector<int>& prices, int fee) {
        int n = prices.size(), f[n + 1][2];
        memset(f, 0, sizeof f);
        f[0][1] = INT_MIN / 2;
        for (int i = 0; i < n; i ++ ) {
            f[i + 1][0] = max(f[i][0], f[i][1] + prices[i] - fee);
            f[i + 1][1] = max(f[i][1], f[i][0] - prices[i]);
        }
        return f[n][0];
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( n ) O(n) O(n)

解法三:空间优化

class Solution:
    def maxProfit(self, prices: List[int], fee: int) -> int:
        f0, f1 = 0, -inf
        for p in prices:
            f0, f1 = max(f0, f1 + p - fee), max(f1, f0 - p)
        return f0
class Solution {
public:
    int maxProfit(vector<int>& prices, int fee) {
        int f0 = 0, f1 = INT_MIN / 2;
        for (int p: prices) {
            int new_f0 = max(f0, f1 + p - fee);
            f1 = max(f1, f0 - p);
            f0 = new_f0;
        }
        return f0;
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n p r i c e s prices prices 的长度
  • 空间复杂度: O ( 1 ) O(1) O(1)

二、区间DP

  • 线性DP:在前面学习的DP中,都是在数组的前缀或者后缀上进行转移的,这类DP通常叫做线性DP。
  • 区间DP:将问题规模缩小到数组中间的区间上,而不仅仅是前缀或者后缀。

2.1 最长回文子序列

Leetcode 516

解法一:转换为LCS问题

因为回文子序列从左往右或者从右往左都是一样的,因此可以求解序列 s s s 与其反转后的序列 S ‘ S` S 的最长公共子序列(LCS),即为最长公共子序列。

解法二:直接计算,选或者不选

记忆化搜索

class Solution:
    def longestPalindromeSubseq(self, s: str) -> int:
        @cache
        def dfs(i: int, j: int)->int:
            if i > j: return 0
            if i == j: return 1
            if s[i] == s[j]:
                return dfs(i + 1, j - 1) + 2
            return max(dfs(i + 1, j), dfs(i, j - 1))
        return dfs(0, len(s) - 1)
class Solution {
public:
    int longestPalindromeSubseq(string s) {
        int n = s.length(), cahce[n][n];
        memset(cahce, -1, sizeof cahce);
        function<int(int, int)> dfs = [&](int i, int j)->int {
            if (i > j) return 0;
            if (i == j) return 1;
            int &res = cahce[i][j];
            if (res != -1) return res;
            if (s[i] == s[j]) {
                res = dfs(i + 1, j - 1) + 2;
                return res;
            } 
            res = max(dfs(i + 1, j), dfs(i, j - 1));
            return res;
        };
        return dfs(0, n - 1);
    }
};
  • 时间复杂度: O ( n 2 ) O(n^2) O(n2),其中 n n n s s s 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的转移个数。本题中状态个数等于 O ( n 2 ) O(n^2) O(n2),而单个状态的转移个数为 O ( 1 ) O(1) O(1),因此时间复杂度为 O ( n 2 ) O(n^2) O(n2)
  • 空间复杂度: O ( n 2 ) O(n^2) O(n2)

递推写法

class Solution:
    def longestPalindromeSubseq(self, s: str) -> int:
        n = len(s)
        f = [[0] * n for _ in range(n)]
        for i in range(n - 1, -1, -1):
            f[i][i] = 1
            for j in range(i + 1, n):
                if s[i] == s[j]:
                    f[i][j] = f[i + 1][j - 1] + 2
                else:
                    f[i][j] = max(f[i + 1][j], f[i][j - 1])
        return f[0][-1]
class Solution {
public:
    int longestPalindromeSubseq(string s) {
        int n = s.length(), f[n][n];
        memset(f, 0, sizeof f);
        for (int i = n - 1; i >= 0; i -- ) {
            f[i][i] = 1;
            for (int j = i + 1; j < n; j ++ )
                f[i][j] = s[i] == s[j] ? f[i + 1][j - 1] + 2 : max(f[i + 1][j], f[i][j - 1]);
        }
        return f[0][n - 1];
    }
};
  • 时间复杂度: O ( n 2 ) O(n^2) O(n2),其中 n n n s s s 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的转移个数。本题中状态个数等于 O ( n 2 ) O(n^2) O(n2),而单个状态的转移个数为 O ( 1 ) O(1) O(1),因此时间复杂度为 O ( n 2 ) O(n^2) O(n2)
  • 空间复杂度: O ( n 2 ) O(n^2) O(n2)

空间优化

class Solution:
    def longestPalindromeSubseq(self, s: str) -> int:
        n = len(s)
        f = [0] * n
        for i in range(n - 1, -1, -1):
            f[i] = 1
            pre = 0
            for j in range(i + 1, n):
                tmp = f[j]
                f[j] = pre + 2 if s[i] == s[j] else max(f[j], f[j - 1])
                pre = tmp
        return f[-1]
class Solution {
public:
    int longestPalindromeSubseq(string s) {
        int n = s.length(), f[n];
        memset(f, 0, sizeof f);
        for (int i = n - 1; i >= 0; i -- ) {
            f[i] = 1;
            int pre = 0;
            for (int j = i + 1; j < n; j ++ ) {
                int tmp = f[j];
                f[j] = s[i] == s[j] ? pre + 2 : max(f[j], f[j - 1]);
                pre = tmp;
            }
        }
        return f[n - 1];
    }
};
  • 时间复杂度: O ( n 2 ) O(n^2) O(n2),其中 n n n s s s 的长度
  • 空间复杂度: O ( n ) O(n) O(n)

2.2 多边形三角剖分的最低得分

Leetcode 1039

解法一:记忆化搜索

class Solution:
    def minScoreTriangulation(self, values: List[int]) -> int:
        @cache
        def dfs(i: int, j: int) -> int:
            if i + 1 == j: return 0
            return min(dfs(i, k) + dfs(k, j) + values[i] * values[j] * values[k]
                        for k in range(i + 1, j))
        return dfs(0, len(values) - 1)
class Solution {
public:
    int minScoreTriangulation(vector<int>& values) {
        int n = values.size(), cache[n][n];
        memset(cache, -1, sizeof cache);
        function<int(int, int)> dfs = [&](int i, int j) -> int {
            if (i + 1 == j) return 0;
            int &res = cache[i][j];
            if (res != -1) return res;
            res = INT_MAX;
            for (int k = i + 1; k < j; k ++ )
                res = min(res, dfs(i, k) + dfs(k, j) + values[i] * values[j] * values[k]);
            return res;
        };
        return dfs(0, n - 1);
    }
};

  • 时间复杂度: O ( n 3 ) O(n^3) O(n3),其中 n n n v a l u e s values values 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的计算时间。本题中状态个数等于 O ( n 2 ) O(n^2) O(n2),单个状态的计算时间为 O ( n ) O(n) O(n),因此时间复杂度为 O ( n 3 ) O(n^3) O(n3)
  • 空间复杂度: O ( n 2 ) O(n^2) O(n2)。有 O ( n 2 ) O(n^2) O(n2) 个状态。

解法二:递推

class Solution:
    def minScoreTriangulation(self, values: List[int]) -> int:
        n = len(values)
        f = [[0] * n for _ in range(n)]
        for i in range(n - 3, -1, -1):
            for j in range(i + 2, n):
                f[i][j] = min(f[i][k] + f[k][j] + values[i] * values[j] * values[k] 
                                for k in range(i + 1, j))
        return f[0][-1]
class Solution {
public:
    int minScoreTriangulation(vector<int>& values) {
        int n = values.size(), f[n][n];
        memset(f, 0, sizeof f);
        for (int i = n - 3; i >= 0; i -- )
            for (int j = i + 2; j < n; j ++ ) {
                f[i][j] = INT_MAX;
                for (int k = i + 1; k < j; k ++ ) 
                    f[i][j] = min(f[i][j], f[i][k] + f[k][j] + values[i] * values[j] * values[k]);
            }
        return f[0][n - 1];
    }
};
  • 时间复杂度: O ( n 3 ) O(n^3) O(n3),其中 n n n v a l u e s values values 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的计算时间。本题中状态个数等于 O ( n 2 ) O(n^2) O(n2),单个状态的计算时间为 O ( n ) O(n) O(n),因此时间复杂度为 O ( n 3 ) O(n^3) O(n3)
  • 空间复杂度: O ( n 2 ) O(n^2) O(n2)。有 O ( n 2 ) O(n^2) O(n2) 个状态。

2.3 由子序列构造的最长回文串的长度

Leetcode 1771

class Solution:
    def longestPalindrome(self, word1: str, word2: str) -> int:
        s = word1 + word2
        ans, n = 0, len(s)
        f = [[0] * n for _ in range(n)]
        for i in range(n - 1, -1, -1):
            f[i][i] = 1
            for j in range(i + 1, n):
                if s[i] == s[j]:
                    f[i][j] = f[i + 1][j - 1] + 2
                    if i < len(word1) <= j:
                        ans = max(ans, f[i][j])
                else:
                    f[i][j] = max(f[i + 1][j], f[i][j - 1])
        return ans
class Solution {
public:
    int longestPalindrome(string word1, string word2) {
        string s = word1 + word2;
        int ans = 0, n = s.length(), f[n][n];
        memset(f, 0, sizeof f);
        for (int i = n - 1; i >= 0; i -- ) {
            f[i][i] = 1;
            for (int j = i + 1; j < n; j ++ ) 
                if (s[i] == s[j]) {
                    f[i][j] = f[i + 1][j - 1] + 2;
                    if (i < word1.length() && j >= word1.length())
                        ans = max(ans, f[i][j]);
                } 
                else 
                    f[i][j] = max(f[i + 1][j], f[i][j - 1]); 
        }
        return ans;
    }
};

2.4 合并石头的最低成本

Leetcode 1000

class Solution:
    def mergeStones(self, stones: List[int], k: int) -> int:
        n = len(stones)
        if (n - 1) % (k - 1): return -1
        s = list(accumulate(stones, initial=0))

        @cache
        def dfs(i: int, j: int, p: int)->int:
            if p == 1:
                return 0 if i == j else dfs(i, j, k) + s[j + 1] - s[i]
            return min(dfs(i, m, 1) + dfs(m + 1, j, p - 1) for m in range(i, j, k - 1))
        
        return dfs(0, n - 1, 1)
class Solution {
public:
    int mergeStones(vector<int>& stones, int k) {
        int n = stones.size();
        if ((n - 1) % (k - 1)) return -1;
        int s[n + 1];
        s[0] = 0;
        for (int i = 0; i < n; i ++ )
            s[i + 1] = s[i] + stones[i];

        int cache[n][n][k + 1];
        memset(cache, -1, sizeof cache);
        function<int(int, int, int)> dfs = [&](int i, int j, int p) -> int {
            int &res = cache[i][j][p];
            if (res != -1) return res;
            if (p == 1)
                return res = i == j ? 0 : dfs(i, j, k) + s[j + 1] - s[i];
            res = INT_MAX;
            for (int m = i; m < j; m += k - 1)
                res = min(res, dfs(i, m, 1) + dfs(m + 1, j, p - 1));
            return res;
        };
        return dfs(0, n - 1, 1);
    }
};
  • 时间复杂度: O ( n 3 ) O(n^3) O(n3),其中 n n n s t o n e s stones stones 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的计算时间。这里状态个数为 O ( n 2 k ) O(n^2k) O(n2k),单个状态的计算时间为 O ( n k ) O(\frac{n}{k}) O(kn),因此时间复杂度为 O ( n 3 ) O(n^3) O(n3)
  • 空间复杂度: O ( n 2 k ) O(n^2k) O(n2k)

优化

待写

递推

待写

三、树形DP——直径系列

3.1 二叉树的直径

Leetcode 543

class Solution:
    def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
        ans = 0
        
        def dfs(node: Optional[TreeNode])->int:
            if node is None: return -1
            l_len = dfs(node.left)
            r_len = dfs(node.right)
            nonlocal ans
            ans = max(ans, l_len + r_len + 2)
            return max(l_len, r_len) + 1

        dfs(root)
        return ans
class Solution {
public:
    int diameterOfBinaryTree(TreeNode *root) {
        int ans = 0;
        function<int(TreeNode*)> dfs = [&](TreeNode *node) -> int {
            if (node == nullptr)
                return -1; // 下面 +1 后,对于叶子节点就刚好是 0
            int l_len = dfs(node->left); // 左子树最大链长+1
            int r_len = dfs(node->right); // 右子树最大链长+1
            ans = max(ans, l_len + r_len + 2); // 两条链拼成路径
            return max(l_len, r_len) + 1; // 当前子树最大链长
        };
        dfs(root);
        return ans;
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n 为二叉树的节点个数。
  • 空间复杂度: O ( n ) O(n) O(n)。最坏情况下,二叉树退化成一条链,递归需要 O ( n ) O(n) O(n) 的栈空间。

3.2 二叉树中的最大路径和

Leetcode 124

class Solution:
    def maxPathSum(self, root: Optional[TreeNode]) -> int:
        ans = -inf

        def dfs(node: Optional[TreeNode])->int:
            if node is None: return 0
            l_val = dfs(node.left)
            r_val = dfs(node.right)
            nonlocal ans
            ans = max(ans, l_val + r_val + node.val)
            return max(max(l_val, r_val) + node.val, 0)

        dfs(root)
        return ans
class Solution {
public:
    int maxPathSum(TreeNode* root) {
        int ans = INT_MIN;
        function<int(TreeNode*)> dfs = [&](TreeNode* node)->int {
            if (node == nullptr) return 0;
            int l_val = dfs(node->left);
            int r_val = dfs(node->right);
            ans = max(ans, l_val + r_val + node->val);
            return max(max(l_val, r_val) + node->val, 0);
        };
        dfs(root);
        return ans;
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n 为二叉树的节点个数。
  • 空间复杂度: O ( n ) O(n) O(n)。最坏情况下,二叉树退化成一条链,递归需要 O ( n ) O(n) O(n) 的栈空间。

3.3 相邻字符不同的最长路径(普通树的直径)

Leetcode 2246

class Solution:
    def longestPath(self, parent: List[int], s: str) -> int:
        n = len(parent)
        g = [[] for _ in range(n)]
        for i in range(1, n):
            g[parent[i]].append(i)

        ans = 0
        def dfs(x: int) -> int:
            nonlocal ans
            max_len = 0
            for y in g[x]:
                son_len = dfs(y) + 1
                if s[y] != s[x]:
                    ans = max(ans, max_len + son_len)
                    max_len = max(max_len, son_len)
            return max_len
        
        dfs(0)
        return ans + 1
class Solution {
public:
    int longestPath(vector<int>& parent, string s) {
        int n = parent.size();
        vector<vector<int>> g(n);
        for (int i = 1; i < n; i ++ )
            g[parent[i]].push_back(i);
        int ans = 0;
        function<int(int)> dfs = [&] (int x) -> int {
            int maxLen = 0;
            for (int y: g[x]) {
                int son_len = dfs(y) + 1;
                if (s[y] != s[x]) {
                    ans = max(ans, maxLen + son_len);
                    maxLen = max(maxLen, son_len);
                }
            }
            return maxLen;
        };
        dfs(0);
        return ans + 1;
    }
};
  • 时间复杂度: O ( n ) O(n) O(n)
  • 空间复杂度: O ( n ) O(n) O(n)

3.4 最长同值路径

Leetcode 687

class Solution:
    def longestUnivaluePath(self, root: Optional[TreeNode]) -> int:
        ans = 0
        
        def dfs(node: Optional[TreeNode])->int:
            if node is None: return -1
            l_len = dfs(node.left) + 1
            r_len = dfs(node.right) + 1
            if node.left and node.left.val != node.val: l_len = 0
            if node.right and node.right.val != node.val: r_len = 0
            nonlocal ans
            ans = max(ans, l_len + r_len)
            return max(l_len, r_len)
        
        dfs(root)
        return ans
class Solution {
public:
    int longestUnivaluePath(TreeNode* root) {
        int ans = 0;
        function<int(TreeNode*)> dfs = [&] (TreeNode *node) -> int {
            if (node == nullptr) return -1;
            int l_len = dfs(node->left) + 1;
            int r_len = dfs(node->right) + 1;
            if (node->left && node->left->val != node->val) l_len = 0;
            if (node->right && node->right->val != node->val) r_len = 0;
            ans = max(ans, l_len + r_len);
            return max(l_len, r_len);
        };
        dfs(root);
        return ans;
    }
};
  • 时间复杂度: O ( n ) O(n) O(n)
  • 空间复杂度: O ( n ) O(n) O(n)

3.5 统计子树中城市之间最大距离

Leetcode 1617

class Solution:
    def countSubgraphsForEachDiameter(self, n: int, edges: List[List[int]]) -> List[int]:
        # 建树
        g = [[] for _ in range(n)]
        for x, y in edges:
            g[x - 1].append(y - 1)
            g[y - 1].append(x - 1)

        ans = [0] * (n - 1)

        # 标记是否选择该节点
        in_set = [False] * n  
        
        def f(i: int) -> None:
            if i == n:
                vis = [False] * n
                diameter = 0  # 直径
                for v, b in enumerate(in_set):
                    if not b: continue

                    # 求树的直径
                    def dfs(x: int) -> int:
                        nonlocal diameter
                        vis[x] = True
                        max_len = 0
                        for y in g[x]:
                            if not vis[y] and in_set[y]:
                                ml = dfs(y) + 1
                                diameter = max(diameter, max_len + ml)
                                max_len = max(max_len, ml)
                        return max_len
                    
                    dfs(v)
                    break
                if diameter and vis == in_set:
                    ans[diameter - 1] += 1
                return
            
            # 不选择节点i
            f(i + 1)

            # 选择节点i
            in_set[i] = True
            f(i + 1)
            in_set[i] = False
    
        f(0)

        return ans
class Solution {
public:
    vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>>& edges) {
        vector<vector<int>> g(n);
        for (auto &e: edges) {
            int x=  e[0] - 1, y = e[1] - 1;
            g[x].push_back(y);
            g[y].push_back(x);
        }

        vector<int> ans(n - 1), in_set(n), vis(n);
        int diameter = 0;

        function<int(int)> dfs = [&](int x) -> int {
            vis[x] = true;
            int max_len = 0;
            for (int y: g[x]) 
                if (!vis[y] && in_set[y]) {
                    int ml = dfs(y) + 1;
                    diameter = max(diameter, max_len + ml);
                    max_len = max(max_len, ml);
                }
            return max_len;
        };

        function<void(int)> f = [&](int i) {
            if (i == n) {
                for (int v = 0; v < n; v ++ )
                    if (in_set[v]) {
                        fill(vis.begin(), vis.end(), 0);
                        diameter = 0;
                        dfs(v);
                        break;
                    }
                if (diameter && vis == in_set) ++ ans[diameter - 1];
                return;
            }

            f(i + 1);

            in_set[i] = true;
            f(i + 1);
            in_set[i] = false;
        };

        f(0);
        return ans;
    }
};
  • 时间复杂度: O ( n 2 n ) O(n2^n) O(n2n) O ( 2 n ) O(2^n) O(2n) 枚举子集, O ( n ) O(n) O(n) 求直径,所以时间复杂度为 O ( n 2 n ) O(n2^n) O(n2n)
  • 空间复杂度: O ( n ) O(n) O(n)

二进制优化

class Solution:
    def countSubgraphsForEachDiameter(self, n: int, edges: List[List[int]]) -> List[int]:
        g = [[] for _ in range(n)]
        for x, y in edges:
            g[x - 1].append(y - 1)
            g[y - 1].append(x - 1)

        ans = [0] * (n - 1)
        for mask in range(3, 1 << n):
            if (mask & (mask - 1)) == 0:
                continue;
            vis = diameter = 0
            def dfs(x: int) -> int:
                nonlocal vis, diameter
                vis |= 1 << x
                max_len = 0
                for y in g[x]:
                    if (vis >> y & 1) == 0 and mask >> y & 1:
                        ml = dfs(y) + 1
                        diameter = max(diameter, max_len + ml)
                        max_len = max(max_len, ml)
                return max_len
            dfs(mask.bit_length() - 1)
            if vis == mask:
                ans[diameter - 1] += 1
        return ans
class Solution {
public:
    vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>> &edges) {
        vector<vector<int>> g(n);
        for (auto &e : edges) {
            int x = e[0] - 1, y = e[1] - 1; // 编号改为从 0 开始
            g[x].push_back(y);
            g[y].push_back(x); // 建树
        }

        int dis[n][n]; memset(dis, 0, sizeof(dis));
        function<void(int, int, int)> dfs = [&](int i, int x, int fa) {
            for (int y : g[x])
                if (y != fa) {
                    dis[i][y] = dis[i][x] + 1; // 自顶向下
                    dfs(i, y, x);
                }
        };
        for (int i = 0; i < n; ++i)
            dfs(i, i, -1); // 计算 i 到其余点的距离

        function<int(int, int, int, int, int)> dfs2 = [&](int i, int j, int d, int x, int fa) {
            // 能递归到这,说明 x 可以选
            int cnt = 1; // 选 x
            for (int y : g[x])
                if (y != fa &&
                   (dis[i][y] < d || dis[i][y] == d && y > j) &&
                   (dis[j][y] < d || dis[j][y] == d && y > i)) // 满足这些条件就可以选
                    cnt *= dfs2(i, j, d, y, x); // 每棵子树互相独立,采用乘法原理
            if (dis[i][x] + dis[j][x] > d)  // x 是可选点
                ++cnt; // 不选 x
            return cnt;
        };
        vector<int> ans(n - 1);
        for (int i = 0; i < n; ++i)
            for (int j = i + 1; j < n; ++j)
                ans[dis[i][j] - 1] += dfs2(i, j, dis[i][j], i, -1);
        return ans;
    }
};
  • 时间复杂度: O ( n 2 n ) O(n2^n) O(n2n) O ( 2 n ) O(2^n) O(2n) 枚举子集, O ( n ) O(n) O(n) 求直径,所以时间复杂度为 O ( n 2 n ) O(n2^n) O(n2n)
  • 空间复杂度: O ( n ) O(n) O(n)

3.6 最大价值和与最小价值和的差值

Leetcode 2538

class Solution:
    def maxOutput(self, n: int, edges: List[List[int]], price: List[int]) -> int:
        g = [[] for _ in range(n)]
        for x, y in edges:
            g[x].append(y)
            g[y].append(x)
        
        ans = 0
        def dfs(x: int, fa: int) -> (int, int):
            nonlocal ans
            max_s1 = p = price[x]
            max_s2 = 0
            for y in g[x]:
                if y == fa: continue
                s1, s2 = dfs(y, x)

                ans = max(ans, max_s1 + s2, max_s2 + s1)
                max_s1 = max(max_s1, s1 + p)
                max_s2 = max(max_s2, s2 + p)

            return max_s1, max_s2
        
        dfs(0, -1)

        return ans
class Solution {
public:
    long long maxOutput(int n, vector<vector<int>>& edges, vector<int>& price) {
        vector<vector<int>> g(n);
        for (auto &e: edges) {
            int x = e[0], y = e[1];
            g[x].push_back(y);
            g[y].push_back(x);
        }

        long ans = 0;
        function<pair<long, long>(int, int)> dfs = [&] (int x, int fa) -> pair<long, long> {
            long p = price[x], max_s1 = p, max_s2 = 0;
            for (int y: g[x]) 
                if (y != fa) {
                    auto [s1, s2] = dfs(y, x);
                    ans = max(ans, max(max_s1 + s2, max_s2 + s1));
                    max_s1 = max(max_s1, s1 + p);
                    max_s2 = max(max_s2, s2 + p);
                }
            return {max_s1, max_s2};
        };
        
        dfs(0, -1);
        return ans;
    }
};
  • 时间复杂度: O ( n ) O(n) O(n)
  • 空间复杂度: O ( n ) O(n) O(n)

四、树形DP——最大独立集

在图论中,独立集(Independent Set)是一种顶点的集合,其中任意两个顶点都不相邻,也就是说,集合中的顶点之间没有边相连。换句话说,独立集是一组顶点,其中没有两个顶点通过一条边相连。

4.1 打家劫舍 III(二叉树上)

Leetcode 337

class Solution:
    def rob(self, root: Optional[TreeNode]) -> int:
        def dfs(node: Optional[TreeNode])->(int, int):
            if node is None: return 0, 0
            l_rob, l_not_rob = dfs(node.left)
            r_rob, r_not_rob = dfs(node.right)
            rob = l_not_rob + r_not_rob + node.val  # 选
            not_rob = max(l_rob, l_not_rob) + max(r_rob, r_not_rob)  # 不选
            return rob, not_rob
        return max(dfs(root))
class Solution {
    pair<int, int> dfs(TreeNode* node) {
        if (node == nullptr) return {0, 0};
        auto [l_rob, l_not_rob] = dfs(node->left);
        auto [r_rob, r_not_rob] = dfs(node->right);
        int rob = l_not_rob + r_not_rob + node->val;
        int not_rob = max(l_rob, l_not_rob) + max(r_rob, r_not_rob);
        return {rob, not_rob};
    }

public:
    int rob(TreeNode* root) {
        auto [root_rob, root_not_rob] = dfs(root);
        return max(root_rob, root_not_rob);
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n 为二叉树的节点个数。每个节点都会递归恰好一次。
  • 空间复杂度: O ( n ) O(n) O(n)。最坏情况下,二叉树是一条链,递归需要 O ( n ) O(n) O(n) 的栈空间。

4.2 没有上司的舞会(普通树上)

AcWing 285

#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;

const int N = 6010;

int n;
int happy[N];
int h[N], e[N], ne[N], idx;
int f[N][2];
bool has_father[N];

void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}

void dfs(int u) {
    f[u][1] = happy[u];
    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        dfs(j);
        
        f[u][0] += max(f[j][0], f[j][1]);
        f[u][1] += f[j][0];
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i ++ ) scanf("%d", happy + i);
    memset(h, -1, sizeof h);
    for (int i = 0; i < n - 1; i ++ ) {
        int a, b;
        scanf("%d%d", &a, &b);
        has_father[a] = true;
        add(b, a);
    }
    
    int root = 1;
    while (has_father[root]) root ++ ;
    
    dfs(root);
    
    printf("%d\n", max(f[root][0], f[root][1]));
    
    return 0;
}

4.3 T秒之后青蛙的位置

Leetcode 1377

要注意几个问题:

  • 如何处理浮点数计算精度问题?先计算分母乘机,最后将结果取倒即可。
  • 为什么写【自底向上】而不是【自顶向下】?自底向上是先找到问题后在返回答案,自顶向下是不断往下找到答案后返回结果,因此自顶向下会造成一些不必要的计算。
  • 如何解决特判 n=1 的情况?在节点 n=1 中添加一个父节点 n=0
  • 对于时间 T,如何在 DFS 中减少一个变量的引入?让 leftTt 减少到 0 而不是从 0 增加到 t,这样的代码只需要和 0 比较,而不需要和 t 比较。
class Solution:
    def frogPosition(self, n: int, edges: List[List[int]], t: int, target: int) -> float:
        g = [[] for _ in range(n + 1)]
        g[1] = [0]
        for x, y in edges:
            g[x].append(y)
            g[y].append(x)
        
        ans = 0
        def dfs(x: int, fa: int, left_t: int, prod: int) -> True:
            """
            x: 当前遍历的节点
            fa: 当前节点的父节点
            left_t: 剩余时间
            prod: 概率
            return: 用于判断dfs是否结束
            """

            # t 秒后必须在 target(恰好到达,或者 target 是叶子停在原地)
            if x == target and (left_t == 0 or len(g[x]) == 1):
                nonlocal ans
                ans = 1 / prod
                return True

            if x == target or left_t == 0: return False

            for y in g[x]:
                if y != fa and dfs(y, x, left_t - 1, prod * (len(g[x]) - 1)):
                    return True

            return False

        dfs(1, 0, t, 1)

        return ans
class Solution {
public:
    double frogPosition(int n, vector<vector<int>>& edges, int t, int target) {
        vector<vector<int>> g(n + 1);
        g[1] = {0};
        for (auto &e: edges) {
            int x = e[0], y = e[1];
            g[x].push_back(y);
            g[y].push_back(x);
        }

        double ans = 0;
        function<bool(int, int, int, long long)> dfs = [&] (int x, int fa, int left_t, long long prob) -> bool {
            if (x == target && (left_t == 0 || g[x].size() == 1)) {
                ans = 1.0 / prob;
                return true;
            }
            if (x == target || left_t == 0) return false;
            for (int y: g[x])
                if (y != fa && dfs(y, x, left_t - 1, prob * (g[x].size() - 1)))
                    return true;
            return false;
        };

        dfs(1, 0, t, 1);
        return ans;
    }
};

4.4 最小化旅行的价格总和

Leetcode 2646

class Solution:
    def minimumTotalPrice(self, n: int, edges: List[List[int]], price: List[int], trips: List[List[int]]) -> int:
        g = [[] for _ in range(n)]
        for x, y in edges:
            g[x].append(y)
            g[y].append(x)
        
        cnt = [0] * n
        for start, end in trips:
            def dfs(x: int, fa: int) -> bool:
                if x == end:
                    cnt[x] += 1
                    return True
                for y in g[x]:
                    if y != fa and dfs(y, x):
                        cnt[x] += 1
                        return True
                return False
            dfs(start, -1)
        
        def dfs(x: int, fa: int) -> (int, int):
            not_have = price[x] * cnt[x]
            halve = not_have // 2
            for y in g[x]:
                if y != fa:
                    nh, h = dfs(y ,x)
                    not_have += min(nh, h)
                    halve += nh
            return not_have, halve
        
        return min(dfs(0, -1))
class Solution {
public:
    int minimumTotalPrice(int n, vector<vector<int>>& edges, vector<int>& price, vector<vector<int>>& trips) {
        vector<vector<int>> g(n);
        for (auto &e: edges) {
            int x = e[0], y = e[1];
            g[x].push_back(y);
            g[y].push_back(x);
        }

        int cnt[n]; memset(cnt, 0, sizeof cnt);
        for (auto &t: trips) {
            int end = t[1];
            function<bool(int, int)> dfs = [&](int x, int fa) -> bool {
                if (x == end) {
                    ++ cnt[x];
                    return true;
                }
                for (int y: g[x]) 
                    if (y != fa && dfs(y, x)) {
                        ++ cnt[x];
                        return true;
                    }
                return false;
            };

            dfs(t[0], -1);
        }

        function<pair<int, int>(int, int)> dfs = [&](int x, int fa) -> pair<int ,int> {
            int not_halve = price[x] * cnt[x];
            int halve = not_halve / 2;
            for (int y: g[x])
                if (y != fa) {
                    auto [nh, h] = dfs(y, x);
                    not_halve += min(nh, h);
                    halve += nh;
                }
            return {not_halve, halve};
        };

        auto [nh, h] = dfs(0, -1);

        return min(nh, h);
    }
};
  • 时间复杂度: O ( n m ) O(nm) O(nm),其中 m m m t r i p s trips trips 的长度
  • 空间复杂度: O ( n ) O(n) O(n)

五、树形DP——最小支配集

5.1 监控二叉树

Leetcode 968

class Solution:
    def minCameraCover(self, root: Optional[TreeNode]) -> int:
        def dfs(node):
            if node is None:
                return inf, 0, 0
            l_choose, l_by_fa, l_by_children = dfs(node.left)
            r_choose, r_by_fa, r_by_children = dfs(node.right)
            choose = min(l_choose, l_by_fa) + min(r_choose, r_by_fa) + 1
            by_fa = min(l_choose, l_by_children) + min(r_choose, r_by_children)
            by_children = min(l_choose + r_by_children, l_by_children + r_choose, l_choose + r_choose)
            return choose, by_fa, by_children
        
        choose, _, by_children = dfs(root)

        return min(choose, by_children)
class Solution {
    tuple<int, int, int> dfs(TreeNode *node) {
        if (node == nullptr) return {INT_MAX / 2, 0, 0};
        auto [l_choose, l_by_fa, l_by_children] = dfs(node->left);
        auto [r_choose, r_by_fa, r_by_children] = dfs(node->right);
        int choose = min(l_choose, l_by_fa) + min(r_choose, r_by_fa) + 1;
        int by_fa = min(l_choose, l_by_children) + min(r_choose, r_by_children);
        int by_children = min({l_choose + r_by_children, l_by_children + r_choose, l_choose + r_choose});
        return {choose, by_fa, by_children};
    }

public:
    int minCameraCover(TreeNode* root) {
        auto [choose, _, by_children] = dfs(root);
        return min(choose, by_children);
    }
};
  • 时间复杂度: O ( n ) O(n) O(n),其中 n n n 为二叉树的节点个数。每个节点都会递归恰好一次。
  • 空间复杂度: O ( n ) O(n) O(n),最坏情况下,二叉树是一条链,递归需要 O ( n ) O(n) O(n) 的栈空间。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值