UOJ#86：mx的组合数 （Lucas定理+原根+NTT+高精度）

${C}_{m}^{n}={C}_{⌊\frac{m}{p}⌋}^{⌊\frac{n}{p}⌋}\ast {C}_{m\phantom{\rule{0.667em}{0ex}}\mathrm{mod}\phantom{\rule{thinmathspace}{0ex}}\phantom{\rule{thinmathspace}{0ex}}p}^{n\phantom{\rule{0.667em}{0ex}}\mathrm{mod}\phantom{\rule{thinmathspace}{0ex}}\phantom{\rule{thinmathspace}{0ex}}p}$

CODE：

#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;

const int maxn=1000000;
const int maxl=50;
const long long M=998244353;
const long long g=3;
typedef long long LL;

LL A[maxn];
LL B[maxn];

int Rev[maxn];
int N,Lg;

struct Big_Int
{
LL num[maxl];
void Down() { while ( !num[ num[0] ] && num[0] ) num[0]--; }
} n,n1,l,r;
char s[maxl];

#define P pair<Big_Int,long long>
#define MP(x,y) make_pair(x,y)

LL fac[maxn];
LL nfac[maxn];

int id[maxn];
int rid[maxn];

LL ans[maxn];
LL p,pg;

LL Pow(LL x,LL y,LL z)
{
if (!y) return 1LL;
LL temp=Pow(x,y>>1,z);
temp=temp*temp%z;
if (y&1) temp=temp*x%z;
return temp;
}

LL Get(LL x)
{
x--;
LL mz=(long long)floor( sqrt( (double)x )+0.5 );
for (LL y=2; y<=x; y++)
{
bool sol=true;
for (LL z=2; z<=mz; z++)
if (x%z==0)
{
if ( Pow(y,z,p)==1LL ) sol=false;
if ( Pow(y,x/z,p)==1LL ) sol=false;
if (!sol) break;
}
if (sol) return y;
}
}

{
scanf("%s",s);
int len=strlen(s);
x.num[0]=len;
for (int i=0; i<len; i++) x.num[len-i]=s[i]-'0';
x.Down();
}

P Div(Big_Int x,LL y)
{
for (int i=x.num[0]; i>=2; i--)
{
x.num[i-1]+=(x.num[i]%y*10LL);
x.num[i]/=y;
}
LL z=x.num[1]%y;
x.num[1]/=y;
x.Down();
return ( MP(x,z) );
}

LL Change(Big_Int x)
{
LL y=0;
for (int i=x.num[0]; i>=1; i--)
{
y=y*10+x.num[i];
if (y>=p) return -1LL;
}
return y;
}

LL CI(int x,int y)
{
if (y>x) return 0;
LL v=fac[x];
v=v*nfac[y]%p;
v=v*nfac[x-y]%p;
return v;
}

void Dec(Big_Int &x)
{
x.num[1]--;
int y=1;
while (x.num[y]<0) x.num[y]+=10,x.num[++y]--;
x.Down();
}

void DFT(LL *a,int f)
{
for (int i=0; i<N; i++)
if (i<Rev[i]) swap(a[i],a[ Rev[i] ]);

for (int len=2; len<=N; len<<=1)
{
int mid=(len>>1);
LL e=Pow(g,(M-1)/len,M);
if (f==-1) e=Pow(e,M-2,M);

for (LL *p=a; p!=a+N; p+=len)
{
LL wn=1;
for (int i=0; i<mid; i++)
{
LL temp=wn*p[mid+i]%M;
p[mid+i]=(p[i]-temp+M)%M;
p[i]=(p[i]+temp)%M;
wn=wn*e%M;
}
}
}
}

void NTT()
{
DFT(A,1);
DFT(B,1);
for (int i=0; i<N; i++) A[i]=A[i]*B[i]%M;
DFT(A,-1);

LL inv=Pow(N,M-2,M);
for (int i=0; i<N; i++) A[i]=A[i]*inv%M;
}

LL C(Big_Int k,Big_Int m)
{
LL ck=Change(k);
LL cm=Change(m);
if ( ck>=0 && cm>=0 ) return CI(ck,cm);

P x=Div(k,p);
P y=Div(m,p);
LL v=C(x.first,y.first);
v=v*CI(x.second,y.second)%p;
return v;
}

void Solve(Big_Int k,Big_Int m)
{
LL ck=Change(k);
LL cm=Change(m);
if ( ck>=0 && cm>=0 )
{
for (int i=0; i<N; i++) A[i]=0;
//this is not the only exit way,don't clear A[] here!!!
for (int i=cm; i<=ck; i++) A[ id[ CI(i,cm) ] ]++;
return;
}

P x=Div(k,p);
P y=Div(m,p);
if (x.first.num[0])
{
Big_Int z=x.first;
Dec(z);
Solve(z,y.first);
for (int i=0; i<N; i++) B[i]=0;
for (int i=y.second; i<p; i++) B[ id[ CI(i,y.second) ] ]++;
NTT();
for (int i=p-1; i<N; i++) A[i%(p-1)]=(A[i%(p-1)]+A[i])%M,A[i]=0;
}

LL v=C(x.first,y.first);
if (v) for (int i=y.second; i<=x.second; i++)
{
LL &q=A[ id[ CI(i,y.second)*v%p ] ];
q=(q+1LL)%M;
}
}

int main()
{
//freopen("86.in","r",stdin);
//freopen("86.out","w",stdout);

scanf("%lld",&p);

pg=Get(p);
LL v=1;
for (int i=0; i<p-1; i++)
{
id[v]=i;
rid[i]=v;
v=v*pg%p;
}

fac[0]=1;
for (LL i=1; i<p; i++) fac[i]=fac[i-1]*i%p;
for (int i=0; i<p; i++) nfac[i]=Pow(fac[i],p-2LL,p);

N=1,Lg=0;
while (N<=2*p+2) N<<=1,Lg++;
for (int i=0; i<N; i++)
for (int j=0; j<Lg; j++)
if (i&(1<<j)) Rev[i]|=(1<<(Lg-j-1));

n1=n;

P x=Div(r,M);
ans[0]=x.second;
ans[0]=(ans[0]+1LL)%M; //[0,r] have r+1 numbers!!!
Solve(r,n);
for (int i=0; i<p-1; i++) ans[ rid[i] ]=A[i];

if (l.num[0])
{
Dec(l);
x=Div(l,M);
ans[0]=(ans[0]-x.second+M)%M;
ans[0]=(ans[0]-1LL+M)%M;
for (int i=0; i<N; i++) A[i]=0; //clear A[] here!!!
Solve(l,n1);

for (int i=0; i<p-1; i++)
{
int v=rid[i];
ans[v]=(ans[v]-A[i]+M)%M;
}
}

for (int i=1; i<p; i++) ans[0]=(ans[0]-ans[i]+M)%M; //write outside the "if"!!!
for (int i=0; i<p; i++) printf("%lld\n",ans[i]);

return 0;
}

©️2019 CSDN 皮肤主题: 技术黑板 设计师: CSDN官方博客