P3803 【模板】多项式乘法(FFT)

\(\color{#0066ff}{题目描述}\)

给定一个n次多项式F(x),和一个m次多项式G(x)。

请求出F(x)和G(x)的卷积。

\(\color{#0066ff}{输入格式}\)

第一行2个正整数n,m。

接下来一行n+1个数字,从低到高表示F(x)的系数。

接下来一行m+1个数字,从低到高表示G(x)的系数

\(\color{#0066ff}{输出格式}\)

一行n+m+1个数字,从低到高表示F(x)∗G(x)的系数。

\(\color{#0066ff}{输入样例}\)

1 2
1 2
1 2 1

\(\color{#0066ff}{输出样例}\)

1 4 5 2

\(\color{#0066ff}{数据范围与提示}\)

保证输入中的系数大于等于 0 且小于等于9。

对于100%的数据:\(n, m \leq {10}^6\) , 共计20个数据点,2s。

数据有一定梯度。

空间限制:256MB

\(\color{#0066ff}{题解}\)

对于两个多项式相乘,显然暴力是\(O(n^2)\)
如何优化呢?
我们知道,n+1个点可以唯一确定一个n次多项式
那么我们对于\(A(x)*B(x)=C(x)\)
可以拆成这样

\(\left\{\begin{matrix}(a_1,b_1) \\ (a_2,b_2) \\(a_3,b_3) \\ . \\ . \\ . \\(a_{n+1},b_{n+1)}\end{matrix}\right\}\ \ *\ \left\{\begin{matrix}(c_1,d_1) \\ (c_2,d_2) \\(c_3,d_3) \\ . \\ . \\ . \\(c_{n+1},d_{n+1)}\end{matrix}\right\}\ \ =\ \ \left\{\begin{matrix}(a_1*c_1,b_1*d_1) \\ (a_2*c_2,b_2*d_2) \\(a_3*c_3,b_3*d_3) \\ . \\ . \\ . \\(a_{n+1}*c_{n+1},b_{n+1}*d_{n+1})\end{matrix}\right\}\)

上面这个点值表达式的运算显然是\(O(n)\)
我们现在要把A和B转成点值表达式,然后乘过去,最后再转换回来
现在考虑转成点值表达式
比如\(A(x)=a_0+a_1x+a_2x^2\)
那么点值表达式就是
\((x_1,a_0+a_1x_1+a_2x_1^2)(x_2,a_0+a_1x_2+a_2x_2^2)(x_3,x_0+a_1x_3+a_2x_3^2)\)
但是带入求y用秦九韶是\(O(n)\)的,而且要带入n个点,所以就\(O(n^2)\)
现在考虑优化这个过程
引入一个东西,单位复数根
定义\(\omega_n^n=1,\omega_n\)有n个
对于复数坐标系的两个点(可以理解为两个向量)
两个单位向量A,B,模长为\(a=b=1\),角度为\(\alpha,\beta\)
欧拉公式:\(e^{ix}=\cos x+i*\sin x\)
所以二者相乘即为\(ae^{i\alpha}*be^{i\beta}=abe^{i(\alpha+\beta)}=Ae^{i\Theta}\)
也就是说,复数相乘集合意义,模长相乘,极角相加
因此,\(w_n^n=1\),则\(n\Theta\%2\pi=0\)
\(\Theta=\frac{2\pi}{n}\)
定义\(\omega_0=(1,0),\omega_1=\omega_n^1\)
\(\omega_n^k=\omega_0*\omega_1^k\)
而且有\(\omega_n^{k+\frac{n}{2}}=-\omega_n^k\),因为关于原点对称
还有\(\omega_{2n}^{2k}=\omega_n^k\),n份取k份等同于2n份取2k份
回到多项式\(A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}\)
我们硬凑一下,使得\(A(x)\)共有\(2^k\)次方项(不足后面补系数0)
然后开始分治,把\(A(x)\)拆成\(A_0(x)\)\(A_1(x)\)
\(A_0(x)=a_0+a_2x+a_4x^2+...+a_{n-2}x^{\frac{n-2}{2}}\)
\(A_1(x)=a_1+a_3x+a_5x^2+...+a_{n-1}x^{\frac{n-2}{2}}\)
不难发现\(A(x)=A_0(x^2)+x*A_1(x^2)\)
我们把\(\omega_n^k\)带入
\(A(\omega_n^k)=A_0(\omega_n^{2k})+\omega_n^k*A_1(\omega_n^{2k})=A_0(\omega_{\frac{n}{2}}^k)+\omega_n^k*A_1(\omega_{\frac{n}{2}}^k)\)
\(A(\omega_n^{k+\frac{n}{2}})=A_0(-\omega_n^{k})-\omega_n^k*A_1(-\omega_n^{k})=A_0(\omega_{\frac{n}{2}}^k)-\omega_n^k*A_1(\omega_{\frac{n}{2}}^k)\)
两式子差距只在正负号!
所以求出了\(A_0,A_1\),相加即为第一个,相减即为第二个
也就是说,对于当前序列,我们求出了一半,另一半也出来了
那么我们去找一找对应关系
\(\{a_0\} ||||| \{a_1\} ||||| \{a_2\} ||||| \{a_3\} ||||| \{a_4\} ||||| \{a_5\} ||||| \{a_6\} ||||| \{a_7\}\)
\(\{000\} ||| \{001\} ||| \{010\} ||| \{011\} ||| \{100\} ||| \{101\} ||| \{110\} ||| \{111\}\)
\(|||||||||||\{a_0,a_2,a_4,a_6\} ||||||||||||||||||||||||| \{a_1,a_3,a_5,a_7\}||||||||||||\)
\(||||||\{a_0,a_4\} |||||||||| \{a_2,a_6\} ||||||||||||| \{a_1,a_5\} |||||||||| \{a_3,a_7\} |||||\)
\(\{a_0\} ||||| \{a_4\} ||||| \{a_2\} ||||| \{a_6\} ||||| \{a_1\} ||||| \{a_5\} ||||| \{a_3\} ||||| \{a_7\}\)
\(\{000\} ||| \{100\} ||| \{010\} ||| \{110\} ||| \{001\} ||| \{101\} ||| \{011\} ||| \{111\}\)
卧槽,这是二进制翻转啊
我们用\(r_i\)代表数i翻转后是几
\(r[i]=r[i>>1]>>1|(i\&1)*(len<<1)\)
要求当前位置的翻转后的数,把当前最后一位删去
那么当前的数的r已经求出来了
将这个数翻转(注意,二进制位数固定,比如长度为8,那么\(00000001\to 10000000\)),再<<1,
这样现在的数除了最高位其他都是当前的翻转数了
这时只要考虑原来最后一位是不是1即可
通过FFT,我们求出了A(x)的点值表达式
B(x)同理,然\(A(x)*B(x)=C(x)\)\(O(n)\)
现在的问题是怎么转回系数表达式
定理

\(\left\{\begin{matrix} y_0\\y_1\\y_2\\.\\.\\.\\y_{n-1}\end{matrix}\right\}=\left\{\begin{matrix} 1 & 1 & 1 & ...& 1 \\1 & \omega_n & \omega_n^2 & ... & \omega_n^{n-1}\\1 & \omega_n^2 & \omega_n^4 & ... & \omega_n^{2(n-1)}\\. & . & . & . & .\\. & . & . & . & .\\. & . & . & . & .\\1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & ... & \omega_n^{(n-1)(n-1)}\end{matrix}\right\}*\left\{\begin{matrix} a_0\\a_1\\a_2\\.\\.\\.\\a_{n-1}\end{matrix}\right\}\)

左面是点值表达式,右面是系数表达式
现在已知\(Y=W*A\)
\(A=Y*W^{-1}\)
于是要矩阵求逆
不过这个矩阵又有个定理
它的逆矩阵为:指数全取负,再所有数/n之后的矩阵
那么就简单了
#include<bits/stdc++.h>
#define LL long long
LL in() {
    char ch; LL x = 0, f = 1;
    while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
    for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
    return x * f;
}
const double pi = acos(-1);
const int maxn = 3e6 + 41;
struct node {
    double x, y;
    node(double x = 0, double y = 0): x(x), y(y) {}
    friend node operator + (const node &a, const node &b) { return node(a.x + b.x, a.y + b.y); }
    friend node operator - (const node &a, const node &b) { return node(a.x - b.x, a.y - b.y); }
    friend node operator * (const node &a, const node &b) { return node(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
    friend node operator / (const node &a, const double &b) { return node(a.x / b, a.y / b); }
}A[maxn], B[maxn], C[maxn];
int len, n, m, r[maxn];
void FFT(node *D, int flag) {
    for(int i = 0; i < len; i++) if(i < r[i]) std::swap(D[i], D[r[i]]);
    for(int l = 1; l < len; l <<= 1) {
        node w0(cos(pi / l), flag * sin(pi / l));
        for(int i = 0; i < len; i += (l << 1)) {
            node w(1, 0), *a0 = D + i, *a1 = D + i + l;
            for(int k = 0; k < l; k++, a0++, a1++, w = w * w0) {
                node tmp = *a1 * w;
                *a1 = *a0 - tmp;
                *a0 = *a0 + tmp;
            }
        }
    }
    if(!(~flag)) for(int i = 0; i < len; i++) D[i] = D[i] / len;
}
int main() {
    n = in(), m = in();
    for(len = 1; len <= n + m; len <<= 1);
    for(int i = 0; i <= n; i++) A[i] = in();
    for(int i = 0; i <= m; i++) B[i] = in();
    for(int i = 1; i < len; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) * (len >> 1));
    FFT(A, 1), FFT(B, 1);
    for(int i = 0; i < len; i++) C[i] = A[i] * B[i];
    FFT(C, -1);
    for(int i = 0; i <= n + m; i++) printf("%d%c", (int)round(C[i].x), i == n + m? '\n' : ' ');
    return 0;
}

NTT模数写法

#include<bits/stdc++.h>
#define LL long long
LL in() {
    char ch; LL x = 0, f = 1;
    while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
    for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
    return x * f;
}
using std::vector;
const int mod = 998244353;
const int maxn = 3e6 + 10;
LL ksm(LL x, LL y) {
    LL re = 1LL;
    while(y) {
        if(y & 1) re = re * x % mod;
        x = x * x % mod;
        y >>= 1;
    }   
    return re;
}
        
void FNTT(vector<int> &A, int len, int flag) { 
    A.resize(len);
    int *r = new int[maxn];
    r[0] = 0;
    for(int i = 0; i < len; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) * (len >> 1));
    for(int i = 0; i < len; i++) if(i < r[i]) std::swap(A[i], A[r[i]]);
    for(int l = 1; l < len; l <<= 1) {
        int w0 = ksm(3, (mod - 1) / (l << 1));
        for(int i = 0; i < len; i += (l << 1)) {
            int w = 1, a0 = i, a1 = l + i;
            for(int k = 0; k < l; k++, a0++, a1++, w = 1LL * w * w0 % mod) {
                int tmp = 1LL * A[a1] * w % mod;
                A[a1] = ((A[a0] - tmp) % mod + mod) % mod;
                A[a0] = (A[a0] + tmp) % mod;
            }
        }
    }
    if(!(~flag)) {
        std::reverse(A.begin() + 1, A.end());
        int inv = ksm(len, mod - 2);
        for(int i = 0; i < len; i++) A[i] = 1LL * inv * A[i] % mod;
    }
    delete []r;
}

vector<int> operator * (vector<int> A, vector<int> B) {
    int tot = A.size() + B.size() - 1;
    int len = 1;
    while(len <= tot) len <<= 1;
    FNTT(A, len, 1);
    FNTT(B, len, 1);
    vector<int> ans;
    ans.resize(len);
    for(int i = 0; i < len; i++) ans[i] = 1LL * A[i] * B[i] % mod;
    FNTT(ans, len, -1);
    ans.resize(tot);
    return ans;
}
signed main() {
    int n = in(), m = in();
    vector<int> A, B, C;
    for(int i = 0; i <= n; i++) A.push_back(in());
    for(int i = 0; i <= m; i++) B.push_back(in());
    C = A * B;
    for(int i = 0; i <= n + m; i++) printf("%d%c", C[i], i == n + m? '\n' : ' ');
    return 0;
}

转载于:https://www.cnblogs.com/olinr/p/10089226.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值