OBST
最优二叉排序树(OBST):在BST的前提下,每个节点有被搜索的概率,然后要找到一个BST,使得带权搜索路径最小
K
[
1
⋯
n
]
=
<
k
1
,
⋯
,
k
n
>
(
k
1
<
k
2
<
⋯
<
k
n
)
K\left[1\cdots n \right]=<k_1,\cdots,k_n>(k_1<k_2<\cdots<k_n)
K[1⋯n]=<k1,⋯,kn>(k1<k2<⋯<kn)
k
i
k_i
ki是BST上的节点(key)
D
[
0
⋯
n
]
=
<
d
0
,
⋯
,
d
n
>
D\left[0\cdots n\right]=<d_0,\cdots,d_n>
D[0⋯n]=<d0,⋯,dn>
其中
k
1
<
d
0
k_1<d_0
k1<d0
d
n
>
k
n
d_n>k_n
dn>kn
k
i
<
d
i
<
k
i
+
1
(
i
=
1
,
2
,
⋯
,
n
−
1
)
k_i<d_i<k_{i+1}(i=1,2,\cdots,n-1)
ki<di<ki+1(i=1,2,⋯,n−1)
d
i
d_i
di表示的是搜索失败的节点(虚节点)
用
p
i
p_i
pi表示
k
i
k_i
ki被搜索到的概率
q
i
q_i
qi表示
d
i
d_i
di被搜索到的概率
∑
i
=
1
n
p
i
+
∑
i
=
0
n
q
i
=
1
\sum_{i=1}^{n}p_i+\sum_{i=0}^{n}q_i=1
i=1∑npi+i=0∑nqi=1
带权路径
E
T
=
∑
i
=
1
n
d
e
p
t
h
T
(
k
i
)
p
i
+
∑
i
=
0
n
d
e
p
t
h
T
(
d
i
)
q
i
E_{T}=\sum_{i=1}^{n}depth_{T}(k_i)p_i+\sum_{i=0}^{n}depth_{T}(d_i)q_i
ET=i=1∑ndepthT(ki)pi+i=0∑ndepthT(di)qi
其中根节点的深度为1
我们的目标就是让这个带权路径最小
令
e
[
i
,
j
]
e\left[i,j\right]
e[i,j]表示包含
K
[
i
,
j
]
K\left[i,j\right]
K[i,j]的OBST的带权路径,
w
[
i
,
j
]
=
∑
l
=
i
j
p
l
+
∑
l
=
i
−
1
j
q
l
w\left[i,j\right]=\sum_{l=i}^{j}p_l+\sum_{l=i-1}^{j}q_l
w[i,j]=∑l=ijpl+∑l=i−1jql,
T
i
j
T_{i}^{j}
Tij表示包含
k
i
k_i
ki到
k
j
k_j
kj的树
r
[
i
,
j
]
r\left[i,j\right]
r[i,j]表示
T
i
j
T_i^j
Tij的OBST根节点
则
(1)
e
[
i
,
i
−
1
]
=
q
i
−
1
e\left[i,i-1\right]=q_{i-1}
e[i,i−1]=qi−1
(2)当
i
≤
j
i\le j
i≤j
e
[
i
,
j
]
=
min
i
≤
r
≤
j
(
∑
l
=
i
r
−
1
(
d
e
p
t
h
T
i
r
−
1
(
k
l
)
+
1
)
p
l
+
∑
l
=
r
+
1
j
(
d
e
p
t
h
T
r
+
1
j
(
k
l
)
+
1
)
p
l
+
∑
l
=
i
−
1
r
−
1
(
d
e
p
t
h
T
i
r
−
1
(
d
l
)
+
1
)
q
l
+
∑
l
=
r
j
(
d
e
p
t
h
T
r
+
1
j
(
d
l
)
+
1
)
q
l
+
p
r
)
=
min
i
≤
r
≤
j
(
∑
l
=
i
r
−
1
d
e
p
t
h
T
i
r
−
1
(
k
l
)
p
l
+
∑
l
=
i
−
1
r
−
1
d
e
p
t
h
T
i
r
−
1
(
d
l
)
q
l
+
∑
l
=
r
+
1
j
d
e
p
t
h
T
r
+
1
j
(
k
l
)
p
l
+
∑
l
=
r
j
d
e
p
t
h
T
r
+
1
j
(
d
l
)
q
l
+
∑
l
=
i
j
p
l
+
∑
l
=
i
−
1
j
q
l
)
=
min
i
≤
r
≤
j
(
e
[
i
,
r
−
1
]
+
e
[
r
+
1
,
j
]
+
w
[
i
,
j
]
)
\begin{aligned} e\left[i,j\right]&=\min\limits_{i\le r \le j}(\sum_{l=i}^{r-1}(depth_{T_{i}^{r-1}}(k_l)+1)p_l +\sum_{l=r+1}^{j}(depth_{T_{r+1}^{j}}(k_l)+1)p_l+\\ &\quad\quad \sum_{l=i-1}^{r-1}(depth_{T_{i}^{r-1}}(d_l)+1)q_l +\sum_{l=r}^{j}(depth_{T_{r+1}^{j}}(d_l)+1)q_l+\\ &\quad\quad p_r)\\ &=\min\limits_{i\le r \le j}(\sum_{l=i}^{r-1}depth_{T_{i}^{r-1}}(k_l)p_l +\sum_{l=i-1}^{r-1}depth_{T_{i}^{r-1}}(d_l)q_l+\\ &\quad\quad\sum_{l=r+1}^{j}depth_{T_{r+1}^{j}}(k_l)p_l +\sum_{l=r}^{j}depth_{T_{r+1}^{j}}(d_l)q_l+\\ &\quad\quad \sum_{l=i}^{j}p_l+\sum_{l=i-1}^{j}q_l)\\ &=\min\limits_{i\le r \le j}(e\left[i,r-1\right]+e\left[r+1,j\right]+w\left[i,j\right]) \end{aligned}
e[i,j]=i≤r≤jmin(l=i∑r−1(depthTir−1(kl)+1)pl+l=r+1∑j(depthTr+1j(kl)+1)pl+l=i−1∑r−1(depthTir−1(dl)+1)ql+l=r∑j(depthTr+1j(dl)+1)ql+pr)=i≤r≤jmin(l=i∑r−1depthTir−1(kl)pl+l=i−1∑r−1depthTir−1(dl)ql+l=r+1∑jdepthTr+1j(kl)pl+l=r∑jdepthTr+1j(dl)ql+l=i∑jpl+l=i−1∑jql)=i≤r≤jmin(e[i,r−1]+e[r+1,j]+w[i,j])
所以
e
[
i
,
j
]
=
{
q
i
−
1
,
j
=
i
−
1
min
i
≤
r
≤
j
(
e
[
i
,
r
−
1
]
+
e
[
r
+
1
,
j
]
)
+
w
[
i
,
j
]
,
i
≤
j
e\left[i,j\right]=\begin{cases} q_{i-1},&j=i-1\\ \min\limits_{i\le r \le j}(e\left[i,r-1\right]+e\left[r+1,j\right])+w\left[i,j\right],&i\le j \end{cases}
e[i,j]={qi−1,i≤r≤jmin(e[i,r−1]+e[r+1,j])+w[i,j],j=i−1i≤j
可以看出,如果根节点是固定的,那么左右子树都是OBST,也就是是想找到
i
i
i到
j
j
j的OBST,可以枚举根节点,算左右两边的OBST
于是dp方程就有了
#include <iostream>
#include <cstring>
#include <cfloat>
#include <vector>
using namespace std;
const int N = 1000;
//假设有n个key,那数组要e[n+2][n+2],w[n+2][n+2],r[n+2][n+2]
double e[N][N];
double w[N][N];
int r[N][N];
/**
* @description: 最优二叉排序树
* @param {double} *p 找到的key权重
* @param {double} *q 找不到的key的权重
* @param {int} n 找到的key的个数
*/
void build_obst(double *p, double *q, int n)
{
// memset(e,0,sizeof(e));
// memset(w,0,sizeof(w));
for (int i = 1; i <= n + 1; ++i)
{
e[i][i - 1] = q[i - 1];
w[i][i - 1] = q[i - 1];
}
for (int l = 1; l <= n; ++l)
{
for (int i = 1; i <= n - l + 1; ++i)
{
int j = i + l - 1;
double ans = DBL_MAX;
w[i][j] = w[i][j - 1] + p[j] + q[j];
for(int cur_root=i;cur_root<=j;++cur_root){
double t=e[i][cur_root-1]+e[cur_root+1][j];
if(t<ans){
ans=t;
r[i][j]=cur_root;
}
}
e[i][j]=ans+w[i][j];
}
}
}
int main()
{
/**
7
0.04 0.06 0.08 0.02 0.10 0.12 0.14
0.06 0.06 0.06 0.06 0.05 0.05 0.05 0.05
*/
double p[N];
double q[N];
int n;
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
{
scanf("%lf", &p[i]);
}
for (int i = 0; i <= n; ++i)
{
scanf("%lf", &q[i]);
}
build_obst(p, q, n);
for (int i = 0; i <= n + 1; ++i)
{
for (int j = 0; j <= n + 1; ++j)
{
printf("%lf ", e[i][j]);
}
printf("\n");
}
printf("\n");
for (int i = 0; i <= n + 1; ++i)
{
for (int j = 0; j <= n + 1; ++j)
{
printf("%lf ", w[i][j]);
}
printf("\n");
}
printf("\n");
for (int i = 0; i <= n + 1; ++i)
{
for (int j = 0; j <= n + 1; ++j)
{
printf("%d ", r[i][j]);
}
printf("\n");
}
return 0;
}
时间复杂度
∑
l
=
1
n
∑
i
=
1
n
−
l
+
1
∑
r
=
i
i
+
l
−
1
1
=
∑
l
=
1
n
∑
i
=
1
n
−
l
+
1
l
=
∑
l
=
1
n
(
n
−
l
+
1
)
l
=
∑
l
=
1
n
(
(
n
+
1
)
l
−
l
2
)
=
(
n
+
1
)
(
1
+
n
)
n
2
−
n
(
n
+
1
)
(
2
n
+
1
)
6
=
n
(
n
+
1
)
(
n
+
2
)
6
\begin{aligned} &\quad \sum_{l=1}^{n}\sum_{i=1}^{n-l+1}\sum_{r=i}^{i+l-1}1\\ &=\sum_{l=1}^{n}\sum_{i=1}^{n-l+1}l\\ &=\sum_{l=1}^{n}(n-l+1)l\\ &=\sum_{l=1}^{n}((n+1)l-l^2)\\ &=\frac{(n+1)(1+n)n}{2}-\frac{n(n+1)(2n+1)}{6}\\ &=\frac{n(n+1)(n+2)}{6} \end{aligned}
l=1∑ni=1∑n−l+1r=i∑i+l−11=l=1∑ni=1∑n−l+1l=l=1∑n(n−l+1)l=l=1∑n((n+1)l−l2)=2(n+1)(1+n)n−6n(n+1)(2n+1)=6n(n+1)(n+2)
所以是
Θ
(
n
3
)
\Theta(n^3)
Θ(n3)
空间复杂度
O
(
n
2
)
O(n^2)
O(n2)
优化
r
[
i
]
[
j
−
1
]
≤
r
[
i
]
[
j
]
≤
r
[
i
+
1
]
[
j
]
r[i][j-1]\le r[i][j] \le r[i+1][j]
r[i][j−1]≤r[i][j]≤r[i+1][j]
证明:
再说
#include <iostream>
#include <cstring>
#include <cfloat>
#include <vector>
using namespace std;
const int N = 1000;
//假设有n个key,那数组要e[n+2][n+2],w[n+2][n+2],r[n+2][n+2]
double e[N][N];
double w[N][N];
int r[N][N];
/**
* @description: 最优二叉排序树
* @param {double} *p 找到的key权重
* @param {double} *q 找不到的key的权重
* @param {int} n 找到的key的个数
*/
void build_obst(double *p, double *q, int n)
{
// memset(e,0,sizeof(e));
// memset(w,0,sizeof(w));
for (int i = 1; i <= n + 1; ++i)
{
e[i][i - 1] = q[i - 1];
w[i][i - 1] = q[i - 1];
}
for (int l = 1; l <= n; ++l)
{
for (int i = 1; i <= n - l + 1; ++i)
{
int j = i + l - 1;
double ans = DBL_MAX;
w[i][j] = w[i][j - 1] + p[j] + q[j];
// for(int cur_root=i;cur_root<=j;++cur_root){
// double t=e[i][cur_root-1]+e[cur_root+1][j];
// if(t<ans){
// ans=t;
// r[i][j]=cur_root;
// }
// }
// e[i][j]=ans+w[i][j];
if (i < j)
{
for (int cur_root = r[i][j - 1]; cur_root <= r[i + 1][j]; ++cur_root)
{
double t = e[i][cur_root - 1] + e[cur_root + 1][j];
if (t < ans)
{
ans = t;
r[i][j] = cur_root;
}
}
e[i][j] = ans + w[i][j];
}
else
{ //i==j
e[i][j] = e[i][j - 1] + e[j + 1][j] + w[i][j];
r[i][j] = i;
}
}
}
}
int main()
{
/**
7
0.04 0.06 0.08 0.02 0.10 0.12 0.14
0.06 0.06 0.06 0.06 0.05 0.05 0.05 0.05
*/
double p[N];
double q[N];
int n;
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
{
scanf("%lf", &p[i]);
}
for (int i = 0; i <= n; ++i)
{
scanf("%lf", &q[i]);
}
build_obst(p, q, n);
for (int i = 0; i <= n + 1; ++i)
{
for (int j = 0; j <= n + 1; ++j)
{
printf("%lf ", e[i][j]);
}
printf("\n");
}
printf("\n");
for (int i = 0; i <= n + 1; ++i)
{
for (int j = 0; j <= n + 1; ++j)
{
printf("%lf ", w[i][j]);
}
printf("\n");
}
printf("\n");
for (int i = 0; i <= n + 1; ++i)
{
for (int j = 0; j <= n + 1; ++j)
{
printf("%d ", r[i][j]);
}
printf("\n");
}
return 0;
}
时间复杂度
T
=
∑
l
=
2
n
∑
i
=
1
n
−
l
+
1
∑
k
=
r
[
i
]
[
j
−
1
]
r
[
i
+
1
]
[
j
]
1
+
∑
l
=
1
1
∑
i
=
1
n
−
l
+
1
1
=
n
+
∑
l
=
2
n
∑
i
=
1
n
−
l
+
1
(
r
[
i
+
1
]
[
j
]
−
r
[
i
]
[
j
−
1
]
+
1
)
=
n
+
∑
l
=
2
n
(
n
−
l
+
1
)
+
∑
l
=
2
n
∑
i
=
1
n
−
l
+
1
(
r
[
i
+
1
]
[
j
]
−
r
[
i
]
[
j
−
1
]
)
=
n
+
(
n
+
1
)
(
n
−
1
)
−
(
n
−
1
)
(
2
+
n
)
2
+
∑
l
=
2
n
∑
i
=
1
n
−
l
+
1
(
r
[
i
+
1
]
[
j
]
−
r
[
i
]
[
j
−
1
]
)
=
n
(
n
+
1
)
2
+
∑
l
=
2
n
∑
i
=
1
n
−
l
+
1
(
r
[
i
+
1
]
[
j
]
−
r
[
i
]
[
j
−
1
]
)
\begin{aligned} T&=\sum_{l=2}^{n}\sum_{i=1}^{n-l+1}\sum_{k=r[i][j-1]}^{r[i+1][j]}1+\sum_{l=1}^{1}\sum_{i=1}^{n-l+1}1\\ &=n+\sum_{l=2}^{n}\sum_{i=1}^{n-l+1}(r[i+1][j]-r[i][j-1]+1)\\ &=n+\sum_{l=2}^{n}(n-l+1)+\sum_{l=2}^{n}\sum_{i=1}^{n-l+1}(r[i+1][j]-r[i][j-1])\\ &=n+(n+1)(n-1)-\frac{(n-1)(2+n)}{2}+\sum_{l=2}^{n}\sum_{i=1}^{n-l+1}(r[i+1][j]-r[i][j-1])\\ &=\frac{n(n+1)}{2}+\sum_{l=2}^{n}\sum_{i=1}^{n-l+1}(r[i+1][j]-r[i][j-1]) \end{aligned}
T=l=2∑ni=1∑n−l+1k=r[i][j−1]∑r[i+1][j]1+l=1∑1i=1∑n−l+11=n+l=2∑ni=1∑n−l+1(r[i+1][j]−r[i][j−1]+1)=n+l=2∑n(n−l+1)+l=2∑ni=1∑n−l+1(r[i+1][j]−r[i][j−1])=n+(n+1)(n−1)−2(n−1)(2+n)+l=2∑ni=1∑n−l+1(r[i+1][j]−r[i][j−1])=2n(n+1)+l=2∑ni=1∑n−l+1(r[i+1][j]−r[i][j−1])
考察
∑
i
=
1
n
−
l
+
1
(
r
[
i
+
1
]
[
j
]
−
r
[
i
]
[
j
−
1
]
)
\sum_{i=1}^{n-l+1}(r[i+1][j]-r[i][j-1])
∑i=1n−l+1(r[i+1][j]−r[i][j−1])的每一项
r
[
2
]
[
l
]
−
r
[
1
]
[
l
−
1
]
+
r
[
3
]
[
l
+
1
]
−
r
[
2
]
[
l
]
+
r
[
4
]
[
l
+
2
]
−
r
[
3
]
[
l
+
1
]
+
⋮
r
[
n
−
l
+
2
]
[
n
]
−
r
[
n
−
l
+
1
]
[
n
−
1
]
r[2][l]-r[1][l-1]+\\ r[3][l+1]-r[2][l]+\\ r[4][l+2]-r[3][l+1]+\\ \vdots\\ r[n-l+2][n]-r[n-l+1][n-1]
r[2][l]−r[1][l−1]+r[3][l+1]−r[2][l]+r[4][l+2]−r[3][l+1]+⋮r[n−l+2][n]−r[n−l+1][n−1]
所以
∑
i
=
1
n
−
l
+
1
(
r
[
i
+
1
]
[
j
]
−
r
[
i
]
[
j
−
1
]
)
=
r
[
n
−
l
+
2
]
[
n
]
−
r
[
1
]
[
l
−
1
]
\sum_{i=1}^{n-l+1}(r[i+1][j]-r[i][j-1])=r[n-l+2][n]-r[1][l-1]
∑i=1n−l+1(r[i+1][j]−r[i][j−1])=r[n−l+2][n]−r[1][l−1]
所以
T
=
n
(
n
+
1
)
2
+
∑
l
=
2
n
(
r
[
n
−
l
+
2
]
[
n
]
−
r
[
1
]
[
l
−
1
]
)
T=\frac{n(n+1)}{2}+\sum_{l=2}^{n}(r[n-l+2][n]-r[1][l-1])
T=2n(n+1)+l=2∑n(r[n−l+2][n]−r[1][l−1])
显然
T
≤
n
(
n
+
1
)
2
+
∑
l
=
2
n
(
n
−
1
)
=
3
n
2
−
3
n
+
2
2
T\le \frac{n(n+1)}{2}+\sum_{l=2}^{n}(n-1)=\frac{3n^2-3n+2}{2}
T≤2n(n+1)+l=2∑n(n−1)=23n2−3n+2
根据
r
[
i
]
[
j
−
1
]
≤
r
[
i
]
[
j
]
≤
r
[
i
+
1
]
[
j
]
r[i][j-1]\le r[i][j] \le r[i+1][j]
r[i][j−1]≤r[i][j]≤r[i+1][j]
有
r
[
i
+
1
]
[
j
]
−
r
[
i
]
[
j
−
1
]
≥
0
r[i+1][j]-r[i][j-1] \ge 0
r[i+1][j]−r[i][j−1]≥0
T
=
n
(
n
+
1
)
2
+
∑
l
=
2
n
∑
i
=
1
n
−
l
+
1
(
r
[
i
+
1
]
[
j
]
−
r
[
i
]
[
j
−
1
]
)
≥
n
(
n
+
1
)
2
+
0
T=\frac{n(n+1)}{2}+\sum_{l=2}^{n}\sum_{i=1}^{n-l+1}(r[i+1][j]-r[i][j-1])\ge \frac{n(n+1)}{2}+0
T=2n(n+1)+l=2∑ni=1∑n−l+1(r[i+1][j]−r[i][j−1])≥2n(n+1)+0
所以时间复杂度
Θ
(
n
2
)
\Theta(n^2)
Θ(n2)
空间复杂度
O
(
n
2
)
O(n^2)
O(n2)