前言
在快速排序中,一个比较核心的操作是partition
,就是选中一个元素作为枢轴pivot
,然后执行partition
操作将数组元素分为三个部分<= pivot
、=pivot
和>=pivot
,partition
返回pivot
的最终下标,然后再递归pivot
左右两部分数组。
问题:传统的partition
每次只能确定一个元素的位置。这样存在的问题是,如果有很多重复的元素,显然是做了很多重复的比较工作。
解决:荷兰国旗问题就是用于解决这个问题的,它的partition
操作将数组分成三个部分<pivot
、=pivot
和>pivot
。每次partition
返回=pivot
部分的边界,然后递归<pivot
和>pivot
部分即可。这样一次就可以确定多个元素的位置。
1、首先描述一下问题
荷兰国旗是由红白蓝三种颜色拼接组成:
现在有若干个红、白、蓝三种颜色的球随机排列成一条直线。现在我们的任务是把这些球按照红、白、蓝排序。
2、具体操作
对应到我们的数组partition
就是将红白蓝类比为< pivot
、=pivot
和> pivot
。
假设给定一个数组,arr = [3,3,3,2,6,8,3,9,2,4,3]
初始状态下,定义变量L
表示< pivot
区域的最右位置,变量R
表示> pivot
区域的第一个位置,cur
遍历数组,直到cur
遇到R
。(如下图)
选择数组的最后一个元素作为枢轴pivot
,因此最后一个位置可以当它不存在,因为初始时R = arr.length - 1
。
cur
遍历数组,遇到的元素分为三种情况
arr[cur] == pivot
,此时只需要跳过,L
和R
都不需要更新。arr[cur] < pivot
,此时情况如下图
此时需要将L
的边界向右扩充1
,具体操作是将L + 1
位置的元素和cur
位置的元素交换,然后L++
,cur++
arr[cur] > pivot
,此时情况如下图
此时需要将R
的边界向左扩充1
,具体操作是将R - 1
位置的元素和cur
位置的元素交换,然后R--
,cur
不变(因为交换后cur
位置的元素原本是cur
之后的元素,是没有遍历过的)
- 遍历结束后,如下图
cur == R
,遍历结束,处理一下枢轴pivot
,很简单就是将R
位置的元素和pivot
位置的元素交换。至此,可以知道= pivot
区域的左右边界,然后只需要递归< pivot
和> pivot
部分即可完成快速排序。
代码:
package class_01;
public class NetherlandsFlag {
public static int[] partition(int[] arr, int l, int r, int p) {
int less = l - 1;
int more = r + 1;
while (l < more) {
if (arr[l] < p) {
swap(arr, ++less, l++);
} else if (arr[l] > p) {
swap(arr, --more, l);
} else {
l++;
}
}
return new int[] { less + 1, more - 1 };
}
// for test
public static void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
// for test
public static int[] generateArray() {
int[] arr = new int[10];
for (int i = 0; i < arr.length; i++) {
arr[i] = (int) (Math.random() * 3);
}
return arr;
}
// for test
public static void printArray(int[] arr) {
if (arr == null) {
return;
}
for (int i = 0; i < arr.length; i++) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
public static void main(String[] args) {
int[] test = generateArray();
printArray(test);
int[] res = partition(test, 0, test.length - 1, 1);
printArray(test);
System.out.println(res[0]);
System.out.println(res[1]);
}
}