题目链接:【BZOJ】3992 [SDOI2015]序列统计
题目大意:给一个集合
S
,元素均为小于
题目分析:利用原根的性质,对下标做变换,设
g
为
#include <bits/stdc++.h>
using namespace std ;
typedef long long LL ;
typedef pair < int , int > pii ;
typedef unsigned long long ULL ;
#define clr( a , x ) memset ( a , x , sizeof a )
const int MAXN = 50000 ;
const int Mod = 1004535809 ;
LL res[MAXN] , x[MAXN] ;
int mod , p , g , S[10] , top ;
int n , m , X ;
int vis[MAXN] ;
int gp[MAXN] ;
int pm ( LL x , int n , int mod ) {
LL res = 1 ;
while ( n ) {
if ( n & 1 ) res = res * x % mod ;
x = x * x % mod ;
n >>= 1 ;
}
return res ;
}
void preprocess () {
int n = mod - 1 ;
top = 0 ;
for ( int i = 2 ; i <= n ; ++ i ) {
if ( n % i == 0 ) {
S[top ++] = i ;
while ( n % i == 0 ) n /= i ;
}
}
if ( n > 1 ) S[top ++] = n ;
for ( int i = 2 ; ; ++ i ) {
int ok = 1 ;
for ( int j = 0 ; j < top ; ++ j ) {
if ( pm ( i , ( mod - 1 ) / S[j] , mod ) == 1 ) {
ok = 0 ;
break ;
}
}
if ( ok ) {
g = i ;
break ;
}
}
}
void FFT ( LL y[] , int n , int f ) {
for ( int i = 1 , j , k , t ; i < n ; ++ i ) {
for ( j = 0 , k = n >> 1 , t = i ; k ; k >>= 1 , t >>= 1 ) {
j = j << 1 | ( t & 1 ) ;
}
if ( i < j ) swap ( y[i] , y[j] ) ;
}
for ( int s = 2 , ds = 1 ; s <= n ; ds = s , s <<= 1 ) {
LL wn = pm ( 3 , ( Mod - 1 ) / s , Mod ) ;
if ( !f ) wn = pm ( wn , Mod - 2 , Mod ) ;
for ( int k = 0 ; k < n ; k += s ) {
LL w = 1 , t ;
for ( int i = k ; i < k + ds ; ++ i ) {
y[i + ds] = ( y[i] - ( t = w * y[i + ds] % Mod ) + Mod ) % Mod ;
y[i] = ( y[i] + t ) % Mod ;
w = w * wn % Mod ;
}
}
}
}
void pow ( int k ) {
int n = 1 ;
while ( n < mod + mod ) n <<= 1 ;
int nv = pm ( n , Mod - 2 , Mod ) ;
while ( k ) {
FFT ( x , n , 1 ) ;
if ( k & 1 ) {
FFT ( res , n , 1 ) ;
for ( int i = 0 ; i < n ; ++ i ) {
res[i] = res[i] * x[i] % Mod ;
}
FFT ( res , n , 0 ) ;
for ( int i = 0 ; i < n ; ++ i ) {
res[i] = res[i] * nv % Mod ;
}
for ( int i = p ; i < n ; ++ i ) {
res[i % p] = ( res[i % p] + res[i] ) % Mod ;
res[i] = 0 ;
}
}
for ( int i = 0 ; i < n ; ++ i ) {
x[i] = x[i] * x[i] % Mod ;
}
FFT ( x , n , 0 ) ;
for ( int i = 0 ; i < n ; ++ i ) {
x[i] = x[i] * nv % Mod ;
}
for ( int i = p ; i < n ; ++ i ) {
x[i % p] = ( x[i % p] + x[i] ) % Mod ;
x[i] = 0 ;
}
k >>= 1 ;
}
}
void solve () {
preprocess () ;
p = mod - 1 ;
clr ( vis , 0 ) ;
clr ( res , 0 ) ;
clr ( x , 0 ) ;
res[0] = 1 ;
gp[0] = 1 ;
for ( int i = 1 ; i < mod ; ++ i ) {
gp[i] = gp[i - 1] * g % mod ;
}
for ( int i = 0 , v ; i < m ; ++ i ) {
scanf ( "%d" , &v ) ;
vis[v] = 1 ;
}
for ( int i = 0 ; i < p ; ++ i ) {
x[i] = vis[gp[i]] ;
}
pow ( n ) ;
for ( int i = 0 ; i < p ; ++ i ) {
if ( gp[i] == X ) {
printf ( "%lld\n" , res[i] ) ;
return ;
}
}
}
int main () {
while ( ~scanf ( "%d%d%d%d" , &n , &mod , &X , &m ) ) solve () ;
return 0 ;
}
压缩后代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define clr(a,x) memset(a,x,sizeof a)
const int MAXN=50000,Mod=1004535809;
LL res[MAXN],x[MAXN];
int mod,p,g,S[10],top,n,m,X,vis[MAXN],gp[MAXN];
int pm(LL x,int n,int mod,LL res=1){
for(;n;x=x*x%mod,n>>=1)if(n&1)res=res*x%mod;
return res;
}
void preprocess(){
int n=mod-1,top=0,ok;
for(int i=2;i<=n;++i)if(n%i==0){
S[top++]=i;
while(n%i==0)n/=i;
}
if(n>1)S[top++]=n;
for(g=2;ok=1;++g){
for(int j=0;j<top;++j)if(pm(g,(mod-1)/S[j],mod)==1)ok=0;
if(ok)break;
}
}
int FFT(LL y[],int n,int f){
for(int i=1,j,k,t;i<n;++i){
for(j=0,k=n>>1,t=i;k;k>>=1,t>>=1)j=j<<1|(t&1);
if(i<j)swap(y[i],y[j]);
}
for(int s=2,ds=1,k;s<=n;ds=s,s<<=1){
LL wn=pm(3,(Mod-1)/s,Mod),w,t;
if(!f)wn=pm(wn,Mod-2,Mod);
for(k=0;w=1,k<n;k+=s)for(int i=k;i<k+ds;++i,w=w*wn%Mod){
y[i+ds]=(y[i]-(t=w*y[i+ds]%Mod)+Mod)%Mod;
y[i]=(y[i]+t)%Mod;
}
}
}
void calc(LL x[],LL y[],int nv,int n){
for(int i=0;i<n;++i)x[i]=x[i]*y[i]%Mod;
FFT(x,n,0);
for(int i=0;i<n;++i)x[i]=x[i]*nv%Mod;
for(int i=p;i<n;++i)x[i%p]=(x[i%p]+x[i])%Mod,x[i]=0;
}
void pow(int k){
int n=1,nv;
while(n<mod+mod)n<<=1;
for(nv=pm(n,Mod-2,Mod),FFT(x,n,1);k;calc(x,x,nv,n),k>>=1,k&&FFT(x,n,1))
if(k&1)FFT(res,n,1),calc(res,x,nv,n);
}
void solve(){
preprocess();
p=mod-1;
clr(vis,0),clr(res,0),clr(x,0);
res[0]=gp[0]=1;
for(int i=1;i<mod;++i)gp[i]=gp[i-1]*g%mod;
for(int i=0,v;i<m;vis[v]=1,++i)scanf("%d",&v);
for(int i=0;i<p;++i)x[i]=vis[gp[i]];
pow(n);
for(int i=0;i<p;++i)if(gp[i]==X){
printf("%lld\n",res[i]);
return;
}
}
int main(){
while(~scanf("%d%d%d%d",&n,&mod,&X,&m))solve();
return 0;
}