题意:给定一个长度为n,m的只包含’0’~'9’的字符串s[n],t[m],问s[n]中有多少个子序列的大小(不含前导0)比t[m]大。
本题要分类讨论,无分类讨论的话,会复杂很多。
1.s[n]中子序列长度大于m的情况,此时可以枚举开头位置和 子序列长度,套用组合公式计算;
2.s[n]中子序列长度等于m,且大于t[m]的情况,可以用DP解决。
如果能想到分类讨论,这道题就简单很多了,不过DP的时候还需注意一下状态的定义。
状态:dp[i,j]表示s[i ~ n]中长度为j的且大于t[j ~ m]的子序列的个数。
状态转移:i ← n ~ 1, j ← m ~ 1
①j == m: dp[i,m]=dp[i+1,m]+(s[i]>t[m]);
②j ← m-1 ~ 1:
dp[i,j]=dp[i+1,j];
if(s[i]>t[j]) dp[i,j]+=c[n-i,m-j];
else if(s[i] == t[j]) dp[i,j]+=dp[i+1,j+1].
DP循环是从后面开始的,一开始我是从前面开始,但是WA了,看了几个别人的AC代码+标程后,再改从后面开始。其实本题涉及到组合数的情形,从后面开始也会比较简单一些。
代码如下:
#include<cstdio>
#include<iostream>
using namespace std;
typedef long long ll;
const int maxn=3e3+6;
const int mod=998244353;
ll dp[maxn][maxn],c[maxn][maxn];
char s[maxn],t[maxn];
int main()
{
int T;
scanf("%d",&T);
for(int i=0;i<=3000;++i)
c[i][0]=c[i][i]=1;
for(int i=2;i<=3000;++i)
for(int j=1;j<i;++j)
c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
while(T--)
{
int n,m;
scanf("%d%d",&n,&m);
scanf("%s%s",s+1,t+1);
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j)
dp[i][j]=0;
dp[n][m]= t[m]<s[n]? 1:0;
for(int i=n-1;i>=1;--i)
dp[i][m]=dp[i+1][m]+(s[i]>t[m]);
for(int j=m-1;j>=1;--j)
{
for(int i=n-m+j;i>=1;--i)
{
dp[i][j]=dp[i+1][j]%mod;
if(s[i]>t[j])
dp[i][j]=(dp[i][j]+c[n-i][m-j])%mod;
else if(s[i]==t[j])
dp[i][j]=(dp[i][j]+dp[i+1][j+1])%mod;
}
}
ll ans=0;
/*for(int j=m;j>=1;j--)
{
for(int i=n;i>=1;i--)
cout<<dp[i][j]<<' ';cout<<endl;
}for(int j=0;j<m;j++)
{
for(int i=0;i<n;i++)
cout<<dp[i][j].b<<' ';cout<<endl;
}*/
for(int i=1;i<=n-m;++i)
{
if(s[i]=='0') continue;
for(int j=m;j<=n-i;++j)
ans=(ans+c[n-i][j])%mod;
}//cout<<ans<<' '<<dp[1][1]<<endl;
printf("%lld\n",(ans+dp[1][1])%mod);
}
return 0;
}