前言
区间dp,顾名思义,是解决一类区间问题的动态规划。通常用来 f [ l ] [ r ] f[l][r] f[l][r]来表示区间 [ L , R ] [L, R] [L,R]上的最优解。主要难在对于状态的转移。区间dp有迭代和递归两种写法,而递归写法就是一个记忆化搜索。
- 通用模板
for (int len = 1; len <= n; len ++) { 枚举长度
for (int l = 1; l + len - 1 <= n; ++ l) { 枚举区间
int r = l + len - 1;
for (int k = l; k < r; k ++) { 枚举断点(状态转移)
f[l][r] = max(f[l][r], f[l][k] + f[k + 1][r] + val) value
}
}
}
例题
例题【1】石子合并
链接:洛谷1775
题意:有
N
≤
300
N \le 300
N≤300个石子,排成一排,其编号为
1
,
2
,
3
,
⋯
,
N
1,2,3,\cdots,N
1,2,3,⋯,N。每堆石子有一定的质量
m
i
≤
1000
m_i \le 1000
mi≤1000 。
现在要将这
N
N
N 堆石子合并成为一堆。每次只能合并相邻的两堆,合并的代价为这两堆石子的质量之和。
试找出一种合理的方法,使总的代价最小,并输出最小代价。
Solution:
N
N
N堆石子,经过
N
−
1
N-1
N−1次合并之后会变成一堆,我们无法得知是先合并哪两堆。不妨我们反过来思考,假设合并了
N
−
2
N-2
N−2次,在合并最后一次的时候,一定是在一个位置
k
k
k,使得
w
1
→
k
+
w
k
+
1
→
N
w_{1 \to k} + w_{k +1 \to N}
w1→k+wk+1→N值最小(最大),一次类推,在倒数第二次合并的时候分别在
[
1
,
k
]
[1, k]
[1,k] 和
[
k
+
1
,
n
]
[k + 1, n]
[k+1,n]继续找位置
k
′
k'
k′,直到区间长度为1。
因此,按照上诉思路可以得出解法一:记忆化搜索。
code:
int dp(int l, int r) {
if(l >= r) return 0;
int &res = f[l][r];
if(res != inf) return res;
for (int k = l; k < r; k ++) {
int v = dp(l, k) + dp(k + 1, r) + s[r] - s[l - 1];
Min(res, v);
}
return res;
}
根据上面的思路,根据区间dp模板。对于此题,有区间dp解法
定义:
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示为区间
[
i
,
j
]
[i, j]
[i,j]合并的最小值。
状态转移:
f
[
i
]
[
j
]
=
0
i
=
j
f
[
i
]
[
j
]
=
min
k
=
l
r
−
1
(
f
[
l
]
[
k
]
+
f
[
k
+
1
]
[
r
]
)
+
∑
h
=
l
r
w
h
i
≠
j
f[i][j] = 0 \quad i = j \\ \qquad \qquad f[i][j] = \min_{k=l}^{r - 1} (f[l][k] + f[k + 1][r]) + \sum_{h=l}^rw_h \quad i \neq j
f[i][j]=0i=jf[i][j]=mink=lr−1(f[l][k]+f[k+1][r])+∑h=lrwhi=j
(Q:为什么
k
∈
[
l
,
r
)
k \in [l, r)
k∈[l,r),寻找断点,将区间分为两个子集分别为
[
l
,
k
]
[l, k]
[l,k]和
[
k
+
1
,
r
]
[k + 1, r]
[k+1,r]。所以
k
k
k要满足
k
+
1
≤
r
k + 1 \le r
k+1≤r
code :
memset(f, 0x3f, sizeof f);
cin >> n;
for (int i = 1; i <= n; i ++ ) cin >> w[i], s[i] = s[i - 1] + w[i];
for (int len = 1; len <= n; len ++) {
for (int l = 1; l + len - 1 <= n; l ++ ) {
int r = len + l - 1;
if(l == r) f[l][r] = 0;
for (int k = l; k < r; ++ k)
Min(f[l][r], f[l][k] + f[k + 1][r] + s[r] - s[l - 1]);
}
}
cout << f[1][n] << "\n";
时间复杂度: O ( n 3 ) \mathcal{O(n^3)} O(n3) (可通过 n ≤ 300 n \le 300 n≤300的题目,通过四边形优化能将复杂度优化到 O ( n 2 ) \mathcal{O(n^2)} O(n2)另外对于石子合并问题,有 O ( n l o g n ) \mathcal{O(nlogn)} O(nlogn)算法P5569, GarsiaWachs算法
例题【2】回文子序列
链接:leetcode最长回文子序列
题意: 给你一个字符串 s ,找出其中最长的回文子序列
- 子序列定义为:不改变剩余字符顺序的情况下,删除某些字符或者不删除任何字符形成的一个序列。
Solution:
解法一:
定义:
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示区间
[
i
,
j
]
[i, j]
[i,j]的最长回文子序列。
状态转移:
f
[
i
]
[
j
]
=
1
i
=
j
f
[
i
]
[
j
]
=
f
[
i
+
1
]
[
j
−
1
]
+
2
s
[
i
]
=
s
[
j
]
f
[
i
]
[
j
]
=
m
a
x
(
f
[
i
+
1
]
[
j
]
,
f
[
i
]
[
j
−
1
]
s
[
i
]
≠
s
[
j
]
f[i][j] = 1 \quad i = j \\ \qquad \qquad f[i][j] = f[i + 1][j -1] + 2 \quad s[i] = s [j] \\ \qquad \qquad f[i][j] = max(f[i + 1][j], f[i][j - 1] \quad s[i] \neq s[j]
f[i][j]=1i=jf[i][j]=f[i+1][j−1]+2s[i]=s[j]f[i][j]=max(f[i+1][j],f[i][j−1]s[i]=s[j]
Code:
const int N = 1010;
int f[N][N];
class Solution {
public:
int longestPalindromeSubseq(string s) {
int n = s.size();
for (int i = 1; i < n; i ++)
for (int j = 1; j < n; j ++)
f[i][j] = 0;
for (int len = 1; len <= n; len ++) {
for (int l = 0; l + len - 1 < n; l ++) {
int r = l + len - 1;
if(l == r) f[l][r] = 1;
else {
if(s[l] == s[r]) f[l][r] = f[l + 1][r - 1] + 2;
else f[l][r] = max(f[l + 1][r], f[l][r - 1]);
}
}
}
return f[0][n - 1];
}
};
另: 解法二:将字符串反转得到字符串 t t t, 此时的字符串 s s s和 t t t的最长公共子序列就是字符串 s s s的最长回文子序列。
这里直接给出代码:
class Solution {
public:
int longestPalindromeSubseq(string s) {
int n = s.size();
string t = s; reverse(t.begin(), t.end());
vector<vector<int>> dp(n + 1, vector<int>(n + 1, 0));
for (int i = 1; i <= n; i ++) {
for (int j = 1; j <= n; j ++) {
if(s[i - 1] == t[j - 1]) dp[i][j] = dp[i - 1][j - 1] + 1;
else dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]);
}
}
return dp[n][n];
}
};
最后,如果题目要求的是最长回文子串,区间dp不在适用(切记!),有关算法可以有字符串哈希,马拉车…)
时间复杂度分析:
O
(
n
2
)
\mathcal{O(n^2)}
O(n2)
小结
- 状态设计: 通常定义 f [ i ] [ j ] f[i][j] f[i][j]表示区间 [ i , j ] [i, j] [i,j]的最优解。有的时候,单纯的二维无法满足需求,需要第三维记录其他状态。
- 状态转移:一定都是有小区间往大区间转移
- 时间复杂度:一般为 O ( n 3 ) \mathcal{O(n^3)} O(n3)或者 O ( n 2 ) \mathcal{O(n^2)} O(n2)
例题【3】环形石子合并(破环成链)
链接:P1880
题意:同例题【1】,只不过这里的石子是环形的。
Solution:
对于链状的石子合并能够利用区间dp解决,那对于环状的,就想办法变成链!
怎么变呢?只需要把
[
1
,
n
]
[1,n]
[1,n]复制到
[
n
+
1
,
2
∗
n
]
[n+1,2*n]
[n+1,2∗n],然后对于这个2*n的链进行区间dp。最后找长度为n的链的最优解。
Code:
#include <bits/stdc++.h>
using namespace std;
const int N = 420, INF = 0x3f3f3f3f;
int n;
int w[N], s[N];
int f[N][N], g[N][N];
int main() {
cin >> n;
for(int i = 1; i <= n; ++ i) {
cin >> w[i];
w[i + n] = w[i];
}
for(int i = 1; i <= n * 2; ++ i) s[i] = s[i - 1] + w[i];
memset(g, 0x3f, sizeof g);
for(int len = 1; len <= n; ++ len) {
for(int l = 1; l + len - 1 <= 2 * n; ++ l) {
int r = l + len - 1;
if(l == r) f[l][r] = g[l][r] = 0;
else {
for(int k = l; k + 1 <= r; ++ k ) {
f[l][r] = max(f[l][r], f[l][k] + f[k + 1][r] + s[r] - s[l - 1]);
g[l][r] = min(g[l][r], g[l][k] + g[k + 1][r] + s[r] - s[l - 1]);
}
}
}
}
int maxv = -INF, minv = INF;
for(int i = 1; i <= n; ++ i) {
maxv = max(maxv, f[i][i + n - 1]);
minv = min(minv, g[i][i + n - 1]);
}
cout << minv << "\n" << maxv << "\n";
}
例题【4】多边形
链接:Acwing285
题意:
Solution:
很容易发现这个就是石子合并的变形,价值不在只有相加,还有相乘。因为这里有负数,这里就涉及到了最大值如何的来,其实就四个答案取
max
\max
max即可(
m
x
(
+
/
∗
)
m
i
,
m
x
(
+
/
∗
)
m
x
,
m
i
(
+
/
∗
)
m
i
,
m
i
(
+
/
∗
)
m
x
mx (+/*) mi, mx (+/*) mx, mi (+/*) mi , mi (+/*) mx
mx(+/∗)mi,mx(+/∗)mx,mi(+/∗)mi,mi(+/∗)mx).
定义:
f
[
i
]
[
j
]
[
k
]
f[i][j][k]
f[i][j][k]表示区间
[
i
,
j
]
[i, j]
[i,j]合并后的最大值/最小值(
k
∈
[
1
,
2
]
k \in [1, 2]
k∈[1,2])
代码实现:
#include <bits/stdc++.h>
#define ALL(a) (a).begin(), (a).end()
using namespace std;
using LL = long long;
typedef pair<int, int> PII;
template < typename T> inline void Max(T &a, T b) { if(a < b) a = b; }
template < typename T> inline void Min(T &a, T b) { if(a > b) a = b; }
constexpr int N = 110, inf = 1E9;
int n, a[N], f[N][N][2];
char op[N];
int get(int a, int b, char op) {
if(op == 't') return a + b;
return a * b;
}
int main() {
cin.tie(nullptr) -> sync_with_stdio(false);
cin >> n;
for (int i = 1;i <= n; i ++ ) {
cin >> op[i] >> a[i];
op[i + n] = op[i];
a[i + n] = a[i];
}
for (int i = 1; i <= 2 * n; i ++ ) {
for (int j = 1; j <= 2 * n; j ++ ) {
if(i == j) f[i][j][0] = f[i][j][1] = a[i];
else {
f[i][j][0] = -inf;
f[i][j][1] = inf;
}
}
}
for (int len = 2; len <= 2 * n; len ++ ) {
for (int l = 1; l + len - 1 <= 2 * n; l ++ ) {
int r = l + len - 1;
for (int k = l; k < r; k ++ ) {
for (int i = 0; i < 2; i ++ ) {
for (int j = 0 ; j < 2; j ++ ) {
f[l][r][0] = max(f[l][r][0], get(f[l][k][i], f[k + 1][r][j], op[k + 1]));
f[l][r][1] = min(f[l][r][1], get(f[l][k][i], f[k + 1][r][j], op[k + 1]));
}
}
}
}
}
int ans = -inf, idx = 0;
vector<int> res;
for (int i = 1; i <= n; i ++ ) {
int t = f[i][i + n - 1][0];
if(t > ans) {
res.clear();
res.push_back(i);
ans = t;
} else if(t == ans) res.push_back(i);
}
cout << ans << "\n";
for (auto x : res) cout << x << " ";
return 0;
}
【例题5】String painter
链接:hdu_2476(hdu活了!!!,好耶!)
题意: 给出两个长度一样的字符串
s
1
s1
s1 和
s
2
s2
s2,求将
s
1
s1
s1转换成
s
2
s2
s2的最小步数,每次可以选择一段子串转换成其他字符。
Solution:
先将问题分解为两个:1. 将空串变成
s
2
s2
s2,2. 再将
s
1
s1
s1转换成
s
2
s2
s2。
对于一:很显然的区间dp
定义:
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示空船变成区间
[
i
,
j
]
[i, j]
[i,j]的最小操作数
状态转移:
f
[
i
]
[
j
]
=
m
i
n
(
f
[
i
]
[
j
−
1
]
,
f
[
i
+
1
]
[
j
]
)
s
[
i
]
=
s
[
j
]
f
[
i
]
[
j
]
=
min
k
=
i
j
−
1
(
f
[
i
]
[
j
]
,
f
[
i
]
[
k
]
+
f
[
k
+
1
]
[
j
]
s
[
i
]
≠
s
[
j
]
f[i][j] = min(f[i][j - 1], f[i + 1][j]) \quad s[i] = s[j] \\ \qquad \qquad f[i][j] = \min_{k=i}^{j - 1}(f[i][j], f[i][k] + f[k + 1][j] \quad s[i] \neq s[j]
f[i][j]=min(f[i][j−1],f[i+1][j])s[i]=s[j]f[i][j]=mink=ij−1(f[i][j],f[i][k]+f[k+1][j]s[i]=s[j]
对于二:普通dp
定义:
f
[
i
]
f[i]
f[i]表示区间
[
0
,
i
]
[0, i]
[0,i]将
s
1
s1
s1转换成
s
2
s2
s2的最小操作
转移:
f
[
i
]
=
min
k
=
0
i
(
f
[
i
]
,
f
[
k
]
+
d
p
[
k
+
1
]
[
j
]
f[i] = \min_{k=0}^i(f[i], f[k] + dp[k + 1][j]
f[i]=mink=0i(f[i],f[k]+dp[k+1][j]
Code:
#include <bits/stdc++.h>
#define ALL(a) (a).begin(), (a).end()
using namespace std;
using LL = long long;
typedef pair<int, int> PII;
template < typename T> inline void Max(T &a, T b) { if(a < b) a = b; }
template < typename T> inline void Min(T &a, T b) { if(a > b) a = b; }
const int N = 101;
int dp[N][N], f[N];
char s1[N], s2[N];
int main() {
cin.tie(nullptr) -> sync_with_stdio(false);
while (cin >> s1 >> s2) {
memset(dp, 0x3f, sizeof dp);
int n = strlen(s1);
for (int i = 0; i < n; i ++) dp[i][i] = 1;
for (int len = 1; len <= n; len ++) {
for (int l = 0; l + len - 1 < n; l ++) {
int r = l + len - 1;
if(len == 1) dp[l][r] = 1;
else {
if(s2[l] == s2[r]) dp[l][r] = min(dp[l][r - 1], dp[l + 1][r]);
for (int k = l; k < r; k ++)
dp[l][r] = min(dp[l][r], dp[l][k] + dp[k + 1][r]);
}
}
}
for (int i = 0; i < n; i ++) {
f[i] = dp[0][i];
if(s1[i] == s2[i]) {
if(i) f[i] = f[i - 1];
else f[i] = 0;
}
for (int k = 0; k < i; k ++) {
f[i] = min(f[i], f[k] + dp[k + 1][i]);
}
}
cout << f[n - 1] << "\n";
}
return (0-0);
}