传送门:HDU-6078
题意:有2个序列A和B,要从A,B中选子序列出来组成“小-大-小”这样的序列,且A,B对应的位置要相等,问有多少种选取方法
题解:dp+树状数组
设f[x][y][k]为当前A数组枚举到第x个,B数组枚举到第y个,起伏状态为k(0/1)时的方案数,考虑到用普通的dp转移会达到O(n^4),可以用二维树状数组进行维护,由于第一维具有递增的特性,因此只要维护第二维的下标和值,设C[i][j]为下标小于等于i,值小于等于j的前缀和
f[x][y][0]=t[1].sum(y-1,b[y]-1)
f[x][y][1]=t[0].sum(y-1,mx)-t[0].sum(y-1,b[y])
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<stdlib.h>
#include<math.h>
#include<string.h>
#include<set>
#include<vector>
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define first x
#define second y
#define eps 1e-5
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int inf = 0x3f3f3f3f;
const int MX = 2e3 + 5;
const LL mod = 998244353;
int a[MX], b[MX], vis[MX], n, m;
struct BTree {
LL A[MX][MX];
int mx;
void init() {
mx = 0;
memset(A, 0, sizeof(A));
}
inline int lowbit(int x) {
return x & (-x);
}
void add(int x, int y, int d) {
for (int i = x; i <= mx; i += lowbit(i))
for (int j = y; j <= mx; j += lowbit(j))
A[i][j] += d;
}
LL sum(int x, int y) {
LL ret = 0;
for (int i = x; i; i -= lowbit(i))
for (int j = y; j; j -= lowbit(j))
ret = (ret + A[i][j]) % mod;
return ret;
}
} t[2];
void pre_solve() {
int sz = 0;
memset(vis, 0, sizeof(vis));
for (int i = 1; i <= n; i++) vis[a[i]] = 1;
for (int i = 1; i <= m; i++) if (vis[b[i]]) b[++sz] = b[i];
m = sz;
sz = 0;
memset(vis, 0, sizeof(vis));
for (int i = 1; i <= m; i++) vis[b[i]] = 1;
for (int i = 1; i <= n; i++) if (vis[a[i]]) a[++sz] = a[i];
n = sz;
t[0].init();
t[1].init();
t[0].mx = t[1].mx = m + 1;
for (int i = 1; i <= n; i++) t[0].mx = max(t[0].mx, a[i] + 1);
t[1].mx = t[0].mx;
for (int i = n; i > 0; i--) a[i + 1] = a[i];
for (int i = m; i > 0; i--) b[i + 1] = b[i];
n++; m++;
}
LL f[MX][MX][2];
int main() {
//freopen("in.txt", "r", stdin);
int T;
scanf("%d", &T);
while (T--) {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= m; i++) scanf("%d", &b[i]);
pre_solve();
t[1].add(1, t[1].mx, 1);
LL ans = 0;
for (int i = 2; i <= n; i++) {
for (int j = 2; j <= m; j++) {
if (a[i] != b[j]) continue;
f[i][j][0] = (t[1].sum(j - 1, t[1].mx) - t[1].sum(j - 1, b[j]) + mod) % mod;
f[i][j][1] = t[0].sum(j - 1, b[j] - 1);
ans = (ans + f[i][j][0] + f[i][j][1]) % mod;
t[0].add(j, b[j], f[i][j][0]);
t[1].add(j, b[j], f[i][j][1]);
}
}
printf("%lld\n", ans);
}
return 0;
}