前言
这里是线性做法。
在题解里几句话说清楚的性质愣是推了一上午。
too vegetable
解析
考虑怎样的排列是不合法的。
一个排列如果不合法,也就是在某次交换时其中一个元素距离目标的距离没有减少反而增大了,那么以后这个数一定会再换回来,也就是这个数会反复横跳。
考虑怎样的数会反复横跳,不难发现,会反复横跳,也就等价于左边有比自己大的元素,右边有比自己小的元素,也就是存在长度为三的递减子序列。
所以可以抽象出合法的充要条件:不存在长度为三的递减子序列,也就等价于排列可以拆分为两个递增序列。
这咋算啊?
打一下表,发现没有字典序的时候,答案就是卡特兰数。
为什么呢?
尝试往上嗯套。设
m
x
i
=
max
j
=
1
i
a
j
mx_i=\max_{j=1}^ia_j
mxi=maxj=1iaj,那么一个排列就可以理解为所有的
(
m
x
i
,
i
)
(mx_i,i)
(mxi,i) 的点顺次连接的路径,不难发现它和卡特兰数所谓“
(
0
,
0
)
−
>
(
n
,
n
)
(0,0)->(n,n)
(0,0)−>(n,n) 且不超过对角线上方” 的路径是双射的。
那么本题就好办了,暴力枚举第一个比给出排列大的位置,那么此时必然需要更新
m
x
i
mx_i
mxi,设
f
(
(
a
,
b
)
−
>
(
c
,
d
)
)
f((a,b)->(c,d))
f((a,b)−>(c,d)) 是从
(
a
,
b
)
(a,b)
(a,b) 走到
(
c
,
d
)
(c,d)
(c,d) 且不超过对角线的方案数(可以通过翻折容斥
O
(
1
)
O(1)
O(1) 求解),那么这里的方案数就是
∑
j
=
m
x
i
+
1
n
f
(
(
j
,
i
)
−
>
(
n
,
n
)
)
=
f
(
(
m
x
i
+
1
,
i
−
1
)
−
>
(
n
,
n
)
)
\sum_{j=mx_i+1}^nf((j,i)->(n,n))=f((mx_i+1,i-1)->(n,n))
∑j=mxi+1nf((j,i)−>(n,n))=f((mxi+1,i−1)−>(n,n))。一直到给出排列的前缀一定无法拆分为两个递增序列是退出。
总复杂度
O
(
n
)
O(n)
O(n)。
代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define ok debug("line: %d\n",__LINE__)
inline ll read(){
ll x(0),f(1);char c=getchar();
while(!isdigit(c)) {if(c=='-')f=-1;c=getchar();}
while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
bool mem1;
const int N=2e6+100;
const int inf=1e9+100;
const int mod=998244353;
const bool Flag=0;
#define add(x,y) ((((x)+=(y))>=mod)&&((x)-=mod))
inline ll ksm(ll x,ll k){
ll res(1);
while(k){
if(k&1) res=res*x%mod;
x=x*x%mod;
k>>=1;
}
return res;
}
int n;
int a[N];
ll jc[N],ni[N];
void init(int n){
jc[0]=1;
for(int i=1;i<=n;i++) jc[i]=jc[i-1]*i%mod;
ni[n]=ksm(jc[n],mod-2);
for(int i=n-1;i>=0;i--) ni[i]=ni[i+1]*(i+1)%mod;
}
inline int C(int n,int m){
return n<m||m<0?0:jc[n]*ni[m]%mod*ni[n-m]%mod;
}
inline int walk(int i,int j){
return (C(n-i+n-j,n-i)+mod-C(n-1-i+n+1-j,n-1-i))%mod;
}
bool vis[N];
void work(){
n=read();
for(int i=1;i<=n;i++) a[i]=read(),vis[i]=0;
int mx(0),sec(1);
int ans(0);
for(int i=1;i<=n;i++){
vis[a[i]]=1;
if(a[i]>mx) mx=a[i];
else if(a[i]!=sec){
add(ans,walk(mx+1,i-1));
break;
}
while(vis[sec]) ++sec;
add(ans,walk(mx+1,i-1));
}
printf("%lld\n",ans);
}
bool mem2;
signed main(){
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
#endif
init(2e6);
int T=read();
while(T--) work();
return 0;
}