多项式快速幂
给出一个多项式 F ( x ) F(x) F(x) 以及常数 k k k,求出 G ( x ) ≡ F k ( x ) ( m o d x n ) G(x)\equiv F^k(x)\pmod {x^n} G(x)≡Fk(x)(modxn)。
之前有学到 ln \ln ln 和 exp \exp exp,就可以用来处理这样的指数问题。
两边同时取 ln \ln ln 得到: ln G ( x ) ≡ ln ( F ( x ) ) k \ln G(x)\equiv\ln(F(x))^k lnG(x)≡ln(F(x))k,根据对数函数的性质,得到 ln G ( x ) ≡ k ln F ( x ) \ln G(x)\equiv k\ln F(x) lnG(x)≡klnF(x)。
然后两边再同时取 exp \exp exp,得到 G ( x ) ≡ e k ln F ( x ) G(x)\equiv e^{k\ln F(x)} G(x)≡eklnF(x),接下来拉一拉 ln \ln ln 和 exp \exp exp 的板子就好了。
发现上面有这样一条柿子: ln G ( x ) ≡ k ln F ( x ) \ln G(x)\equiv k\ln F(x) lnG(x)≡klnF(x),而我们的数据运算都是在模 998244353 998244353 998244353 下进行的,也就是说,让 k k k 对 998244353 998244353 998244353 取模不影响答案,那么就不需要高精度什么的了。
对于洛谷上的加强版,不保证 a 0 = 1 a_0=1 a0=1,是做不了多项式 ln \ln ln 的,要分类讨论一下:
- a 0 > 1 a_0>1 a0>1,那么设 p = a 0 p=a_0 p=a0,有 F k ( x ) ≡ ( F ( x ) × i n v ( p ) ) k × p k F^k(x)\equiv (F(x)\times inv(p))^k\times p^k Fk(x)≡(F(x)×inv(p))k×pk,对于 F ( x ) × i n v ( p ) F(x)\times inv(p) F(x)×inv(p),它的 0 0 0 次项一定是 1 1 1,就可以做 ln \ln ln 了,求完之后再乘 p k p^k pk 即可。
- a 0 = 0 a_0=0 a0=0,那么我们向右找到第一个系数不为 0 0 0 的项,假设是 x z x^z xz,那么让 F ( x ) F(x) F(x) 除以 x z x^z xz 即系数左移 z z z 位,然后像上面那样搞就好,搞完之后再乘 ( x z ) k (x^z)^k (xz)k 即可。
加强版代码如下(由于用了大量vector,需要开O2才能过):
#include <cstdio>
#include <cstring>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;
#define maxn 800010
#define mod 998244353
#define bin(x) (1<<(x))
int n,k,k_,length;char s[maxn];
void read(int &x)
{
x=0;char ch=getchar();while(ch<'0'||ch>'9')ch=getchar();
while(ch>='0'&&ch<='9'){x=(10ll*x+ch-'0')%mod;ch=getchar();}
}
int inv[maxn];
int ksm(int x,int y){int re=1;for(;(y&1?re=1ll*re*x%mod:0),y;y>>=1,x=1ll*x*x%mod);return re;}
#define INV(x) ksm(x,mod-2)
struct NTT{
vector<int> w[30];NTT(){
inv[1]=1;for(int i=2;i<=maxn-10;i++)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
for(int i=1,wn;i<=19;i++){
w[i].resize(bin(i));w[i][0]=1;wn=ksm(3,(mod-1)/bin(i));
for(int j=1;j<bin(i-1);j++)w[i][j]=1ll*w[i][j-1]*wn%mod;
}
}
int limit,r[maxn];void dft(int *f,int lg,int type=0)
{
limit=bin(lg);if(type)reverse(f+1,f+limit);
for(int i=1;i<limit;i++){r[i]=(r[i>>1]>>1)|((i&1)<<(lg-1));if(i<r[i])swap(f[i],f[r[i]]);}
for(int mid=1,Lg=1;mid<limit;mid<<=1,Lg++)for(int j=0;j<limit;j+=(mid<<1))for(int i=0;i<mid;i++)
{int t=1ll*f[j+i+mid]*w[Lg][i]%mod;f[j+i+mid]=(f[j+i]-t+mod)%mod;f[j+i]=(f[j+i]+t)%mod;}
}
}ntt;
int A[maxn],B[maxn],C[maxn],M;
struct POLY{
vector<int> a;int len;void rs(int N){a.resize(len=N);}POLY(){rs(M);};
int &operator [](int x){return a[x];}
friend POLY operator *(POLY A_,int x){for(int i=0;i<A_.len;i++)A_[i]=1ll*A_[i]*x%mod;return A_;}
void dft(int *A_,int lg,int ln){for(int i=0;i<bin(lg);i++)A_[i]=(i<min(len,ln)?a[i]:0);ntt.dft(A_,lg);}
void idft(int *A_,int lg,int ln){ntt.dft(A_,lg,1);rs(ln);for(int i=0;i<ln;i++)a[i]=1ll*A_[i]*inv[bin(lg)]%mod;}
POLY Mul(POLY b,int ln=M){
int lg=ceil(log2(2*ln-1));dft(A,lg,ln);b.dft(B,lg,ln);
for(int i=0;i<bin(lg);i++)B[i]=1ll*A[i]*B[i]%mod;b.idft(B,lg,ln);return b;
}
}F,G;
void getinv(POLY &f,POLY &g,int ln=M)
{
if(ln==1){g.rs(1);g[0]=INV(f[0]);return;}getinv(f,g,(ln+1)>>1);
int lg=ceil(log2(ln*2-1));f.dft(A,lg,ln);g.dft(B,lg,ln);
for(int i=0;i<bin(lg);i++)B[i]=1ll*(2-1ll*A[i]*B[i]%mod+mod)%mod*B[i]%mod;g.idft(B,lg,ln);
}
POLY getinv(POLY &f,int ln=M){POLY g;getinv(f,g,ln);return g;};
POLY Jifen(POLY f){f.rs(f.len+1);for(int i=f.len-1;i>0;i--)f[i]=1ll*f[i-1]*inv[i]%mod;f[0]=0;return f;}
POLY Dao(POLY f){for(int i=0;i<f.len-1;i++)f[i]=1ll*f[i+1]*(i+1)%mod;f.rs(f.len-1);return f;}
POLY getln(POLY f,int ln=M){return Jifen(Dao(f).Mul(getinv(f,ln),ln-1));}
void getexp(POLY &f,POLY &g,int ln=M)
{
if(ln==1){g.rs(1);g[0]=1;return;}getexp(f,g,(ln+1)>>1);
POLY p=getln(g,ln);for(int i=0;i<ln;i++)p[i]=(f[i]-p[i]+mod)%mod;p[0]++;g=p.Mul(g,ln);
}
POLY getexp(POLY f,int ln){POLY g;getexp(f,g,ln);return g;}
POLY getksm(POLY f,int k,int k_,int ln=M)
{
int st=0;while(!f[st]&&st<ln)st++;for(int i=0;i<ln;i++)f[i]=(i+st<ln?f[i+st]:0);
int f_0=f[0];f_0=ksm(f_0,k_);f=getexp(getln(f*INV(f[0]),ln)*k,ln)*f_0;
if(st)if(length<=9&&1ll*st*k<ln)for(int i=n-1;i>=0;i--)f[i]=(i>=st*k?f[i-st*k]:0);else f.rs(0),f.rs(ln); return f;
}
int main()
{
scanf("%d %s",&n,s);M=n;F.rs(n);
for(int i=0;i<n;i++)scanf("%d",&F[i]);
length=strlen(s);for(int i=0;i<length;i++)k=(10ll*k+s[i]-'0')%mod,k_=(10ll*k_+s[i]-'0')%(mod-1);
G=getksm(F,k,k_,n);for(int i=0;i<n;i++)printf("%d ",G[i]);
}