这周就做了两题,一题就要调三天。。。。
多项式多点求值
分治,对 f ( x ) f(x) f(x)在 x l , x l + 1 ⋯ x r x_l,x_{l + 1}\cdots x_r xl,xl+1⋯xr的值,构造 p ( x ) = ∏ i = l m i d ( x − x i ) p(x) = \prod_{i = l}^{mid}(x - x_i) p(x)=∏i=lmid(x−xi),设 f ( x ) f(x) f(x)对 p ( x ) p(x) p(x)取模后为 g ( x ) g(x) g(x),易证对任意 i ∈ [ l , m i d ] i \in [l,mid] i∈[l,mid],均有 f ( x i ) = g ( x i ) f(x_i) = g(x_i) f(xi)=g(xi), i ∈ [ m i d + 1 , r ] i \in [mid + 1,r] i∈[mid+1,r]时同理,递归求解即可
p ( x ) p(x) p(x)可用分治FFT预处理出来
显然 g ( x ) g(x) g(x)次数不超过 f ( x ) f(x) f(x)次数的一半,故每次递归使问题规模减半,由主定理得复杂度为 O ( n log 2 n ) O(n\log^2n) O(nlog2n)
但常数极大…
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
const int maxn = 6.5e4,maxm = 1.32e5,mod = 998244353,g = 3; //2 ^ 17 > 1.3e5 !!!!!
int n,m,N,a[maxn],rev[maxm],F[maxm],T[maxm],T1[maxm],T2[maxm],Q[maxm],R[maxm];
vector <int> P[4 * maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
int qpow(int x,int k){
long long d = 1,t = x;
while(k){
if(k & 1) d = d * t % mod;
t = t * t % mod,k >>= 1;
}
return d;
}
void init(int n){
N = 1;
int cnt = 0;
while(N <= n) N <<= 1,cnt ++;
for(int i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
}
void NTT(int *F,int n,int p){
for(int i = 0; i < n; i ++) if(i < rev[i]) swap(F[i],F[rev[i]]);
for(int i = 1; i < n; i <<= 1){
int w1 = qpow(g,(mod - 1) / (i << 1));
for(int j = 0; j < n; j += i << 1){
int w = 1;
for(int k = j; k < j + i; k ++){
int t1 = F[k],t2 = 1ll * w * F[k + i] % mod;
F[k] = (t1 + t2) % mod,F[k + i] = (t1 - t2 + mod) % mod;
w = 1ll * w * w1 % mod;
}
}
}
if(p == -1){
int inv = qpow(n,mod - 2);
for(int i = 0; i < n; i ++) F[i] = 1ll * F[i] * inv % mod;
for(int i = n / 2; i >= 1; i --) swap(F[i],F[n - i]);
}
}
void Mul(int *F,int *G,int n,int p = 0){
init(n << 1);
NTT(F,N,1),NTT(G,N,1);
for(int i = 0; i < N; i ++) F[i] = 1ll * F[i] * G[i] % mod;
NTT(F,N,-1);
if(!p) for(int i = n + 1; i < N; i ++) F[i] = 0;
}
void get_Inv(int n,int *F,int *G){
if(n == 1){
G[0] = qpow(F[0],mod - 2);
return;
}
get_Inv((n + 1) >> 1,F,G);
init(n << 1);
for(int i = 0; i < n; i ++) T[i] = F[i];
NTT(T,N,1),NTT(G,N,1);
for(int i = 0; i < N; i ++) G[i] = 1ll * (2 - 1ll * G[i] * T[i] % mod + mod) % mod * G[i] % mod;
NTT(G,N,-1);
for(int i = n; i < N; i ++) G[i] = 0;
for(int i = 0; i < N; i ++) T[i] = 0;
}
void Mod(int *F,int *G,int n,int m){
for(int i = 0; i <= n - m; i ++) R[i] = G[m - i];
for(int i = n - m + 1; i <= m; i ++) R[i] = 0;
get_Inv(n - m + 1,R,Q);
for(int i = 0; i <= n - m; i ++) R[i] = F[n - i];
for(int i = n - m + 1; i <= n; i ++) R[i] = 0;
Mul(Q,R,n - m);
for(int i = 0; i <= m; i ++) R[i] = G[m - i];
for(int i = m + 1; i < N; i ++) R[i] = 0;
Mul(Q,R,max(m,n - m),1);
for(int i = 0; i <= n; i ++) Q[i] = (F[n - i] - Q[i] + mod) % mod;
for(int i = 0; i < m; i ++) F[i] = Q[n - i];
for(int i = m; i <= n; i ++) F[i] = 0;
for(int i = 0; i < N; i ++) R[i] = 0,Q[i] = 0;
}
void pre(int x,int l,int r){
if(l == r){
P[x].push_back(mod - a[l]),P[x].push_back(1);
return;
}
int mid = (l + r) / 2,ls = x << 1,rs = x << 1 | 1;
pre(ls,l,mid),pre(rs,mid + 1,r);
for(int i = 0; i <= mid - l + 1; i ++) T1[i] = P[ls][i];
for(int i = 0; i <= r - mid; i ++) T2[i] = P[rs][i];
Mul(T1,T2,mid - l + 1,1);
for(int i = 0; i <= r - l + 1; i ++) P[x].push_back(T1[i]);
for(int i = 0; i < N; i ++) T1[i] = T2[i] = 0;
}
void solve(int x,int l,int r,int *F){
if(l == r){
printf("%d\n",F[0]);
return;
}
int mid = (l + r) / 2,T[r - l + 10],ls = x << 1,rs = x << 1 | 1;
for(int i = 0; i <= r - l; i ++) T[i] = F[i];
for(int i = 0; i <= mid - l + 1; i ++) T1[i] = P[ls][i];
Mod(T,T1,r - l,mid - l + 1);
for(int i = 0; i <= mid - l + 1; i ++) T1[i] = 0;
solve(ls,l,mid,T);
for(int i = 0; i <= r - l; i ++) T[i] = F[i];
for(int i = 0; i <= r - mid; i ++) T1[i] = P[rs][i];
Mod(T,T1,r - l,r - mid);
for(int i = 0; i <= r - mid; i ++) T1[i] = 0;
solve(rs,mid + 1,r,T);
}
int main(){
n = read(),m = read();
if(!m) return 0;
for(int i = 0; i <= n; i ++) F[i] = read();
for(int i = 1; i <= m; i ++) a[i] = read();
pre(1,1,m);
if(n >= m){
for(int i = 0; i <= m; i ++) T1[i] = P[1][i];
Mod(F,T1,n,m);
for(int i = 0; i <= m; i ++) T1[i] = 0;
}
solve(1,1,m,F);
return 0;
}
多项式快速差值
由拉格朗日插值公式 拉格朗·日插值 ,有
f
(
x
)
=
∑
i
=
1
n
(
∏
i
≠
j
x
−
x
i
x
i
−
x
j
y
i
)
=
∑
i
=
1
n
(
y
i
∏
i
≠
j
(
x
i
−
x
j
)
∏
i
≠
j
(
x
−
x
j
)
)
f(x) = \sum_{i = 1}^n\bigg(\prod_{i \not=j}\frac{x - x_i}{x_i - x_j}y_i\bigg) = \sum_{i = 1}^n\bigg(\frac{y_i}{\prod_{i \not=j}(x_i - x_j)}\prod_{i\not=j}(x - x_j)\bigg)
f(x)=i=1∑n(i=j∏xi−xjx−xiyi)=i=1∑n(∏i=j(xi−xj)yii=j∏(x−xj))
首先考虑式子中的系数
y
i
∏
i
≠
j
(
x
i
−
x
j
)
\frac{y_i}{\prod_{i \not=j}(x_i - x_j)}
∏i=j(xi−xj)yi如何快速计算,也就是对于
i
∈
[
1
,
n
]
i \in [1,n]
i∈[1,n],要求出
∏
i
≠
j
(
x
i
−
x
j
)
\prod_{i \not=j}(x_i - x_j)
∏i=j(xi−xj)的值:
首先设
g
(
x
)
=
∏
i
=
1
n
(
x
−
x
i
)
g(x) = \prod_{i = 1}^n(x - x_i)
g(x)=i=1∏n(x−xi)
则有 ∏ i ≠ j ( x i − x j ) = g ( x i ) ( x − x i ) \prod_{i \not=j}(x_i - x_j) = \frac{g(x_i)}{(x - x_i)} i=j∏(xi−xj)=(x−xi)g(xi)
根据洛必达法则,易得 g ( x i ) x − x i = g ′ ( x i ) \frac{g(x_i)}{x - x_i} = g'(x_i) x−xig(xi)=g′(xi)
那么问题就变成求
g
′
(
x
)
g'(x)
g′(x)在
x
1
,
x
2
⋯
,
x
n
x_1,x_2\cdots,x_n
x1,x2⋯,xn处的值,直接套上面的多点求值即可
(
g
(
x
)
g(x)
g(x)与多点求值中要预处理的式子是一样的,所以可以先算出来,多点求值时就不用再做一遍)
下面回到正题,考虑如何计算 f ( x ) f(x) f(x)
用分治的思想,设 f l , r ( x ) = ∑ i = l r ( y i ∏ i ≠ j ( x i − x j ) ∏ l ⩽ j ⩽ r , i ≠ j ( x − x j ) ) f_{l,r}(x) = \sum_{i = l}^r\bigg(\frac{y_i}{\prod_{i \not=j}(x_i - x_j)}\prod_{l \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg) fl,r(x)=i=l∑r(∏i=j(xi−xj)yil⩽j⩽r,i=j∏(x−xj))
这里要注意 f l , r ( x ) f_{l,r}(x) fl,r(x)并不代表第 l l l到 r r r个点所确定的 r − l r - l r−l次函数,因为系数的分母中 j j j的取值仍然是 1 1 1到 n n n的,所以这只是为了分治而弄出来的函数,没有实际意义
则有
f
l
,
r
(
x
)
=
∑
i
=
l
r
(
y
i
g
′
(
x
i
)
∏
l
⩽
j
⩽
r
,
i
≠
j
(
x
−
x
j
)
)
=
∑
i
=
l
m
i
d
(
y
i
g
′
(
x
i
)
∏
l
⩽
j
⩽
r
,
i
≠
j
(
x
−
x
j
)
)
+
∑
i
=
m
i
d
+
1
r
(
y
i
g
′
(
x
i
)
∏
l
⩽
j
⩽
r
,
i
≠
j
(
x
−
x
j
)
)
=
(
∏
i
=
m
i
d
+
1
r
(
x
−
x
i
)
)
(
∑
i
=
l
m
i
d
(
y
i
g
′
(
x
i
)
∏
l
⩽
j
⩽
m
i
d
,
i
≠
j
(
x
−
x
j
)
)
)
+
(
∏
i
=
l
m
i
d
(
x
−
x
i
)
)
(
∑
i
=
m
i
d
+
1
r
(
y
i
g
′
(
x
i
)
∏
m
i
d
+
1
⩽
j
⩽
r
,
i
≠
j
(
x
−
x
j
)
)
)
=
(
∏
i
=
m
i
d
+
1
r
(
x
−
x
i
)
)
f
l
,
m
i
d
(
x
)
+
(
∏
i
=
l
m
i
d
(
x
−
x
i
)
)
f
m
i
d
+
1
,
r
(
x
)
\begin{aligned} f_{l,r}(x) &= \sum_{i = l}^r\bigg(\frac{y_i}{g'(x_i)}\prod_{l \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg) \\ &= \sum_{i = l}^{mid}\bigg(\frac{y_i}{g'(x_i)}\prod_{l \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg) +\sum_{i = mid + 1}^r\bigg(\frac{y_i}{g'(x_i)}\prod_{l \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg) \\ &= \bigg(\prod_{i = mid +1}^r(x - x_i)\bigg)\bigg(\sum_{i = l}^{mid}\bigg(\frac{y_i}{g'(x_i)}\prod_{l \leqslant j \leqslant mid,i\not=j}(x - x_j)\bigg)\bigg) \end{aligned} \\ \;\;\;+ \bigg(\prod_{i = l}^{mid}(x - x_i)\bigg)\bigg(\sum_{i = mid + 1}^r\bigg(\frac{y_i}{g'(x_i)}\prod_{mid + 1 \leqslant j \leqslant r,i\not=j}(x - x_j)\bigg)\bigg) \\ = \bigg(\prod_{i = mid +1}^r(x - x_i)\bigg)f_{l,mid}(x) + \bigg(\prod_{i = l}^{mid}(x - x_i)\bigg)f_{mid + 1,r}(x)\;\;\;\;\;\;\;\,
fl,r(x)=i=l∑r(g′(xi)yil⩽j⩽r,i=j∏(x−xj))=i=l∑mid(g′(xi)yil⩽j⩽r,i=j∏(x−xj))+i=mid+1∑r(g′(xi)yil⩽j⩽r,i=j∏(x−xj))=(i=mid+1∏r(x−xi))(i=l∑mid(g′(xi)yil⩽j⩽mid,i=j∏(x−xj)))+(i=l∏mid(x−xi))(i=mid+1∑r(g′(xi)yimid+1⩽j⩽r,i=j∏(x−xj)))=(i=mid+1∏r(x−xi))fl,mid(x)+(i=l∏mid(x−xi))fmid+1,r(x)
可以直接分治计算,边界
f
n
,
n
(
x
)
=
y
n
g
′
(
x
n
)
f_{n,n}(x) = \frac{y_n}{g'(x_n)}
fn,n(x)=g′(xn)yn,由主定理得复杂度为
O
(
n
log
2
n
)
O(n\log^2n)
O(nlog2n) 常数比上面更大,而且长…
非常巧合的是(也可能是发明人故意构造),分治中的 “系数多项式” ∏ i = l r ( x − x i ) \prod_{i = l}^r(x - x_i) ∏i=lr(x−xi) 也是在多点求值的预处理(也即分治计算 g ( x ) g(x) g(x))时计算过的,不需要另外计算
细节看代码
//为了卡常,有些地方写得有点奇怪,两个最主要的优化用注释标出来了
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
const int maxn = 1e5 + 50,maxm = 2.63e5,mod = 998244353,g = 3;
int n,m,N,x[maxn],y[maxn],F[maxn],G[maxn],H[maxn],rev[maxm],T[maxm],T1[maxm],T2[maxm],Q[maxm],R[maxm],W[20][maxm],inv_W[20][maxm],inv[maxm];
vector <int> P[4 * maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
//将加减时的取模运算改为加减模数,非常有用,至少快了800ms/点
inline int add(int x,int y){
if(x + y < mod) return x + y;
else return x + y - mod;
}
inline int dec(int x,int y){
if(x - y >= 0) return x - y;
else return x - y + mod;
}
int qpow(int x,int k){
long long d = 1,t = x;
while(k){
if(k & 1) d = d * t % mod;
t = t * t % mod,k >>= 1;
}
return d;
}
void init(int n){
N = 1;
int cnt = 0;
while(N <= n) N <<= 1,cnt ++;
for(int i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
}
void NTT(int *F,int n,int p){
//预处理单位根和长度逆元,否则复杂度几乎多一个log
for(int i = 0; i < n; i ++) if(i < rev[i]) swap(F[i],F[rev[i]]);
for(int i = 1,cnt = 0; i < n; i <<= 1){
cnt ++;
for(int j = 0; j < n; j += i << 1){
for(int k = j; k < j + i; k ++){
int w = (p == 1 ? W[cnt][k - j] : inv_W[cnt][k - j]);
int t1 = F[k],t2 = 1ll * w * F[k + i] % mod;
F[k] = add(t1,t2),F[k + i] = dec(t1,t2);
}
}
}
if(p == -1) for(int i = 0; i < n; i ++) F[i] = 1ll * F[i] * inv[n] % mod;
}
void Mul(int *F,int *G,int n,int p = 0){
init(n << 1);
NTT(F,N,1),NTT(G,N,1);
for(int i = 0; i < N; i ++) F[i] = 1ll * F[i] * G[i] % mod;
NTT(F,N,-1);
if(!p) for(int i = n + 1; i < N; i ++) F[i] = 0;
}
void get_Inv(int n,int *F,int *G){
if(n == 1){
G[0] = qpow(F[0],mod - 2);
return;
}
get_Inv((n + 1) >> 1,F,G);
init(n << 1);
for(int i = 0; i < n; i ++) T[i] = F[i];
NTT(T,N,1),NTT(G,N,1);
for(int i = 0; i < N; i ++) G[i] = 1ll * dec(2,1ll * G[i] * T[i] % mod) * G[i] % mod;
NTT(G,N,-1);
for(int i = n; i < N; i ++) G[i] = 0;
for(int i = 0; i < N; i ++) T[i] = 0;
}
void Mod(int *F,vector <int> G,int n,int m,int *H){
for(int i = 0; i <= n - m; i ++) R[i] = G[m - i];
for(int i = n - m + 1; i <= m; i ++) R[i] = 0;
get_Inv(n - m + 1,R,Q);
for(int i = 0; i <= n - m; i ++) R[i] = F[n - i];
Mul(Q,R,n - m);
for(int i = 0; i <= m; i ++) R[i] = G[m - i];
for(int i = m + 1; i < N; i ++) R[i] = 0;
Mul(Q,R,max(m,n - m),1);
for(int i = 0; i < m; i ++) H[i] = dec(F[i],Q[n - i]);
for(int i = m; i <= n; i ++) H[i] = 0;
for(int i = 0; i < N; i ++) R[i] = 0,Q[i] = 0;
}
void pre(int x,int l,int r,int *a){
if(l == r){
P[x].push_back(mod - a[l]),P[x].push_back(1);
return;
}
int mid = (l + r) / 2,ls = x << 1,rs = x << 1 | 1;
pre(ls,l,mid,a),pre(rs,mid + 1,r,a);
for(int i = 0; i <= mid - l + 1; i ++) T1[i] = P[ls][i];
for(int i = 0; i <= r - mid; i ++) T2[i] = P[rs][i];
Mul(T1,T2,mid - l + 1,1);
for(int i = 0; i <= r - l + 1; i ++) P[x].push_back(T1[i]);
for(int i = 0; i < N; i ++) T1[i] = T2[i] = 0;
}
void calc(int x,int l,int r,int *a,int *F,int *G){
if(r - l <= 256){
for(int i = l; i <= r; i ++){
int s = 0;
for(int j = r - l; j >= 0; j --) s = add(1ll * s * a[i] % mod,F[j]);
G[i] = s;
}
return;
}
int mid = (l + r) / 2,T[r - l + 10],ls = x << 1,rs = x << 1 | 1;
Mod(F,P[ls],r - l,mid - l + 1,T);
calc(ls,l,mid,a,T,G);
Mod(F,P[rs],r - l,r - mid,T);
calc(rs,mid + 1,r,a,T,G);
}
void Eva(int *F,int *a,int n,int m,int *G){
// pre(1,1,m);
if(n >= m) Mod(F,P[1],n,m,F);
calc(1,1,m,a,F,G);
}
void solve(int x,int l,int r,int *F){
if(l == r){
F[0] = 1ll * y[l] * qpow(H[l],mod - 2) % mod;
return;
}
int mid = (l + r) / 2,ls = x << 1,rs = x << 1 | 1,Fl[2 * (r - l + 10)],Fr[2 * (r - l + 10)];
for(int i = 0; i <= 2 * (r - l + 2); i ++) Fl[i] = Fr[i] = 0;
solve(ls,l,mid,Fl),solve(rs,mid + 1,r,Fr);
for(int i = r - mid; i >= 0; i --) T1[i] = P[rs][i];
for(int i = mid - l + 1; i >= 0; i --) T2[i] = P[ls][i];
Mul(T1,Fl,mid - l + 1,1),Mul(T2,Fr,mid - l + 1,1);
for(int i = r - l; i >= 0; i --) F[i] = add(T1[i],T2[i]);
for(int i = 0; i < N; i ++) T1[i] = T2[i] = 0;
}
int main(){
n = read();
init(n << 1);
for(int i = 0; i <= 19; i ++){
W[i][0] = inv_W[i][0] = 1;
W[i][1] = qpow(g,(mod - 1) / (1 << i)),inv_W[i][1] = qpow(W[i][1],mod - 2);
for(int j = 2; j < N; j ++) W[i][j] = 1ll * W[i][j - 1] * W[i][1] % mod,inv_W[i][j] = 1ll * inv_W[i][j - 1] * inv_W[i][1] % mod;
}
for(int i = 1; i <= N; i <<= 1) inv[i] = qpow(i,mod - 2);
for(int i = 1; i <= n; i ++) x[i] = read(),y[i] = read();
pre(1,1,n,x);
for(int i = 0; i <= n; i ++) G[i] = P[1][i];
for(int i = 0; i < n; i ++) G[i] = 1ll * G[i + 1] * (i + 1) % mod;
G[n] = 0;
Eva(F,x,n - 1,n,H);
solve(1,1,n,F);
for(int i = 0; i < n; i ++) printf("%d ",F[i]);
printf("\n");
return 0;
}