解析:
首先有个结论:对于任意i,j(i<j)我们发现i-a[i]+1..ai与j-a[j]+1..aj要么是包含关系,要么就是没有交集
既然如此,我们就可以把他所给定的信息转化为树状结构
我们发现问题转换为求fi表示对于1,2,3....i+1求出有多少个排列使得在排除1的情况下使得不存在一个子区间(不能有1)满足他是连续的(该定义原题中已经给出了)
显然上述的1可以替换成i+1,两者等价
难点:
那么我们考虑一下对于f[i]来说加入i+1会有什么情况
1.如果原来关于i的排列是满足上述条件的,那么我们发现只要我们不把i+1加在i的两边就符合条件,否则不符合条件
2.如果原来关于i的排列不满足上述条件,那么该序列中的极大连续真子区间不会超过1个,所以我们可以枚举这个区间的大小,不妨设它为j,首先该区间的数的选择情况为i-j-1,然后就是子问题,此贡献就是f[j],然后缩点之后还有贡献为f[i-j],总的贡献为(i-j-1)*f[j]*f[i-j],但是这里要注意,就是这个j是有范围的:2<=j<=i-2
综上所述
对于这个我们只需要分治NTT就行了
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll mod=998244353;
const int N=2e5+10;
ll a[N],b[N],f[N];
int lim[N],rev[N];
int T,n,len,k;
ll ans;
ll ksm(ll x,ll y)
{
ll ans=1;
for (;y;y>>=1,x=(x*x)%mod) if (y&1) ans=(ans*x)%mod;
return ans;
}
void NTT(ll *a,int len,int t)
{
for (int i=0;i<len;i++) if (i<rev[i]) swap(a[i],a[rev[i]]);
for (int i=1;i<len;i<<=1)
{
int s=(i<<1);
ll wn=ksm(3,(mod-1)/s);
if (t==-1) wn=ksm(wn,mod-2);
for (int j=0;j<len;j+=s) {
ll w=1;
for (int k=j;k<j+i;k++) {
ll x=a[k]; ll y=(a[k+i]*w)%mod;
a[k]=(x+y)%mod; a[k+i]=(x-y+mod)%mod;
w=(w*wn)%mod;
}
}
}
if (t==-1) {
ll w=ksm(len,mod-2);
for (int i=0;i<len;i++) a[i]=(a[i]*w)%mod;
}
}
void FFT()
{
for (int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
NTT(a,len,1); NTT(b,len,1);
for (int i=0;i<len;i++) a[i]=(a[i]*b[i])%mod;
NTT(a,len,-1);
}
void lalala(int l,int mid,int r)
{
for (int i=0;i<=mid-l;i++) a[i]=f[i+l]*(i+l-1);
for (int i=0;i<=mid-l;i++) b[i]=(ll)f[i+l]%mod;
if (l==1) a[0]=b[0]=0;
for (k=0,len=1;len<=(mid-l)*2;len<<=1) k++;
for (int i=mid-l+1;i<len;i++) a[i]=0;
for (int i=mid-l+1;i<len;i++) b[i]=0;
FFT();
for (int i=mid+1;i<=r;i++) if (i-l*2>=0) f[i]=(f[i]+a[i-l*2])%mod;
}
void lalala_l(int l,int mid,int r)
{
a[0]=a[1]=0;
for (int i=2;i<=min(l-1,r-l);i++) a[i]=f[i];
for (int i=0;i<=mid-l;i++) b[i]=f[i+l];
for (k=0,len=1;len<=mid-l+min(l-1,r-l);len<<=1) k++;
for (int i=min(l-1,r-l)+1;i<len;i++) a[i]=0;
for (int i=mid-l+1;i<len;i++) b[i]=0;
FFT();
for (int i=mid+1;i<=r;i++) f[i]=(f[i]+(ll)a[i-l]*(i-2)%mod)%mod;
}
void init(int l,int r)
{
// cout << l << ' ' << r << endl;
if (l==r) {
if (l==1) f[l]=2;
else
f[l]=((ll)f[l-1]*(ll)(l-1)%mod+f[l])%mod;
}
else {
int mid=(l+r)/2;
init(l,mid);
lalala(l,mid,r);
if (l>1) lalala_l(l,mid,r);
init(mid+1,r);
}
}
void solve(int l,int r)
{
if (lim[r]!=r-l+1) {ans=0; return;}
if (l==r) return;
int sum=0,s=r-1;
while (s>=l)
{
if (s-lim[s]+1<l) {ans=0; return;}
solve(s-lim[s]+1,s);
s-=lim[s]; sum++;
}
ans=(ans*f[sum])%mod;
}
int main()
{
scanf("%d%d",&T,&n);
f[0]=1;
init(1,n);
// for (int i=1;i<=n;i++) printf("%lld\n",f[i]);
while (T--) {
ans=1;
for (int i=1;i<=n;i++) scanf("%d",&lim[i]);
solve(1,n);
printf("%lld\n",ans);
}
}