题目地址:
https://leetcode.com/problems/sort-an-array/
对一个数组进行排序。
法1:快速排序。
import java.util.ArrayList;
import java.util.List;
public class Solution {
public List<Integer> sortArray(int[] nums) {
List<Integer> res = new ArrayList<>();
quickSort(nums, 0, nums.length - 1);
for (int num : nums) {
res.add(num);
}
return res;
}
private void quickSort(int[] nums, int l, int r) {
if (l >= r) {
return;
}
int mid = partition(nums, l, r);
quickSort(nums, l, mid - 1);
quickSort(nums, mid + 1, r);
}
private int partition(int[] nums, int l, int r) {
swap(nums, l, l + ((r - l) >> 1));
int pivot = nums[l];
while (l < r) {
while (l < r && pivot <= nums[r]) {
r--;
}
nums[l] = nums[r];
while (l < r && nums[l] <= pivot) {
l++;
}
nums[r] = nums[l];
}
nums[l] = pivot;
return l;
}
private void swap(int[] nums, int i, int j) {
int tmp = nums[i];
nums[i] = nums[j];
nums[j] = tmp;
}
}
平均时间复杂度 O ( n log n ) O(n\log n) O(nlogn),平均空间复杂度 O ( log n ) O(\log n) O(logn)。快速排序有很多种不同的写法。这里的代码是比较经典的那种快速排序,设置pivot之后每次移动单边的指针,找到不满足条件的数后就覆盖掉另外一边,直到两根指针相遇。优点是可以返回出partition的位置,有助于解决第 k k k大元素这种问题。
算法正确性证明:
先证明partition方法的正确性。只需证明partition方法结束后,
1、while循环一定会结束;
2、数组的
[
l
,
r
]
[l,r]
[l,r]这个区间里,处于
l
l
l位置的那个
p
i
v
o
t
pivot
pivot,左边的数都小于等于它,
p
i
v
o
t
pivot
pivot右边的数都大于等于它;
3、这个区间里的数字并未发生改变(也就是原先是哪些数字,完了之后还是那些数字);
首先证明1成立。只需要证明 r − l r-l r−l在每次循环结束后一定会严格下降即可。如果执行循环第一句的时候 p i v o t ≤ n u m s [ r ] pivot\le nums[r] pivot≤nums[r],则 l − r l-r l−r立即增加了 1 1 1,此后 l − r l-r l−r最多不降,该次循环结束后, l − r l-r l−r至少会减少 1 1 1,结论成立;否则, p i v o t > n u m s [ r ] pivot> nums[r] pivot>nums[r],那么 n u m s [ r ] nums[r] nums[r]会把 n u m s [ l ] nums[l] nums[l]覆盖掉,之后一定有 l l l至少增加 1 1 1,所以仍然有 l − r l-r l−r至少会减少 1 1 1,结论也成立。所以循环一定会结束。
再证明2成立。反证法。如果循环结束后, p i v o t pivot pivot左边有个数 n 1 n_1 n1大于 p i v o t pivot pivot,那么由于循环一定会结束(这个才能保证左右指针会在pivot所在位置相遇),所以结束前左指针 l l l一定经过过该数,当 l l l停在该数的位置时,下面的一次循环里, r r r指针一定停在一个小于 p i v o t pivot pivot的数 n 2 n_2 n2上,并且 n 2 n_2 n2一定在 n 1 n_1 n1右边,此时 n 2 n_2 n2会把 n 1 n_1 n1覆盖掉,与原假设矛盾。
接下来证明3成立。先证明每次循环后,如果 l < r l<r l<r仍然成立,那么必然有这样的情形:数组中少了一个 p i v o t pivot pivot数(意思是如果原数组有 k k k个数等于 p i v o t pivot pivot,现在会少掉一个),并且多一个 n u m s [ l ] nums[l] nums[l],且此时有 n u m s [ l ] = n u m s [ r ] nums[l]=nums[r] nums[l]=nums[r]。用数学归纳法,这一点显然可以证明。接下来考虑两个指针相遇的情形。如果相遇之前,是 r r r指针主动撞上 l l l指针,那么最后一次循环两次赋值,实际上都是在自己给自己赋值。循环结束后,多出来的那个等于 n u m s [ l ] nums[l] nums[l]的值被 p i v o t pivot pivot覆盖,数组数字还原为当初的情形,结论成立。如果相遇之前,是 l l l指针主动撞上 r r r指针,那么最后一次循环两次赋值中的后一次赋值,实际上也是在自己给自己赋值。循环结束后,多出来的那个等于 n u m s [ l ] nums[l] nums[l]的值被 p i v o t pivot pivot覆盖,数组数字还原为当初的情形,结论也成立。
综上所述,partition方法正确,且返回了 p i v o t pivot pivot所在的坐标。接下来quickSort的正确性可以由数学归纳法轻松得到,这里就省略了。
算法复杂度证明:
时间复杂度:假设每次选择的
p
i
v
o
t
pivot
pivot最后移动到的位置满足均匀概率分布,则有递推方程:
T
(
n
)
=
n
+
1
n
[
(
T
(
0
)
+
T
(
n
−
1
)
)
+
(
T
(
1
)
+
T
(
n
−
2
)
)
+
.
.
.
+
(
T
(
n
−
1
)
+
T
(
0
)
)
]
T(n)=n+\frac{1}{n}[(T(0)+T(n-1))+(T(1)+T(n-2))+...+(T(n-1)+T(0))]
T(n)=n+n1[(T(0)+T(n−1))+(T(1)+T(n−2))+...+(T(n−1)+T(0))]整理得:
n
T
(
n
)
=
n
2
+
2
∑
i
=
0
n
−
1
T
(
i
)
nT(n)=n^2+2\sum_{i=0}^{n-1}T(i)
nT(n)=n2+2i=0∑n−1T(i)换变量得:
(
n
−
1
)
T
(
n
−
1
)
=
(
n
−
1
)
2
+
2
∑
i
=
0
n
−
2
T
(
i
)
(n-1)T(n-1)=(n-1)^2+2\sum_{i=0}^{n-2}T(i)
(n−1)T(n−1)=(n−1)2+2i=0∑n−2T(i)两式相减得:
n
T
(
n
)
−
(
n
−
1
)
T
(
n
−
1
)
=
2
n
−
1
+
2
T
(
n
−
1
)
n
T
(
n
)
−
(
n
+
1
)
T
(
n
−
1
)
=
2
n
−
1
T
(
n
)
n
+
1
−
T
(
n
−
1
)
n
=
3
n
+
1
−
1
n
nT(n)-(n-1)T(n-1)=2n-1+2T(n-1)\\nT(n)-(n+1)T(n-1)=2n-1\\\frac{T(n)}{n+1}-\frac{T(n-1)}{n}=\frac{3}{n+1}-\frac{1}{n}
nT(n)−(n−1)T(n−1)=2n−1+2T(n−1)nT(n)−(n+1)T(n−1)=2n−1n+1T(n)−nT(n−1)=n+13−n1接下来只需要对
n
n
n等于
1
,
2
,
.
.
.
,
n
1,2,...,n
1,2,...,n的情况累加起来即可,最后得到
T
(
n
)
=
O
(
n
log
n
)
T(n)=O(n\log n)
T(n)=O(nlogn)
空间复杂度:也假设每次选择的
p
i
v
o
t
pivot
pivot最后移动到的位置满足均匀概率分布。首先由数学期望的性质,必然有
T
(
m
)
T(m)
T(m)是单调增的。接下来有递推公式:
T
(
n
)
=
2
n
(
T
(
n
−
1
)
+
T
(
n
−
2
)
+
.
.
.
+
T
(
n
2
)
)
T(n)=\frac{2}{n}(T(n-1)+T(n-2)+...+T(\frac{n}{2}))
T(n)=n2(T(n−1)+T(n−2)+...+T(2n))
所以有
n
T
(
n
)
−
(
n
−
1
)
T
(
n
−
1
)
=
2
T
(
n
−
1
)
T
(
n
)
n
+
1
−
T
(
n
−
1
)
n
=
1
n
−
1
−
1
n
nT(n)-(n-1)T(n-1)=2T(n-1)\\\frac{T(n)}{n+1}-\frac{T(n-1)}{n}=\frac{1}{n-1}-\frac{1}{n}
nT(n)−(n−1)T(n−1)=2T(n−1)n+1T(n)−nT(n−1)=n−11−n1所以
T
(
n
)
=
O
(
1
)
+
T
(
n
2
)
T(n)=O(1)+T(\frac{n}{2})
T(n)=O(1)+T(2n)一路递推下去得
T
(
n
)
=
O
(
log
n
)
T(n)=O(\log n)
T(n)=O(logn)所以平均空间复杂度是
O
(
log
n
)
O(\log n)
O(logn)。
法2:归并排序,递归版本。
import java.util.ArrayList;
import java.util.List;
public class Solution {
public List<Integer> sortArray(int[] nums) {
List<Integer> res = new ArrayList<>();
mergeSort(nums, 0, nums.length - 1, new int[nums.length]);
for (int num : nums) {
res.add(num);
}
return res;
}
private void mergeSort(int[] nums, int l, int r, int[] tmp) {
if (l >= r) {
return;
}
int mid = l + ((r - l) >> 1);
mergeSort(nums, l, mid, tmp);
mergeSort(nums, mid + 1, r, tmp);
merge(nums, l, mid, r, tmp);
}
private void merge(int[] nums, int l, int mid, int r, int[] tmp) {
int i = l, j = mid + 1, index = 0;
while (i <= mid && j <= r) {
if (nums[i] <= nums[j]) {
tmp[index++] = nums[i++];
} else {
tmp[index++] = nums[j++];
}
}
while (i <= mid) {
tmp[index++] = nums[i++];
}
while (j <= r) {
tmp[index++] = nums[j++];
}
index = 0;
for (int k = l; k <= r; k++) {
nums[k] = tmp[index++];
}
}
}
时间复杂度
O
(
n
log
n
)
O(n\log n)
O(nlogn),空间
O
(
n
)
O(n)
O(n),其中堆空间
O
(
n
)
O(n)
O(n),栈空间
O
(
log
n
)
O(\log n)
O(logn)。
算法正确性证明和复杂度证明都很简单,这里省略。
法3:归并排序,非递归版本。
import java.util.ArrayList;
import java.util.List;
public class Solution {
public List<Integer> sortArray(int[] nums) {
mergeSort(nums);
List<Integer> res = new ArrayList<>();
for (int num : nums) {
res.add(num);
}
return res;
}
private void mergeSort(int[] nums) {
int[] tmp = new int[nums.length];
// i模拟步长,两倍速度增长。当步长大于等于区间长度了就停下来
for (int i = 1; i < nums.length; i *= 2) {
// j表示归并的第一个区间首元素下标
for (int j = 0; j + i < nums.length; j += i * 2) {
int index = 0;
// l表示归并的两个区间中第一个区间首元素下标,r表示第二个区间首下标
int l = j, r = j + i;
// 这里要注意r不能溢出去,也就是说归并时第二个区间的长度有可能比步长小
while (l < j + i && r < j + 2 * i && r < nums.length) {
if (nums[l] <= nums[r]) {
tmp[index++] = nums[l++];
} else {
tmp[index++] = nums[r++];
}
}
while (l < j + i) {
tmp[index++] = nums[l++];
}
while (r < j + 2 * i && r < nums.length) {
tmp[index++] = nums[r++];
}
// 归并完两个区间后,要赋值回原数组
index = 0;
for (int k = j; k < r; k++) {
nums[k] = tmp[index++];
}
}
}
}
}
时间复杂度
O
(
n
log
n
)
O(n\log n)
O(nlogn),空间
O
(
n
)
O(n)
O(n),没有额外栈空间消耗。
算法正确性和复杂度证明都很显然。非递归归并排序完全就是模仿归并的过程,只不过步长是手动模拟的而已,排序过程其实十分显然。
法4:堆排序。先堆化整个数组,形成一个最大堆,然后将堆顶和数组倒数第一个数字交换,接着下滤堆顶元素,再将堆顶和数组倒数第二个数字交换,接着再下滤堆顶元素,这样一直操作下去即可。
import java.util.ArrayList;
import java.util.List;
public class Solution {
public List<Integer> sortArray(int[] nums) {
heapSort(nums);
List<Integer> res = new ArrayList<>();
for (int num : nums) {
res.add(num);
}
return res;
}
private void heapSort(int[] nums) {
// 先对数组堆化
heapify(nums, nums.length);
// 然后将堆顶与堆中最后一个元素交换,接着对堆顶下滤,然后缩小堆的规模
for (int i = nums.length - 1; i > 0; i--) {
swap(nums, 0, i);
percolateDown(nums, i, 0);
}
}
// 这里的n指的是堆的size。size为n的二叉堆中,
// 最后一个有孩子的节点下标是(n-2)/2 = (n>>1)-1,从这个下标开始做下滤操作
private void heapify(int[] nums, int n) {
for (int i = (n >> 1) - 1; i >= 0; i--) {
percolateDown(nums, n, i);
}
}
// 这个函数的作用是,在size为n的堆中,对下标为i的元素下滤
private void percolateDown(int[] nums, int n, int i) {
while ((i << 1) + 1 < n) {
int child = (i << 1) + 1;
if ((i << 1) + 2 < n && nums[(i << 1) + 2] > nums[child]) {
child++;
}
if (nums[i] >= nums[child]) {
break;
}
swap(nums, i, child);
i = child;
}
}
private void swap(int[] nums, int i, int j) {
int tmp = nums[i];
nums[i] = nums[j];
nums[j] = tmp;
}
}
时间复杂度 O ( n log n ) O(n\log n) O(nlogn),空间 O ( 1 ) O(1) O(1)。