题目链接
题目大意
给定长度为 n − 1 n - 1 n−1 的数组 a , b a, \ b a, b。
现在定义如下三个长度均为 n − 1 n - 1 n−1 的数组 x , y , z x, \ y, \ z x, y, z:
x 1 = a 1 a 1 + b 1 , y 1 = b 1 a 1 + b 1 , z 1 = 0 − − − − − − − − − − − − − − − − − − − − − − − − − x i = a i 2 ( a i + b i ) 2 , y i = 2 ⋅ a i ⋅ b i ( a i + b i ) 2 , z i = b i 2 ( a i + b i ) 2 , i ∈ [ 2 , n ) x_1 = \frac{a_1}{a_1 + b_1}, \ y_1 = \frac{b_1}{a_1 + b_1}, \ z_1 = 0 \\ ------------------------- \\ x_i = \frac{a_i^2}{(a_i+b_i)^2}, \ y_i = \frac{2 \cdot a_i \cdot b_i}{(a_i+b_i)^2}, \ z_i = \frac{b_i^2}{(a_i+b_i)^2}, \ i \in [2, \ n) x1=a1+b1a1, y1=a1+b1b1, z1=0−−−−−−−−−−−−−−−−−−−−−−−−−xi=(ai+bi)2ai2, yi=(ai+bi)22⋅ai⋅bi, zi=(ai+bi)2bi2, i∈[2, n)
显然
x
,
y
,
z
x, \ y, \ z
x, y, z 满足
x
i
+
y
i
+
z
i
=
1
,
∀
i
∈
[
1
,
n
)
x_i + y_i + z_i = 1, \ \forall i \in [1, \ n)
xi+yi+zi=1, ∀i∈[1, n)
一开始我们在 1 1 1。每次执行以下过程算一步。
设现在到了 i i i,若 i = n i = n i=n 则立即停止以下过程,否则继续进行:
- 有 x i x_i xi 的概率到 i + 1 i + 1 i+1
- 有 y i y_i yi 的概率留在 i i i
- 有 z i z_i zi 的概率回到 i − 1 i - 1 i−1
求从 1 1 1 到达 n n n 的期望步数 + 1 +1 +1。
Solution
首先期望 d p dp dp 的常见套路就是反着设和反着更新状态,最后答案为 E 1 E_1 E1。
因此我们设 E i E_i Ei 表示从 i i i 出发到 n n n 的期望步数,那么显然有 E n = 0 E_n = 0 En=0,而答案就是 E 1 + 1 E_1 + 1 E1+1。
考虑状态转移方程( 1 ⩽ i < n 1 \leqslant i < n 1⩽i<n):
E i = x i ( E i + 1 + 1 ) + y i ( E i + 1 ) + z i ( E i − 1 + 1 ) E_i = x_i(E_{i + 1} + 1) + y_i(E_i + 1) + z_i(E_{i - 1} + 1) Ei=xi(Ei+1+1)+yi(Ei+1)+zi(Ei−1+1)
为表达简洁,下面令 y i : = y i − 1 y_i := y_i - 1 yi:=yi−1,化简得到
z i E i − 1 + y i E i + x i E i + 1 = − 1 , i ∈ [ 1 , n ) z_iE_{i - 1}+ y_iE_i + x_iE_{i +1} = -1, \ i \in [1, \ n) ziEi−1+yiEi+xiEi+1=−1, i∈[1, n)
将这 n − 1 n - 1 n−1 行写成一个 ( n − 1 ) × n (n - 1) \times n (n−1)×n 的矩阵(在 i = n − 1 i = n - 1 i=n−1 时 E n = 0 E_n = 0 En=0,所以无需写出 x n − 1 x_{n - 1} xn−1,这样就不用写第 n n n 列):
[ y 1 x 1 − 1 z 2 y 2 x 2 − 1 z 3 y 3 x 3 − 1 ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ − 1 z n − 2 y n − 2 x n − 2 − 1 z n − 1 y n − 1 − 1 ] \left[ \begin{array}{lccccr|c} {y_1} & {x_1} & {} & {} & {} & {} & {-1} \\ {z_2} & {y_2} & {x_2} & {} & {} & {} & {-1} \\ {} & {z_3} & {y_3} & {x_3} & {} & {} & {-1} \\ {\vdots} & {\vdots} & {\vdots} & {\ddots} & {\vdots} & {\vdots} & {-1} \\ {} & {} & {} & z_{n - 2} & y_{n - 2} & x_{n - 2} & {-1} \\ {} & {} & {} & {} & z_{n - 1}& y_{n - 1} & {-1} \end{array} \right] y1z2⋮x1y2z3⋮x2y3⋮x3⋱zn−2⋮yn−2zn−1⋮xn−2yn−1−1−1−1−1−1−1
这其实就是 ( n − 1 ) × ( n − 1 ) (n - 1) \times (n - 1) (n−1)×(n−1) 的系数矩阵的增广阵,现在我们来求解这个 三对角矩阵。
仍然用 x x x 记录对角线上方一侧的系数,同时新增一个长度为 n − 1 n - 1 n−1 的数组(或者叫向量) s s s 表示增广的那一列的值,初始值为 ( − 1 , − 1 , ⋯ , − 1 ) T (-1, \ -1, \ \cdots, \ -1)^{T} (−1, −1, ⋯, −1)T。
一开始 i = 1 i = 1 i=1。现在循环执行以下操作,直至 i = n − 1 i = n - 1 i=n−1。
x i = x i y i s i = s i y i y i + 1 = y i + 1 − z i + 1 ⋅ x i , w h e n i + 1 < n s i + 1 = s i + 1 − z i + 1 ⋅ s i , w h e n i + 1 < n i = i + 1 \begin{align*} x_i &= \frac{x_i}{y_i} \\ s_i &= \frac{s_i}{y_i} \\ y_{i + 1} &= y_{i + 1} - z_{i + 1} \cdot x_i, \ when \ \ i + 1 < n \\ s_{i + 1} &= s_{i + 1} - z_{i + 1} \cdot s_i, \ when \ \ i + 1 < n \\ i &= i + 1 \\ \end{align*} xisiyi+1si+1i=yixi=yisi=yi+1−zi+1⋅xi, when i+1<n=si+1−zi+1⋅si, when i+1<n=i+1
在每一步操作中,首先是将对角线上的 y i y_i yi 归一化,对整行除以 y i y_i yi(但是 y i y_i yi 在之后没用了,所以可以不用管 y i y_i yi,只除其他的);接着是对第 i + 1 i + 1 i+1 行的 z i + 1 z_{i + 1} zi+1 进行消去,但是由于 z z z 不参与其他运算,所以也没必要把 z i + 1 z_{i+1} zi+1 归零。
这样最后我们得到了一个上三角矩阵,加右侧的一列增广。
接下来令 i = n − 2 i = n - 2 i=n−2,循环执行以下操作,直至 i = 0 i = 0 i=0。
s i = s i − s i + 1 ⋅ x i i = i − 1 \begin{align*} s_i &= s_i - s_{i + 1} \cdot x_i \\ i &= i - 1 \end{align*} sii=si−si+1⋅xi=i−1
这一步就是为了继续消去上三角阵对角线上面一侧的系数。
最终我们得到的 s s s 就是 E E E。所以答案就是 s 1 + 1 s_1 + 1 s1+1。
时间复杂度 O ( n l o g P ) O(nlogP) O(nlogP)
- P P P 为模数,有 l o g log log 是因为有除法
C++ Code
- 实现的时候我的下标从 0 0 0 开始
#include <bits/stdc++.h>
using i64 = int64_t;
using u64 = uint64_t;
using f64 = double_t;
using i128 = __int128_t;
template<class T>
constexpr T power(T a, i64 b) {
T res = 1;
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}
template<int P>
struct MInt {
int x;
constexpr MInt() : x{} {}
constexpr MInt(i64 x) : x{norm(x % getMod())} {}
static int Mod;
constexpr static int getMod() {
if (P > 0) {
return P;
} else {
return Mod;
}
}
constexpr static void setMod(int Mod_) {
Mod = Mod_;
}
constexpr int norm(int x) const {
if (x < 0) {
x += getMod();
}
if (x >= getMod()) {
x -= getMod();
}
return x;
}
constexpr int val() const {
return x;
}
explicit constexpr operator int() const {
return x;
}
constexpr MInt operator-() const {
MInt res;
res.x = norm(getMod() - x);
return res;
}
constexpr MInt inv() const {
assert(x != 0);
return power(*this, getMod() - 2);
}
constexpr MInt &operator*=(MInt rhs) & {
x = 1LL * x * rhs.x % getMod();
return *this;
}
constexpr MInt &operator+=(MInt rhs) & {
x = norm(x + rhs.x);
return *this;
}
constexpr MInt &operator-=(MInt rhs) & {
x = norm(x - rhs.x);
return *this;
}
constexpr MInt &operator/=(MInt rhs) & {
return *this *= rhs.inv();
}
friend constexpr MInt operator*(MInt lhs, MInt rhs) {
MInt res = lhs;
res *= rhs;
return res;
}
friend constexpr MInt operator+(MInt lhs, MInt rhs) {
MInt res = lhs;
res += rhs;
return res;
}
friend constexpr MInt operator-(MInt lhs, MInt rhs) {
MInt res = lhs;
res -= rhs;
return res;
}
friend constexpr MInt operator/(MInt lhs, MInt rhs) {
MInt res = lhs;
res /= rhs;
return res;
}
friend constexpr std::istream &operator>>(std::istream &is, MInt &a) {
i64 v;
is >> v;
a = MInt(v);
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, const MInt &a) {
return os << a.val();
}
friend constexpr bool operator==(MInt lhs, MInt rhs) {
return lhs.val() == rhs.val();
}
friend constexpr bool operator!=(MInt lhs, MInt rhs) {
return lhs.val() != rhs.val();
}
};
template<>
int MInt<0>::Mod = 998244353;
template<int V, int P>
constexpr MInt<P> CInv = MInt<P>(V).inv();
constexpr int P = 1000000007;
using Z = MInt<P>;
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout << std::fixed << std::setprecision(12);
int n;
std::cin >> n;
std::vector<int> a(n - 1);
for (int i = 0; i < n - 1; i++) {
std::cin >> a[i];
}
std::vector<int> b(n - 1);
for (int i = 0; i < n - 1; i++) {
std::cin >> b[i];
}
std::vector<Z> x(n - 1);
std::vector<Z> y(n - 1);
std::vector<Z> z(n - 1);
x[0] = Z(a[0]) / (a[0] + b[0]);
y[0] = Z(b[0]) / (a[0] + b[0]) - 1;
for (int i = 1; i < n - 1; i++) {
Z dn = Z(a[i] + b[i]) * (a[i] + b[i]);
x[i] = Z(a[i]) * a[i] / dn;
y[i] = 2 * Z(a[i]) * b[i] / dn - 1;
z[i] = Z(b[i]) * b[i] / dn;
}
std::vector<Z> s(n - 1, -1);
for (int i = 0; i < n - 1; i++) {
s[i] /= y[i];
x[i] /= y[i];
if (i + 1 < n - 1) {
s[i + 1] -= z[i + 1] * s[i];
y[i + 1] -= z[i + 1] * x[i];
}
}
for (int i = n - 3; i >= 0; i--) {
s[i] -= s[i + 1] * x[i];
}
std::cout << s[0] + 1 << "\n";
return 0;
}