Description
牛客网 2018校招真题 无聊的牛牛和羊羊
Solving Ideas
递推
f
(
n
,
m
)
f(n,m)
f(n,m)表示
n
n
n个无聊,
m
m
m个有聊变为全部无聊期望时间
状态
f
(
n
,
m
)
f(n,m)
f(n,m)的下一个状态可能是
f
(
n
,
m
)
f(n, m)
f(n,m),
f
(
n
+
1
,
m
−
1
)
f(n+1, m-1)
f(n+1,m−1),
f
(
n
+
2
,
m
−
2
)
f(n+2, m-2)
f(n+2,m−2)分别对应概率为
p
1
p_1
p1,
p
2
p_2
p2,
p
3
p_3
p3
f ( n , m ) = p 1 ( f ( n , m ) + 1 ) + p 2 ( f ( n + 1 , m − 1 ) + 1 ) + p 3 ( f ( n + 2 , m − 2 ) + 1 ) f(n, m)=p_1(f(n, m)+1)+p_2(f(n+1, m-1)+1)+p_3(f(n+2, m-2)+1) f(n,m)=p1(f(n,m)+1)+p2(f(n+1,m−1)+1)+p3(f(n+2,m−2)+1)
- p 1 = C ( n , 2 ) / C ( m + n , 2 ) p1=C(n, 2)/C(m+n, 2) p1=C(n,2)/C(m+n,2)
- p 2 = C ( n , 1 ) ∗ C ( m , 1 ) / C ( m + n , 2 ) p2=C(n, 1)*C(m, 1)/C(m+n, 2) p2=C(n,1)∗C(m,1)/C(m+n,2)
- p 3 = C ( m , 2 ) / C ( m + n , 2 ) p3=C(m, 2)/C(m+n, 2) p3=C(m,2)/C(m+n,2)
- p 1 + p 2 + p 3 = 1 p1+p2+p3=1 p1+p2+p3=1
化简得:
(
1
−
p
1
)
f
(
n
,
m
)
=
1
+
p
2
∗
f
(
n
+
1
,
m
−
1
)
+
p
3
∗
f
(
n
+
2
,
m
−
2
)
(1-p_1)f(n, m)=1+p_2*f(n+1, m-1)+p_3*f(n+2, m-2)
(1−p1)f(n,m)=1+p2∗f(n+1,m−1)+p3∗f(n+2,m−2)
由 1 < = n , m < = 50 1 <= n, m <= 50 1<=n,m<=50可知, ( n + m − 1 ) ! = 0 (n+m-1) != 0 (n+m−1)!=0,则
- p 1 = n ( n − 1 ) / ( ( n + m ) ( n + m − 1 ) ) p1=n(n-1)/((n+m)(n+m-1)) p1=n(n−1)/((n+m)(n+m−1))
- p 2 = 2 m n / ( ( n + m ) ( n + m − 1 ) ) p2=2mn/((n+m)(n+m-1)) p2=2mn/((n+m)(n+m−1))
- p 3 = m ( m − 1 ) / ( ( n + m ) ∗ ( n + m − 1 ) ) p3=m(m-1)/((n+m)*(n+m-1)) p3=m(m−1)/((n+m)∗(n+m−1))
初始状态:
- f ( n , 0 ) = 0 f(n,0)=0 f(n,0)=0
- f ( n , 1 ) = ( n + m ) / 2 f(n,1)=(n + m) / 2 f(n,1)=(n+m)/2
Time complexity :
O
(
m
)
O(m)
O(m)
Space complexity :
O
(
m
)
O(m)
O(m)
Solution
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
/**
* @author wylu
*/
public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String[] strs = br.readLine().split(" ");
int n = Integer.parseInt(strs[0]), m = Integer.parseInt(strs[1]);
int total = n + m;
double[] f = new double[m + 1];
f[0] = 0;
f[1] = total / 2.0;
double s = total * (total - 1) / 2.0;
for (int i = 2; i <= m; i++) {
double tmp = total - i;
double p1 = tmp * (tmp - 1) / 2.0 / s;
double p2 = tmp * i / s;
double p3 = i * (i - 1) / 2.0 / s;
f[i] = (1 + p2 * f[i - 1] + p3 * f[i - 2]) / (1 - p1);
}
System.out.println(String.format("%.1f", f[m]));
}
}