###【题目背景】
最近,小 S 对冒泡排序产生了浓厚的兴趣。为了问题简单,小 S 只研究对 1到 n的排列的冒泡排序。
下面是对冒泡排序的算法描述。
输入:一个长度为 n 的排列 p[1…n]
输出:p 排序后的结果。
for i = 1 to n do
for j = 1 to n - 1 do
if(p[i] > p[i + 1])
交换 p[i] 与 p[i + 1] 的值
冒泡排序的交换次数被定义为交换过程的执行次数。可以证明交换次数的一个下
界是
1
2
∑
i
=
1
n
∣
p
i
−
i
∣
\frac{1}{2}\sum_{i=1}^n|{p_i-i}|
21∑i=1n∣pi−i∣,其中
p
i
p_i
pi是排列 p 中第 i 个位置的数字。如果你对证明感兴趣,可
以看提示。
###【题目描述】
小 S 开始专注于研究长度为 n 的排列中,满足交换次数 =
1
2
∑
i
=
1
n
∣
p
i
−
i
∣
\frac{1}{2}\sum_{i=1}^n|{p_i-i}|
21∑i=1n∣pi−i∣ 的排列
(在后文中,为了方便,我们把所有这样的排列叫“好”的排列)。他进一步想,这样的
排列到底多不多?它们分布的密不密集?
小 S 想要对于一个给定的长度为 n 的排列 q,计算字典序严格大于 q 的“好”的
排列个数。但是他不会做,于是求助于你,希望你帮他解决这个问题,考虑到答案可能
会很大,因此只需输出答案对 998244353 取模的结果。
###【输入格式】
从文件 inverse.in 中读入数据。
输入第一行包含一个正整数 T,表示数据组数。
对于每组数据,第一行有一个正整数 n, 保证 n ≤ 6 × 10 5 。
接下来一行会输入 n 个正整数,对应于题目描述中的 q i ,保证输入的是一个 1 到
n 的排列。
###【输出格式】
输出到文件 inverse.out 中。
输出共 T 行,每行一个整数。
对于每组数据,输出一个整数,表示字典序严格大于 q 的“好”的排列个数对
998244353 取模的结果。
###【样例 1 输入】
1
3
1 3 2
###【样例 1 输出】
3
###【样例 1 解释】
字典序比 1 3 2 大的排列中,除了 3 2 1 以外都是“好”的排列,故答案为 3。
###数据范围
T
≤
5
,
N
≤
600
,
000
T\le5,N\le600,000
T≤5,N≤600,000
##解题过程
一看到这道题就想着猜结论啊……于是我打了个表,发现如果输入恰好是1~n时,答案就是卡特兰数-1.
那其它的怎么办呢?我就把所有的符合要求的情况打印出来,发现——序列中不存在长度大于2的下降子序列!!
这是真的吗?反正我不断随机了好多数据,都满足这个规律,然后就开始想怎么做吧……
常规dp思路肯定是考虑从前往后把数字一个一个加进去,但是我感觉这样似乎不太方便处理字典序(如果有dalao用的这个思路请指教),但是这里面有一些思想还是值得借鉴的。我们考虑加入到了第i个数字,前面i-1个数字中的最大值为a,当前剩下的数字中最小值为b,我们会发现,要么当前插入b,要么当前插入的数字要大于a,否则一定不满足题意。我们称当前大于a的数字和b组成的集合为“可行数字集合”,记为
S
i
S_i
Si。
根据常规的排列字典序做法,我们可以考虑这样做:f[i][j]表示填到第
i
i
i个数字,并且希望在当前填入当前可行数字集合中第j大的数字的方法数。显然的,如果
∣
S
i
∣
<
n
−
i
|S_i| < n-i
∣Si∣<n−i,必然有
∣
S
i
+
1
∣
=
j
|S_{i+1}|=j
∣Si+1∣=j,因为如果
j
<
∣
S
i
∣
j<|S_i|
j<∣Si∣,那么比他小的有j-1个,再加上可以填入最小值,总共是j个;否则,大于a的j-1个数字都可用,且大于b小于a的数字必然还有剩余,这中间也可以再取出一个可行性数字。因此这种情况下有
f
[
i
]
[
j
]
=
∑
k
=
1
j
f
[
i
−
1
]
[
k
]
=
f
[
i
]
[
j
−
1
]
+
f
[
i
−
1
]
[
j
]
,
其
中
f
[
1
]
[
1
]
=
1
f[i][j]=\sum_{k=1}^jf[i-1][k]=f[i][j-1]+f[i-1][j],其中f[1][1]=1
f[i][j]=k=1∑jf[i−1][k]=f[i][j−1]+f[i−1][j],其中f[1][1]=1
再考虑
∣
S
i
∣
=
n
−
i
|S_i|=n-i
∣Si∣=n−i的情况。这种情况下,若
j
<
∣
S
i
∣
j<|S_i|
j<∣Si∣,那么接下来仍然有j个数可以填,但如果
j
=
∣
S
i
∣
j=|S_i|
j=∣Si∣,接下来就只剩j-1个数字可用了,需要特殊处理。综上,我们会发现f的总递推式:
f
[
i
]
[
j
]
=
{
f
[
i
−
1
]
[
j
]
+
f
[
i
]
[
j
−
1
]
j
<
i
f
[
i
]
[
j
−
1
]
j
=
i
f[i][j]=\left\{ \begin{matrix} f[i-1][j]+f[i][j-1] & j<i\\ f[i][j-1] & j=i \end{matrix} \right.
f[i][j]={f[i−1][j]+f[i][j−1]f[i][j−1]j<ij=i
接下来考虑怎么统计结果。对于它输入的排列,我们先考虑第一位大于输入的情况。比如假设输入数列为p,那么大于第一位数字的答案数量应该就是:
∑
i
=
1
n
−
p
i
f
[
n
]
[
i
]
\sum_{i=1}^{n-p_i}f[n][i]
∑i=1n−pif[n][i],再考虑第一位等于
p
i
p_i
pi的情况。以此类推即可求出最终解。但是有一个特殊的地方,如果遍历到当前序列某个前缀已经不满足要求,必须直接break,以防多算。
举个栗子吧,比如输入的序列是2 4 3 1 5,我们先预处理出f数组:
f
=
1
1
1
1
2
2
1
3
5
5
1
4
9
14
14
f=\begin{matrix} 1\\ 1 & 1\\ 1 & 2 & 2\\ 1 & 3 & 5 & 5\\ 1 & 4 & 9 & 14 & 14 \end{matrix}
f=11111123425951414
对于第一位,我们可以选3,4,5,那
f
[
5
]
[
1
]
+
f
[
5
]
[
2
]
+
f
[
5
]
[
3
]
=
14
f[5][1]+f[5][2]+f[5][3]=14
f[5][1]+f[5][2]+f[5][3]=14;接下来第一位确定为2,再看第二位,可以选择5,答案可以再加上
f
[
4
]
[
1
]
f[4][1]
f[4][1];再看第三位,可以填入5(4已经被使用过),答案加上
f
[
3
]
[
1
]
f[3][1]
f[3][1];但是此时发现前三位是2,4,3了,后面两个位置中必然有一个是1,与题意不符,于是break。答案为16.
这样做的话就可以得到一个
O
(
n
2
)
O(n^2)
O(n2)的dp做法,80分到手了,再加上卡特兰数的规律有84分的好成绩。于是我放弃了继续想这道题,去看了T3,放掉了A掉这道题的机会……早知道这道题多想一会儿了/手动无奈。
##正解
考虑到我们最终只会使用f数组某一行的一个前缀和,于是我们考虑如何快速求出这个前缀和。根据定义,有
∑
i
=
1
m
f
[
n
]
[
i
]
=
f
[
n
+
1
]
[
m
]
\sum_{i=1}^mf[n][i]=f[n+1][m]
i=1∑mf[n][i]=f[n+1][m]
因此我们只要能够快速求出f数组中的某个值即可。观察它的递推式,f[n][m]的值可以看做是从坐标(1,1)走到(n,m),每次可以向右或向上走一格,且不能越过(碰到不算)y=x这条直线的方案数。这是很经典的问题,根据折线定理,
f
[
n
]
[
m
]
=
C
n
+
m
−
2
n
−
1
−
C
n
+
m
−
2
n
f[n][m]=C_{n+m-2}^{n-1}-C_{n+m-2}^n
f[n][m]=Cn+m−2n−1−Cn+m−2n
于是我们预处理一下阶乘和逆元,就可以在O(1)的时间内查询出前缀和,整道题就可以在
O
(
n
)
O(n)
O(n)的时间内解决。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1200005, mod = 998244353;
int num[maxn], vis[maxn], T, n;
ll fact[maxn], rev[maxn];
ll modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
ll get_num(int i, int j){
ll C1 = fact[i + j - 2] * rev[i - 1] % mod * rev[j - 1] % mod;
ll C2 = j > 1 ? fact[i + j - 2] * rev[j - 2] % mod * rev[i] % mod : 0;
return C1 - C2 < 0 ? C1 - C2 + mod : C1 - C2;
}
const int maxr = 10000000;
char str[maxr]; int rpos;
char readc(){
if(!rpos) fread(str, 1, maxr, stdin);
char c = str[rpos++];
if(rpos == maxr) rpos = 0;
return c;
}
int read(){
int x; char c;
while((c = readc()) < '0' || c > '9');
x = c - '0';
while((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
return x;
}
int main(){
T = read();
fact[0] = 1;
for(int i = 1; i < maxn; i++)
fact[i] = fact[i - 1] * i % mod;
rev[maxn - 1] = modpow(fact[maxn - 1], mod - 2);
for(int i = maxn - 2; i >= 0; i--)
rev[i] = rev[i + 1] * (i + 1) % mod;
while(T--){
n = read();
for(int i = 1; i <= n; i++){
num[i] = read();
vis[i] = 0;
}
int mx = 0, mn = 1, res = 0;
for(int i = 1; i <= n; i++){
while(vis[mn]) ++mn;
int id = n - i + 1, cnt = n - max(mx, num[i]);
if(cnt > 0) res += get_num(id + 1, cnt);
if(res >= mod) res -= mod;
vis[num[i]] = 1;
if(num[i] < mx && num[i] > mn) break;
if(num[i] > mx) mx = num[i];
}
printf("%d\n", res);
}
return 0;
}