链接:51nod1059
给个数n求n阶乘准确值,n最大10万,会大整数知道100阶乘就很多了,普通大数乘法运算,1秒1万已经极限了,4秒求10万阶乘不可能,所以要考虑优化大数乘法,对于位数比较多的乘法用NTT或者FFT(事实上之后测试以及其他高手数学论证,FFT浮点数误差在大数乘法里可以忽略(代码里用+0.1转换),且FFT比NTT快一点,FFT浮点数运算,但是FTT总运算次数少抵消了浮点数运算的劣势),NTT和FFT求大数乘法的理论另看其他文章这里不讲了。
这还不够,如果从1开始乘到10万,即便NTT也过不了,因为NTT必须是2的整数次方,不够的会补零,使得每次乘法要NTT的数组变得很长,因此从1开始乘到10万不行,正确方法是尽量让位数相近的大数乘,把10万分成两半:1-5万,5万-10万之后再对他们再分,类似二分,比如10!=(1*2 )*(3*4 )*(5*6 )*(7*8 )*(9*10 )=(2 *12)*(30*56 )*90=40320*90=362880
这么做看似也是进行了n次乘法,但是ntt时候避免了:补零过多而使得大数的数组长度变得很长。时间复杂度大大优化
但是到这里还有很多细节注意,在于NTT的写法,时间常数不可以太大,对于经常访问的数,比如大素数P,应该加const,能快很多,否则是不过的(常量访问不需要寻址,而且对%运算有优化)。
我这里预处理了NTT需要得数组,也能加速,最后结果2秒多就过了
对于熟悉NTT的人知道,NTT的大素数P取值应当大于:len*HEX^2 len是要被NTT数组长度,HEX是数组里数最大值,体现在大数乘法里就是你压的位数,len是结果的长度,10万阶乘有450000多位,因此如果P=998244353,最多压2位因为100*100*450000=4.5*10^9,刚好小于P。想多压就得把P取得很大。
用python的可以试试py自带的计算阶乘函数,超快,我这个有很多不必要的地址赋值清楚和拷贝操作,为了封装大整数用的,事实上使用FFT,在我这个基础上优化不必要的地址赋值清楚和拷贝操作,输入输出优化也加上,就可以达到500ms左右的时间,c++就得全都自己手撕了,本来是带的以前写好的大数模板,但多余代码去掉了
此题用极限优化的FFT是最好解法,极限FFT能做到500ms过,至于怎么极限优化FFT,既要有数学上的简化,也有对于内部实现时,临时变量的取舍。51nod可以看别人代码,过了的可以抄袭一份时间排名靠前自己研究。当然,有的人把FFT的logn次运算拆开写了(为了节省循环的那点时间),这样代码量太多而且繁琐,时间仅仅优化一点点,我觉得得不偿失。
#include<bits/stdc++.h>
//#include<windows.h>
using namespace std;
#define ll long long
#define inf 1e-5
const int inv2=500000004;
const int INF=2147483647;
const int MAX=100010;
const int mod=1e9+7;
namespace NTT{//FFT准备较多,避免名称混淆,定义个命名空间,以后避免重名的麻烦
const int P=998244353;
int saveN=-1;
ll powg[MAX*8];
int seq[MAX*8];
const int p2[30]={1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072,262144,524288,1048576,2097152};
const int grn[30]={1,998244352,911660635,372528824,929031873,452798380,922799308,781712469,476477967,166035806,258648936,584193783,63912897,350007156,666702199,968855178,629671588,24514907,996173970,363395222,565042129,733596141};
const int grninv[30]={1,998244352,86583718,509520358,337190230,87557064,609441965,135236158,304459705,685443576,381598368,335559352,129292727,358024708,814576206,708402881,283043518,3707709,121392023,704923114,950391366,428961804};
void rader(int *x,int N){//必须是2的整数幂
int i,j,k;
for(i=0;i<N;i++){
x[i]=i;
}
for(j=0,i=0;i<N-1;i++){
if(i<j){
swap(x[i],x[j]);
}
k=N/2;
while(j>=k){
j-=k;
k/=2;
}
j+=k;
}
}
ll quickmi(ll a,ll b,ll P){
ll ans=1;
for(a%=P;b;b>>=1,a=a*a%P)
if(b&1)
ans=ans*a%P;
return ans;
}
ll changeN(ll N){//把N变成2的整数次方,并且生成序列
ll i;
for(i=0;p2[i]<N;i++);
N=p2[i];
if(saveN!=N){//避免重复计算倒序,如果和上次一样就用原来的
saveN=N;
rader(seq,N);
}
return N;
}
//x长度必须是2的整数倍,x数组一定要长度>=N,否则可能出现越界,kind是-1表示逆变换,up新数组长度,xlen是原本数组长度
ll* ntt(const ll *x,int N,int kind,int up=MAX,int xlen=MAX){
int i,j,n,m,bit,now;
ll d,inv,b,G,Gn;
ll *temp,*F;
F=new ll[up];
for(i=0;i<N;i++)
F[i]=seq[i]<xlen?x[seq[i]]:0;
for(n=2,bit=1;n<=N;n*=2,bit++){
temp=new ll[up];
m=n/2;
powg[0]=1;
G=kind==1?grn[bit]:grninv[bit];
for(i=1;i<m;i++){//提前预处理加速
powg[i]=powg[i-1]*G%P;
}
for(i=0;i<N;i+=n){
for(j=0;j<m;j++){
b=F[i+j];
d=F[i+j+m]*powg[j]%P;
temp[i+j]=(b+d)%P;
temp[i+j+m]=(b-d+P)%P;//利用公式少做乘法加速
}
}
delete []F;
F=temp;
}
if(kind==-1){
ll invN=quickmi(N,P-2,P);
for(i=0;i<N;i++){
F[i]=F[i]*invN%P;
}
}
return F;
}
};
#define BIT 2//压位
ll p10[]={1,10,100,1000,10000,100000,1000000,10000000,100000000,1000000000,10000000000,100000000000};
const ll HEX=p10[BIT];
struct BigInteger{
int sign;
ll *number;
int size=0;
BigInteger(int up=MAX){
number=new ll[up];
}
BigInteger(char*s,int up=MAX){
init(s,up);
}
BigInteger(ll s,int up=MAX){
init(s,up);
}
~BigInteger(){
delete []number;
}
void init(char*s,int up=MAX){//s可以有符号,1代表正数,-1代表负数
number=new ll[up];
ll c=0;
int i,j=0,len;
len=strlen(s);
sign=1;
for(i=len-1;i>=0;i--){
if(s[i]>47&&s[i]<58){
if(j<BIT){
c+=p10[j++]*(s[i]-'0');
}else{
number_push(c);
j=0;
c=p10[j++]*(s[i]-'0');
}
}else if(s[i]=='-'){
sign=-1;
}
}
number_push(c);
if(size==1&&number[0]==0){
sign=0;
}
}
void init(ll s,int up=MAX){
number=new ll[up];
sign=0;
if(s<0){
sign=-1;
s=-s;
}else if(s>0){
sign=1;
}
while(s){
number_push(s%HEX);
s/=HEX;
}
}
void number_push(ll a){
number[size++]=a;
}
BigInteger* nttmul(BigInteger *b){//NTT实现
BigInteger *c=new BigInteger();
c->sign=sign*b->sign;
ll i,len,N,next=0,up;
len=size+b->size-1;
N=NTT::changeN(len);
up=N*2;
ll *F1=NTT::ntt(number,N,1,up,size);
ll *F2=NTT::ntt(b->number,N,1,up,b->size);
ll *F3=new ll[up];
for(i=0;i<N;i++){
F3[i]=F1[i]*F2[i]%NTT::P;
}
ll* x=NTT::ntt(F3,N,-1,up,N);
for(i=0;i<len;i++){
x[i]+=next;
next=x[i]/HEX;
x[i]=x[i]%HEX;
}
if(next){
x[len++]=next;
}
for(;!x[len-1];len--);
memset(x+len,0,sizeof(ll)*(up-len));//末尾清0;
delete []c->number;
c->number=x;
c->size=len;
delete []F1;
delete []F2;
delete []F3;
return c;
}
BigInteger* mul(BigInteger *b){
BigInteger *c=new BigInteger(size+b->size);
c->sign=sign*b->sign;
ll i,j,h,k,next=0;
for(i=0;i<size;i++){
for(j=0;j<b->size;j++){
h=j+i;
k=number[i]*b->number[j]+next;
if(h<c->size){
k+=c->number[h];
c->number[h]=k%HEX;
}else{
c->number_push(k%HEX);
}
next=k/HEX;
}
for(;next;c->number_push(next%HEX),next/=HEX);
}
return c;
}
void println(){//打印格式长度等于BIT
if(sign==0){
printf("0\n");
return;
}
if(sign==-1)
printf("-");
printf("%llu",number[size-1]);
for(int i=size-2;i>=0;i--){
printf("%02llu",number[i]);
}printf("\n");
}
char *toChar(ll a){
char *c=new char[BIT];
memset(c,'0',sizeof(char)*BIT);
for(int i=BIT-1;i>=0;i--){
c[i]=(char)(a%10+'0');
a/=10;
}
return c;
}
void myprintln(){//打印格式长度等于BIT
ll i,j,k;
printf("%lld",number[size-1]);
for(i=0,k=number[size-1];k>0;i++,k/=10);
int len=i;
for(int i=size-2;i>=0;i--){
if(len+BIT<1000){
printf("%02lld",number[i]);
len+=BIT;
}else{
char *c=toChar(number[i]);
for(j=0;j<BIT;j++){
printf("%c",c[j]);
if(j+len==999)
printf("\n");
}
len=len+BIT-1000;
delete []c;
}
}
}
};
BigInteger *temp;
BigInteger *arr[MAX];
int main(int argc,char *argv[]){
//freopen("in.txt","r",stdin); //输入重定向,输入数据将从in.txt文件中读取
//freopen("数据1059/out.txt","w",stdout); //输出重定向,输出数据将保存在out.txt文件中
//srand(time(NULL));//有的OJ不能加这句话
ll i,j,k,n;
scanf("%lld",&n);
//n=100000;
for(i=1;i<=n;i++){
arr[i]=new BigInteger(i,16);
}
for(j=1;j<n;j=k){
k=j*2;
for(i=1;i+j<=n;i+=k){
temp=arr[i];
if(arr[i]->size<32&&arr[i+j]->size<32){
arr[i]=arr[i]->mul(arr[i+j]);
}else{
arr[i]=arr[i]->nttmul(arr[i+j]);
}
delete temp;
delete arr[i+j];
}
}
arr[1]->myprintln();
return 0;
}