UOJ#86:mx的组合数 (Lucas定理+原根+NTT+高精度)

题目传送门:http://uoj.ac/problem/86


题目分析:高精度写死人系列,我写了一个晚上才写完QAQ。

一开始拿到这题没什么头绪,然后从部分分开始想。上数学课的时候忽然间发现40分的部分分就是个暴力枚举+Lucas定理。根据:

Cmn=CmpnpCmmodpnmodp

直接枚举m[L,R],然后递归到m<p,n<p时退出即可。

然后我们发现这个递归大概展开logp(R)层,而且这很像个数位DP。于是我们不妨对原问题差分,用函数Solve(k,n)求出当m[0,k]时,Cmna(modp)(0<=a<p)的答案。很明显可以先调用Solve(kp1,np),将其答案与Cxnmodp(0<=x<p)构成的数组进行合并,然后单独处理C(kp,np)Cxnmodp(0<=x<=kmodp)的部分。我们发现前者中两个数组的合并是下标乘积的形式,而p又是个质数,所以可以转化为原根的幂然后做NTT。由于不是很好处理模p等于0的情况,可以先算出模p不为0的情况,最后再用R-L+1减去。最后的复杂度是plog(p)logp(R)

总之这是个十分套路的题目,然而要写高精度所以比较烦。而且我一开始还写错了几个地方:一是[0,R]中有R+1个数,我以为是R个数QAQ;二是递归的退出条件不一定是m<p,n<p,还有可能在别的一些地方……;三是我把对ans[0]的减法放在了if (l.num[0])里面,这样l=0就炸了。

随手写一发,我的code居然排到了rk2?!


CODE:

#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;

const int maxn=1000000;
const int maxl=50;
const long long M=998244353;
const long long g=3;
typedef long long LL;

LL A[maxn];
LL B[maxn];

int Rev[maxn];
int N,Lg;

struct Big_Int
{
    LL num[maxl];
    void Down() { while ( !num[ num[0] ] && num[0] ) num[0]--; }
} n,n1,l,r;
char s[maxl];

#define P pair<Big_Int,long long>
#define MP(x,y) make_pair(x,y)

LL fac[maxn];
LL nfac[maxn];

int id[maxn];
int rid[maxn];

LL ans[maxn];
LL p,pg;

LL Pow(LL x,LL y,LL z)
{
    if (!y) return 1LL;
    LL temp=Pow(x,y>>1,z);
    temp=temp*temp%z;
    if (y&1) temp=temp*x%z;
    return temp;
}

LL Get(LL x)
{
    x--;
    LL mz=(long long)floor( sqrt( (double)x )+0.5 );
    for (LL y=2; y<=x; y++)
    {
        bool sol=true;
        for (LL z=2; z<=mz; z++)
            if (x%z==0)
            {
                if ( Pow(y,z,p)==1LL ) sol=false;
                if ( Pow(y,x/z,p)==1LL ) sol=false;
                if (!sol) break;
            }
        if (sol) return y;
    }
}

void Read(Big_Int &x)
{
    scanf("%s",s);
    int len=strlen(s);
    x.num[0]=len;
    for (int i=0; i<len; i++) x.num[len-i]=s[i]-'0';
    x.Down();
}

P Div(Big_Int x,LL y)
{
    for (int i=x.num[0]; i>=2; i--)
    {
        x.num[i-1]+=(x.num[i]%y*10LL);
        x.num[i]/=y;
    }
    LL z=x.num[1]%y;
    x.num[1]/=y;
    x.Down();
    return ( MP(x,z) );
}

LL Change(Big_Int x)
{
    LL y=0;
    for (int i=x.num[0]; i>=1; i--)
    {
        y=y*10+x.num[i];
        if (y>=p) return -1LL;
    }
    return y;
}

LL CI(int x,int y)
{
    if (y>x) return 0;
    LL v=fac[x];
    v=v*nfac[y]%p;
    v=v*nfac[x-y]%p;
    return v;
}

void Dec(Big_Int &x)
{
    x.num[1]--;
    int y=1;
    while (x.num[y]<0) x.num[y]+=10,x.num[++y]--;
    x.Down();
}

void DFT(LL *a,int f)
{
    for (int i=0; i<N; i++)
        if (i<Rev[i]) swap(a[i],a[ Rev[i] ]);

    for (int len=2; len<=N; len<<=1)
    {
        int mid=(len>>1);
        LL e=Pow(g,(M-1)/len,M);
        if (f==-1) e=Pow(e,M-2,M);

        for (LL *p=a; p!=a+N; p+=len)
        {
            LL wn=1;
            for (int i=0; i<mid; i++)
            {
                LL temp=wn*p[mid+i]%M;
                p[mid+i]=(p[i]-temp+M)%M;
                p[i]=(p[i]+temp)%M;
                wn=wn*e%M;
            }
        }
    }
}

void NTT()
{
    DFT(A,1);
    DFT(B,1);
    for (int i=0; i<N; i++) A[i]=A[i]*B[i]%M;
    DFT(A,-1);

    LL inv=Pow(N,M-2,M);
    for (int i=0; i<N; i++) A[i]=A[i]*inv%M;
}

LL C(Big_Int k,Big_Int m)
{
    LL ck=Change(k);
    LL cm=Change(m);
    if ( ck>=0 && cm>=0 ) return CI(ck,cm);

    P x=Div(k,p);
    P y=Div(m,p);
    LL v=C(x.first,y.first);
    v=v*CI(x.second,y.second)%p;
    return v;
}

void Solve(Big_Int k,Big_Int m)
{
    LL ck=Change(k);
    LL cm=Change(m);
    if ( ck>=0 && cm>=0 )
    {
        for (int i=0; i<N; i++) A[i]=0;
        //this is not the only exit way,don't clear A[] here!!!
        for (int i=cm; i<=ck; i++) A[ id[ CI(i,cm) ] ]++;
        return;
    }

    P x=Div(k,p);
    P y=Div(m,p);
    if (x.first.num[0])
    {
        Big_Int z=x.first;
        Dec(z);
        Solve(z,y.first);
        for (int i=0; i<N; i++) B[i]=0;
        for (int i=y.second; i<p; i++) B[ id[ CI(i,y.second) ] ]++;
        NTT();
        for (int i=p-1; i<N; i++) A[i%(p-1)]=(A[i%(p-1)]+A[i])%M,A[i]=0;
    }

    LL v=C(x.first,y.first);
    if (v) for (int i=y.second; i<=x.second; i++)
    {
        LL &q=A[ id[ CI(i,y.second)*v%p ] ];
        q=(q+1LL)%M;
    }
}

int main()
{
    //freopen("86.in","r",stdin);
    //freopen("86.out","w",stdout);

    scanf("%lld",&p);

    pg=Get(p);
    LL v=1;
    for (int i=0; i<p-1; i++)
    {
        id[v]=i;
        rid[i]=v;
        v=v*pg%p;
    }

    fac[0]=1;
    for (LL i=1; i<p; i++) fac[i]=fac[i-1]*i%p;
    for (int i=0; i<p; i++) nfac[i]=Pow(fac[i],p-2LL,p);

    N=1,Lg=0;
    while (N<=2*p+2) N<<=1,Lg++;
    for (int i=0; i<N; i++)
        for (int j=0; j<Lg; j++)
            if (i&(1<<j)) Rev[i]|=(1<<(Lg-j-1));

    Read(n);
    n1=n;
    Read(l);
    Read(r);

    P x=Div(r,M);
    ans[0]=x.second;
    ans[0]=(ans[0]+1LL)%M; //[0,r] have r+1 numbers!!!
    Solve(r,n);
    for (int i=0; i<p-1; i++) ans[ rid[i] ]=A[i];

    if (l.num[0])
    {
        Dec(l);
        x=Div(l,M);
        ans[0]=(ans[0]-x.second+M)%M;
        ans[0]=(ans[0]-1LL+M)%M;
        for (int i=0; i<N; i++) A[i]=0; //clear A[] here!!!
        Solve(l,n1);

        for (int i=0; i<p-1; i++)
        {
            int v=rid[i];
            ans[v]=(ans[v]-A[i]+M)%M;
        }
    }

    for (int i=1; i<p; i++) ans[0]=(ans[0]-ans[i]+M)%M; //write outside the "if"!!!
    for (int i=0; i<p; i++) printf("%lld\n",ans[i]);

    return 0;
}
发布了160 篇原创文章 · 获赞 76 · 访问量 10万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 技术黑板 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览