题意:
给你一个01组成的序列,现在让你进行两两交换后,让序列非递减。求操作的期望
思路:
c
n
t
0
cnt0
cnt0:我们统计有多少个0在序列中,
那么最终的序列要保证前
c
n
t
0
cnt0
cnt0个都是0,
c
n
t
1
cnt1
cnt1:统计前
c
n
t
0
cnt0
cnt0中有多少个1
d
p
[
k
]
dp[k]
dp[k]:在前cnt0个中有k个0,
x
x
x:在前cnt0个中有cnt0 - k个1,
d
p
[
c
n
t
0
]
=
0
dp[cnt0] = 0
dp[cnt0]=0
p
p
p: 每一次交换前cnt0中的1出去成功的概率
p
=
x
∗
x
C
n
2
p = \frac{x*x}{C_{n}^{2}}
p=Cn2x∗x
综上:
d
p
[
k
]
=
1
+
d
p
[
k
]
∗
(
1
−
p
)
+
d
p
[
k
+
1
]
∗
p
dp[k] = 1 + dp[k]*(1-p) + dp[k+1]*p
dp[k]=1+dp[k]∗(1−p)+dp[k+1]∗p
d
p
[
k
]
∗
p
=
1
+
d
p
[
k
+
1
]
∗
p
dp[k]*p=1+dp[k+1]*p
dp[k]∗p=1+dp[k+1]∗p
d
p
[
k
]
=
1
p
+
d
p
[
k
+
1
]
dp[k]=\frac{1}{p} + dp[k+1]
dp[k]=p1+dp[k+1]
a
n
s
=
d
p
[
c
n
t
0
−
c
n
t
1
]
ans = dp[cnt0-cnt1]
ans=dp[cnt0−cnt1]
AC
package com.hgs.codeforces.contest.div2.contest829.e;
/**
* @author youtsuha
* @version 1.0
* Create by 2022/10/23 15:36
*/
import java.util.*;
import java.io.*;
public class Main {
static FastScanner cin;
static PrintWriter cout;
static long frac[];
static long invFrac[];
static long mod = 998244353;
private static void init()throws IOException {
cin = new FastScanner(System.in);
cout = new PrintWriter(System.out);
int n = (int) (2e5+10);
frac = new long[n];
invFrac = new long[n];
frac[0] = 1;
for(int i = 1; i < n; i ++ ) frac[i] = frac[i-1]*i%mod;
invFrac[n-1] = inverse(frac[n-1],mod);
for(int i = n - 2; i >= 0; i -- ) invFrac[i] = invFrac[i+1]*(i+1)%mod;
}
private static void close(){
cout.close();
}
static int arrayCount(int[] a,int len, int v){
int ans = 0;
for(int i = 0; i < len; i ++ ) if(a[i] == v) ans++;
return ans;
}
static long qpow(long a, long k, long p){
long res = 1;
while(k > 0){
if((k&1) == 1) res = res*a%p;
a = a*a%p;
k >>= 1;
}
return res;
}
static long C(int a, int b){
return (frac[b]*invFrac[a]%mod*invFrac[b-a])%mod;
}
static long inverse(long p, long mod){
return qpow(p,mod-2,mod);
}
private static void sol()throws IOException {
int n = cin.nextInt();
int a[] = new int[n];
for(int i = 0; i < n; i ++ ) a[i] = cin.nextInt();
int cnt0 = arrayCount(a, n, 0);
int cnt1 = arrayCount(a,cnt0,1);
long dp[] = new long[n+1];
dp[cnt0] = 0;
// cout.println("CN2:" + C(2,n));
for(int i = cnt0; i > cnt0 - cnt1; i -- ) {
long x = cnt0 - (i-1);
long p = x*x%mod*inverse(C(2,n),mod)%mod;
dp[i-1] = (dp[i] + inverse(p,mod)) % mod;
}
cout.println(dp[cnt0-cnt1]);
}
public static void main(String[] args) throws IOException {
init();
int tt = cin.nextInt();
while(tt-- > 0)sol();
close();
}
}
class FastScanner {
BufferedReader br;
StringTokenizer st = new StringTokenizer("");
public FastScanner(InputStream s) {
br = new BufferedReader(new InputStreamReader(s));
}
public FastScanner(String s) throws FileNotFoundException {
br = new BufferedReader(new FileReader(new File(s)));
}
public String next() throws IOException {
while (!st.hasMoreTokens()){
try {
st = new StringTokenizer(br.readLine());
} catch (IOException e) { e.printStackTrace(); }
}
return st.nextToken();
}
public int nextInt() throws IOException {
return Integer.parseInt(next());
}
public long nextLong() throws IOException {
return Long.parseLong(next());
}
public double nextDouble() throws IOException {
return Double.parseDouble(next());
}
}