题目链接:
https://ac.nowcoder.com/acm/contest/5667/G
思路与官方题解相同,只是描述方式不同。
设 d p i , j dp_{i,j} dpi,j表示子序列 [ A i , A i + m − j ] [A_i,A_{i+m-j}] [Ai,Ai+m−j]与子序列 [ B j , B m ] [B_j,B_m] [Bj,Bm]是否匹配,即 ∀ k ∈ [ 0 , m − j ] , A i + k ≥ B j + k \forall k \in [0,m-j],A_{i+k}\ge B_{j+k} ∀k∈[0,m−j],Ai+k≥Bj+k匹配则为 1 1 1,不匹配则为 0 0 0。可以想到答案就是 j = 1 j=1 j=1时子序列 [ A i , A i + m − 1 ] [A_i,A_{i+m-1}] [Ai,Ai+m−1]与子序列 [ B 1 , B m ] [B_1,B_m] [B1,Bm]匹配的总数,即 a n s = ∑ i = 1 n − m + 1 d p i , 1 ans=\sum_{i=1}^{n-m+1}{dp_{i,1}} ans=i=1∑n−m+1dpi,1
考虑 d p dp dp的转移方程 d p i , j = d p i + 1 , j + 1 & ( A i ≥ B j ) dp_{i,j}=dp_{i+1,j+1}\&(A_i\ge B_j) dpi,j=dpi+1,j+1&(Ai≥Bj)即若想子序列 [ A i , A i + m − j ] [A_i,A_{i+m-j}] [Ai,Ai+m−j]与子序列 [ B j , B m ] [B_j,B_m] [Bj,Bm]匹配,则需要子序列 [ A i + 1 , A i + m − j ] [A_{i+1},A_{i+m-j}] [Ai+1,Ai+m−j]与子序列 [ B j + 1 , B m ] [B_{j+1},B_m] [Bj+1,Bm]匹配,且 A i ≥ B j A_i\ge B_j Ai≥Bj
可以想到,这个转移的计算复杂度是 O ( n m ) O(nm) O(nm),约 6 e 9 6e9 6e9,超时,也超空间。
由于所有 d p dp dp值都是0或1,可以采用bitset优化,这样复杂度就变为 O ( n m w ) O(\frac{nm}{w}) O(wnm), w = 32 或 64 w=32或64 w=32或64,处于可以接受的范围。
设
b
i
t
d
p
i
[
j
]
=
d
p
i
,
j
bitdp_i[j]=dp_{i,j}
bitdpi[j]=dpi,j,这样有关
j
j
j的操作,都可以用bitset优化。可以用右移将
j
j
j向前移动一位,使得
b
i
t
d
p
i
+
1
[
j
+
1
]
bitdp_{i+1}[j+1]
bitdpi+1[j+1]转移到
b
i
t
d
p
i
[
j
]
bitdp_i[j]
bitdpi[j]。对于
A
i
≥
B
j
A_i\ge B_j
Ai≥Bj的判断,可以预处理一个长度为
m
m
m的bitset表
S
i
S_i
Si,
S
i
[
j
]
=
(
A
i
≥
B
j
)
S_i[j]=(A_i\ge B_j)
Si[j]=(Ai≥Bj),这样,转移式就优化为
b
i
t
d
p
i
=
(
b
i
t
d
p
i
+
1
>
>
1
∣
(
1
<
<
m
)
)
&
S
i
bitdp_i=(bitdp_{i+1}>>1|(1<<m))\&S_i
bitdpi=(bitdpi+1>>1∣(1<<m))&Si
计算
S
i
S_i
Si时,把
m
m
m个
B
j
B_j
Bj排序后,那么每个
A
i
A_i
Ai只会插在这些
B
j
B_j
Bj中间或两边,因此最多只有
m
+
1
m+1
m+1个
S
i
S_i
Si(从全是
0
0
0的情况,到全是
1
1
1的情况)。第
j
j
j种
S
i
S_i
Si,就是在前一种
S
i
S_i
Si的基础上在第
j
j
j大的
B
j
B_j
Bj位置上从
0
0
0变为
1
1
1。这样就可以在
O
(
m
2
w
)
O(\frac{m^2}{w})
O(wm2)的复杂度内预处理完
S
i
S_i
Si。
代码如下(不是我写的):
#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e5+10,MAXM=4e4+10;
int n,m,ans,a[MAXN],b[MAXM],pos[MAXN];
bitset<MAXM> s[MAXM],cur,w;
bool cmp(int x,int y){return b[x]<b[y];}
int main()
{
scanf("%d%d",&n,&m);w[m]=1;
for(int i=1;i<=n;i++) scanf("%d",a+i);
for(int i=1;i<=m;i++) scanf("%d",b+i),pos[i]=i;
sort(pos+1,pos+1+m,cmp);
sort(b+1,b+1+m);
for(int i=1;i<=m;i++) s[i]=s[i-1],s[i][pos[i]]=1;
for(int i=n;i>=1;i--)
{
int d=upper_bound(b+1,b+1+m,a[i])-b-1;
cur=(((cur>>1)|w)&s[d]);
if(cur[1]) ans++;
}
printf("%d\n",ans);
}