未完待续
下面讨论为什么要用分治,首先看分治的复杂度。回来加图!!!
还有些时间说下分治FFT的复杂度:
首先NTT和FFT复杂度都是
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)然后我们需要知道每次分治的长度与次数:
n
l
o
g
(
n
)
+
2
×
n
2
l
o
g
(
n
2
)
+
3
×
n
3
l
o
g
(
n
3
)
+
.
.
.
.
.
+
n
∗
n
n
l
o
g
(
n
n
)
nlog(n)+2\times\frac{n}{2}log(\frac{n}{2})+3\times\frac{n}{3}log(\frac{n}{3})+.....+n*\frac{n}{n}log(\frac{n}{n})
nlog(n)+2×2nlog(2n)+3×3nlog(3n)+.....+n∗nnlog(nn)
然后等于:
n
l
o
g
(
n
)
+
n
∗
l
o
g
(
n
2
)
+
.
.
.
.
.
+
n
∗
l
o
g
(
n
n
)
nlog(n)+n*log(\frac{n}{2})+.....+n*log(\frac{n}{n})
nlog(n)+n∗log(2n)+.....+n∗log(nn)
后面的log()比较碍事,我们都认为是log(n)就好了实际小于。
然后划一下一共
l
o
g
(
n
)
log(n)
log(n)个
n
l
o
g
(
n
)
nlog(n)
nlog(n)相加也就是
n
l
o
g
2
(
n
)
nlog^2(n)
nlog2(n)
这就是分治FFT的复杂度。
然后怎么得到的。
我们转化成 G x = − 1 F 0 ∑ i = 1 x G x − i F i G_x=-\frac{1}{F_0}\sum_{i=1}^{x}G_{x-i}F_{i} Gx=−F01i=1∑xGx−iFi
然后我们每次枚举的区间是[l,r)我们分成[l,mid),[mid,r),那么我们分治完能够得到[l,mid)D的
G
i
G_i
Gi值,那么
G
l
∗
F
r
−
l
G_l*F_{r-l}
Gl∗Fr−l不就是
G
r
G_r
Gr其中一项吗?我们将其加进去不就好了。
也就是
G
r
=
−
1
F
0
(
G
0
F
r
+
G
1
F
r
−
1
+
G
2
F
r
−
2
+
.
.
.
.
.
.
G
r
−
1
F
1
)
)
G_r=-\frac{1}{F_0}(G_0F_r+G_1F_{r-1}+G_2F_{r-2}+......G_{r-1}F_{1}))
Gr=−F01(G0Fr+G1Fr−1+G2Fr−2+......Gr−1F1))
每一项加进去一次。
例题:P4238 【模板】多项式乘法逆
链接
分治FFT
///苟利国家生死以,岂因祸福避趋之。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 5e6 + 7;
const double PI = acos(-1);
const int p = 998244353, G = 3, Gi = 332748118;
int n, m;
int limit = 1;
ll res, ans[N];
int l;
int r[N];
ll a[N], g[N],f[N];
ll qpow(ll a, ll b)
{
ll ans = 1;
while(b)
{
if(b & 1) ans = ans * a % p;
a = a * a % p;
b >>= 1;
}
return ans;
}
void NTT(ll *A, int type)
{
for(int i = 0; i < limit; i++)
{
if(i < r[i]) swap(A[i], A[r[i]]); ///保证不在换过去
}
///从底层开始合并:
for(int mid = 1; mid < limit; mid = mid * 2)
{
///待合并区间长度的一半,最开始是长度为1的合并,类似倍增的思想
ll wn = qpow(G, (p - 1) / (mid * 2));
if(type==-1) wn = qpow(wn, p - 2);
for(int len = mid * 2, pos = 0; pos < limit; pos += len)
{
///len是区间的长度,pos是当前的位置
ll w = 1;
for(int k = 0; k < mid; k++, w = w * wn % p)
{
///只扫左半部分,蝴蝶变换得到有半部分。
int x = A[pos + k]; //左半部分
int y = w * A[pos + mid + k] % p; //有半部分
A[pos + k] = (x + y) % p;
A[pos + mid + k] = (x - y + p) % p;
}
}
}
if(type == 1) return ;
ll inlimit = qpow(limit, p - 2);
for(int i = 0; i < limit; i++)
{
A[i] = (A[i] * inlimit) % p;
///除以我们推出的N。
}
}
void init(int lm){
for(int i = 0; i < lm; i++)
{
r[i] = (r[i >> 1] >> 1) | ((i & 1)? (lm >>1):0);
}
}
void solve(int l,int r){
if( l + 1 >= r) return ;
int mid = (l+r)/2;
solve(l,mid);
int len=r-l;
init(len);
for(int i=1;i<len;i++) g[i] = a[i];
for(int i=l;i<mid;i++) f[i-l] = ans[i];
for(int i=mid;i<r;i++) f[i-l] = 0;
limit=len;
NTT(f,1);
NTT(g,1);
for(int i=0;i<len;i++) f[i]=f[i]*g[i]%p;
NTT(f,-1);
int inv = qpow((-a[0]%p+p)%p,p-2);///-1/a[0]
for(int i=mid ; i<r;i++) ans[i] = (ans[i]+f[i-l]*inv%p)%p;
solve(mid,r);
}
int main()
{
cin >> n;
for(int i = 0; i < n; i++) cin >> a[i], a[i] = (a[i] + p) % p;
while(limit < n) limit <<= 1;
ans[0]=qpow(a[0],p-2);
solve(0,limit);
for(int i = 0; i < n; i++) cout << ans[i] << " ";
return 0;
}
倍增优化
///苟利国家生死以,岂因祸福避趋之。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 5e6 + 7;
const double PI = acos(-1);
const int p = 998244353, G = 3, Gi = 332748118;
int n, m;
int limit = 1;
ll res, ans[N];
int l;
int r[N];
ll a[N], g[N],f[N];
ll qpow(ll a, ll b)
{
ll ans = 1;
while(b)
{
if(b & 1) ans = ans * a % p;
a = a * a % p;
b >>= 1;
}
return ans;
}
void NTT(ll *A, int type)
{
for(int i = 0; i < limit; i++)
{
if(i < r[i]) swap(A[i], A[r[i]]); ///保证不在换过去
}
///从底层开始合并:
for(int mid = 1; mid < limit; mid = mid * 2)
{
///待合并区间长度的一半,最开始是长度为1的合并,类似倍增的思想
ll wn = qpow(G, (p - 1) / (mid * 2));
if(type==-1) wn = qpow(wn, p - 2);
for(int len = mid * 2, pos = 0; pos < limit; pos += len)
{
///len是区间的长度,pos是当前的位置
ll w = 1;
for(int k = 0; k < mid; k++, w = w * wn % p)
{
///只扫左半部分,蝴蝶变换得到有半部分。
int x = A[pos + k]; //左半部分
int y = w * A[pos + mid + k] % p; //有半部分
A[pos + k] = (x + y) % p;
A[pos + mid + k] = (x - y + p) % p;
}
}
}
if(type == 1) return ;
ll inlimit = qpow(limit, p - 2);
for(int i = 0; i < limit; i++)
{
A[i] = (A[i] * inlimit) % p;
///除以我们推出的N。
}
}
void init(int lm){
for(int i = 0; i < lm; i++)
{
r[i] = (r[i >> 1] >> 1) | ((i & 1)? (lm >>1):0);
}
}
void solve(int num){
if(num==1) {
ans[0]=qpow(a[0],p-2);
return ;
}
solve((num+1)>>1);///上取整
limit=1;
while(limit<(num<<1)) limit<<=1;
init(limit);
for(int i=0;i<limit;i++){
f[i]=(i<num?a[i]:0);
g[i]=ans[i];
}
NTT(f,1);NTT(g,1);
for(int i=0;i<limit;i++) ans[i]=(2-f[i]*g[i]%p+p)%p*g[i]%p;
NTT(ans,-1);
for(int i=num;i<limit;i++) ans[i]=0;
}
int main()
{
cin >> n;
for(int i = 0; i < n; i++) cin >> a[i], a[i] = (a[i] + p) % p;
while(limit < n) limit <<= 1;
solve(n);
for(int i = 0; i < n; i++) cout << ans[i] << " ";
return 0;
}