Solution
首先,连续段只会包含而不会相交,而且每个连续段向第一个包含它的连续段连边,就会形成一个树的结构。这个如果无法理解可以看LCA今年营员交流。
然后设当序列为
1
,
1
,
1...
n
1,1,1...n
1,1,1...n的时候的答案为
f
n
f_n
fn,每个点的儿子个数为
b
i
b_i
bi,那么
a
n
s
=
Π
i
=
1
n
f
b
i
+
1
ans=\Pi_{i=1}^n f_{b_i+1}
ans=Πi=1nfbi+1
这个的话就是你可以把每个儿子的连续段缩成一个点,这样方案数是不会变且一一对应的,这个思想在后面推式子的时候很重要。
那么怎么递推呢?考虑加入最大值的两种情况:
1、原来的排列已经没有非平凡的连续段,也就是长度大于
1
1
1小于整段长度的连续段,那么只要插入位置不在原来的最大值两边即可,方案数为
f
n
−
1
×
(
n
−
2
)
f_{n-1}\times(n-2)
fn−1×(n−2)。
2、插入的最大值把原来的一个连续段分成了两段,而且这个连续段被分开的两段都不能有连续段,设原来的连续段的长度为
i
i
i,那么方案数就是一个长为
i
+
1
i+1
i+1的连续段去掉最大值,然后把
n
n
n插到这个最大值的位置的方案数,为
f
i
+
1
f_{i+1}
fi+1。然后把这
i
+
1
i+1
i+1个点缩成一个,那么现在一共有
n
−
(
i
+
1
)
+
1
n-(i+1)+1
n−(i+1)+1个点,方案数为
f
n
−
i
f_{n-i}
fn−i。最后考虑
i
+
1
i+1
i+1个数的值域,因为原来的最大值即
n
−
1
n-1
n−1不能出现在其中,所以共有
n
−
i
−
2
n-i-2
n−i−2种取值。所以最后的递推式即为
f
n
=
(
n
−
2
)
f
n
−
1
+
∑
i
=
2
n
−
3
f
i
+
1
f
n
−
i
(
n
−
i
−
2
)
f_n=(n-2)f_{n-1}+\sum_{i=2}^{n-3}f_{i+1}f_{n-i}(n-i-2)
fn=(n−2)fn−1+i=2∑n−3fi+1fn−i(n−i−2)
这个可以用分治NTT优化,似乎看上去要比一般的分治FFT要难一点,其实想清楚后还是不难的。
Code
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define pa pair<int,int>
const int Maxn=50010;
const int inf=2147483647;
const int mod=998244353,gn=3;
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*f;
}
void upd(int&x,int y){x+=y;if(x>=mod)x-=mod;}
int Pow(int x,int y)
{
if(!y)return 1;
int t=Pow(x,y>>1),re=(LL)t*t%mod;
if(y&1)re=(LL)re*x%mod;
return re;
}
int rev[Maxn<<2],n,a[Maxn],b[Maxn],sta[Maxn],top,A[Maxn<<2],B[Maxn<<2],f[Maxn];
void ntt(int *a,int n,int o)
{
for(int i=0;i<n;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
{
int wn;
if(o==1)wn=Pow(gn,(mod-1)/(i<<1));
else wn=Pow(gn,mod-1-(mod-1)/(i<<1));
for(int j=0;j<n;j+=(i<<1))
{
int w=1;
for(int k=0;k<i;k++)
{
int t=(LL)a[i+j+k]*w%mod;w=(LL)w*wn%mod;
a[i+j+k]=(a[j+k]-t+mod)%mod;
a[j+k]=(a[j+k]+t)%mod;
}
}
}
if(o==-1)
{
int inv=Pow(n,mod-2);
for(int i=0;i<n;i++)a[i]=(LL)a[i]*inv%mod;
}
}
void solve(int l,int r)
{
if(l==r){upd(f[l],(LL)f[l-1]*(l-2)%mod);return;}
int mid=l+r>>1;
solve(l,mid);
int N=1;while(N<=((r-l+1)<<1))N<<=1;
rev[0]=0;for(int i=1;i<N;i++)rev[i]=((rev[i>>1]>>1)|((i&1)*(N>>1)));
for(int i=0;i<N;i++)A[i]=B[i]=0;
for(int i=l;i<=mid;i++)A[i-l+1]=f[i];
for(int i=3;i<=min(r+1-l,l-1);i++)B[i]=(LL)f[i]*(i-2)%mod;
ntt(A,N,1),ntt(B,N,1);
for(int i=0;i<N;i++)A[i]=(LL)A[i]*B[i]%mod;
ntt(A,N,-1);
for(int i=max(mid+1,5);i<=r;i++)upd(f[i],A[i+1-(l-1)]);
for(int i=0;i<N;i++)A[i]=B[i]=0;
for(int i=l;i<=min(mid,r-2);i++)A[i-l+1]=(LL)f[i]*(i-2)%mod;
for(int i=3;i<=min(r+1-l,r-2);i++)B[i-2]=f[i];
ntt(A,N,1),ntt(B,N,1);
for(int i=0;i<N;i++)A[i]=(LL)A[i]*B[i]%mod;
ntt(A,N,-1);
for(int i=max(mid+1,5);i<=r;i++)upd(f[i],A[i-l]);
solve(mid+1,r);
}
int main()
{
int T=read();n=read();
memset(f,0,sizeof(f));
f[1]=1,f[2]=2;solve(3,n);
while(T--)
{
for(int i=1;i<=n;i++)a[i]=read(),b[i]=0;
bool flag=false;
if(a[n]!=n){puts("0");continue;}
top=0;
for(int i=1;i<=n;i++)
{
while(top&&i-a[i]<=sta[top]-a[sta[top]])top--,b[i]++;
if(top&&sta[top]>=i-a[i]+1){puts("0");flag=true;break;}
sta[++top]=i;
}
if(flag)continue;
int ans=1;
for(int i=1;i<=n;i++)ans=(LL)ans*f[b[i]+1]%mod;
printf("%d\n",ans);
}
}