【题目链接】
【思路要点】
- 考虑枚举最后剩下的一种球是哪一种球。
- 令 s u m = ∑ i = 1 N a i sum=\sum_{i=1}^{N}a_i sum=∑i=1Nai,问题被转化为了今有 a i a_i ai个黑球和 s u m − a i sum-a_i sum−ai个白球,在最后剩下黑球的情况下期望的操作步数。
- 设 f i f_i fi表示当前有 i i i个黑球和 s u m − a i sum-a_i sum−ai个白球,在最后剩下黑球的情况下期望的操作步数。
- 在这个定义下,有 f s u m = 0 f_{sum}=0 fsum=0,因为我们仅考虑最后剩下黑球的情况,所以同样有 f 0 = 0 f_{0}=0 f0=0。
- 令 p = i ∗ ( s u m − i ) s u m ∗ ( s u m − 1 ) p=\frac{i*(sum-i)}{sum*(sum-1)} p=sum∗(sum−1)i∗(sum−i), p p p表示的是在进行一次操作后黑球多或少一个的概率,它们显然是相等的。
- 我们可以将这个问题转化为在序列上随机游走的问题:一个长度为 s u m + 1 sum+1 sum+1的序列,位置编号为0至 s u m sum sum,我们在点 i ( 0 < i < s u m ) i(0<i<sum) i(0<i<sum)处向前和向后的概率是相等的,在序列的两头会停下来。
- 有结论:在点 i i i处走到点 s u m sum sum的概率为 i s u m \frac{i}{sum} sumi,走到点的概率0为 s u m − i s u m \frac{sum-i}{sum} sumsum−i。
- 证明较为简单:不妨设点 i i i处走到点 s u m sum sum的概率为 p i p_i pi,有 p 0 = 0 , p s u m = 1 , p i = p i − 1 + p i + 1 2 ( 0 < i < s u m ) p_0=0,p_{sum}=1,p_i=\frac{p_{i-1}+p_{i+1}}{2}(0<i<sum) p0=0,psum=1,pi=2pi−1+pi+1(0<i<sum),因此有 p i + 1 − p i = p i − p i − 1 ( 0 < i < s u m ) p_{i+1}-p_i=p_i-p_{i-1}(0<i<sum) pi+1−pi=pi−pi−1(0<i<sum),即 p p p是一个等差数列,则易证上述结论。
- 有了这个结论,我们就可以列出 f i f_i fi的关系式了。
- 有 f i = p ∗ f i − 1 + p ∗ f i + 1 + ( 1 − 2 p ) ∗ f i + i s u m ( 0 < i < s u m ) f_i=p*f_{i-1}+p*f_{i+1}+(1-2p)*f_i+\frac{i}{sum}(0<i<sum) fi=p∗fi−1+p∗fi+1+(1−2p)∗fi+sumi(0<i<sum),注意最后加的不是1,而是点 i i i处走到点 s u m sum sum的概率,因为若走到点0处,就不满足最后剩下黑球的前提了,所以不作统计。
- 这就显然有了一种 O ( ( ∑ a i ) 3 ) O((\sum a_i)^3) O((∑ai)3)的高斯消元的做法。
- 我们发现只要得到了 f 1 f_1 f1,我们就能轻松地解出 f f f的前 a i a_i ai项,回答问题。
- 通过高斯消元的做法打表,我们发现 f 1 = ( s u m − 1 ) 2 s u m f_1=\frac{(sum-1)^2}{sum} f1=sum(sum−1)2(此处笔者并不会证明)。
- 由此解出 f f f的前 a i a_i ai项,答案即为 ∑ i = 1 N f a i \sum_{i=1}^{N}f_{a_i} ∑i=1Nfai。
- 时间复杂度 O ( N + M a x { a i } ) O(N+Max\{a_i\}) O(N+Max{ai})。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 100005; const int P = 1e9 + 7; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int a[MAXN], f[MAXN]; int power(int x, int y) { if (y == 0) return 1; int tmp = power(x, y / 2); if (y % 2 == 0) return 1ll * tmp * tmp % P; else return 1ll * tmp * tmp % P * x % P; } int main() { int n; read(n); int sum = 0, Max = 0; for (int i = 1; i <= n; i++) read(a[i]), sum += a[i], chkmax(Max, a[i]); f[1] = (sum - 1ll) * (sum - 1ll) % P * power(sum, P - 2) % P; for (int i = 1; i < Max; i++) f[i + 1] = (2ll * f[i] - f[i - 1] - (sum - 1ll) * power(sum - i, P - 2) % P + 2 * P) % P; int ans = 0; for (int i = 1; i <= n; i++) ans = (ans + f[a[i]]) % P; writeln(ans); return 0; }