[题解] Double Strings 动态规划+组合数学
题目链接
题意大概就是:
有两个字符串
A
,
B
A,B
A,B,要从中选出两个长度相同的子序列,满足:
这两个子序列前面
k
(
k
>
=
1
)
k(k>=1)
k(k>=1)个字符相同,并且第
k
+
1
k+1
k+1个字符满足
a
k
+
1
<
b
k
+
1
a_{k+1}<b_{k+1}
ak+1<bk+1,后面的字符任意,求所有这样的子序列个数。
首先对于前面
k
k
k个字符相同,可以用
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]表示
A
A
A的前
i
i
i个、
B
B
B的前
j
j
j个字符的相同子序列个数。考虑最后两个字符,如果
a
i
≠
b
j
a_i\ne b_j
ai=bj,那么根据容斥原理,所有子序列中有不以
a
i
a_i
ai结尾的(即
d
p
[
i
−
1
]
[
j
]
dp[i-1][j]
dp[i−1][j])、不以
b
j
b_j
bj结尾的(即
d
p
[
i
]
[
j
−
1
]
dp[i][j-1]
dp[i][j−1]),以及不以
a
i
,
b
j
a_i,b_j
ai,bj结尾的(即
d
p
[
i
−
1
]
[
j
−
1
]
dp[i-1][j-1]
dp[i−1][j−1])
可以得到:
d
p
[
i
]
[
j
]
=
d
p
[
i
−
1
]
[
j
]
+
d
p
[
i
]
[
j
−
1
]
−
d
p
[
i
−
1
]
[
j
−
1
]
dp[i][j] = dp[i-1][j]+dp[i][j-1]-dp[i-1][j-1]
dp[i][j]=dp[i−1][j]+dp[i][j−1]−dp[i−1][j−1]
对于相等的,那么多出一种情况,此时要加上
d
p
[
i
−
1
]
[
j
−
1
]
dp[i-1][j-1]
dp[i−1][j−1]。
其次是对于后面的任意子序列,应该怎么算?
设去掉
a
k
+
1
,
b
k
+
1
a_{k+1},b_{k+1}
ak+1,bk+1后,
A
A
A还剩
n
n
n个字符,
B
B
B还剩
m
m
m个字符,(不妨设
n
≤
m
n\le m
n≤m),那么总方案数为:
∑
i
=
0
n
(
n
i
)
⋅
(
m
i
)
=
∑
i
=
0
n
(
n
n
−
i
)
⋅
(
m
i
)
=
(
n
+
m
n
)
\sum_{i=0}^{n} \binom{n}{i}\cdot \binom{m}{i} =\sum_{i=0}^{n} \binom{n}{n-i}\cdot \binom{m}{i}=\binom{n+m}{n}
i=0∑n(in)⋅(im)=i=0∑n(n−in)⋅(im)=(nn+m)
证明非常显然。
由此我们可以得到一个推论: ∑ i = 0 n ( n i ) 2 = ( 2 n n ) \sum_{i=0}^{n}\binom{n}{i}^2=\binom{2n}{n} i=0∑n(in)2=(n2n)
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 5e3 + 10;
const ll mod = 1e9 + 7;
ll dp[maxn][maxn];
char a[maxn],b[maxn];
int l1, l2;
ll fac[2*maxn],invfac[2*maxn],inv[2*maxn];
void init(){
fac[0] = 1;
for(int i = 1; i < 2*maxn; i++) fac[i] = i*fac[i-1]%mod;
inv[1] = 1;
invfac[0] = invfac[1] = 1;
for(int i = 2; i < 2*maxn; i++){
inv[i] = (mod - mod/i)*inv[mod%i]%mod;
invfac[i] = invfac[i-1]*inv[i]%mod;
}
}
inline ll C(ll a, ll b){
return ((fac[a]*invfac[b])%mod)*invfac[a-b]%mod;
}
int main()
{
init();
scanf("%s%s",a+1,b+1);
l1 = strlen(a+1);
l2 = strlen(b+1);
for(int i = 0; i <= l1; i++) dp[i][0] = 1;
for(int i = 0; i <= l2; i++) dp[0][i] = 1;
ll ans = 0;
for(int i = 1; i <= l1; i++){
for(int j = 1; j <= l2; j++){
dp[i][j] = (dp[i-1][j] + dp[i][j-1] - dp[i-1][j-1])%mod;
if(a[i] == b[j]) dp[i][j] = (dp[i][j] + dp[i-1][j-1])%mod;
if(a[i] < b[j]) ans = (ans + dp[i-1][j-1]*C(l1-i+l2-j,l1-i)%mod)%mod;
}
}
printf("%lld\n",(ans%mod+mod)%mod);
return 0;
}