题目链接
首先把
k
k
k 设为
n
+
k
2
\frac{n+k}{2}
2n+k
拿到手第一想法是直接求刚好有
k
k
k 个大于的情况,设为
f
i
f_i
fi
单纯的组合计数貌似比较难搞。。。
考虑有至少
k
k
k 个大于的情况,设为
g
i
g_i
gi
似乎两者有关联?
很明显
g
i
=
∑
j
=
i
n
f
j
g_i = \displaystyle\sum_{j=i}^{n}f_j
gi=j=i∑nfj 嘛
如果能求出
g
i
g_i
gi 也就可以求出
f
i
f_i
fi 啦
那我怎么求
g
i
g_i
gi 呢?
现在比较好求的是在有
n
n
n 个元素的
a
a
a 组中选
k
k
k 个大于出来
可以列出DP方程:
d
p
[
i
]
[
j
]
=
d
p
[
i
−
1
]
[
j
]
+
d
p
[
i
−
1
]
[
j
−
1
]
×
(
t
[
i
]
−
j
+
1
)
dp[i][j]=dp[i-1][j]+dp[i-1][j-1]\times(t[i]-j+1)
dp[i][j]=dp[i−1][j]+dp[i−1][j−1]×(t[i]−j+1)
(
t
[
i
]
t[i]
t[i] 表示
a
a
a 的第
i
i
i 个元素比多少个
b
b
b 的元素大)
剩下的
n
−
k
n-k
n−k 个元素怎么办?
不能确定的话,要不要都弄上去算了?
重新定义一次
g
i
=
d
p
[
n
]
[
i
]
×
(
n
−
i
)
!
g_i=dp[n][i]\times(n-i)!
gi=dp[n][i]×(n−i)!
这是个什么东西?
部分确定部分不确定的,有什么用?
不是要求至少有
k
k
k 个大于的吗,这有关联吗?
的确没有
但是这个
g
g
g 直接和最终要求的
f
f
f 有关联!
显然这个
g
i
=
∑
j
=
i
n
c
j
f
j
g_i=\displaystyle\sum_{j=i}^{n}c_jf_j
gi=j=i∑ncjfj
这里有个常数,因为实际上
g
i
g_i
gi还算了很多重复的东西
对于一个
f
j
f_j
fj,它在
g
i
g_i
gi 中出现了不止一次,为什么呢?
g
i
g_i
gi 实际上是确定了
i
i
i 个大于号后,全排了剩下的
剩下的肯定还会产生一些大于号
对于
f
j
f_j
fj,它计录的是刚好有
j
j
j 个大于号
但是在
g
i
g_i
gi 中,只确定了
j
j
j 个中的
i
i
i 个,剩下的
j
−
i
j-i
j−i 个来自于全排列部分
这样一来,在
g
i
g_i
gi 中
f
j
f_j
fj 出现了恰好
(
j
i
)
j\choose{i}
(ij) 次
所以
g
i
=
∑
j
=
i
n
(
j
i
)
f
j
g_i=\displaystyle\sum_{j=i}^{n}{j\choose{i}}f_j
gi=j=i∑n(ij)fj
只剩最后一步了,由二项式反演公式
f
i
=
∑
j
=
i
n
(
−
1
)
j
−
i
(
j
i
)
g
j
f_i=\displaystyle\sum_{j=i}^{n}(-1)^{j-i}{j\choose{i}}g_j
fi=j=i∑n(−1)j−i(ij)gj
就结束啦
(证明方法可以参考这里)
/*
Created 2019-1-6
"已经没有什么好害怕的了"
*/
#include <bits/stdc++.h>
using namespace std;
const int N = 2000 + 5;
const int mod = 1e9 + 9;
int n, k;
int f[N][N], ans;
int a[N], b[N], t[N];
int fac[N], inv[N];
int power(int a, int n) {
int b = 1;
while (n) {
if (n & 1) {
b = 1LL * b * a % mod;
}
a = 1LL * a * a % mod;
n >>= 1;
}
return b;
}
int C(int n, int k) {
return 1LL * fac[n] * inv[k] % mod * inv[n-k] % mod;
}
int main() {
freopen("read.in", "r", stdin);
scanf("%d %d", &n, &k);
if ((n+k) & 1) {
puts("0");
return 0;
}
k = (n + k) / 2;
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
sort(a+1, a+1+n);
for (int i = 1; i <= n; i++) {
scanf("%d", &b[i]);
}
sort(b+1, b+1+n);
for (int i = 1; i <= n; i++) {
t[i] = t[i-1];
while (t[i] < n && a[i] > b[t[i]+1]) {
t[i]++;
}
}
f[0][0] = 1;
for (int i = 1; i <= n; i++) {
for (int j = 0; j <= n; j++) {
f[i][j] = f[i-1][j];
if (j && t[i] - j + 1) {
f[i][j] = (f[i][j] + 1LL * f[i-1][j-1] * (t[i] - j + 1) % mod) % mod;
}
}
}
fac[0] = 1;
for (int i = 1; i <= n; i++) {
fac[i] = 1LL * fac[i-1] * i % mod;
}
inv[n] = power(fac[n], mod-2);
for (int i = n; i; i--) {
inv[i-1] = 1LL * inv[i] * i % mod;
}
for (int i = 1; i <= n; i++) {
f[n][i] = 1LL * f[n][i] * fac[n-i] % mod;
}
for (int i = k; i <= n; i++) {
ans = (ans + 1LL * ((i - k) & 1 ? -1 : 1) * C(i, k) * f[n][i]) % mod;
}
ans = (ans + mod) % mod;
printf("%d\n", ans);
return 0;
}