写在前边:如果大佬发现了我姿势有问题,请帮我提出建议,十分感谢。
题意:给出一个有n(<=2000)个数字的序列 a(ai <=2000) 再给出一个有m(m<=2000)个数字的序列 b (bi<=2000) ,定义波浪序列为:x1<x2>x3<x4……(注意第一次必须是上升,不能是下降,也就是说第一项必须是波谷)。现在要求找到一个严格单调递增的序列 f:f1,f2,……fk。以及相对应的严格单调递增的序列g:g1,g2,……gk。(k>=1)使得每个a_fi = b_gi,同时满足a_f1,a_f2,a_f3……a_fk为波浪序列。求不同的fg映射有多少种选取方式。
题解:稍微翻译一下题意:a,b中分别从前向后选取k个数字。然后相对应的 a 中选择的每个位置的数字要和 b 中选择的对应位次的数字相同。(当然如果a数组出现过x,而b没有出现过x,显然x不可能被选取),而 f 、g 则是相对应的下标。要满足选取出来的这个数字序列是一个波浪序列。显然波浪序列中的数字分成两种:波峰和波谷。
总体来说,这个题就是a、b数组之间的匹配问题,同时满足是一个波浪序列。
显然的,我们构造一个三维的dp数组:dp[ i ][ j ][ k ](1<=i<=n , 1<=j<=m , k = 0 or 1),dp[ i ][ j ][ 0 ]表示让 ai 和 bj 匹配,并且这个数字做为波浪序列的最后一项,且为波谷,得到的方案总数,对应的dp[ i ][ j ][ 1 ]表示做波峰的那种情况。
转移方程:dp[ i ][ j ][ 0 ] = ∑dp[ k ][ l ][ 1 ](1<=k<i , 1<=l<j , a_k>ai)。意思是如果让ai做波谷,那么枚举前面那个波峰的可能位置,在那个波峰后边加上ai就可以成为一个新的波浪序列。所以需要二维枚举前一个波峰的位置。同时,任何一个匹配成功的位置 都可以作为第一项,也就是一个波谷。
对应的:dp[ i ][ j ][ 1 ] = ∑dp[ k ][ l ][ 0 ](1<=k<i , 1<=l<j , a_k<a_i),就是枚举前面一个波谷的位置。同时注意到第一项不可能是波峰,所以这个转移方程是唯一的,不需要做上边的讨论。
这是一个n^4 的复杂度。接下来做一些优化:
显然,必须有a_i=b_j 的时候,才能计算dp[ i ][ j ] 否则是没有意义的。而数字最大是2000。于是可以简单的进行如下操作:
vector< int > nums [2001];
nums[ a[ i ] ].push_back( i )。
这样我们枚举b_j 的时候,直接找到nums[ b_j ]就知道需要计算哪些a_i了。
这可以让最外层n^2的枚举变成m+n,因为每个a,b都只被遍历一次。
这个优化不疼不痒,我们继续优化。
我们考虑同类的a,也就是a_i = a_(i-x)这样a序列中值相等的ai。
在计算 dp[ i-x ][ j ][ 0 ] 以及dp[ i ][ j ][ 0 ]的时候,根据上边的方程,我们要分别计算
∑dp[ k ][ l ][ 1 ] (1<=k<i-x , 1<=l<j , a_k<a_i-x)。
∑dp[ k ][ l ][ 1 ] (1<=k<i , 1<=l<j , a_k<a_i)。
那么显然我们第二次计算的时候,其实把第一次的东西又计算了一遍,于是我们得到:dp[ i ][ j ][ 0 ] = dp[ pre[ i] ][ j ][ 0 ] +∑dp[ k ][ l ][ 1 ] (pre[ i ]<=k<i , 1<=l<j , a_k<a_i)。
那么我们这下子只需要计算∑dp[ k ][ l ][ 1 ] (pre[ i ]<=k<i , 1<=l<j , a_k<a_i)。我们让第二维做外层的循环,那么这个∑就是之前得到的所有的dp[ X ][ 1..j-1][ 1 ]在PRE[ I ]<=X<I这段区间上,满足a_X<a_i条件的和。就是一个带条件的区间查询。显然可以想到用线段树做优化。
再考虑一下+1的问题,我们前边说过,波谷可以作为第一项。我们观察pre[ i ],如果pre[ i ] = 0说明这个ai是第一次出现,那么可以让他做波谷 于是在dp[ 0 ]上+1,而如果pre[ i ]!=0,那么我们上边的式子表明 dp[ i ][ 0 ] = dp[ pre[ i ] ][ 0 ] +delta;而dp[ pre [ i ] ][ 0 ]包含了一种方案是pre[ i ]位置的a做第一项,而这种方案对于dp[ i ][ 0 ]是非法的,那么这个1可以看作是让i位置做第一项,所以此时就不给dp[ i ][ 0 ]加一了。
而dp[ i ][ j ][ 1 ]的方程完全类似,就不写了。
tree1是在sumdp1上建立的线段树,sumdp1[ i ] 是所有计算过的dp[ i ][ 1..j-1 ][ 1 ]的和。
tree0同理。
算法流程:
for(j : 1……m)
for(i:nums[ b[ j ] ])
dp[ i ][ j ][ 0 ] = dp[ pre[ i ] ][ j ][ 0 ] + tree1.getGreaterSum(pre[ i ]+1,i-1,a[ i ]) (对tree1 的pre[ i ]+1到i-1的区间中满足a[ x ]比a[ i ]大的位置求和)。
if(pre[ i ] ==0)dp[ i ][ j ][ 0 ]++;
dp[ i ][ j ][ 1 ] = dp[ pre[ i ] ][ j ][ 1 ] + tree0.getLessSum(pre[ i ]+1,i-1,a[ i ]) (对tree0的pre[ i ]+1搭配i-1的区间中满足a[ x ]比a[ i ]小的位置求和)。
endFor
for(i:nums[ b[ j ] ])
tree1.add(i,dp[ i ][ j ][ 1 ])
tree0.add(i,dp[ i ][ j ][ 0 ])
endFor
endFor
注意内存必须要用两个for循环吗,因为在dp到 j阶段的时候,用到的必须是1……j-1的和,所以必须 j 所有状态都计算完成,才能去更新tree。
剩下的线段树操作不赘述。不懂可以看代码。还不懂就只能复习一下线段树姿势了。
复杂度估计:nlogn(build)+(m+n)logn。总的来说是O(Knlogn) K可能有8-10左右。
Code:
#include<bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
const int MAX = 2005;
const long long SKY = 998244353;
long long dp[MAX][MAX][2];
int la [MAX];
int pre[MAX];
int a[MAX],b[MAX];
int m,n,t;
vector<int> nums[MAX];
inline int read(){
char ch = getchar();
int re = 0;
while (ch>='0'&&ch<='9'){
re = re*10+ch-'0';
ch = getchar();
}
return re;
}
struct Seg_Tree{
int value[MAX<<2],lazy[MAX<<2],biggest[MAX<<2],smallest[MAX<<2];
void build(int x,int l,int r){
if (l==r){
biggest[x] =smallest[x]= a[l];
return;
}else{
int Mid = (l+r)>>1;
build (x<<1,l,Mid);
build (x<<1|1,Mid+1,r);
smallest[x] = min(smallest[x<<1],smallest[x<<1|1]);
biggest[x] = max(biggest[x<<1],biggest[x<<1|1]);
}
}
void clear(){
memset(value,0,sizeof(value));
memset(lazy,0,sizeof(lazy));
memset(biggest,0,sizeof(biggest));
memset(smallest,INF,sizeof(smallest));
}
void Down(int x){
int lc = x<<1,rc = x<<1|1;
value[lc]+=lazy[x];value[lc]%=SKY;
value[rc]+=lazy[x];value[rc]%=SKY;
lazy[lc]+=lazy[x];lazy[lc]%=SKY;
lazy[rc]+=lazy[x];lazy[rc]%=SKY;
lazy[x]= 0;
}
void Up(int x){
value[x]=(value[x<<1]+value[x<<1|1])%SKY;
}
void Update1(int x,int l,int r,int index,int delta){
// cout<<"Update1"<<x<<" "<<l<<" "<<r<<" "<<index<<" "<<delta<<" Value:"<<value[x]<<endl;
if (l==r){
value[x]+=delta;
value[x]%=SKY;
return;
}
if (lazy[x]){
Down(x);
}
int Mid = (l+r)>>1;
if (index<=Mid){
Update1(x<<1,l,Mid,index,delta);
}else{
Update1(x<<1|1,Mid+1,r,index,delta);
}
Up(x);
}
int getLess(int x,int l,int r,int L,int R,int limit){
// cout<<"getLessd"<<x<<" "<<l<<" "<<r<<" "<<L<<" "<<R<<" "<<limit<<" Value:"<<value[x]<<" Big:"<<biggest[x]<<" Small:"<<smallest[x]<<endl;
if (l>R||r<L){
return 0;
}else{
if (L<=l&&r<=R){
if (biggest[x]<limit){
return value[x];
}else if (smallest[x]>=limit){
return 0;
}else{
int Mid = (l+r)>>1;
return (getLess(x<<1,l,Mid,L,R,limit)+getLess(x<<1|1,Mid+1,r,L,R,limit))%SKY;
}
}else{
int Mid = (l+r)>>1;
return (getLess(x<<1,l,Mid,L,R,limit)+getLess(x<<1|1,Mid+1,r,L,R,limit))%SKY;
}
}
}
int getLessSum(int l,int r,int limit){
if (l>r){
return 0;
}
return getLess(1,0,n,l,r,limit);
}
int getGreater(int x,int l,int r,int L,int R,int limit){
if (l>R||r<L){
return 0;
}else{
if (L<=l&&r<=R){
if (smallest[x]>limit){
return value[x];
}else if (biggest[x]<=limit){
return 0;
}else{
int Mid = (l+r)>>1;
return (getGreater(x<<1,l,Mid,L,R,limit)+getGreater(x<<1|1,Mid+1,r,L,R,limit))%SKY;
}
}else{
int Mid = (l+r)>>1;
return (getGreater(x<<1,l,Mid,L,R,limit)+getGreater(x<<1|1,Mid+1,r,L,R,limit))%SKY;
}
}
}
int getGreaterSum(int l,int r,int limit){
if (l>r){
return 0;
}
return getGreater(1,0,n,l,r,limit);
}
}tree0,tree1;
void input(){
n = read();m = read();
for (int i = 1;i<=n;i++){
a[i] = read();
pre[i] = la[a[i]];
la[a[i]] = i;
nums[a[i]].push_back(i);
}
for (int i =1;i<=m;i++){
b[i] = read();
}
tree0.build(1,0,n);
tree1.build(1,0,n);
}
void init(){
memset(la,0,sizeof(la));
tree0.clear();
tree1.clear();
for (int i = 0;i<=2000;i++){
nums[i].clear();
}
}
void work(){
long long ans = 0;
for (int j = 1;j<=m;j++){
for (vector<int>::iterator it = nums[b[j]].begin();it!=nums[b[j]].end();it++){
int i = *it;
dp[i][j][1] = (dp[pre[i]][j][1]+tree0.getLessSum(pre[i]+1,i-1,a[i]))%SKY;
if (pre[i]!=0)
dp[i][j][0]= (dp[pre[i]][j][0]+tree1.getGreaterSum(pre[i]+1,i-1,a[i]))%SKY;
else
dp[i][j][0] = tree1.getGreaterSum(pre[i]+1,i-1,a[i])+1;
}
for (vector<int>::iterator it = nums[b[j]].begin();it!=nums[b[j]].end();it++){
int i = *it;
tree1.Update1(1,0,n,i,dp[i][j][1]);
tree0.Update1(1,0,n,i,dp[i][j][0]);
}
}
ans = tree1.getLessSum(1,n,10000)+tree0.getLessSum(1,n,10000);
ans%=SKY;
printf("%I64d\n",ans);
}
int main(){
t = read();
while (t--){
init();
input();
work();
}
}