Write an efficient algorithm that searches for a value in an m x n matrix. This matrix has the following properties:
- Integers in each row are sorted from left to right.
- The first integer of each row is greater than the last integer of the previous row.
For example,
Consider the following matrix:
[ [1, 3, 5, 7], [10, 11, 16, 20], [23, 30, 34, 50] ]
Given target = 3
, return true
.
想法是先在第一列进行Binary Search,找到元素可能所在的那一行。再在这一行进行BS,找到这个元素。找不到就是false。
复杂度是O(logm+logn)
public class Solution {
public boolean searchMatrix(int[][] matrix, int target) {
int m = matrix.length;
int n = matrix[0].length;
int[] firstColumn = new int[m];
for(int i = 0; i < m; i++){
firstColumn[i] = matrix[i][0];
}
int column = columnSearch(firstColumn, 0, m - 1, target);
if(column == -1)
return false;
int[] targetRow = matrix[column];
int row = rowSearch(targetRow, 0, n - 1, target);
if(row == -1)
return false;
return true;
}
//binary search to find the row that might contain the target
int columnSearch(int[] array, int low, int high, int target){
if(low > high)
return -1;
int mid = (low + high) / 2;
if(array[high] < target)
return high;
if(low == high){
if(array[mid] <= target)
return mid;
else
return -1;
}
else{
//if the array has more than one element, and array[mid]<=target<array[mid+1]
//then the target might be in the row matrix[mid]
if(array[mid] <= target && array[mid + 1] > target)
return mid;
}
if(array[mid] < target)
return columnSearch(array, mid + 1, high, target);
else
return columnSearch(array, low, mid - 1, target);
}
//binary search on an array to return the specific index of the target
//return -1 if not found
int rowSearch(int[] array, int low, int high, int target){
if(low > high)
return -1;
int mid = (high + low) / 2;
if(array[mid] == target)
return mid;
if(array[mid] < target){
return rowSearch(array, mid + 1, high, target);
}
else{
return rowSearch(array, low, mid - 1, target);
}
}
}
另外这个矩阵其实可以看成一个大的排好序的数组,于是只要用一次二分法就行了。
复杂度是O(log(mn))。其实和上面的算法一样。
public boolean searchMatrix(int[][] matrix, int target) {
int m = matrix.length;
int n = matrix[0].length;
int start = 0;
int end = m * n - 1;
while(start <= end){
int mid = (start + end) / 2;
int midX = mid / n;
int midY = mid % n;
if(matrix[midX][midY] == target)
return true;
else if(matrix[midX][midY] > target){
end = mid - 1;
continue;
}
else{
start = mid + 1;
continue;
}
}
return false;
}