题目链接:http://www.51nod.com/contest/problem.html#!problemId=1387
51nod的一道题目,高中生出的一道鬼畜的题目。外面套了一层模型,打表找规律就能发现本质就是让你求N!%P,其中P是一个费马素数。很明显了,就是让你分治NTT。然而这并不是一个很简单的算法(虽然在高中生似乎人人都会),学习了一天这篇:http://picks.logdown.com/posts/245545-binomial-coefficient-modulo-prime
终于学会了多项式多点求值(当时看gxx_slide就没看懂)
然后这题就是将N!分成
N−−√
块,每块就是
f(x)=∏i=1N√(x+i)
我们要求的就是
∏i=0N√−1f(i∗N−−√)
计算f(x)的系数我们就分治,两个子式计算完用ntt乘起来就可以。
后面涉及到多点求值,多点求值的思想很简单,我们发现在计算 f(xi) 的时候,把原来的多项式模去一个 (x−xi) 后对结果没有影响,因为把 xi 带入被模掉的式子里面算出来肯定是0.于是我们就可以这样分治的去计算:要计算
f(xi)(i=1,2,3...n)
另
g1(x)=f(x) mod ∏i=1n/2(x+i)
g2(x)=f(x) mod ∏i=n/2+1n(x+i)
就是说算前一半值我们用g1(x)去算,算后一半值用g2(x)去算,复杂度很好算,每次多项式的阶减半,要求的点的数目也减半,假设多项式的阶和要求的点是同阶的,那么就有
T(n)=2∗T(n/2)+O(nlogn)
,
T(n)=O(nlog2n)
然而由于多项式取模常数巨大,因此这道题需要在规模比较小的时候改用暴力才能通过。这道题我还是感觉很难写,实现也用了线段树实现,不知道有没有其他优秀的姿势。。。。
最上面的那几个函数是用来求原根的
#include<bits/stdc++.h>
using namespace std;
typedef long long Int;
#define ls l,mid,x<<1
#define rs mid+1,r,x<<1|1
const int Maxn=262146<<1;
int n,p,g,m,ans;
vector<int>pri;
bool isp[1000020];
vector<int>yinzi;
int a[Maxn],b[Maxn],tmp1[Maxn],tmp2[Maxn],tmp3[Maxn],tmp[Maxn];
int rep[Maxn];
int powmod(int x,int y)
{
int ret=1;
while(y){if(y&1)ret=ret*(Int)x%p;y>>=1;x=x*(Int)x%p;}
return ret;
}
void exgcd(int a,int b,int& x,int& y)
{
b?(exgcd(b,a%b,y,x),y-=a/b*x):(x=1,y=0);
}
int getinv(int a)
{
int x,y;
exgcd(a,p,x,y);
return x<0?x+p:x;
}
void getp()
{
for(int i=2;i<=100000;i++)
{
if(!isp[i])
{
pri.push_back(i);
for(int j=i+i;j<=100000;j+=i)
isp[j]=1;
}
}
int tp=p-1;
for(int i=0;pri[i]*(Int)pri[i]<=tp;i++)
{
if(tp%pri[i]==0)
{
yinzi.push_back(pri[i]);
while(tp%pri[i]==0)tp/=pri[i];
}
}
if(tp!=1)yinzi.push_back(tp);
}
bool check(int x)
{
if(powmod(x,(p-1)/2)==1)return 0;
for(int i=0;i<yinzi.size();i++)
{
if(powmod(x,(p-1)/yinzi[i])==1)return 0;
}
return 1;
}
int getg()
{
for(int i=2;;i++)
{
if(check(i))return i;
}
return -1;
}
vector<int>V[Maxn<<2],xx[Maxn<<2];
void rev(int *a,int n)
{
int i,j,k;
for(i=1,j=n>>1;i<n-1;i++)
{
if(i<j)swap(a[i],a[j]);
for(k=n>>1;j>=k;j-=k,k>>=1);j+=k;
}
}
void dft(int *a,int n,int flag=1)
{
rev(a,n);
for(int m=2;m<=n;m<<=1)
{
int wm=powmod(g,(p-1)/m);
if(flag<0)wm=getinv(wm);
for(int k=0;k<n;k+=m)
{
Int w=1;
for(int j=k;j<k+(m>>1);j++,w=w*wm%p)
{
Int u=a[j],v=a[j+(m>>1)]*(Int)w%p;
a[j]=(u+v)%p;
a[j+(m>>1)]=(u-v+p)%p;
}
}
}
}
void mul(int *a,int *b,int n)
{
dft(a,n);dft(b,n);
for(int i=0;i<n;i++)a[i]=a[i]*(Int)b[i]%p;
dft(a,n,-1);
int revn=getinv(n);
for(int i=0;i<n;i++)a[i]=a[i]*(Int)revn%p;
}
void GetInv( int *a, int *a0, int t )
{
if(t==1){a0[0]=getinv(a[0]);return;}
GetInv(a,a0,(t+1)>>1);
int N=1;
while(N<(t<<1)+3)N<<=1;
int inv_k=getinv(N);
for(int i=0;i<N;i++)tmp[i]=i<t?a[i]:0;
dft(a0,N);dft(tmp,N);
for(int i=0;i<N;i++)
tmp[i]=(2-tmp[i]*(Int)a0[i]%p+p)%p;
for(int i=0;i<N;i++)
a0[i]=a0[i]*(Int)tmp[i]%p;
dft(a0,N,-1);
for(int i=0;i<t;i++)a0[i]=a0[i]*(Int)inv_k%p;
for(int i=t;i<N;i++)a0[i]=0;
}
void module(int *a,int *b,int n,int m)//最高次数为n-1,m-1,Q(x)最高次数为n-m,R(x)最高次数为m-2
{
for(int i=0;i<n;i++)tmp1[i]=a[n-1-i];
for(int i=0;i<m;i++)tmp2[i]=b[m-1-i];
int N=1;
while(N<n+m)N<<=1;
for(int i=n-m+1;i<N;i++)tmp1[i]=tmp2[i]=0;
for(int i=0;i<N;i++)tmp3[i]=0;
GetInv(tmp2,tmp3,n-m+1);//zhuyi
for(int i=n-m+1;i<N;i++)tmp3[i]=0;
mul(tmp3,tmp1,N);
for(int i=n-m+1;i<N;i++)tmp3[i]=0;
reverse(tmp3,tmp3+n-m+1);
for(int i=m;i<N;i++)b[i]=0;
mul(tmp3,b,N);
for(int i=0;i<m-1;i++)a[i]=(a[i]-(Int)tmp3[i]+p)%p;
}
void build(int l,int r,int x)
{
if(l==r)
{
V[x].push_back((p-(l-1)*m)%p);V[x].push_back(1);
xx[x].push_back(l);xx[x].push_back(1);
return;
}
int mid=(l+r)>>1;
build(ls);build(rs);
int len=r-l+1;
if(len<=1500)
{
V[x].assign(len+1,0);
xx[x].assign(len+1,0);
for(int i=0;i<V[x<<1].size();i++)
for(int j=0;j<V[x<<1|1].size();j++)
{
V[x][i+j]=(V[x][i+j]+(Int)V[x<<1][i]*V[x<<1|1][j])%p;
xx[x][i+j]=(xx[x][i+j]+(Int)xx[x<<1][i]*xx[x<<1|1][j])%p;
}
return ;
}
int N=1;
while(N<=len+3)N<<=1;
for(int i=0;i<N;i++)
{
if(i<V[x<<1].size())
{
a[i]=V[x<<1][i];
tmp1[i]=xx[x<<1][i];
}
else a[i]=tmp1[i]=0;
if(i<V[x<<1|1].size())
{
b[i]=V[x<<1|1][i];
tmp3[i]=xx[x<<1|1][i];
}
else b[i]=tmp3[i]=0;
}
mul(a,b,N);
mul(tmp1,tmp3,N);
for(int i=0;i<=len;i++)
{
V[x].push_back(a[i]);
xx[x].push_back(tmp1[i]);
}
}
void solve3(int l,int r,int x)
{
if(r<=l+1500)
{
for(int i=l;i<=r;i++)
{
int tx=(i-1)*m;
int tp=0;
for(int j=V[x].size()-1;j>=0;j--)
{
tp=(tp*(Int)tx+V[x][j])%p;
}
ans=ans*(Int)tp%p;
}
//rep[l]=((l-1)*m*(Int)V[x][1]%p+V[x][0])%p;
return;
}
int mid=(l+r)>>1;
int len=r-l+1;
int N=1;while(N<=len)N<<=1;
for(int i=0;i<N;i++)
{
if(i<=len)a[i]=V[x][i];
else a[i]=0;
if(i<V[x<<1|1].size())b[i]=V[x<<1|1][i];
else b[i]=0;
}
module(a,b,len+1,V[x<<1|1].size());
for(int i=0;i<V[x<<1|1].size();i++)V[x<<1|1][i]=i==V[x<<1|1].size()-1?0:a[i];
for(int i=0;i<N;i++)
{
if(i<=len)a[i]=V[x][i];
else a[i]=0;
if(i<V[x<<1].size())b[i]=V[x<<1][i];
else b[i]=0;
}
module(a,b,len+1,V[x<<1].size());
for(int i=0;i<V[x<<1].size();i++)V[x<<1][i]=i==V[x<<1].size()-1?0:a[i];
solve3(ls);solve3(rs);
}
void solve2()
{
build(1,m,1);
for(int i=0;i<=m;i++)V[1][i]=xx[1][i];
solve3(1,m,1);
}
int main()
{
scanf("%d%d",&n,&p);
/* int tpans=1;
for(int i=1;i<=n;i++)
tpans=tpans*(Int)i%p;
if(n&1)tpans=tpans*(Int)(p+1)/2%p;
printf("realans=%d\n",tpans);*/
getp();
g=getg();
m=sqrt(n+0.5);
ans=1;
for(int i=m*m+1;i<=n;i++)
ans=ans*(Int)i%p;
solve2();
if(n&1)ans=ans*(Int)(p+1)/2%p;
printf("%d\n",ans);
}