【题目链接】
【思路要点】
- 考虑 subtask1 s u b t a s k 1 ,我们很容易可以得到一个动态规划的解法。
- 注意到行与行之间转移的卷积本质,我们可以用FFT快速计算出DP数组的某一行,可以通过 subtask2 s u b t a s k 2 。
- 原题中 N N 非常大,我们不可能求得DP数组的第行。
- 考虑多项式 x(x+1)(x+2)(x+3)…(x+p−1) x ( x + 1 ) ( x + 2 ) ( x + 3 ) … ( x + p − 1 ) ,在模质数 p p 意义下,应当等于。因为我们
打表发现这两个多项式拥有 p p 个相同的根()。- 因此令 A=x(x+1)(x+2)(x+3)…(x+p−1) A = x ( x + 1 ) ( x + 2 ) ( x + 3 ) … ( x + p − 1 ) ,令 a=⌊Np⌋,b=N−ap a = ⌊ N p ⌋ , b = N − a p 。
- 有 x(x+1)(x+2)(x+3)…(x+N)≡Aa∗x(x+1)(x+2)(x+3)...(x+b) (Mod p) x ( x + 1 ) ( x + 2 ) ( x + 3 ) … ( x + N ) ≡ A a ∗ x ( x + 1 ) ( x + 2 ) ( x + 3 ) . . . ( x + b ) ( M o d p ) 。
- 乘号后面的部分可以用分治FFT计算得到,并且其次数不超过 p−1 p − 1 。
- 乘号前面的部分相邻两个系数非零的项次数差为 p−1 p − 1 。
- 特殊处理 b=p−1 b = p − 1 的情况,对于其余部分,我们可以分别解决乘号前后的子问题,将它们的答案相乘得到答案。
- 乘号前面的部分相当于在询问所有 (ai)(0≤i≤a) ( a i ) ( 0 ≤ i ≤ a ) 中,有多少不是 p p 的倍数。
- 我们可以用Lucas定理解决这个问题:即。
- 不难发现 (ai) ( a i ) 不是 p p 的倍数当且仅当的 p p 进制表示每一位都不超过的 p p 进制表示的对应位。
- 进制转换即可。
- 时间复杂度。
【代码】
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 262144;
const int MAXLOG = 30;
const int P = 1e9 + 7;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); }
template <typename T> void read(T &x) {
x = 0; int f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
x *= f;
}
template <typename T> void write(T x) {
if (x < 0) x = -x, putchar('-');
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
write(x);
puts("");
}
namespace FFT {
const int MAXN = 262144;
const long double pi = acosl(-1);
struct point {long double x, y; };
point operator + (point a, point b) {return (point) {a.x + b.x, a.y + b.y}; }
point operator - (point a, point b) {return (point) {a.x - b.x, a.y - b.y}; }
point operator * (point a, point b) {return (point) {a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x}; }
point operator / (point a, long double x) {return (point) {a.x / x, a.y / x}; }
int N, Log, home[MAXN];
point tmp[MAXN];
void FFTinit() {
for (int i = 0; i < N; i++) {
int tmp = i, ans = 0;
for (int j = 1; j <= Log; j++) {
ans <<= 1;
ans += tmp & 1;
tmp >>= 1;
}
home[i] = ans;
}
}
void FFT(point *a, int mode) {
for (int i = 0; i < N; i++)
if (home[i] < i) swap(a[i], a[home[i]]);
for (int len = 2; len <= N; len <<= 1) {
point delta = (point) {cosl(2 * pi / len * mode), sinl(2 * pi / len * mode)};
for (int i = 0; i < N; i += len) {
point now = (point) {1, 0};
for (int j = i, k = i + len / 2; k < i + len; j++, k++) {
point tmp = a[j];
point tnp = a[k] * now;
a[j] = tmp + tnp;
a[k] = tmp - tnp;
now = now * delta;
}
}
}
if (mode == -1) {
for (int i = 0; i < N; i++)
a[i] = a[i] / (4 * N);
}
}
void times(int *a, int *b, int *c, int limit, int p) {
N = 1, Log = 0;
while (N < 2 * limit) {
N <<= 1;
Log++;
}
for (int i = 0; i < limit; i++)
tmp[i] = (point) {(long double) (a[i] + b[i]), (long double) (a[i] - b[i])};
for (int i = limit; i < N; i++)
tmp[i] = (point) {0, 0};
FFTinit();
FFT(tmp, 1);
for (int i = 0; i < N; i++)
tmp[i] = tmp[i] * tmp[i];
FFT(tmp, -1);
for (int i = 0; i < N; i++)
c[i] = (long long) (tmp[i].x + 0.5) % p;
}
}
char s[MAXN];
int a[MAXN], f[MAXLOG][MAXN];
int n, len, p, bits[MAXN];
int modulo() {
int r = 0;
for (int i = len; i >= 1; i--)
r = (r * 10 + a[i]) % p;
return r;
}
void divide() {
int r = 0;
for (int i = len; i >= 1; i--) {
r = r * 10 + a[i];
a[i] = r / p;
r %= p;
}
while (len && a[len] == 0) len--;
}
void work(int l, int r, int depth) {
int len = r - l + 1;
for (int i = 0; i <= 2 * len; i++)
f[depth][i] = 0;
if (l == r) {
f[depth][0] = 1;
if (l != 0) f[depth][1] = l;
return;
}
int mid = (l + r) / 2;
work(l, mid, depth);
work(mid + 1, r, depth + 1);
int lim = max(mid - l + 1, r - mid) + 1;
FFT :: times(f[depth], f[depth + 1], f[depth], lim, p);
}
int main() {
int T; read(T);
while (T--) {
scanf("\n%s", s + 1); read(p);
len = strlen(s + 1);
reverse(s + 1, s + len + 1);
for (int i = 1; i <= len; i++)
a[i] = s[i] - '0';
n = 0;
while (len != 0) {
bits[++n] = modulo();
divide();
}
bits[n + 1] = 0;
if (bits[1] == p - 1) {
bits[1] = 0;
for (int i = 2; true; i++) {
chkmax(n, i);
if (++bits[i] == p) bits[i] = 0;
else break;
}
}
int ans = 1;
for (int i = 2; i <= n; i++)
ans = ans * (bits[i] + 1ll) % P;
work(0, bits[1], 0);
int tans = 0;
for (int i = 0; i <= bits[1]; i++)
if (f[0][i]) tans++;
writeln(1ll * tans * ans % P);
}
return 0;
}