这个算法是对O(N^3)进行了一个优化,如果O(N^3)你不会,OK,点这里
思路:
我们的转移方程与O(N^3)一模一样,但是还有可以优化的空间
先把O(N^3)的方程列出来
for (int i = 1; i <= n; i++){
for (int j = 1; j <= n; j++){
if(a[i] == b[j]){
for (int k = 0; k < j; k++){
/*
这个位置不从 k = 1开始的原因是
k = 0的情况就是前面没有构成的情况,不能省略
*/
if(b[k] < b[j])dp[i][j] = max(dp[i][j], dp[i - 1][k] + 1);
}
}
else dp[i][j] = dp[i - 1][j];
}
}
观察发现,当内层循环 j 的时候,外层的 i 没有发生变化,但是遇到a[ i ] == b[ j ]时,需要枚举 j 前面最大的dp[ i - 1] [ k ],因此可以用一个变量临时的记录前面的最大的dp[ i - 1] [ j ],下次用到的时候直接计算,计算过程中一边更新就可以省略掉一个循环(即使下面需要用到b[ k ] < b[ j ]但是不要忘了,a[ i ] == b[ j ])
定义变量 val 初始为0,如果一开始b[ 0 ]就满足b [ 0 ] < a[ i ],val 的初始值就应该是dp[ i - 1][ 0 ](前i - 1个以0结尾的最大长度)
遇到a[ i ] == b[ j ]时,判断当前b[ j ]是否大于 val,如果大于,dp[ i ] [ j ] = val + 1,否则dp[ i ] [ j ] = dp[ i - 1] [ j ];
val 其实是 j 前面小于 b[ j ](即a[ i ])最大的一个,如果满足条件 b[ j ] < a[ i ],则更新 val = max(val, dp[i - 1[ j ]);
状态转移代码:
for (int i = 1; i <= n; i++){
int val = 0;
if(b[0] < a[i])val = dp[i - 1][0];
for (int j = 1; j <= n; j++){
if(a[i] == b[j])dp[i][j] = val + 1;
else dp[i][j] = dp[i - 1][j];
if(b[j] < a[i])val = max(val, dp[i - 1][j]);
}
}
完整代码
#include <bits/stdc++.h>
using namespace std;
const int N = 3010;
int a[N], b[N], dp[N][N], n;
template <class T>
bool read(T & a){
a = 0;
int flag = 0;
char ch;
if((ch = getchar()) == '-'){
flag = 1;
}
else if(ch >= '0' && ch <= '9'){
a = a * 10 + ch - '0';
}
while ((ch = getchar()) >= '0' && ch <= '9'){
a = a * 10 + ch - '0';
}
if(flag)a = -a;
return true;
}
template <class T, class ... R>
bool read(T & a, R & ... b){
if(!read(a))return 0;
read(b...);
}
template <class T>
bool out(T a){
if(a < 0)putchar('-');
if(a >= 10)out(a / 10);
putchar(a % 10 + '0');
return true;
}
template <class T, class ... R>
bool out(T a, R ... b){
if(!out(a))return 0;
out(b...);
}
int main()
{
read(n);
for (int i = 1; i <= n; i++)read(a[i]);
for (int i = 1; i <= n; i++)read(b[i]);
for (int i = 1; i <= n; i++){
int val = 0;
if(b[0] < a[i])val = dp[i - 1][0];
for (int j = 1; j <= n; j++){
if(a[i] == b[j])dp[i][j] = val + 1;
else dp[i][j] = dp[i - 1][j];
if(b[j] < a[i])val = max(val, dp[i - 1][j]);
}
}
int res = 0;
for (int i = 1; i <= n; i++)res = max(res, dp[n][i]);
printf("%d\n", res);
return 0;
}