传送门:HDU6078
题意:给出两个序列A和B,让你找出两组等长下标序列f和g,使得对于每个i,Afi == Bgi ,并且Afi序列为波浪序列。 问能找出多少种这样的下标序列。
波浪序列定义: a1<a2>a3<a4>a5<a6
思路:先贴上官方题解:
设fi,j,k表示仅考虑a[1..i]与b[1..j],选择的两个子序列结尾分别是ai和bj,且上升下降状态是k 时的方案数,则fi,j,k=∑fx,y,1−k,其中x<i,y<j。暴力转移的时间复杂度为O(n4),不能接受。
考虑将枚举决策点x,y的过程也DP掉。设gi,y,k表示从某个fx,y,k作为决策点出发,当前要更新的是i的方案数,hi,j,k表示从某个fx,y,k作为决策点出发,已经经历了g的枚举,当前要更新的是j的方案数。转移则是要么开始更新,要么将i或者j继续枚举到i+1以及j+1。因为每次只有一个变量在动,因此另一个变量恰好可以表示上一个位置的值,可以很方便地判断是否满足上升和下降。
官方题解只看懂了暴力之前的部分。。于是百度了一发dalao们的做法,发现还是挺简洁易懂的:dp[i][j][0]表示以a[i]和b[j]为公共序列结尾且为波谷的情况总和。
dp[i][j][1]则表示波峰的情况总和。
sum[i][j][0]表示∑(dp[k][j][0] | 1<=k<=j-1)。 //目前以b[j]为波谷结尾的‘总’匹配数
sum[i][j][1]则表示∑(dp[k][j][1] | 1<=k<=j-1)。 //目前以b[j]为波峰结尾的‘总’匹配数
那么对于每个a[i],只有存在j使得b[j]==a[i]时,
dp[i][j][0]等于∑(sum[i-1][k][1] | 1<=k<=j-1&&b[k]>a[i])+1,//代码中用cnt1动态的求这部分和
dp[i][j][1]等于∑(sum[i-1][k][0] | 1<=k<=j-1&&b[k]<=a[i]-1). //代码中用cnt0动态的求这部分和
以上转自:点击打开链接
总的来说就是用一个类似前缀和的sum数组优化掉了一个n的复杂度(内层枚举1...i的过程),又用两个变量动态求sum数组的和又优化掉了一个n的复杂度(内层枚举1...j的过程),让总体复杂度从n^4变成了n^2.
因为总的dp过程只与i和i-1相关,因此所有数组都可以优化掉一维。
代码:
#include<bits/stdc++.h>
#define ll long long
#define pb push_back
#define fi first
#define se second
#define pi acos(-1)
#define inf 0x3f3f3f3f
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define rep(i,x,n) for(int i=x;i<n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
using namespace std;
typedef pair<int,int>P;
const int MAXN = 2010;
const int mod = 998244353;
int gcd(int a,int b){return b?gcd(b,a%b):a;}
int a[MAXN], b[MAXN];
int dp[MAXN][2];
int sum[MAXN][2];
int main()
{
int T;
cin >> T;
while(T--)
{
int n, m;
cin >> n >> m;
for(int i = 1; i <= n; i++)
scanf("%d", a + i);
for(int i = 1; i <= m; i++)
scanf("%d", b + i);
memset(dp, 0, sizeof(dp));
memset(sum, 0, sizeof(sum));
int ans = 0;
for(int i = 1; i <= n; i++)
{
int cnt0 = 0, cnt1 = 1;//第一个数字只能为波谷,说明第‘0’个数字为波峰
for(int j = 1; j <= m; j++)
{
dp[j][0] = dp[j][1] = 0;
if(a[i] == b[j])
{
dp[j][0] = cnt1;
dp[j][1] = cnt0;
}
else if(a[i] > b[j])
cnt0 = (cnt0 + sum[j][0]) % mod;
else
cnt1 = (cnt1 + sum[j][1]) % mod;
sum[j][0] = (sum[j][0] + dp[j][0]) % mod;
sum[j][1] = (sum[j][1] + dp[j][1]) % mod;
}
}
for(int j = 1; j <= m; j++)
{
ans = (ans + sum[j][0]) % mod;
ans = (ans + sum[j][1]) % mod;
}
cout << ans << endl;
}
return 0;
}