题目:
http://www.lydsy.com/JudgeOnline/problem.php?id=1919
题意:
给出两个长度为
n
的整数序列
对于两个长度为
求
a∗b∗b∗⋯∗b
每一位模
(n+1)
的值,其中有
C
个
n≤5⋅105,a[i],b[i],C≤109
题解:
由原根的性质可知,长度为
n
的FFT即可支持
先考虑如何快速计算长度为
当
n=2k
时,FFT每次是将序列一分为二,然后利用分治的技巧来进行合并。
因此当
n=2k1⋅3k2⋅5k3⋅7k4
时,FFT每次可能将序列一分为
p(p=2,3,5,7)
,合并时的式子需要重新推导。
不妨设是将
p
个长度为
由于分裂时将模
p
意义相同的部分放在了一起,所以对于合并后的多项式
拆分的 p 个多项式分别为
故有
于是可以 O(p) 合并出每个点的值,而这样的分治层数是 O(∑4i=1ki)=O(logn) ,每层的复杂度是 O(pn)=O(7n) ,因此整体的复杂度是 O(nlogn) 。
上述方法也可非递归实现,在分裂过程中注意每段之间互不影响,在合并过程中注意存储方式即可,笔者的做法就是迭代的做法。
再考虑解决精度问题,由同余关系的性质,可以使得每次计算相乘时的值域降低到
O(n2)
,但需要将单位复根映射到模意义下的剩余系中。
由于
(n+1)
是质数,
φ(n+1)=n
,所以在模
(n+1)
意义下存在原根
g
,使得
由于模
(n+1)
意义下原根的数量为
φ(n)=n∏piisprime,pi|n1−1pi
,而
n
的质因子大小不超过10,所以期望检查
上述做法基于
n
是10-smooth number
,即Cooley–Tukey FFT algorithm,而对于更强性质的
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int maxn = 500001;
int n, m, mod, tot, p[maxn], pw[maxn], a[maxn], b[maxn];
int mod_pow(int x, int k)
{
int ret = 1;
for( ; k > 0; k >>= 1, x = (LL)x * x % mod)
if(k & 1)
ret = (LL)ret * x % mod;
return ret;
}
void NTT(int x[maxn], int flag)
{
// go deeper
static int y[maxn] = {};
int *cur = x, *nxt = y;
for(int i = tot - 1, delta = n / p[i]; i > 0; --i, delta /= p[i], swap(cur, nxt))
for(int j = 0, *np = nxt; j < n; j += delta * p[i])
for(int k = 0; k < p[i]; ++k)
for(int l = 0, *cp = cur + j + k; l < delta; ++l, ++np, cp += p[i])
*np = *cp;
// recursion
for(int i = 0, clen = 1, nlen = p[i]; i < tot; ++i, clen = nlen, nlen *= p[i], swap(cur, nxt))
for(int j = 0, k = 0, ww = 1, delta = 0; j < n; ++j, k = k + 1 < clen ? k + 1 : 0, ww = (LL)ww * pw[i] % mod, delta = delta + nlen > j ? delta : delta + nlen)
{
nxt[j] = 0;
for(int t = 0, www = 1; t < nlen; t += clen, www = (LL)www * ww % mod)
nxt[j] = (nxt[j] + (LL)www * cur[delta + t + k]) % mod;
}
if(flag == -1)
{
reverse(cur + 1, cur + n);
for(int i = 0; i < n; ++i)
cur[i] = (LL)cur[i] * n % mod; // n * n mod (n + 1) = 1
}
if(cur != x)
memcpy(x, cur, n * sizeof(int));
}
int main()
{
int tmp;
scanf("%d%d", &n, &m);
mod = n + 1;
tmp = n;
m = (m - 1) % n + 1;
for(int i = 2; i * i <= tmp; ++i)
for( ; tmp % i == 0; tmp /= i, p[tot++] = i);
if(tmp > 1)
p[tot++] = tmp;
for(int ori = 2; ; ++ori)
{
bool flag = 1;
for(int i = 0; i < tot && flag; ++i)
if(!i || p[i - 1] != p[i])
flag &= mod_pow(ori, n / p[i]) != 1;
if(flag)
{
pw[tot - 1] = ori;
for(int i = tot - 2; i >= 0; --i)
pw[i] = mod_pow(pw[i + 1], p[i + 1]);
break;
}
}
for(int i = 0; i < n; ++i)
scanf("%d", a + i);
NTT(a, 1);
for(int i = 0; i < n; ++i)
scanf("%d", b + i);
NTT(b, 1);
for(int i = 0; i < n; ++i)
a[i] = (LL)a[i] * mod_pow(b[i], m) % mod;
NTT(a, -1);
for(int i = 0; i < n; ++i)
printf("%d\n", a[i]);
return 0;
}