题目背景
WD整日沉浸在积木中,无法自拔……
题目描述
WD想买nnn块积木,商场中每块积木的高度都是1,俯视图为正方形(边长不一定相同)。由于一些特殊原因,商家会给每个积木随机一个大小并标号,发给WD。
接下来WD会把相同大小的积木放在一层,并把所有层从大到小堆起来。WD希望知道所有不同的堆法中层数的期望。两种堆法不同当且仅当某个积木在两种堆法中处于不同的层中,由于WD只关心积木的相对大小,因此所有堆法等概率出现,而不是随机的大小等概率(可以看样例理解)。输出结果mod 998244353即可。
(如果还是不能够理解题意,请看样例)
输入输出格式
输入格式:
第一行一个数T,表示询问个数。
接下来T行每行一个数n,表示WD希望使用n块积木。
输出格式:
共T行,每行一个数表示答案mod 998244353。
解析:
不妨设f(n)为n块积木的堆法
s(n)为n块积木所有堆法的层数和
上述式子只要用递推思想推一下就行了
然后我们发现这是个卷积,然后 我们是知道的,这时候我们想想能不能直接用分治NTT呢?
但是由于 是一个关于n,i的式子貌似不是固定的?
那怎么办呢?
因
所以我们把式子两边除以n!
然后
s(n)同理
那么设
那么我们发现这就可以直接两遍分治NTT就行了
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5;
const ll mod=998244353;
ll f[N+10],a[N<<2],b[N<<2],g[N+10],s[N+10],inv[N+10],fac[N+10];
int rev[N<<2];
int T,n,len;
ll poww(ll x,ll y)
{
ll ans=1;
for (;y;y>>=1,x=(x*x)%mod) if (y&1) ans=(ans*x)%mod;
return ans;
}
void NTT(ll *a,int len,int t)
{
for (int i=0;i<len;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
for (int i=1;i<len;i<<=1)
{
int s=(i<<1);
ll wn=poww(3,(mod-1)/s);
if (t==-1) wn=poww(wn,mod-2);
for (int j=0;j<len;j+=s)
{
ll w=1;
for (int k=j;k<j+i;k++)
{
ll x=a[k]; ll y=(a[k+i]*w)%mod;
a[k]=(x+y)%mod; a[k+i]=(x-y+mod)%mod;
w=(w*wn)%mod;
}
}
}
if (t==-1) {
ll w=(poww(len,mod-2));
for (int i=0;i<len;i++) a[i]=(a[i]*w)%mod;
}
}
void lalala(int l,int mid,int r)//这太丑了
{
int sum=0;
for (len=1;len<=(r-l)*2;len<<=1) sum++;
for (int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(sum-1));
for (int i=0;i<len;i++) a[i]=0,b[i]=0;
for (int i=0;i<=mid-l;i++) a[i]=f[i+l];
for (int i=0;i<=r-l;i++) b[i]=g[i];
NTT(a,len,1); NTT(b,len,1);
for (int i=0;i<len;i++) a[i]=(a[i]*b[i])%mod;
NTT(a,len,-1);
for (int i=mid+1-l;i<=r-l;i++) f[i+l]=(f[i+l]+a[i])%mod;
}
void solve(int l,int r)
{
// cout << l << ' ' << r << endl;
if (l==r) return;
int mid=(l+r)/2;
solve(l,mid);
lalala(l,mid,r);
solve(mid+1,r);
}
void init()
{
fac[0]=1;
for (int i=1;i<=N;i++) fac[i]=(fac[i-1]*(ll)i)%mod;
inv[0]=1; inv[1]=1;
for (int i=2;i<=N;i++) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
for (int i=2;i<=N;i++) inv[i]=(inv[i-1]*inv[i])%mod;
for (int i=1;i<=N;i++) g[i]=inv[i]; g[0]=1;
f[0]=1;
// cout << inv[2] << endl;
solve(0,N);
for (int i=0;i<=N;i++) s[i]=f[i];
f[0]=0;
solve(0,N);
// cout << f[1]%mod << endl;
for (int i=1;i<=N;i++) f[i]=(f[i]*poww(s[i],mod-2))%mod;
}
int main()
{
scanf("%d",&T);
init();
while (T--) {
scanf("%d",&n);
printf("%lld\n",f[n]);
}
}