一、状态机DP
1.1 买卖股票的最佳时机 II(不限制交易次数)
解法一:记忆化搜索
class Solution:
def maxProfit(self, prices: List[int]) -> int:
n = len(prices)
@cache
def dfs(i: int, hold: bool)->int:
if i < 0:
return -inf if hold else 0
if hold:
return max(dfs(i - 1, True), dfs(i - 1, False) - prices[i])
return max(dfs(i - 1, False), dfs(i - 1, True) + prices[i])
return dfs(n - 1, False)
class Solution {
public:
int maxProfit(vector<int>& prices) {
int n = prices.size(), cache[n][2];
memset(cache, -1, sizeof cache);
function<int(int, bool)> dfs = [&] (int i, bool hold)->int {
if (i < 0) return hold ? INT_MIN : 0;
int &res = cache[i][hold];
if (res != -1) return res;
if (hold) {
res = max(dfs(i - 1, true), dfs(i - 1, false) - prices[i]);
return res;
}
res = max(dfs(i - 1, false), dfs(i - 1, true) + prices[i]);
return res;
};
return dfs(n - 1, false);
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( n ) O(n) O(n)
解法二:递推
class Solution:
def maxProfit(self, prices: List[int]) -> int:
n = len(prices)
f = [[0] * 2 for _ in range(n + 1)]
f[0][1] = -inf
for i, p in enumerate(prices):
f[i + 1][0] = max(f[i][0], f[i][1] + p)
f[i + 1][1] = max(f[i][1], f[i][0] - p)
return f[n][0]
class Solution {
public:
int maxProfit(vector<int>& prices) {
int n = prices.size(), f[n + 1][2];
memset(f, 0, sizeof f);
f[0][1] = INT_MIN;
for (int i = 0; i < n; i ++ ) {
f[i + 1][0] = max(f[i][0], f[i][1] + prices[i]);
f[i + 1][1] = max(f[i][1], f[i][0] - prices[i]);
}
return f[n][0];
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( n ) O(n) O(n)
解法三:两个变量进行空间优化
class Solution:
def maxProfit(self, prices: List[int]) -> int:
n = len(prices)
f0, f1 = 0, -inf
for p in prices:
new_f0 = max(f0, f1 + p)
f1 = max(f1, f0 - p)
f0 = new_f0
return f0
class Solution {
public:
int maxProfit(vector<int>& prices) {
int n = prices.size();
int f0 = 0, f1 = INT_MIN;
for (int p: prices) {
int new_f0 = max(f0, f1 + p);
f1 = max(f1, f0 - p);
f0 = new_f0;
}
return f0;
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( 1 ) O(1) O(1)
1.2 买卖股票的最佳时机含冷冻期
解法一:记忆化搜索
class Solution:
def maxProfit(self, prices: List[int]) -> int:
n = len(prices)
@cache
def dfs(i: int, hold: bool)->int:
if i < 0:
return -inf if hold else 0
if hold:
return max(dfs(i - 1, True), dfs(i - 2, False) - prices[i])
return max(dfs(i - 1, False), dfs(i - 1, True) + prices[i])
return dfs(n - 1, False)
class Solution {
public:
int maxProfit(vector<int>& prices) {
int n = prices.size(), cache[n][2];
memset(cache, -1, sizeof cache);
function<int(int, bool)> dfs = [&] (int i, bool hold)->int {
if (i < 0) return hold ? INT_MIN : 0;
int &res = cache[i][hold];
if (res != -1) return res;
if (hold) {
res = max(dfs(i - 1, true), dfs(i - 2, false) - prices[i]);
return res;
}
res = max(dfs(i - 1, false), dfs(i - 1, true) + prices[i]);
return res;
};
return dfs(n - 1, false);
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( n ) O(n) O(n)
解法二:递推
class Solution:
def maxProfit(self, prices: List[int]) -> int:
n = len(prices)
f = [[0] * 2 for _ in range(n + 2)]
f[1][1] = -inf
for i, p in enumerate(prices):
f[i + 2][0] = max(f[i + 1][0], f[i + 1][1] + p)
f[i + 2][1] = max(f[i + 1][1], f[i][0] - p)
return f[-1][0]
class Solution {
public:
int maxProfit(vector<int>& prices) {
int n = prices.size(), f[n + 2][2];
memset(f, 0, sizeof f);
f[1][1] = INT_MIN;
for (int i = 0; i < n; i ++ ) {
f[i + 2][0] = max(f[i + 1][0], f[i + 1][1] + prices[i]);
f[i + 2][1] = max(f[i + 1][1], f[i][0] - prices[i]);
}
return f[n + 1][0];
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( n ) O(n) O(n)
解法三:空间优化
class Solution:
def maxProfit(self, prices: List[int]) -> int:
pre0, f0, f1 = 0, 0, -inf
for p in prices:
pre0, f0, f1 = f0, max(f0, f1 + p), max(f1, pre0 - p)
return f0
class Solution {
public:
int maxProfit(vector<int>& prices) {
int pre0 = 0, f0 = 0, f1 = INT_MIN;
for (int p: prices) {
int new_f0 = max(f0, f1 + p);
f1 = max(f1, pre0 - p);
pre0 = f0;
f0 = new_f0;
}
return f0;
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( 1 ) O(1) O(1)
1.3 买卖股票的最佳时机 IV
解法一:记忆化搜索
class Solution:
def maxProfit(self, k: int, prices: List[int]) -> int:
n = len(prices)
@cache
def dfs(i: int, j: int, hold: bool)->int:
if j < 0: return -inf
if i < 0: return -inf if hold else 0
if hold: return max(dfs(i - 1, j, True), dfs(i - 1, j, False) - prices[i])
return max(dfs(i - 1, j, False), dfs(i - 1, j - 1, True) + prices[i])
return dfs(n - 1, k, False)
class Solution {
public:
int maxProfit(int k, vector<int>& prices) {
int n = prices.size(), cache[n][k + 1][2];
memset(cache, -1, sizeof cache);
function<int(int, int, bool)> dfs = [&] (int i, int j, bool hold) -> int{
if (j < 0) return INT_MIN;
if (i < 0) return hold ? INT_MIN / 2: 0;
int &res = cache[i][j][hold];
if (res != -1) return res;
if (hold) {
res = max(dfs(i - 1, j, true), dfs(i - 1, j, false) - prices[i]);
return res;
}
res = max(dfs(i - 1, j, false), dfs(i - 1, j - 1, true) + prices[i]);
return res;
};
return dfs(n - 1, k, false);
}
};
- 时间复杂度: O ( n k ) O(nk) O(nk),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( n k ) O(nk) O(nk)
解法二:递推
class Solution:
def maxProfit(self, k: int, prices: List[int]) -> int:
n = len(prices)
f = [[[-inf] * 2 for _ in range(k + 2)] for _ in range(n + 1)]
for j in range(1, k + 2):
f[0][j][0] = 0
for i, p in enumerate(prices):
for j in range(1, k + 2):
f[i + 1][j][0] = max(f[i][j][0], f[i][j][1] + p)
f[i + 1][j][1] = max(f[i][j][1], f[i][j - 1][0] - p)
return f[-1][-1][0]
class Solution {
public:
int maxProfit(int k, vector<int>& prices) {
int n = prices.size(), f[n + 1][k + 2][2];
memset(f, -0x3f, sizeof f);
for (int j = 1; j <= k + 1; j ++ ) f[0][j][0] = 0;
for (int i = 0; i < n; i ++ )
for (int j = 1; j <= k + 1; j ++ ) {
f[i + 1][j][0] = max(f[i][j][0], f[i][j][1] + prices[i]);
f[i + 1][j][1] = max(f[i][j][1], f[i][j - 1][0] - prices[i]);
}
return f[n][k + 1][0];
}
};
- 时间复杂度: O ( n k ) O(nk) O(nk),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( n k ) O(nk) O(nk)
解法三:空间复杂度优化
class Solution:
def maxProfit(self, k: int, prices: List[int]) -> int:
n = len(prices)
f = [[-inf] * 2 for _ in range(k + 2)]
for j in range(1, k + 2):
f[j][0] = 0
for p in prices:
for j in range(k + 1, 0, -1):
f[j][1] = max(f[j][1], f[j - 1][0] - p)
f[j][0] = max(f[j][0], f[j][1] + p)
return f[-1][0]
class Solution {
public:
int maxProfit(int k, vector<int>& prices) {
int n = prices.size(), f[k + 2][2];
memset(f, -0x3f, sizeof f);
for (int j = 1; j <= k + 1; j ++ ) f[j][0] = 0;
for (int p: prices)
for (int j = 1; j <= k + 1; j ++ ) {
f[j][0] = max(f[j][0], f[j][1] + p);
f[j][1] = max(f[j][1], f[j - 1][0] - p);
}
return f[k + 1][0];
}
};
- 时间复杂度: O ( n k ) O(nk) O(nk),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( k ) O(k) O(k)
1.4 买卖股票的最佳时机含手续费
解法一:记忆化搜索
class Solution:
def maxProfit(self, prices: List[int], fee: int) -> int:
n = len(prices)
@cache
def dfs(i: int, hold: bool) -> int:
if i < 0: return -inf if hold else 0
if hold:
return max(dfs(i - 1, True), dfs(i - 1, False) - prices[i])
return max(dfs(i - 1, False), dfs(i - 1, True) + prices[i] - fee)
return dfs(n - 1, False)
class Solution {
public:
int maxProfit(vector<int>& prices, int fee) {
int n = prices.size(), cache[n][2];
memset(cache, -1, sizeof cache);
function<int(int, bool)> dfs = [&] (int i, bool hold) -> int {
if (i < 0) return hold ? INT_MIN / 2 : 0;
int &res = cache[i][hold];
if (res != -1) return res;
if (hold) {
res = max(dfs(i - 1, true), dfs(i - 1, false) - prices[i]);
return res;
}
res = max(dfs(i - 1, false), dfs(i - 1, true) + prices[i] - fee);
return res;
};
return dfs(n - 1, false);
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( n ) O(n) O(n)
解法二:递推
class Solution:
def maxProfit(self, prices: List[int], fee: int) -> int:
n = len(prices)
f = [[0] * 2 for _ in range(n + 1)]
f[0][1] = -inf
for i, p in enumerate(prices):
f[i + 1][0] = max(f[i][0], f[i][1] + p - fee)
f[i + 1][1] = max(f[i][1], f[i][0] - p)
return f[n][0]
class Solution {
public:
int maxProfit(vector<int>& prices, int fee) {
int n = prices.size(), f[n + 1][2];
memset(f, 0, sizeof f);
f[0][1] = INT_MIN / 2;
for (int i = 0; i < n; i ++ ) {
f[i + 1][0] = max(f[i][0], f[i][1] + prices[i] - fee);
f[i + 1][1] = max(f[i][1], f[i][0] - prices[i]);
}
return f[n][0];
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( n ) O(n) O(n)
解法三:空间优化
class Solution:
def maxProfit(self, prices: List[int], fee: int) -> int:
f0, f1 = 0, -inf
for p in prices:
f0, f1 = max(f0, f1 + p - fee), max(f1, f0 - p)
return f0
class Solution {
public:
int maxProfit(vector<int>& prices, int fee) {
int f0 = 0, f1 = INT_MIN / 2;
for (int p: prices) {
int new_f0 = max(f0, f1 + p - fee);
f1 = max(f1, f0 - p);
f0 = new_f0;
}
return f0;
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 是 p r i c e s prices prices 的长度
- 空间复杂度: O ( 1 ) O(1) O(1)
二、区间DP
- 线性DP:在前面学习的DP中,都是在数组的前缀或者后缀上进行转移的,这类DP通常叫做线性DP。
- 区间DP:将问题规模缩小到数组中间的区间上,而不仅仅是前缀或者后缀。
2.1 最长回文子序列
解法一:转换为LCS问题
因为回文子序列从左往右或者从右往左都是一样的,因此可以求解序列 s s s 与其反转后的序列 S ‘ S` S‘ 的最长公共子序列(LCS),即为最长公共子序列。
解法二:直接计算,选或者不选
记忆化搜索
class Solution:
def longestPalindromeSubseq(self, s: str) -> int:
@cache
def dfs(i: int, j: int)->int:
if i > j: return 0
if i == j: return 1
if s[i] == s[j]:
return dfs(i + 1, j - 1) + 2
return max(dfs(i + 1, j), dfs(i, j - 1))
return dfs(0, len(s) - 1)
class Solution {
public:
int longestPalindromeSubseq(string s) {
int n = s.length(), cahce[n][n];
memset(cahce, -1, sizeof cahce);
function<int(int, int)> dfs = [&](int i, int j)->int {
if (i > j) return 0;
if (i == j) return 1;
int &res = cahce[i][j];
if (res != -1) return res;
if (s[i] == s[j]) {
res = dfs(i + 1, j - 1) + 2;
return res;
}
res = max(dfs(i + 1, j), dfs(i, j - 1));
return res;
};
return dfs(0, n - 1);
}
};
- 时间复杂度: O ( n 2 ) O(n^2) O(n2),其中 n n n 为 s s s 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的转移个数。本题中状态个数等于 O ( n 2 ) O(n^2) O(n2),而单个状态的转移个数为 O ( 1 ) O(1) O(1),因此时间复杂度为 O ( n 2 ) O(n^2) O(n2)
- 空间复杂度: O ( n 2 ) O(n^2) O(n2)
递推写法
class Solution:
def longestPalindromeSubseq(self, s: str) -> int:
n = len(s)
f = [[0] * n for _ in range(n)]
for i in range(n - 1, -1, -1):
f[i][i] = 1
for j in range(i + 1, n):
if s[i] == s[j]:
f[i][j] = f[i + 1][j - 1] + 2
else:
f[i][j] = max(f[i + 1][j], f[i][j - 1])
return f[0][-1]
class Solution {
public:
int longestPalindromeSubseq(string s) {
int n = s.length(), f[n][n];
memset(f, 0, sizeof f);
for (int i = n - 1; i >= 0; i -- ) {
f[i][i] = 1;
for (int j = i + 1; j < n; j ++ )
f[i][j] = s[i] == s[j] ? f[i + 1][j - 1] + 2 : max(f[i + 1][j], f[i][j - 1]);
}
return f[0][n - 1];
}
};
- 时间复杂度: O ( n 2 ) O(n^2) O(n2),其中 n n n 为 s s s 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的转移个数。本题中状态个数等于 O ( n 2 ) O(n^2) O(n2),而单个状态的转移个数为 O ( 1 ) O(1) O(1),因此时间复杂度为 O ( n 2 ) O(n^2) O(n2)
- 空间复杂度: O ( n 2 ) O(n^2) O(n2)
空间优化
class Solution:
def longestPalindromeSubseq(self, s: str) -> int:
n = len(s)
f = [0] * n
for i in range(n - 1, -1, -1):
f[i] = 1
pre = 0
for j in range(i + 1, n):
tmp = f[j]
f[j] = pre + 2 if s[i] == s[j] else max(f[j], f[j - 1])
pre = tmp
return f[-1]
class Solution {
public:
int longestPalindromeSubseq(string s) {
int n = s.length(), f[n];
memset(f, 0, sizeof f);
for (int i = n - 1; i >= 0; i -- ) {
f[i] = 1;
int pre = 0;
for (int j = i + 1; j < n; j ++ ) {
int tmp = f[j];
f[j] = s[i] == s[j] ? pre + 2 : max(f[j], f[j - 1]);
pre = tmp;
}
}
return f[n - 1];
}
};
- 时间复杂度: O ( n 2 ) O(n^2) O(n2),其中 n n n 为 s s s 的长度
- 空间复杂度: O ( n ) O(n) O(n)
2.2 多边形三角剖分的最低得分
解法一:记忆化搜索
class Solution:
def minScoreTriangulation(self, values: List[int]) -> int:
@cache
def dfs(i: int, j: int) -> int:
if i + 1 == j: return 0
return min(dfs(i, k) + dfs(k, j) + values[i] * values[j] * values[k]
for k in range(i + 1, j))
return dfs(0, len(values) - 1)
class Solution {
public:
int minScoreTriangulation(vector<int>& values) {
int n = values.size(), cache[n][n];
memset(cache, -1, sizeof cache);
function<int(int, int)> dfs = [&](int i, int j) -> int {
if (i + 1 == j) return 0;
int &res = cache[i][j];
if (res != -1) return res;
res = INT_MAX;
for (int k = i + 1; k < j; k ++ )
res = min(res, dfs(i, k) + dfs(k, j) + values[i] * values[j] * values[k]);
return res;
};
return dfs(0, n - 1);
}
};
- 时间复杂度: O ( n 3 ) O(n^3) O(n3),其中 n n n 为 v a l u e s values values 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的计算时间。本题中状态个数等于 O ( n 2 ) O(n^2) O(n2),单个状态的计算时间为 O ( n ) O(n) O(n),因此时间复杂度为 O ( n 3 ) O(n^3) O(n3)。
- 空间复杂度: O ( n 2 ) O(n^2) O(n2)。有 O ( n 2 ) O(n^2) O(n2) 个状态。
解法二:递推
class Solution:
def minScoreTriangulation(self, values: List[int]) -> int:
n = len(values)
f = [[0] * n for _ in range(n)]
for i in range(n - 3, -1, -1):
for j in range(i + 2, n):
f[i][j] = min(f[i][k] + f[k][j] + values[i] * values[j] * values[k]
for k in range(i + 1, j))
return f[0][-1]
class Solution {
public:
int minScoreTriangulation(vector<int>& values) {
int n = values.size(), f[n][n];
memset(f, 0, sizeof f);
for (int i = n - 3; i >= 0; i -- )
for (int j = i + 2; j < n; j ++ ) {
f[i][j] = INT_MAX;
for (int k = i + 1; k < j; k ++ )
f[i][j] = min(f[i][j], f[i][k] + f[k][j] + values[i] * values[j] * values[k]);
}
return f[0][n - 1];
}
};
- 时间复杂度: O ( n 3 ) O(n^3) O(n3),其中 n n n 为 v a l u e s values values 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的计算时间。本题中状态个数等于 O ( n 2 ) O(n^2) O(n2),单个状态的计算时间为 O ( n ) O(n) O(n),因此时间复杂度为 O ( n 3 ) O(n^3) O(n3)。
- 空间复杂度: O ( n 2 ) O(n^2) O(n2)。有 O ( n 2 ) O(n^2) O(n2) 个状态。
2.3 由子序列构造的最长回文串的长度
class Solution:
def longestPalindrome(self, word1: str, word2: str) -> int:
s = word1 + word2
ans, n = 0, len(s)
f = [[0] * n for _ in range(n)]
for i in range(n - 1, -1, -1):
f[i][i] = 1
for j in range(i + 1, n):
if s[i] == s[j]:
f[i][j] = f[i + 1][j - 1] + 2
if i < len(word1) <= j:
ans = max(ans, f[i][j])
else:
f[i][j] = max(f[i + 1][j], f[i][j - 1])
return ans
class Solution {
public:
int longestPalindrome(string word1, string word2) {
string s = word1 + word2;
int ans = 0, n = s.length(), f[n][n];
memset(f, 0, sizeof f);
for (int i = n - 1; i >= 0; i -- ) {
f[i][i] = 1;
for (int j = i + 1; j < n; j ++ )
if (s[i] == s[j]) {
f[i][j] = f[i + 1][j - 1] + 2;
if (i < word1.length() && j >= word1.length())
ans = max(ans, f[i][j]);
}
else
f[i][j] = max(f[i + 1][j], f[i][j - 1]);
}
return ans;
}
};
2.4 合并石头的最低成本
class Solution:
def mergeStones(self, stones: List[int], k: int) -> int:
n = len(stones)
if (n - 1) % (k - 1): return -1
s = list(accumulate(stones, initial=0))
@cache
def dfs(i: int, j: int, p: int)->int:
if p == 1:
return 0 if i == j else dfs(i, j, k) + s[j + 1] - s[i]
return min(dfs(i, m, 1) + dfs(m + 1, j, p - 1) for m in range(i, j, k - 1))
return dfs(0, n - 1, 1)
class Solution {
public:
int mergeStones(vector<int>& stones, int k) {
int n = stones.size();
if ((n - 1) % (k - 1)) return -1;
int s[n + 1];
s[0] = 0;
for (int i = 0; i < n; i ++ )
s[i + 1] = s[i] + stones[i];
int cache[n][n][k + 1];
memset(cache, -1, sizeof cache);
function<int(int, int, int)> dfs = [&](int i, int j, int p) -> int {
int &res = cache[i][j][p];
if (res != -1) return res;
if (p == 1)
return res = i == j ? 0 : dfs(i, j, k) + s[j + 1] - s[i];
res = INT_MAX;
for (int m = i; m < j; m += k - 1)
res = min(res, dfs(i, m, 1) + dfs(m + 1, j, p - 1));
return res;
};
return dfs(0, n - 1, 1);
}
};
- 时间复杂度: O ( n 3 ) O(n^3) O(n3),其中 n n n 为 s t o n e s stones stones 的长度。动态规划的时间复杂度 = 状态个数 × \times × 单个状态的计算时间。这里状态个数为 O ( n 2 k ) O(n^2k) O(n2k),单个状态的计算时间为 O ( n k ) O(\frac{n}{k}) O(kn),因此时间复杂度为 O ( n 3 ) O(n^3) O(n3)。
- 空间复杂度: O ( n 2 k ) O(n^2k) O(n2k)
优化
待写
递推
待写
三、树形DP——直径系列
3.1 二叉树的直径
class Solution:
def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
ans = 0
def dfs(node: Optional[TreeNode])->int:
if node is None: return -1
l_len = dfs(node.left)
r_len = dfs(node.right)
nonlocal ans
ans = max(ans, l_len + r_len + 2)
return max(l_len, r_len) + 1
dfs(root)
return ans
class Solution {
public:
int diameterOfBinaryTree(TreeNode *root) {
int ans = 0;
function<int(TreeNode*)> dfs = [&](TreeNode *node) -> int {
if (node == nullptr)
return -1; // 下面 +1 后,对于叶子节点就刚好是 0
int l_len = dfs(node->left); // 左子树最大链长+1
int r_len = dfs(node->right); // 右子树最大链长+1
ans = max(ans, l_len + r_len + 2); // 两条链拼成路径
return max(l_len, r_len) + 1; // 当前子树最大链长
};
dfs(root);
return ans;
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 为二叉树的节点个数。
- 空间复杂度: O ( n ) O(n) O(n)。最坏情况下,二叉树退化成一条链,递归需要 O ( n ) O(n) O(n) 的栈空间。
3.2 二叉树中的最大路径和
class Solution:
def maxPathSum(self, root: Optional[TreeNode]) -> int:
ans = -inf
def dfs(node: Optional[TreeNode])->int:
if node is None: return 0
l_val = dfs(node.left)
r_val = dfs(node.right)
nonlocal ans
ans = max(ans, l_val + r_val + node.val)
return max(max(l_val, r_val) + node.val, 0)
dfs(root)
return ans
class Solution {
public:
int maxPathSum(TreeNode* root) {
int ans = INT_MIN;
function<int(TreeNode*)> dfs = [&](TreeNode* node)->int {
if (node == nullptr) return 0;
int l_val = dfs(node->left);
int r_val = dfs(node->right);
ans = max(ans, l_val + r_val + node->val);
return max(max(l_val, r_val) + node->val, 0);
};
dfs(root);
return ans;
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 为二叉树的节点个数。
- 空间复杂度: O ( n ) O(n) O(n)。最坏情况下,二叉树退化成一条链,递归需要 O ( n ) O(n) O(n) 的栈空间。
3.3 相邻字符不同的最长路径(普通树的直径)
class Solution:
def longestPath(self, parent: List[int], s: str) -> int:
n = len(parent)
g = [[] for _ in range(n)]
for i in range(1, n):
g[parent[i]].append(i)
ans = 0
def dfs(x: int) -> int:
nonlocal ans
max_len = 0
for y in g[x]:
son_len = dfs(y) + 1
if s[y] != s[x]:
ans = max(ans, max_len + son_len)
max_len = max(max_len, son_len)
return max_len
dfs(0)
return ans + 1
class Solution {
public:
int longestPath(vector<int>& parent, string s) {
int n = parent.size();
vector<vector<int>> g(n);
for (int i = 1; i < n; i ++ )
g[parent[i]].push_back(i);
int ans = 0;
function<int(int)> dfs = [&] (int x) -> int {
int maxLen = 0;
for (int y: g[x]) {
int son_len = dfs(y) + 1;
if (s[y] != s[x]) {
ans = max(ans, maxLen + son_len);
maxLen = max(maxLen, son_len);
}
}
return maxLen;
};
dfs(0);
return ans + 1;
}
};
- 时间复杂度: O ( n ) O(n) O(n)
- 空间复杂度: O ( n ) O(n) O(n)
3.4 最长同值路径
class Solution:
def longestUnivaluePath(self, root: Optional[TreeNode]) -> int:
ans = 0
def dfs(node: Optional[TreeNode])->int:
if node is None: return -1
l_len = dfs(node.left) + 1
r_len = dfs(node.right) + 1
if node.left and node.left.val != node.val: l_len = 0
if node.right and node.right.val != node.val: r_len = 0
nonlocal ans
ans = max(ans, l_len + r_len)
return max(l_len, r_len)
dfs(root)
return ans
class Solution {
public:
int longestUnivaluePath(TreeNode* root) {
int ans = 0;
function<int(TreeNode*)> dfs = [&] (TreeNode *node) -> int {
if (node == nullptr) return -1;
int l_len = dfs(node->left) + 1;
int r_len = dfs(node->right) + 1;
if (node->left && node->left->val != node->val) l_len = 0;
if (node->right && node->right->val != node->val) r_len = 0;
ans = max(ans, l_len + r_len);
return max(l_len, r_len);
};
dfs(root);
return ans;
}
};
- 时间复杂度: O ( n ) O(n) O(n)
- 空间复杂度: O ( n ) O(n) O(n)
3.5 统计子树中城市之间最大距离
class Solution:
def countSubgraphsForEachDiameter(self, n: int, edges: List[List[int]]) -> List[int]:
# 建树
g = [[] for _ in range(n)]
for x, y in edges:
g[x - 1].append(y - 1)
g[y - 1].append(x - 1)
ans = [0] * (n - 1)
# 标记是否选择该节点
in_set = [False] * n
def f(i: int) -> None:
if i == n:
vis = [False] * n
diameter = 0 # 直径
for v, b in enumerate(in_set):
if not b: continue
# 求树的直径
def dfs(x: int) -> int:
nonlocal diameter
vis[x] = True
max_len = 0
for y in g[x]:
if not vis[y] and in_set[y]:
ml = dfs(y) + 1
diameter = max(diameter, max_len + ml)
max_len = max(max_len, ml)
return max_len
dfs(v)
break
if diameter and vis == in_set:
ans[diameter - 1] += 1
return
# 不选择节点i
f(i + 1)
# 选择节点i
in_set[i] = True
f(i + 1)
in_set[i] = False
f(0)
return ans
class Solution {
public:
vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>>& edges) {
vector<vector<int>> g(n);
for (auto &e: edges) {
int x= e[0] - 1, y = e[1] - 1;
g[x].push_back(y);
g[y].push_back(x);
}
vector<int> ans(n - 1), in_set(n), vis(n);
int diameter = 0;
function<int(int)> dfs = [&](int x) -> int {
vis[x] = true;
int max_len = 0;
for (int y: g[x])
if (!vis[y] && in_set[y]) {
int ml = dfs(y) + 1;
diameter = max(diameter, max_len + ml);
max_len = max(max_len, ml);
}
return max_len;
};
function<void(int)> f = [&](int i) {
if (i == n) {
for (int v = 0; v < n; v ++ )
if (in_set[v]) {
fill(vis.begin(), vis.end(), 0);
diameter = 0;
dfs(v);
break;
}
if (diameter && vis == in_set) ++ ans[diameter - 1];
return;
}
f(i + 1);
in_set[i] = true;
f(i + 1);
in_set[i] = false;
};
f(0);
return ans;
}
};
- 时间复杂度: O ( n 2 n ) O(n2^n) O(n2n)。 O ( 2 n ) O(2^n) O(2n) 枚举子集, O ( n ) O(n) O(n) 求直径,所以时间复杂度为 O ( n 2 n ) O(n2^n) O(n2n)
- 空间复杂度: O ( n ) O(n) O(n)
二进制优化
class Solution:
def countSubgraphsForEachDiameter(self, n: int, edges: List[List[int]]) -> List[int]:
g = [[] for _ in range(n)]
for x, y in edges:
g[x - 1].append(y - 1)
g[y - 1].append(x - 1)
ans = [0] * (n - 1)
for mask in range(3, 1 << n):
if (mask & (mask - 1)) == 0:
continue;
vis = diameter = 0
def dfs(x: int) -> int:
nonlocal vis, diameter
vis |= 1 << x
max_len = 0
for y in g[x]:
if (vis >> y & 1) == 0 and mask >> y & 1:
ml = dfs(y) + 1
diameter = max(diameter, max_len + ml)
max_len = max(max_len, ml)
return max_len
dfs(mask.bit_length() - 1)
if vis == mask:
ans[diameter - 1] += 1
return ans
class Solution {
public:
vector<int> countSubgraphsForEachDiameter(int n, vector<vector<int>> &edges) {
vector<vector<int>> g(n);
for (auto &e : edges) {
int x = e[0] - 1, y = e[1] - 1; // 编号改为从 0 开始
g[x].push_back(y);
g[y].push_back(x); // 建树
}
int dis[n][n]; memset(dis, 0, sizeof(dis));
function<void(int, int, int)> dfs = [&](int i, int x, int fa) {
for (int y : g[x])
if (y != fa) {
dis[i][y] = dis[i][x] + 1; // 自顶向下
dfs(i, y, x);
}
};
for (int i = 0; i < n; ++i)
dfs(i, i, -1); // 计算 i 到其余点的距离
function<int(int, int, int, int, int)> dfs2 = [&](int i, int j, int d, int x, int fa) {
// 能递归到这,说明 x 可以选
int cnt = 1; // 选 x
for (int y : g[x])
if (y != fa &&
(dis[i][y] < d || dis[i][y] == d && y > j) &&
(dis[j][y] < d || dis[j][y] == d && y > i)) // 满足这些条件就可以选
cnt *= dfs2(i, j, d, y, x); // 每棵子树互相独立,采用乘法原理
if (dis[i][x] + dis[j][x] > d) // x 是可选点
++cnt; // 不选 x
return cnt;
};
vector<int> ans(n - 1);
for (int i = 0; i < n; ++i)
for (int j = i + 1; j < n; ++j)
ans[dis[i][j] - 1] += dfs2(i, j, dis[i][j], i, -1);
return ans;
}
};
- 时间复杂度: O ( n 2 n ) O(n2^n) O(n2n)。 O ( 2 n ) O(2^n) O(2n) 枚举子集, O ( n ) O(n) O(n) 求直径,所以时间复杂度为 O ( n 2 n ) O(n2^n) O(n2n)
- 空间复杂度: O ( n ) O(n) O(n)
3.6 最大价值和与最小价值和的差值
class Solution:
def maxOutput(self, n: int, edges: List[List[int]], price: List[int]) -> int:
g = [[] for _ in range(n)]
for x, y in edges:
g[x].append(y)
g[y].append(x)
ans = 0
def dfs(x: int, fa: int) -> (int, int):
nonlocal ans
max_s1 = p = price[x]
max_s2 = 0
for y in g[x]:
if y == fa: continue
s1, s2 = dfs(y, x)
ans = max(ans, max_s1 + s2, max_s2 + s1)
max_s1 = max(max_s1, s1 + p)
max_s2 = max(max_s2, s2 + p)
return max_s1, max_s2
dfs(0, -1)
return ans
class Solution {
public:
long long maxOutput(int n, vector<vector<int>>& edges, vector<int>& price) {
vector<vector<int>> g(n);
for (auto &e: edges) {
int x = e[0], y = e[1];
g[x].push_back(y);
g[y].push_back(x);
}
long ans = 0;
function<pair<long, long>(int, int)> dfs = [&] (int x, int fa) -> pair<long, long> {
long p = price[x], max_s1 = p, max_s2 = 0;
for (int y: g[x])
if (y != fa) {
auto [s1, s2] = dfs(y, x);
ans = max(ans, max(max_s1 + s2, max_s2 + s1));
max_s1 = max(max_s1, s1 + p);
max_s2 = max(max_s2, s2 + p);
}
return {max_s1, max_s2};
};
dfs(0, -1);
return ans;
}
};
- 时间复杂度: O ( n ) O(n) O(n)
- 空间复杂度: O ( n ) O(n) O(n)
四、树形DP——最大独立集
在图论中,独立集(Independent Set)是一种顶点的集合,其中任意两个顶点都不相邻,也就是说,集合中的顶点之间没有边相连。换句话说,独立集是一组顶点,其中没有两个顶点通过一条边相连。
4.1 打家劫舍 III(二叉树上)
class Solution:
def rob(self, root: Optional[TreeNode]) -> int:
def dfs(node: Optional[TreeNode])->(int, int):
if node is None: return 0, 0
l_rob, l_not_rob = dfs(node.left)
r_rob, r_not_rob = dfs(node.right)
rob = l_not_rob + r_not_rob + node.val # 选
not_rob = max(l_rob, l_not_rob) + max(r_rob, r_not_rob) # 不选
return rob, not_rob
return max(dfs(root))
class Solution {
pair<int, int> dfs(TreeNode* node) {
if (node == nullptr) return {0, 0};
auto [l_rob, l_not_rob] = dfs(node->left);
auto [r_rob, r_not_rob] = dfs(node->right);
int rob = l_not_rob + r_not_rob + node->val;
int not_rob = max(l_rob, l_not_rob) + max(r_rob, r_not_rob);
return {rob, not_rob};
}
public:
int rob(TreeNode* root) {
auto [root_rob, root_not_rob] = dfs(root);
return max(root_rob, root_not_rob);
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 为二叉树的节点个数。每个节点都会递归恰好一次。
- 空间复杂度: O ( n ) O(n) O(n)。最坏情况下,二叉树是一条链,递归需要 O ( n ) O(n) O(n) 的栈空间。
4.2 没有上司的舞会(普通树上)
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 6010;
int n;
int happy[N];
int h[N], e[N], ne[N], idx;
int f[N][2];
bool has_father[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void dfs(int u) {
f[u][1] = happy[u];
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
dfs(j);
f[u][0] += max(f[j][0], f[j][1]);
f[u][1] += f[j][0];
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i ++ ) scanf("%d", happy + i);
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i ++ ) {
int a, b;
scanf("%d%d", &a, &b);
has_father[a] = true;
add(b, a);
}
int root = 1;
while (has_father[root]) root ++ ;
dfs(root);
printf("%d\n", max(f[root][0], f[root][1]));
return 0;
}
4.3 T秒之后青蛙的位置
要注意几个问题:
- 如何处理浮点数计算精度问题?先计算分母乘机,最后将结果取倒即可。
- 为什么写【自底向上】而不是【自顶向下】?自底向上是先找到问题后在返回答案,自顶向下是不断往下找到答案后返回结果,因此自顶向下会造成一些不必要的计算。
- 如何解决特判
n=1
的情况?在节点n=1
中添加一个父节点n=0
。 - 对于时间
T
,如何在 DFS 中减少一个变量的引入?让leftT
从t
减少到0
而不是从0
增加到t
,这样的代码只需要和0
比较,而不需要和t
比较。
class Solution:
def frogPosition(self, n: int, edges: List[List[int]], t: int, target: int) -> float:
g = [[] for _ in range(n + 1)]
g[1] = [0]
for x, y in edges:
g[x].append(y)
g[y].append(x)
ans = 0
def dfs(x: int, fa: int, left_t: int, prod: int) -> True:
"""
x: 当前遍历的节点
fa: 当前节点的父节点
left_t: 剩余时间
prod: 概率
return: 用于判断dfs是否结束
"""
# t 秒后必须在 target(恰好到达,或者 target 是叶子停在原地)
if x == target and (left_t == 0 or len(g[x]) == 1):
nonlocal ans
ans = 1 / prod
return True
if x == target or left_t == 0: return False
for y in g[x]:
if y != fa and dfs(y, x, left_t - 1, prod * (len(g[x]) - 1)):
return True
return False
dfs(1, 0, t, 1)
return ans
class Solution {
public:
double frogPosition(int n, vector<vector<int>>& edges, int t, int target) {
vector<vector<int>> g(n + 1);
g[1] = {0};
for (auto &e: edges) {
int x = e[0], y = e[1];
g[x].push_back(y);
g[y].push_back(x);
}
double ans = 0;
function<bool(int, int, int, long long)> dfs = [&] (int x, int fa, int left_t, long long prob) -> bool {
if (x == target && (left_t == 0 || g[x].size() == 1)) {
ans = 1.0 / prob;
return true;
}
if (x == target || left_t == 0) return false;
for (int y: g[x])
if (y != fa && dfs(y, x, left_t - 1, prob * (g[x].size() - 1)))
return true;
return false;
};
dfs(1, 0, t, 1);
return ans;
}
};
4.4 最小化旅行的价格总和
class Solution:
def minimumTotalPrice(self, n: int, edges: List[List[int]], price: List[int], trips: List[List[int]]) -> int:
g = [[] for _ in range(n)]
for x, y in edges:
g[x].append(y)
g[y].append(x)
cnt = [0] * n
for start, end in trips:
def dfs(x: int, fa: int) -> bool:
if x == end:
cnt[x] += 1
return True
for y in g[x]:
if y != fa and dfs(y, x):
cnt[x] += 1
return True
return False
dfs(start, -1)
def dfs(x: int, fa: int) -> (int, int):
not_have = price[x] * cnt[x]
halve = not_have // 2
for y in g[x]:
if y != fa:
nh, h = dfs(y ,x)
not_have += min(nh, h)
halve += nh
return not_have, halve
return min(dfs(0, -1))
class Solution {
public:
int minimumTotalPrice(int n, vector<vector<int>>& edges, vector<int>& price, vector<vector<int>>& trips) {
vector<vector<int>> g(n);
for (auto &e: edges) {
int x = e[0], y = e[1];
g[x].push_back(y);
g[y].push_back(x);
}
int cnt[n]; memset(cnt, 0, sizeof cnt);
for (auto &t: trips) {
int end = t[1];
function<bool(int, int)> dfs = [&](int x, int fa) -> bool {
if (x == end) {
++ cnt[x];
return true;
}
for (int y: g[x])
if (y != fa && dfs(y, x)) {
++ cnt[x];
return true;
}
return false;
};
dfs(t[0], -1);
}
function<pair<int, int>(int, int)> dfs = [&](int x, int fa) -> pair<int ,int> {
int not_halve = price[x] * cnt[x];
int halve = not_halve / 2;
for (int y: g[x])
if (y != fa) {
auto [nh, h] = dfs(y, x);
not_halve += min(nh, h);
halve += nh;
}
return {not_halve, halve};
};
auto [nh, h] = dfs(0, -1);
return min(nh, h);
}
};
- 时间复杂度: O ( n m ) O(nm) O(nm),其中 m m m 为 t r i p s trips trips 的长度
- 空间复杂度: O ( n ) O(n) O(n)
五、树形DP——最小支配集
5.1 监控二叉树
class Solution:
def minCameraCover(self, root: Optional[TreeNode]) -> int:
def dfs(node):
if node is None:
return inf, 0, 0
l_choose, l_by_fa, l_by_children = dfs(node.left)
r_choose, r_by_fa, r_by_children = dfs(node.right)
choose = min(l_choose, l_by_fa) + min(r_choose, r_by_fa) + 1
by_fa = min(l_choose, l_by_children) + min(r_choose, r_by_children)
by_children = min(l_choose + r_by_children, l_by_children + r_choose, l_choose + r_choose)
return choose, by_fa, by_children
choose, _, by_children = dfs(root)
return min(choose, by_children)
class Solution {
tuple<int, int, int> dfs(TreeNode *node) {
if (node == nullptr) return {INT_MAX / 2, 0, 0};
auto [l_choose, l_by_fa, l_by_children] = dfs(node->left);
auto [r_choose, r_by_fa, r_by_children] = dfs(node->right);
int choose = min(l_choose, l_by_fa) + min(r_choose, r_by_fa) + 1;
int by_fa = min(l_choose, l_by_children) + min(r_choose, r_by_children);
int by_children = min({l_choose + r_by_children, l_by_children + r_choose, l_choose + r_choose});
return {choose, by_fa, by_children};
}
public:
int minCameraCover(TreeNode* root) {
auto [choose, _, by_children] = dfs(root);
return min(choose, by_children);
}
};
- 时间复杂度: O ( n ) O(n) O(n),其中 n n n 为二叉树的节点个数。每个节点都会递归恰好一次。
- 空间复杂度: O ( n ) O(n) O(n),最坏情况下,二叉树是一条链,递归需要 O ( n ) O(n) O(n) 的栈空间。