题目描述
给定一个 n n 次多项式 和一个 m m 次多项式 ,请求出多项式 Q(x) Q ( x ) , R(x) R ( x ) ,满足以下条件:
Q(x)
Q
(
x
)
次数为
n−m
n
−
m
,
R(x)
R
(
x
)
次数小于
m
m
所有的运算在模
998244353
998244353
意义下进行。
输入输出格式
输入格式:
第一行两个整数
n
n
,,意义如上。
第二行
n+1
n
+
1
个整数,从低到高表示
F(x)
F
(
x
)
的各个系数。
第三行
m+1
m
+
1
个整数,从低到高表示
G(x)
G
(
x
)
的各个系数。
输出格式:
第一行
n−m+1
n
−
m
+
1
个整数,从低到高表示
Q(x)
Q
(
x
)
的各个系数。
第二行
m
m
个整数,从低到高表示 的各个系数。
如果
R(x)
R
(
x
)
不足
m−1
m
−
1
次,多余的项系数补
0
0
。
输入输出样例
输入样例#1:
5 1
1 9 2 6 0 8
1 7
输出样例#1:
237340659 335104102 649004347 448191342 855638018
760903695
说明
对于所有数据,,给出的系数均属于
[0,998244353)∩Z
[
0
,
998244353
)
∩
Z
。
分析:
我们设
degA=n
d
e
g
A
=
n
,
degB=m
d
e
g
B
=
m
,
degC=n−m+1
d
e
g
C
=
n
−
m
+
1
,
degD<m
d
e
g
D
<
m
。
设一个
n
n
次多项式的变换为
xnA(1x)
x
n
A
(
1
x
)
,其实就是把
A
A
翻转。
因为,有
两边同时乘 xn x n ,
那么如果我们在 mod xn−m+1 m o d x n − m + 1 意义下跑,那么 D D 就没掉了,即
一个多项式求逆后就可以得到 B′−1(x) B ′ − 1 ( x ) ,卷上 A′(x) A ′ ( x ) 就可以得到 C′(x) C ′ ( x ) 。
D(x)=A(x)−B(x)∗C(x) D ( x ) = A ( x ) − B ( x ) ∗ C ( x ) 。就完成了。
代码:
// luogu-judger-enable-o2
#include <iostream>
#include <cmath>
#include <cstdio>
#include <algorithm>
#define LL long long
const LL G=3;
const int maxn=3e5+7;
const LL mod=998244353;
using namespace std;
int n,m,len;
LL f[maxn],g[maxn],inv[maxn],c[maxn],r[maxn],x[maxn],y[maxn],w[maxn];
LL power(LL x,LL y)
{
if (y==1) return x;
LL c=power(x,y/2);
c=(c*c)%mod;
if (y%2) c=(c*x)%mod;
return c;
}
void ntt(LL *a,LL f)
{
for (LL i=0;i<len;i++)
{
if (i<r[i]) swap(a[i],a[r[i]]);
}
w[0]=1;
for (LL i=2;i<=len;i*=2)
{
LL wn;
if (f==1) wn=power(G,(mod-1)/i);
else wn=power(G,(mod-1)-(mod-1)/i);
for (LL j=i/2;j>=0;j-=2) w[j]=w[j/2];
for (LL j=1;j<i/2;j+=2) w[j]=(w[j-1]*wn)%mod;
for (LL j=0;j<len;j+=i)
{
for (LL k=0;k<i/2;k++)
{
LL u=a[j+k],v=(a[j+k+i/2]*w[k])%mod;
a[j+k]=(u+v)%mod;
a[j+k+i/2]=(u-v+mod)%mod;
}
}
}
if (f==-1)
{
LL inv=power(len,mod-2);
for (LL i=0;i<len;i++) a[i]=(a[i]*inv)%mod;
}
}
void init(LL len)
{
LL k=trunc(log(len+0.5)/log(2));
for (LL i=0;i<len;i++)
{
r[i]=(r[i>>1]>>1)|((i&1)<<(k-1));
}
}
void NTT(LL *a,LL *b,LL *c,LL n,LL m)
{
len=1;
while (len<=(n+m)) len*=2;
init(len);
for (LL i=0;i<len;i++)
{
if (i<n) x[i]=a[i]; else x[i]=0;
if (i<m) y[i]=b[i]; else y[i]=0;
}
ntt(x,1); ntt(y,1);
for (LL i=0;i<len;i++) c[i]=x[i]*y[i]%mod;
ntt(c,-1);
}
void getinv(LL *a,LL *b,int deg)
{
if (deg==1)
{
b[0]=power(a[0],mod-2);
return;
}
LL d=(deg+1)/2;
getinv(a,b,d);
NTT(a,b,c,m+1,d);
c[0]=(2+mod-c[0])%mod;
for (int i=1;i<=m+d;i++) c[i]=(mod-c[i])%mod;
NTT(c,b,b,m+d+1,d);
for (int i=deg;i<len;i++) b[i]=0;
}
int main()
{
scanf("%d%d",&n,&m);
for (int i=0;i<=n;i++) scanf("%lld",&f[i]);
for (int i=0;i<=m;i++) scanf("%lld",&g[i]);
reverse(f,f+n+1);
reverse(g,g+m+1);
getinv(g,inv,n-m+1);
NTT(f,inv,c,n+1,n-m+1);
reverse(c,c+n-m+1);
for (int i=0;i<=n-m;i++) printf("%lld ",c[i]);
printf("\n");
reverse(f,f+n+1);
reverse(g,g+m+1);
NTT(g,c,c,m+1,n-m+1);
for (int i=0;i<m;i++) printf("%lld ",(f[i]-c[i]+mod)%mod);
}