51nod1059 算10万阶乘

链接:51nod1059

给个数n求n阶乘准确值,n最大10万,会大整数知道100阶乘就很多了,普通大数乘法运算,1秒1万已经极限了,4秒求10万阶乘不可能,所以要考虑优化大数乘法,对于位数比较多的乘法用NTT或者FFT(事实上之后测试以及其他高手数学论证,FFT浮点数误差在大数乘法里可以忽略(代码里用+0.1转换),且FFT比NTT快一点,FFT浮点数运算,但是FTT总运算次数少抵消了浮点数运算的劣势),NTT和FFT求大数乘法的理论另看其他文章这里不讲了。

这还不够,如果从1开始乘到10万,即便NTT也过不了,因为NTT必须是2的整数次方,不够的会补零,使得每次乘法要NTT的数组变得很长,因此从1开始乘到10万不行,正确方法是尽量让位数相近的大数乘,把10万分成两半:1-5万,5万-10万之后再对他们再分,类似二分,比如10!=(1*2 )*(3*4 )*(5*6 )*(7*8 )*(9*10 )=(2 *12)*(30*56 )*90=40320*90=362880

这么做看似也是进行了n次乘法,但是ntt时候避免了:补零过多而使得大数的数组长度变得很长。时间复杂度大大优化

但是到这里还有很多细节注意,在于NTT的写法,时间常数不可以太大,对于经常访问的数,比如大素数P,应该加const,能快很多,否则是不过的(常量访问不需要寻址,而且对%运算有优化)。

我这里预处理了NTT需要得数组,也能加速,最后结果2秒多就过了

对于熟悉NTT的人知道,NTT的大素数P取值应当大于:len*HEX^2 len是要被NTT数组长度,HEX是数组里数最大值,体现在大数乘法里就是你压的位数,len是结果的长度,10万阶乘有450000多位,因此如果P=998244353,最多压2位因为100*100*450000=4.5*10^9,刚好小于P。想多压就得把P取得很大。

用python的可以试试py自带的计算阶乘函数,超快,我这个有很多不必要的地址赋值清楚和拷贝操作,为了封装大整数用的,事实上使用FFT,在我这个基础上优化不必要的地址赋值清楚和拷贝操作,输入输出优化也加上,就可以达到500ms左右的时间,c++就得全都自己手撕了,本来是带的以前写好的大数模板,但多余代码去掉了

此题用极限优化的FFT是最好解法,极限FFT能做到500ms过,至于怎么极限优化FFT,既要有数学上的简化,也有对于内部实现时,临时变量的取舍。51nod可以看别人代码,过了的可以抄袭一份时间排名靠前自己研究。当然,有的人把FFT的logn次运算拆开写了(为了节省循环的那点时间),这样代码量太多而且繁琐,时间仅仅优化一点点,我觉得得不偿失。

#include<bits/stdc++.h>
//#include<windows.h>
using namespace std;
#define ll long long
#define inf 1e-5
const int inv2=500000004;
const int INF=2147483647;
const int MAX=100010;
const int mod=1e9+7;
namespace NTT{//FFT准备较多,避免名称混淆,定义个命名空间,以后避免重名的麻烦
    const int P=998244353;
    int saveN=-1;
    ll powg[MAX*8];
    int seq[MAX*8];
    const int p2[30]={1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072,262144,524288,1048576,2097152};
    const int grn[30]={1,998244352,911660635,372528824,929031873,452798380,922799308,781712469,476477967,166035806,258648936,584193783,63912897,350007156,666702199,968855178,629671588,24514907,996173970,363395222,565042129,733596141};
    const int grninv[30]={1,998244352,86583718,509520358,337190230,87557064,609441965,135236158,304459705,685443576,381598368,335559352,129292727,358024708,814576206,708402881,283043518,3707709,121392023,704923114,950391366,428961804};
    void rader(int *x,int N){//必须是2的整数幂
        int i,j,k;
        for(i=0;i<N;i++){
            x[i]=i;
        }
        for(j=0,i=0;i<N-1;i++){
            if(i<j){
               swap(x[i],x[j]);
            }
            k=N/2;
            while(j>=k){
                j-=k;
                k/=2;
            }
            j+=k;
        }
    }
    ll quickmi(ll a,ll b,ll P){
        ll ans=1;
        for(a%=P;b;b>>=1,a=a*a%P)
            if(b&1)
                ans=ans*a%P;
        return ans;
    }
    ll changeN(ll N){//把N变成2的整数次方,并且生成序列
        ll i;
        for(i=0;p2[i]<N;i++);
        N=p2[i];
        if(saveN!=N){//避免重复计算倒序,如果和上次一样就用原来的
            saveN=N;
            rader(seq,N);
        }
        return N;
    }
    //x长度必须是2的整数倍,x数组一定要长度>=N,否则可能出现越界,kind是-1表示逆变换,up新数组长度,xlen是原本数组长度
    ll* ntt(const ll *x,int N,int kind,int up=MAX,int xlen=MAX){
        int i,j,n,m,bit,now;
        ll d,inv,b,G,Gn;
        ll *temp,*F;
        F=new ll[up];
        for(i=0;i<N;i++)
            F[i]=seq[i]<xlen?x[seq[i]]:0;
        for(n=2,bit=1;n<=N;n*=2,bit++){
            temp=new ll[up];
            m=n/2;
            powg[0]=1;
            G=kind==1?grn[bit]:grninv[bit];
            for(i=1;i<m;i++){//提前预处理加速
                powg[i]=powg[i-1]*G%P;
            }
            for(i=0;i<N;i+=n){
                for(j=0;j<m;j++){
                    b=F[i+j];
                    d=F[i+j+m]*powg[j]%P;
                    temp[i+j]=(b+d)%P;
                    temp[i+j+m]=(b-d+P)%P;//利用公式少做乘法加速
                }
            }
            delete []F;
            F=temp;
        }
        if(kind==-1){
            ll invN=quickmi(N,P-2,P);
            for(i=0;i<N;i++){
                F[i]=F[i]*invN%P;
            }
        }
        return F;
    }
};

#define BIT 2//压位
ll p10[]={1,10,100,1000,10000,100000,1000000,10000000,100000000,1000000000,10000000000,100000000000};
const ll HEX=p10[BIT];
struct BigInteger{
    int sign;
    ll *number;
    int size=0;
    BigInteger(int up=MAX){
        number=new ll[up];
    }
    BigInteger(char*s,int up=MAX){
        init(s,up);
    }
    BigInteger(ll s,int up=MAX){
        init(s,up);
    }
    ~BigInteger(){
        delete []number;
    }
    void init(char*s,int up=MAX){//s可以有符号,1代表正数,-1代表负数
        number=new ll[up];
        ll c=0;
        int i,j=0,len;
        len=strlen(s);
        sign=1;
        for(i=len-1;i>=0;i--){
            if(s[i]>47&&s[i]<58){
                if(j<BIT){
                    c+=p10[j++]*(s[i]-'0');
                }else{
                    number_push(c);
                    j=0;
                    c=p10[j++]*(s[i]-'0');
                }
            }else if(s[i]=='-'){
                sign=-1;
            }
        }
        number_push(c);
        if(size==1&&number[0]==0){
            sign=0;
        }
    }
    void init(ll s,int up=MAX){
        number=new ll[up];
        sign=0;
        if(s<0){
            sign=-1;
            s=-s;
        }else if(s>0){
            sign=1;
        }
        while(s){
            number_push(s%HEX);
            s/=HEX;
        }
    }
    void number_push(ll a){
        number[size++]=a;
    }

    BigInteger* nttmul(BigInteger *b){//NTT实现
        BigInteger *c=new BigInteger();
        c->sign=sign*b->sign;
        ll i,len,N,next=0,up;
        len=size+b->size-1;
        N=NTT::changeN(len);
        up=N*2;
        ll *F1=NTT::ntt(number,N,1,up,size);
        ll *F2=NTT::ntt(b->number,N,1,up,b->size);
        ll *F3=new ll[up];
        for(i=0;i<N;i++){
            F3[i]=F1[i]*F2[i]%NTT::P;
        }
        ll* x=NTT::ntt(F3,N,-1,up,N);
        for(i=0;i<len;i++){
            x[i]+=next;
            next=x[i]/HEX;
            x[i]=x[i]%HEX;
        }
        if(next){
            x[len++]=next;
        }
        for(;!x[len-1];len--);
        memset(x+len,0,sizeof(ll)*(up-len));//末尾清0;
        delete []c->number;
        c->number=x;
        c->size=len;
        delete []F1;
        delete []F2;
        delete []F3;
        return c;
    }
    BigInteger* mul(BigInteger *b){
        BigInteger *c=new BigInteger(size+b->size);
        c->sign=sign*b->sign;
        ll i,j,h,k,next=0;
        for(i=0;i<size;i++){
            for(j=0;j<b->size;j++){
                h=j+i;
                k=number[i]*b->number[j]+next;
                if(h<c->size){
                    k+=c->number[h];
                    c->number[h]=k%HEX;
                }else{
                    c->number_push(k%HEX);
                }
                next=k/HEX;
            }
            for(;next;c->number_push(next%HEX),next/=HEX);
        }
        return c;
    }

    void println(){//打印格式长度等于BIT
        if(sign==0){
            printf("0\n");
            return;
        }
        if(sign==-1)
            printf("-");
        printf("%llu",number[size-1]);
        for(int i=size-2;i>=0;i--){
            printf("%02llu",number[i]);
        }printf("\n");
    }
    char *toChar(ll a){
        char *c=new char[BIT];
        memset(c,'0',sizeof(char)*BIT);
        for(int i=BIT-1;i>=0;i--){
            c[i]=(char)(a%10+'0');
            a/=10;
        }
        return c;
    }
    void myprintln(){//打印格式长度等于BIT
        ll i,j,k;
        printf("%lld",number[size-1]);
        for(i=0,k=number[size-1];k>0;i++,k/=10);
        int len=i;
        for(int i=size-2;i>=0;i--){
            if(len+BIT<1000){
                printf("%02lld",number[i]);
                len+=BIT;
            }else{
                char *c=toChar(number[i]);
                for(j=0;j<BIT;j++){
                    printf("%c",c[j]);
                    if(j+len==999)
                        printf("\n");
                }
                len=len+BIT-1000;
                delete []c;
            }
        }
    }
};

BigInteger *temp;
BigInteger *arr[MAX];
int main(int argc,char *argv[]){
    //freopen("in.txt","r",stdin); //输入重定向,输入数据将从in.txt文件中读取
    //freopen("数据1059/out.txt","w",stdout); //输出重定向,输出数据将保存在out.txt文件中
    //srand(time(NULL));//有的OJ不能加这句话
    ll i,j,k,n;
    scanf("%lld",&n);
    //n=100000;
    for(i=1;i<=n;i++){
        arr[i]=new BigInteger(i,16);
    }
    for(j=1;j<n;j=k){
        k=j*2;
        for(i=1;i+j<=n;i+=k){
            temp=arr[i];
            if(arr[i]->size<32&&arr[i+j]->size<32){
                arr[i]=arr[i]->mul(arr[i+j]);
            }else{
                arr[i]=arr[i]->nttmul(arr[i+j]);
            }
            delete temp;
            delete arr[i+j];
        }
    }
    arr[1]->myprintln();
return 0;
}

 

 

 

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值