# 洛谷P4491：[HAOI2018]染色（容斥+ntt）

$H=min\left(\frac{n}{s},m\right)$$H=min(\frac{n}{s},m)$

$ans=\sum _{i=0}^{N}w\left[i\right]\left(\genfrac{}{}{0}{}{m}{i}\right)\left(\genfrac{}{}{0}{}{n}{is}\right)\frac{\left(is\right)!}{\left(s!{\right)}^{i}}\sum _{j=0}^{N-i}\left(-1{\right)}^{j}\left(\genfrac{}{}{0}{}{m-i}{j}\right)\left(\genfrac{}{}{0}{}{n-is}{js}\right)\frac{\left(js\right)!}{\left(s!{\right)}^{j}}\left(m-i-j{\right)}^{n-is-js}$

$j$$j$$j+i$$j+i$换，拉到前面枚举

$\sum _{j=0}^{N}\frac{m!n!}{\left(m-j\right)!\left(n-js\right)!}\left(\frac{1}{s!}{\right)}^{j}\left(m-j{\right)}^{n-js}\sum _{i=0}^{j}\frac{w\left[i\right]}{i!}\ast \frac{\left(-1{\right)}^{j-i}}{\left(j-i\right)!}$

#include <iostream>
#include <fstream>
#include <algorithm>
#include <cmath>
#include <ctime>
#include <cstdio>
#include <cstdlib>
#include <cstring>

using namespace std;
#define mmst(a, b) memset(a, b, sizeof(a))
#define mmcp(a, b) memcpy(a, b, sizeof(b))

typedef long long LL;

const int N=1000500,BN=10000010;
const LL p=1004535809;

int n,rev[N];
LL nn,njc=1,H,m,s;
LL jc[BN],I[BN],Ijc[BN];
LL w[N];
LL aa[N],bb[N];
LL ans;

LL cheng(LL a,LL b)
{
LL res=1;
for(;b;b>>=1,a=a*a%p)
if(b&1)
res=res*a%p;
return res;
}

void init(int lim)
{
int k=-1;
n=1;
while(n<=lim)
k++,n<<=1;

for(int i=0;i<n;i++)
rev[i]=(rev[i>>1] >> 1) | ((i&1)<<k);
}

void ntt(LL *a,int ops)
{
for(int i=0;i<n;i++)
if(i<rev[i])
swap(a[i],a[rev[i]]);

for(int m=1,l=2;m<n;m<<=1,l<<=1)
{
LL wn= (ops) ? cheng(3,(p-1)/l) : cheng(3,p-1-(p-1)/l);
for(int i=0;i<n;i+=l)
{
LL w=1;
for(int k=0;k<m;k++)
{
LL t=a[i+k+m]*w%p;
a[i+k+m]=(a[i+k]-t+p)%p;
a[i+k]=(a[i+k]+t)%p;
w=w*wn%p;
}
}
}

if(!ops)
for(int i=0;i<n;i++)
a[i]=a[i]*I[n]%p;
}

int main()
{
cin>>nn>>m>>s;

if(s==0)
H=m;
else
H=min(nn/s,m);

for(int i=0;i<=m;i++)
scanf("%lld",&w[i]);

I[1]=jc[0]=Ijc[0]=1;

for(int i=2;i<BN;i++)
I[i]=I[p%i]*(p-p/i)%p;

for(int i=1;i<BN;i++)
jc[i]=jc[i-1]*i%p,Ijc[i]=Ijc[i-1]*I[i]%p;

for(int i=0;i<=H;i++)
{
aa[i]=w[i]*Ijc[i]%p;
if(i&1)
bb[i]=(p-Ijc[i])%p;
else
bb[i]=Ijc[i];
}

init(H+H+5);
ntt(aa,1);
ntt(bb,1);
for(int i=0;i<n;i++)
aa[i]=aa[i]*bb[i]%p;
ntt(aa,0);

for(int j=0;j<=H;j++)
ans=(ans+Ijc[m-j]*Ijc[nn-j*s]%p*cheng(Ijc[s],j)%p*cheng(m-j,nn-j*s)%p*aa[j])%p;

ans=ans*jc[nn]%p*jc[m]%p;

cout<<ans<<endl;

return 0;
}


• 广告
• 抄袭
• 版权
• 政治
• 色情
• 无意义
• 其他

120