给一个长度为 n ( n ≤ 100 ) n (n\leq 100) n(n≤100) 的 0 / 1 0/1 0/1 串,进行 k ( k ≤ 1 0 9 ) k (k \leq 10^9) k(k≤109) 次操作,每次操作选择两个位置 i , j ( 1 ≤ i < j ≤ n ) i,j (1 \leq i < j \leq n) i,j(1≤i<j≤n),交换 i , j i,j i,j 上的数,求 k k k 次操作后,该 0 / 1 0/1 0/1 串变成非降序列的概率,答案对 1 0 9 + 7 10^9+7 109+7 取模。
sol
好题,概率 dp
。
设有 m m m 个 0 0 0,那么题意就是让 a [ 1 , m ] a[1,m] a[1,m] 均为 0 0 0, a [ m + 1 , n ] a[m+1,n] a[m+1,n] 均为 1 1 1。
令 f i , j f_{i,j} fi,j 表示 i i i 个操作后,前 m m m 个数中有 j j j 个 0 0 0 的方案数,答案即为 f k , m ∑ i = 0 m f k , i \frac{f_{k,m}}{\sum\limits_{i=0}^{m}f_{k,i}} i=0∑mfk,ifk,m,边界: f 0 , p = 1 f_{0,p}=1 f0,p=1, p p p 为原序列前 m m m 个数中 0 0 0 的个数。
对于 f i , j f_{i,j} fi,j,考虑它是如何转移来的:
- 之前有 j − 1 j-1 j−1 个 0 0 0,第 i i i 次交换换来一个 0 0 0,由于前面 1 1 1 的个数与后面 0 0 0 的个数均为 m − j + 1 m-j+1 m−j+1,顾方案数为 f i − 1 , j − 1 × ( m − j + 1 ) 2 f_{i-1,j-1}\times (m-j+1)^2 fi−1,j−1×(m−j+1)2。
- 之前有 j + 1 j + 1 j+1 个 0 0 0,第 i i i 次交换换走一个 0 0 0,由于前面有 j + 1 j+1 j+1 个 0 0 0,后面有 n − m − ( m − j − 1 ) = n − 2 m + j + 1 n-m-(m-j-1)=n-2m+j+1 n−m−(m−j−1)=n−2m+j+1 个 1 1 1,顾方案数为 f i − 1 , j + 1 × ( j + 1 ) ( n − 2 m + j + 1 ) f_{i-1,j+1}\times (j+1)(n-2m+j+1) fi−1,j+1×(j+1)(n−2m+j+1)。
- 之前本来就有 j j j 个 0 0 0,第 i i i 次操作没换走也没换来,四种情况:前面交换,后面交换,前后交换 0 0 0,前后交换 1 1 1,则方案数为 C m 2 + C n − m 2 + j ( m − j ) + ( m − j ) ( n − 2 m + j ) C_{m}^{2}+C_{n-m}^{2}+j(m-j)+(m-j)(n-2m+j) Cm2+Cn−m2+j(m−j)+(m−j)(n−2m+j)。
到这里差点结束了,总结: f i , j = f i − 1 , j − 1 × ( m − j + 1 ) 2 + f i − 1 , j + 1 × ( j + 1 ) ( n − 2 m + j + 1 ) + C m 2 + C n − m 2 + j ( m − j ) + ( m − j ) ( n − 2 m + j ) f_{i,j}=f_{i-1,j-1}\times (m-j+1)^2+f_{i-1,j+1}\times (j+1)(n-2m+j+1)+C_{m}^{2}+C_{n-m}^{2}+j(m-j)+(m-j)(n-2m+j) fi,j=fi−1,j−1×(m−j+1)2+fi−1,j+1×(j+1)(n−2m+j+1)+Cm2+Cn−m2+j(m−j)+(m−j)(n−2m+j)。
考虑到
k
≤
1
0
9
k\leq 10^9
k≤109,经验告诉我们直接上矩阵快速幂,毕竟这转移无需判断。
[
c
0
b
1
0
0
⋯
0
a
0
c
1
b
2
0
⋯
0
0
a
1
c
2
b
3
⋯
0
0
0
a
2
c
3
⋯
0
0
0
0
a
3
⋯
0
⋯
⋯
⋯
⋯
⋯
⋯
0
0
0
0
⋯
c
m
]
×
[
f
i
−
1
,
0
f
i
−
1
,
1
f
i
−
1
,
2
f
i
−
1
,
3
f
i
−
1
,
4
⋯
f
i
−
1
,
m
]
=
[
f
i
,
0
f
i
,
1
f
i
,
2
f
i
,
3
f
i
,
4
⋯
f
i
,
m
]
\begin{bmatrix}c_0& b_1& 0 & 0&\cdots & 0\\ a_0 & c_1 & b_2 & 0&\cdots & 0\\ 0 & a_1 & c_2 & b_3 & \cdots & 0\\ 0&0&a_2&c_3&\cdots&0\\ 0&0&0&a_3&\cdots&0\\ \cdots&\cdots&\cdots&\cdots&\cdots&\cdots\\ 0&0&0&0&\cdots&c_m\end{bmatrix} \times \begin{bmatrix}f_{i-1,0}\\f_{i-1,1}\\f_{i-1,2}\\f_{i-1,3}\\f_{i-1,4}\\ \cdots \\f_{i-1,m}\end{bmatrix}=\begin{bmatrix}f_{i,0}\\f_{i,1}\\f_{i,2}\\f_{i,3}\\f_{i,4}\\ \cdots \\f_{i,m}\end{bmatrix}
⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡c0a0000⋯0b1c1a100⋯00b2c2a20⋯000b3c3a3⋯0⋯⋯⋯⋯⋯⋯⋯00000⋯cm⎦⎥⎥⎥⎥⎥⎥⎥⎥⎤×⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡fi−1,0fi−1,1fi−1,2fi−1,3fi−1,4⋯fi−1,m⎦⎥⎥⎥⎥⎥⎥⎥⎥⎤=⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡fi,0fi,1fi,2fi,3fi,4⋯fi,m⎦⎥⎥⎥⎥⎥⎥⎥⎥⎤
其中
a
i
=
(
m
−
i
)
2
a_i=(m-i)^2
ai=(m−i)2,
b
i
=
i
(
n
−
2
m
+
i
)
b_i=i(n-2m+i)
bi=i(n−2m+i),
c
i
=
C
m
2
+
C
n
−
m
2
+
i
(
m
−
i
)
+
(
m
−
i
)
(
n
−
2
m
+
i
)
c_i=C_{m}^{2}+C_{n-m}^{2}+i(m-i)+(m-i)(n-2m+i)
ci=Cm2+Cn−m2+i(m−i)+(m−i)(n−2m+i)。
时间复杂度 O ( n 3 log k ) \mathcal O(n^3\log k) O(n3logk)。
#include <bits/stdc++.h>
using namespace std;
#define int long long
inline int read() {
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return x * f;
}
const int _ = 107, p = 1e9 + 7;
int n, k, m, t, Ans, A[_], a[_], b[_], c[_];
struct matrix {
int a[_][_];
void init() {
memset(a, 0, sizeof a);
}
matrix mul(matrix x) {
matrix res;
res.init();
for (int i = 0; i <= m; ++i)
for (int j = 0; j <= m; ++j)
for (int k = 0; k <= m; ++k)
res.a[i][j] = (res.a[i][j] + a[i][k] * x.a[k][j] % p) % p;
return res;
}
} ans, base;
void Qpow(int k) {
while (k) {
if (k & 1)
ans = ans.mul(base);
k >>= 1, base = base.mul(base);
}
}
inline int qpow(int x, int y) {
int res = 1;
while (y) {
if (y & 1)
res = res * x % p;
y >>= 1, x = x * x % p;
}
return res;
}
signed main() {
n = read(), k = read();
for (int i = 1; i <= n; ++i) A[i] = read();
for (int i = 1; i <= n; ++i) m += (A[i] == 0);
for (int i = 1; i <= m; ++i) t += (A[i] == 0);
for (int i = 0; i <= m; ++i) {
a[i] = (m - i) * (m - i) % p;
b[i] = i * (n - 2 * m + i) % p;
c[i] = (m * (m - 1) * qpow(2, p - 2) % p + (n - m) * (n - m - 1) * qpow(2, p - 2) % p + (m - i) * (n - 2 * m + 2 * i) % p) % p;
}
ans.init();
for (int i = 0; i <= m; ++i) ans.a[i][i] = 1;
base.init();
base.a[0][0] = c[0], base.a[0][1] = b[1];
base.a[m][m - 1] = a[m - 1], base.a[m][m] = c[m];
for (int i = 1; i <= m - 1; ++i) {
base.a[i][i - 1] = a[i - 1];
base.a[i][i] = c[i];
base.a[i][i + 1] = b[i + 1];
}
Qpow(k);
base.init();
base.a[t][0] = 1;
ans = ans.mul(base);
for (int i = 0; i <= m; ++i)
Ans = (Ans + ans.a[i][0]) % p;
printf("%lld", ans.a[m][0] * qpow(Ans, p - 2) % p);
return 0;
}