简略题意:给出等长的AB串,问有多少A串的重排列得到的串字典序大于A串,小于B串。
dp[i][0/1][0/1]代表当前处理到第i个位置,之前的选择有没有贴在上边界,有没有贴在下边界,考虑贴近上边界,下边界,或者不贴近边界,得到若干转移。
唯一需要注意的地方就是对于直接转移到dp[i][0][0]的状态,直接计算有重复元素的重排列即可,可以通过直接处理除数扣掉26的常数。
#define others
#ifdef poj
#include <iostream>
#include <cstring>
#include <cmath>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <string>
#endif // poj
#ifdef others
#include <bits/stdc++.h>
#endif // others
//#define file
#define all(x) x.begin(), x.end()
using namespace std;
const double eps = 1e-8;
int dcmp(double x) { if(fabs(x)<=eps) return 0; return (x>0)?1:-1;};
typedef long long LL;
/*
题意:
问a字符串有多少种排列使得得到的字符串c的字典序大于a,小于b。
问方案数。
*/
namespace solver {
const LL mod = 1e9+7;
LL n;
char a[1100000], b[1100000];
LL v[27], s[27];
LL fac[1100000], dp[1100000][2][2], Inv[1100000];
void init() {
fac[0] = 1;
for(LL i = 1; i < 1100000; i++)
fac[i] = fac[i-1] * i % mod;
memset(dp, -1, sizeof dp);
}
LL Pow(LL a, LL b) {
LL res = 1;
while(b) {
if(b & 1) {
res *= a;
res %= mod;
}
b >>= 1;
a *= a;a %= mod;
}
return res;
}
LL dfs(LL i, LL up, LL down) {
if(~dp[i][up][down]) return dp[i][up][down];
if(i > n) return 0;
LL ans = 0;
LL sum = n - i + 1;
if(up && down && v[b[i]] && a[i] == b[i]) {
v[a[i]] --;
ans += dfs(i+1, 1, 1);
if(ans >= mod) ans %= mod;
v[a[i]] ++;
return dp[i][up][down] = ans;
}
if(up && v[a[i]]) {
v[a[i]] --;
ans += dfs(i+1, 1, 0);
if(ans >= mod) ans %= mod;
v[a[i]] ++;
}
if(down && v[b[i]]) {
v[b[i]] --;
ans += dfs(i+1, 0, 1);
if(ans >= mod) ans %= mod;
v[b[i]] ++;
}
LL tmp = 1;
for(int k = 1; k <= 26; k++) {
tmp *= fac[v[k]];
tmp %= mod;
}
LL tmp_inv = Pow(tmp, mod-2);
int head = (down)?b[i]+1:1, tail = (up)?a[i]-1:26;
for(int j = head; j <= tail; j++) {
if(v[j] == 0) continue;
v[j]--;
ans += fac[sum - 1] * tmp_inv % mod * (v[j]+1)%mod;
if(ans >= mod) ans %= mod;
v[j]++;
}
return dp[i][up][down] = ans;
}
void solve() {
init();
scanf("%s%s", b + 1, a + 1);
n = strlen(a + 1);
int k = 0;
for(LL i = 1; i <= n; i++) {
b[i] = b[i] - 'a' + 1;
if(v[b[i]] == 0) k++;
v[b[i]]++;
}
for(int i = 1; i <= n; i++)
a[i] = a[i] - 'a' + 1;
printf("%lld\n", dfs(1, 1, 1));
}
}
int main() {
#ifdef file
freopen("gangsters.in", "r", stdin);
freopen("gangsters.out", "w", stdout);
#endif // file
solver::solve();
return 0;
}