题目链接: http://acm.hdu.edu.cn/showproblem.php?pid=6093
题意: 定义一个好数满足存在一个d(d
≥
2), 使其在d进制的表示下, 每一位正好组成一个0 ~ d-1的排列, 要求区间[l, r]上有多少个好数(mod 998244353)。
(多组数据,
t≤20
,
1≤l≤r≤105000
)
思路:首先询问区间[l, r]可以拆成两个[1, n]的前缀询问。 对于一个d进制下的好数K, 满足 nn−1<K<nn , 由于 nn<(n+1)n , 所以每个好数仅对应唯一一种d进制表示。考虑询问一段前缀区间[1, n], 若n用d进制表示后位数大于d, 则d进制的所有好数均包含在内, 即 n!−(n−1)! 。 若n用d进制表示后位数等于d,考虑数位dp可以求出小于等于n的d进制下的好数。
由于大整数进制转换的复杂度限制, 直接枚举d会TLE, 即使是二分d也无法避免。 要找到临界点, 可以用对数估计这个临界点, 即一个最大的d, 满足 (d−1)log(d)≤log(n) , 这个d也是关于n单调的, 可以预处理每个10^k的临界d。 临界点之前的答案可以用前缀和预处理得到。
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#define ll long long
const int N = 5010;
const int mo = 998244353;
using namespace std;
char str[N]; int len;
ll ans, ni[N], sum[N];
int num[3][N], t[3], L[N];
bool check(int d){
for (int i = 0; i < len; i ++) num[0][i] = str[i]; t[0] = len;
int cur = 1, pre = 0; t[2] = 0;
while (t[pre]){
int now = 0;
for (int i = t[pre] - 1; i >= 0; i --){
now = now * 10 + num[pre][i];
num[cur][i] = now / d; now %= d;
}
t[cur] = t[pre];
while (!num[cur][t[cur] - 1] && t[cur] > 0) t[cur] --;
num[2][t[2] ++] = now;
swap(cur, pre);
if (t[2] > d) return 1;
}
if (t[2] == d) return 1;
return 0;
}
int tmp[N], n;
void copy(){
n = t[2];
for (int i = 0; i < n; i ++) tmp[i] = num[2][i];
}
int rem[N]; bool used[N];
ll dp(int d){
if (n > d) return ni[d] - ni[d - 1];
int tot = d;
for (int i = d - 1; i >= 0; i --)
rem[i] = i;
memset(used, 0, sizeof(used));
ll ret = 0; bool ok = 1;
for (int i = d - 1; i >= 0; i --){
(ret += 1LL * (rem[tmp[i]] - (i == d - 1 && !used[0])) * ni[tot - 1]) %= mo ;
if (used[tmp[i]]) {ok = 0; break;}
tot --; used[tmp[i]] = 1;
for (int j = tmp[i] + 1; j <= d - 1; j ++) rem[j] --;
}
return ret + ok;
}
ll solve(){
if (len == 1 && str[0] == 0) return 0;
if (len == 1 && str[0] == 1) return 0;
int l = L[len - 1], d;
while (check(l)) d = l, copy(), l ++;
return (sum[d - 1] + dp(d)) % mo;
}
int main(){
ni[0] = 1;
for (int i = 1; i < N; i ++) ni[i] = 1LL * ni[i - 1] * i % mo;
sum[1] = 0;
for (int i = 2; i < N; i ++) sum[i] = (sum[i - 1] + ni[i] - ni[i - 1]) % mo;
L[0] = 2;
L[1] = 3;
for (int i = 2; i < N; i ++){
L[i] = L[i - 1];
while ((L[i]) * log(L[i] + 1) < i * log(10)) L[i] ++;
}
int t;
for (scanf("%d\n", &t); t --; ){
ans = 0;
scanf("%s\n", str); len = strlen(str);
for (int i = 0; i < len - i - 1; i ++)
swap(str[i], str[len - i - 1]);
for (int i = 0; i < len; i ++) str[i] -= '0';
str[0] --;
for (int i = 0; i < len; i ++)
if (str[i] < 0) str[i] += 10, str[i + 1] --;
else break;
while (len > 1 && !str[len - 1]) len --;
ans -= solve();
scanf("%s\n", str); len = strlen(str);
for (int i = 0; i < len - i - 1; i ++)
swap(str[i], str[len - i - 1]);
for (int i = 0; i < len; i ++) str[i] -= '0';
ans += solve();
printf("%lld\n", ((ans % mo) + mo) % mo);
}
return 0;
}