文章目录
一、二分搜索的简介
二分搜索用于在一个有序数组中查找一个符合某些条件的值,如果找得到,则返回该值的下标或迭代器。
二分搜索的核心是不断根据中位数跟特定值比较的结果调整搜索区间,每次调整后的区间总是调整前区间长度的一半。
下文以升序数组为例进行阐述,降序数组同理。
二、二分搜索的基本问题
在升序数组 A 中,查找一个满足特定条件的数组成员 x,若查找成功则返回 x 的下标或迭代器,查找失败则返回 -1 或迭代器末端。一般有 3 种特定条件,描述如下。
- 数组 A 中的每个数只出现一次,x 等于特定数字 target。
- 数组 A 中的每个数出现不止一次,x 是第一个等于特定数字 target 的数。
- 数组 A 中的每个数出现不止一次,x 是第一个大于等于特定数字 target 的数。
- 数组 A 中的每个数出现不止一次,x 是第一个大于特定数字 target 的数。
第 2、3、4 三个问题有两个关键点,一是当中位数等于特定数字 target 时,区间要如何调整,是向右缩小还是向左缩小,二是在搜索的最后一步需要再次将中位数和特定数字 target 进行比较以确定最终返回值。在下文的具体程序演示中会以注释标记这两个关键点。
三、二分搜索的程序模板
(一) 迭代版
int
BinarySearchIteratively(const vector<int>& array, int left_index, int right_index,
int target) {
while (left_index < right_index) {
int mid_index = left_index + (right_index - left_index) / 2;
if (array[mid_index] < target) {
left_index = mid_index + 1;
} else if (target < array[mid_index]) {
right_index = mid_index;
} else {
/* 返回 mid_index 或进一步缩小搜索范围,视情况而定 */
}
}
/* 返回 -1 或返回 left_index,视情况而定 */
}
(二) 递归版
int
BinarySearchRecursively(const vector<int>& array, int left_index, int right_index,
int target) {
if (left_index == right_index) {
/* 返回 -1 或返回 left_index,视情况而定 */
}
int mid_index = left_index + (right_index - left_index) / 2;
if (array[mid_index] < target) {
return BinarySearchRecursively(array, mid_index + 1, right_index, target);
} else if (target < array[mid_index]) {
return BinarySearchRecursively(array, left_index, mid_index, target);
} else {
/* 返回 mid_index 或进一步缩小搜索范围,视情况而定 */
}
}
四、二分搜索的细节
(一) 求两整数平均值
二分搜索的每一轮比较之前总需要求出当前区间的中位数。中位数的下标就是区间两个数字的平均数向下取整。用最普通的两整数相加再除以 2 的方法存在溢出的可能,以下四种写法能减少溢出的风险。
int mid_1 = left + (right - left) / 2;
int mid_2 = (left >> 1) + (right >> 1) + (((left & 1) + (right & 1)) >> 1);
int mid_3 = (left & right) + ((left ^ right) >> 1);
int mid_4 = left + ((right - left) >> 1);
>>1
等同于 /2
,但要注意加号的运算优先级高于移位符号,需要加上括号保证运算顺序合理。
(二) “左闭右开”的区间表示
笔者在程序模板中采用的区间表示法就是“左闭右开”,即[left_index, right_index)。在 C++ 的标准库中,凡是表示区间的两个形参都遵循“左闭右开”原则,示例如下。
int array[10] = {10, 1, 23, 14, 35, 26, 47 ,18 ,79};
sort(array, array + 10);
在 sort
函数的形参列表中,array
表示 10 的地址,array + 10
表示 79 的下一个地址。因此,为了和 C++ 标准库的区间表示法统一,建议读者也采用“左闭右开”原则。
若迭代或递归的原始区间为 [a, b),则在最后一轮循环或最后一层递归搜索区间为 [b, b)。因此,二分搜索进行到最后会有 left_index
等于 right_index
。
(三) 最后一步防止溢出
因为二分搜索进行到最后会有 left_index
等于 right_index
,因此 left_index
可能等于数组长度即 array.size()
,故如果此时需要访问 array[left_index]
,需要添加 left_index
的范围判断。
五、二分搜索的详尽程序
(一) 升序无重复数,查找等于 target 的 A[i]
以下程序已通过 LeetCode 704. Binary Search 所有样例测试。
class Solution {
private:
// 迭代版,搜索区间为 [left_index, right_index)
int SearchIteratively(const vector<int> nums, int left_index, int right_index,
int target) {
while (left_index < right_index) {
int mid_index = left_index + (right_index - left_index) / 2;
if (target < nums[mid_index]) {
right_index = mid_index;
} else if (nums[mid_index] < target) {
left_index = mid_index + 1;
} else {
return mid_index;
}
}
return -1;
}
// 递归版,搜索区间为 [left_index, right_index)
int SearchRecursively(const vector<int>& nums, int left_index, int right_index,
int target) {
if (left_index == right_index) {
return -1;
}
int mid_index = left_index + (right_index - left_index) / 2;
if (nums[mid_index] < target) {
return SearchRecursively(nums, mid_index + 1, right_index, target);
} else if (target < nums[mid_index]) {
return SearchRecursively(nums, left_index, mid_index, target);
} else {
return mid_index;
}
}
public:
int search(vector<int>& nums, int target) {
// return SearchIteratively(nums, 0, nums.size(), target);
return SearchRecursively(nums, 0, nums.size(), target);
}
};
(二) 升序有重复数,查找第一个等于 target 的 A[i]
以下程序可直接复制粘贴后编译运行。
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
/**
* 升序有重复数,查找第一个等于 target 的 A[i],迭代版
*/
int
SearchFirstEqualIteratively(const vector<int>& array, int left_index, int right_index,
int target) {
while (left_index < right_index) {
int mid_index = left_index + (right_index - left_index) / 2;
if (target < array[mid_index]) {
right_index = mid_index;
} else if (array[mid_index] < target) {
left_index = mid_index + 1;
} else { // 关键点一
right_index = mid_index;
}
}
// 关键点二
if (left_index < array.size() && array[left_index] == target) {
return left_index;
} else {
return -1;
}
}
/**
* 升序有重复数,查找第一个等于 target 的 A[i],递归版
*/
int
SearchFirstEqualRecursively(const vector<int>& array, int left_index, int right_index,
int target) {
if (left_index == right_index) {
// 关键点二
if (left_index < array.size() && array[left_index] == target) {
return left_index;
} else {
return -1;
}
}
int mid_index = left_index + (right_index - left_index) / 2;
if (array[mid_index] < target) {
return SearchFirstEqualRecursively(array, mid_index + 1, right_index, target);
} else if (target < array[mid_index]) {
return SearchFirstEqualRecursively(array, left_index, mid_index, target);
} else { // 关键点一
return SearchFirstEqualRecursively(array, left_index, mid_index, target);
}
}
int
main(void) {
vector<int> arr = { 1, 2, 2, 3, 6, 6, 6, 8, 12, 14 };
int index = 0;
int target = 0;
cout << "array:";
for (auto i : arr) {
cout << ' ' << i;
}
cout << endl;
cout << "---------------\nTest Iteration:" << endl;
index = SearchFirstEqualIteratively(arr, 0, arr.size(), target = 10);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualIteratively(arr, 0, arr.size(), target = 1);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualIteratively(arr, 0, arr.size(), target = 2);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualIteratively(arr, 0, arr.size(), target = 6);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualIteratively(arr, 0, arr.size(), target = 14);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualIteratively(arr, 0, arr.size(), target = 100);
cout << "target = " << target << ": index = " << index << endl;
cout << "---------------\nTest Recursion:" << endl;
index = SearchFirstEqualRecursively(arr, 0, arr.size(), target = 10);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualRecursively(arr, 0, arr.size(), target = 1);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualRecursively(arr, 0, arr.size(), target = 2);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualRecursively(arr, 0, arr.size(), target = 6);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualRecursively(arr, 0, arr.size(), target = 14);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstEqualRecursively(arr, 0, arr.size(), target = 100);
cout << "target = " << target << ": index = " << index << endl;
return 0;
}
(三) 升序有重复数,查找第一个大于等于 target 的 A[i]
以下程序可直接复制粘贴后编译运行。
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
/**
* 升序有重复数,查找第一个大于等于 target 的 A[i],迭代版
*/
int
SearchFirstGreaterOrEqualIteratively(const vector<int>& array, int left_index, int right_index,
int target) {
while (left_index < right_index) {
int mid_index = left_index + (right_index - left_index) / 2;
if (target < array[mid_index]) {
right_index = mid_index;
} else if (array[mid_index] < target) {
left_index = mid_index + 1;
} else { // 关键点一
right_index = mid_index;
}
}
// 关键点二
if (left_index < array.size() && array[left_index] >= target) {
return left_index;
} else {
return -1;
}
}
/**
* 升序有重复数,查找第一个大于等于 target 的 A[i],递归版
*/
int
SearchFirstGreaterOrEqualRecursively(const vector<int>& array, int left_index, int right_index,
int target) {
if (left_index == right_index) {
// 关键点二
if (left_index < array.size() && array[left_index] >= target) {
return left_index;
} else {
return -1;
}
}
int mid_index = left_index + (right_index - left_index) / 2;
if (array[mid_index] < target) {
return SearchFirstGreaterOrEqualRecursively(array, mid_index + 1, right_index, target);
} else if (target < array[mid_index]) {
return SearchFirstGreaterOrEqualRecursively(array, left_index, mid_index, target);
} else { // 关键点一
return SearchFirstGreaterOrEqualRecursively(array, left_index, mid_index, target);
}
}
int
main(void) {
vector<int> arr = { 1, 2, 2, 3, 6, 6, 6, 8, 12, 14 };
int index = 0;
int target = 0;
cout << "array:";
for (auto i : arr) {
cout << ' ' << i;
}
cout << endl;
cout << "---------------\nTest Iteration:" << endl;
index = SearchFirstGreaterOrEqualIteratively(arr, 0, arr.size(), target = 10);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualIteratively(arr, 0, arr.size(), target = 1);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualIteratively(arr, 0, arr.size(), target = 2);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualIteratively(arr, 0, arr.size(), target = 6);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualIteratively(arr, 0, arr.size(), target = 14);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualIteratively(arr, 0, arr.size(), target = 100);
cout << "target = " << target << ": index = " << index << endl;
cout << "---------------\nTest Recursion:" << endl;
index = SearchFirstGreaterOrEqualRecursively(arr, 0, arr.size(), target = 10);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualRecursively(arr, 0, arr.size(), target = 1);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualRecursively(arr, 0, arr.size(), target = 2);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualRecursively(arr, 0, arr.size(), target = 6);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualRecursively(arr, 0, arr.size(), target = 14);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterOrEqualRecursively(arr, 0, arr.size(), target = 100);
cout << "target = " << target << ": index = " << index << endl;
return 0;
}
(四) 升序有重复数,查找第一个大于 target 的 A[i]
以下程序可直接复制粘贴后编译运行。
#include <algorithm>
#include <iostream>
#include <vector>
using namespace std;
/**
* 升序有重复数,查找第一个大于 target 的 A[i],迭代版
*/
int
SearchFirstGreaterIteratively(const vector<int>& array, int left_index, int right_index,
int target) {
while (left_index < right_index) {
int mid_index = left_index + (right_index - left_index) / 2;
if (target < array[mid_index]) {
right_index = mid_index;
} else if (array[mid_index] < target) {
left_index = mid_index + 1;
} else { // 关键点一
left_index = mid_index + 1;
}
}
// 关键点二
if (left_index < array.size() && array[left_index] > target) {
return left_index;
} else {
return -1;
}
}
/**
* 升序有重复数,查找第一个大于 target 的 A[i],递归版
*/
int
SearchFirstGreaterRecursively(const vector<int>& array, int left_index, int right_index,
int target) {
if (left_index == right_index) {
// 关键点二
if (left_index < array.size() && array[left_index] > target) {
return left_index;
} else {
return -1;
}
}
int mid_index = left_index + (right_index - left_index) / 2;
if (array[mid_index] < target) {
return SearchFirstGreaterRecursively(array, mid_index + 1, right_index, target);
} else if (target < array[mid_index]) {
return SearchFirstGreaterRecursively(array, left_index, mid_index, target);
} else { // 关键点一
return SearchFirstGreaterRecursively(array, mid_index + 1, right_index, target);
}
}
int
main(void) {
vector<int> arr = { 1, 2, 2, 3, 6, 6, 6, 8, 12, 14 };
int index = 0;
int target = 0;
cout << "array:";
for (auto i : arr) {
cout << ' ' << i;
}
cout << endl;
cout << "---------------\nTest Iteration:" << endl;
index = SearchFirstGreaterIteratively(arr, 0, arr.size(), target = 10);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterIteratively(arr, 0, arr.size(), target = 1);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterIteratively(arr, 0, arr.size(), target = 2);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterIteratively(arr, 0, arr.size(), target = 6);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterIteratively(arr, 0, arr.size(), target = 14);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterIteratively(arr, 0, arr.size(), target = 100);
cout << "target = " << target << ": index = " << index << endl;
cout << "---------------\nTest Recursion:" << endl;
index = SearchFirstGreaterRecursively(arr, 0, arr.size(), target = 10);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterRecursively(arr, 0, arr.size(), target = 1);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterRecursively(arr, 0, arr.size(), target = 2);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterRecursively(arr, 0, arr.size(), target = 6);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterRecursively(arr, 0, arr.size(), target = 14);
cout << "target = " << target << ": index = " << index << endl;
index = SearchFirstGreaterRecursively(arr, 0, arr.size(), target = 100);
cout << "target = " << target << ": index = " << index << endl;
return 0;
}