题意:给一个字符串s,一个字符串t,求s全部字符都选的不同组合中字典序大于s小于t的字符串有多少种。
首先,我们可以将问题转化为小于t的方案数减小于s的方案数,那么我们模仿数位dp,可以得到dfs(pos)表示考虑当前字符串到pos位的方案数,设当前字符串为s,那么有两种转移,
1.填一个比s[pos]小的字符,那么后面填任何字符都随意了,那么答案就是比s[pos]小的字符个数*后面剩下位置的长度的阶乘,
2.填一个与s[pos]一样大的字符,那么后面填什么仍然与s的下一位是否填的一样大有关,那么递归实现即可。
最后由于求的是不同组合的个数,对于同样的字符i共有cnt[i]的阶乘种不同方案,所以最后要将答案除以cnt[i]的阶乘。
下附AC代码。
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define maxn 1000005
using namespace std;
typedef long long ll;
const ll mod=1e9+7;
int n,m;
ll fac[maxn];
char s[maxn],t[maxn];
int cnt1[500],cnt2[500];
ll quickpow(ll p,ll k)
{
ll ans=1;
while(k)
{
if(k&1)
ans=(ans*p)%mod;
p=(p*p)%mod;
k>>=1;
}
return ans;
}
ll getsum(int now)
{
ll cnt=0;
for(int i='a';i<now;i++)
cnt+=cnt1[i];
return cnt;
}
ll dfs1(ll now)
{
if(now==m+1)
return 0;
ll ans=0;
if(cnt1[t[now]])
{
cnt1[t[now]]--;
ll temp=dfs1(now+1);
cnt1[t[now]]++;
ans=(ans+(cnt1[t[now]]*temp))%mod;
}
ans=(ans+getsum(t[now])*fac[m-now])%mod;
return ans;
}
ll dfs2(ll now)
{
if(now==n+1)
return 0;
ll ans=0;
if(cnt1[s[now]])
{
cnt1[s[now]]--;
ll temp=dfs2(now+1);
cnt1[s[now]]++;
ans=(ans+(cnt1[s[now]]*temp))%mod;
}
ans=(ans+getsum(s[now])*fac[n-now])%mod;
return ans;
}
int main()
{
scanf("%s",s+1);n=strlen(s+1);
scanf("%s",t+1);m=strlen(t+1);
for(int i=1;i<=n;i++) cnt1[s[i]]++;
for(int i=1;i<=m;i++) cnt2[t[i]]++;
fac[0]=1;
for(int i=1;i<=n;i++) fac[i]=(fac[i-1]*i)%mod;
ll ans=dfs1(1)-dfs2(1);
for(int i='a';i<='z';i++)
if(cnt1[i])
ans=(ans*quickpow(fac[cnt1[i]],mod-2)%mod);
printf("%I64d\n",(ans-1+mod)%mod);
}