优化一----原地堆排序
前一篇博客我们都需要开辟一个新的数组 来进行堆的存放,下面将讲述原地堆排序。
在前面讲到,堆是存放在一个数组中的,如果我们不想开辟新空间,在原来数组上依然可以实现堆排序,不过索引位置就要从0开始了。
新的计算公式如上图,新的公式可以通过上图归纳出来:
知道子节点索引为i, 求父节点索引: parent(i) = (i-1) / 2
知道父节点索引为i,求左右子节点的索引位置:left child (i) = 2*i +1;
left child (i) = 2*i +2;
在前一篇博客提到的PrintMaxHeap中增加方法:
public static void testMaxHeap3() {
PrintMaxHeap maxHeap = new PrintMaxHeap(100);
Integer[] arr = randomArray(12);
maxHeap.heapSort(arr);
System.out.println("排序后:");
for (Integer integer : arr) {
System.out.print(integer+" ");
}
}
/**
* 此方法只是 简单的 验证原地堆排序 功能
* @param arr
*/
public void heapSort(Integer[] arr) {
int n = arr.length;
//子节点索引为i, 求父节点索引: parent(i) = (i-1) / 2
//最后一个子节点i=n-1, 则父节点= (n-1 -1) /2
for(int i=(n-1 -1)/2; i>=0; i--) {
//从最后一个元素开始,找到最后一个父节点,进行Heapify堆化
shiftDown2(arr, n, i);//从当前父节点i开始,进行shiftDown2排序,总的数据有n个
}
//堆化完成后,进行堆排序, 让数组从小到大
for(int i=n - 1; i>0; i--) {
swap(arr, 0, i);//将最大元素交换到最后一个位置,相当于从堆中移除最大值
shiftDown2(arr, i, 0);//从当前父节点0开始,进行shiftDown2排序,总的数据有i个
}
}
在MaxHeap类中增加方法:
/**
* 原地堆排序
* @param arr 数据arr
* @param n 数据的个数
* @param i 当前要向下移动的父节点, 即当前索引位置
*/
public void shiftDown2(Item[] arr, int n, int i) {
while(i*2+1 < n) {//说明有左孩子
int j = i*2+1;//左孩子的脚标位置
if(j+1 < n && arr[j+1].compareTo(arr[j]) > 0) {
//如果有右孩子,并且,右孩子比左孩子的值大
j = j+1;
}
if(arr[i].compareTo(arr[j]) > 0) {
break;//如果父节点大于 最大的子节点,则已经是新的最大二叉堆
}
swap2(arr, i, j);//否则,则交换父节点和最大子节点的位置,继续向下比较
i = j;//更新这个值的最新位置
}
}
然后在PrintMaxHeap 的main方法中调用testMaxHeap3() 进行验证, 结果如下,成功进行了排序:
[89, 31, 44, 58, 27, 27, 90, 13, 97, 9, 40, 54]
排序后:
9 13 27 27 31 40 44 54 58 89 90 97
优化二----索引堆
2、索引堆
在前面讲到的堆结构中,在堆排序的过程中,需要不断交换元素,如果对象比较复杂,则每次交换也是有一定代价的。
而如果交换的是int型这样既基本数据类型,那就要快很多,因此 我们可以引入一个int数组,用来记录数据真正的位置,如下如所示。
在上面两个图中,原本对于数据的排序,变成了对数据索引的排序, 62在原堆中排第一位,而在索引堆中,由于62是最后插入的一个元素,在数据数组data中,依然是最后一位,但是在index索引数组中,排第一位, 而index[10] = 62; 在堆结构中,10这个索引值也替换了62这个原本的数据。
2、代码实现
明白了基本原理,我们需要在原来堆排序的代码上进行一定修改。
2.1 增加一个数组记录数据的实际位置
/**
* 索引堆
*
*/
public class IndexMaxHeap<Item extends Comparable> {
protected Item[] data;//保存数据
protected int[] indexes;//保存索引
protected int count;//存储数据的数量
private int capcity;//总的存储空间
public IndexMaxHeap(int capcity){
this.capcity = capcity;
indexes = new int[capcity + 1];
//java不允许直接创建泛型数组,否则编译报错,所以通过强转来通过编译
//第0个位置不存储,从第一个位置开始存储,所以空一个位置
data= (Item[]) new Comparable[capcity+1];
count = 0;
}
...
2.2 插入数据和向上移动shiftUp()
/**
* @param i 插入数据的位置(如果不指定位置,默认插入空位置,也可以,不过代码需要修改更多)
* @param item 插入的数据
*
* 注:此方法是指最简单的验证索引堆, 并不是一个完美的方法。此处假设数据是依次添加到数组中,并没有考虑
* 动态删除数据以后,再添加的情况,如果要考虑动态删除再添加,则还需要维护一个数组,这个数组用来记录data数组中
* 还有哪些空位置可以加入数据,然后把数据插入对应的空位置。
* 因为删除操作,肯定会使data数组中的数据不是连贯的,中间肯定有一些位置是没有填数据的。
*/
public void insert(int i, Item item) {
if(count == capcity) {
return;//已经填满了,无法再填入数据了
}
i = i+1;
if(i < 1 && i > capcity) {
return; //脚标越界
}
data[i] = item;//插入指定位置
indexes[count+1] = i;
count ++;
ShiftUp(count);//对索引进行shiftUp操作
}
/**
* k描述的是索引的位置,要拿到对应的数据,就应该修改为data[indexes[k]]
*/
private void ShiftUp(int k) {
while(k > 1 && data[ indexes[k/2] ].compareTo(data[ indexes[k] ]) < 0) {//比较真正数据的大小
swap( k, k/2);//交换索引值的位置
k /= 2;
}
}
/**
* 交换索引数组中两个索引值的位置
*/
private void swap(int i, int j) {
int temp = indexes[i];
indexes[i] = indexes[j];
indexes[j] = temp;
}
2.3 取出最大数据和向下移动shiftDown()
/**
* 取出最大元素
*/
public Item extractMax() {
if(count <= 0) {
return null;
}
Item item = data[indexes[1]];//堆中第一个元素就是最大值
data[indexes[1]] = null;//将这个位置的元素取出后,重置为null
swap(1, count);//将最后一个元素的索引放到第一个位置
count --;//总的数量减一
shiftDown(1);//重新调整整个堆,成为一个新的最大二叉堆
return item;
}
/**
* 将索引向下移动,直到找到合适的位置
*/
private void shiftDown(int i) {
while(i*2 <= count) {//说明有左孩子
int j = i*2;//左孩子的脚标位置
if(j+1 <= count && data[ indexes[j+1] ].compareTo(data[ indexes[j] ]) > 0) {
//如果有右孩子,并且,右孩子比左孩子的值大
j = j+1;
}
if(data[indexes[i]].compareTo(data[indexes[j]]) > 0) {
break;//如果父节点大于 最大的子节点,则已经是新的最大二叉堆
}
swap(i, j);//否则,则交换父节点和最大子节点的位置,继续向下比较
i = j;//更新这个索引的最新位置
}
}
3、测试数据
在PrintIndexMaxHeap类中去测试数据,结果如下:
The max heap size is: 12
Data in the max heap(注意这不是排序,只是数据在堆中的位置):
78 25 39 21 34 48 51 10 78 30 44 39
78
/ \
78 51
/ \ / \
34 44 39 48
/ \ / \ / \ / \
10 21 25 30 39
max=78
The max heap size is: 11
Data in the max heap(注意这不是排序,只是数据在堆中的位置):
null 25 39 21 34 48 51 10 78 30 44
78
/ \
44 51
/ \ / \
34 39 39 48
/ \ / \ / \ / \
10 21 25 30
排序结果: 78 51 48 44 39 39 34 30 25 21 10
数据验证成功。
完整代码:
IndexMaxHeap类:
/**
* 索引堆
*
*/
public class IndexMaxHeap<Item extends Comparable> {
protected Item[] data;//保存数据
protected int[] indexes;//保存索引
protected int count;//存储数据的数量
private int capcity;//总的存储空间
public IndexMaxHeap(int capcity){
this.capcity = capcity;
indexes = new int[capcity + 1];
//java不允许直接创建泛型数组,否则编译报错,所以通过强转来通过编译
//第0个位置不存储,从第一个位置开始存储,所以空一个位置
data= (Item[]) new Comparable[capcity+1];
count = 0;
}
public IndexMaxHeap(Item[] arr) {
this.capcity = arr.length;
indexes = new int[capcity + 1];
data = (Item[]) new Comparable[capcity + 1];
System.arraycopy(arr, 0, data, 1, arr.length);
count = arr.length;
for(int i = count/2; i >= 1; i --) {
shiftDown(i);
}
}
public int size() {
return count;
}
public boolean isEmpty() {
return count == 0;
}
public void destory() {
data = null;
indexes = null;
count = 0;
}
/**
* @param i 插入数据的位置(如果不指定位置,默认插入空位置,也可以,不过代码需要修改更多)
* @param item 插入的数据
*
* 注:此方法是指最简单的验证索引堆, 并不是一个完美的方法。此处假设数据是依次添加到数组中,并没有考虑
* 动态删除数据以后,再添加的情况,如果要考虑动态删除再添加,则还需要维护一个数组,这个数组用来记录data数组中
* 还有哪些空位置可以加入数据,然后把数据插入对应的空位置。
* 因为删除操作,肯定会使data数组中的数据不是连贯的,中间肯定有一些位置是没有填数据的。
*/
public void insert(int i, Item item) {
if(count == capcity) {
return;//已经填满了,无法再填入数据了
}
i = i+1;
if(i < 1 && i > capcity) {
return; //脚标越界
}
data[i] = item;//插入指定位置
indexes[count+1] = i;
count ++;
ShiftUp(count);//对索引进行shiftUp操作
}
/**
* 取出最大元素
*/
public Item extractMax() {
if(count <= 0) {
return null;
}
Item item = data[indexes[1]];//堆中第一个元素就是最大值
data[indexes[1]] = null;//将这个位置的元素取出后,重置为null
swap(1, count);//将最后一个元素的索引放到第一个位置
count --;//总的数量减一
shiftDown(1);//重新调整整个堆,成为一个新的最大二叉堆
return item;
}
/**
* 原地堆排序 用到
*/
public void shiftDown2(Item[] arr, int n, int k) {
}
/**
* 将索引向下移动,直到找到合适的位置
*/
private void shiftDown(int i) {
while(i*2 <= count) {//说明有左孩子
int j = i*2;//左孩子的脚标位置
if(j+1 <= count && data[ indexes[j+1] ].compareTo(data[ indexes[j] ]) > 0) {
//如果有右孩子,并且,右孩子比左孩子的值大
j = j+1;
}
if(data[indexes[i]].compareTo(data[indexes[j]]) > 0) {
break;//如果父节点大于 最大的子节点,则已经是新的最大二叉堆
}
swap(i, j);//否则,则交换父节点和最大子节点的位置,继续向下比较
i = j;//更新这个索引的最新位置
}
}
/**
* k描述的是索引的位置,要拿到对应的数据,就应该修改为data[indexes[k]]
*/
private void ShiftUp(int k) {
while(k > 1 && data[ indexes[k/2] ].compareTo(data[ indexes[k] ]) < 0) {//比较真正数据的大小
swap( k, k/2);//交换索引值的位置
k /= 2;
}
}
/**
* 交换索引数组中两个索引值的位置
*/
private void swap(int i, int j) {
int temp = indexes[i];
indexes[i] = indexes[j];
indexes[j] = temp;
}
}
PrintIndexMaxHeap类:
import java.util.Arrays;
import java.util.Random;
public class PrintIndexMaxHeap extends IndexMaxHeap<Comparable<Integer>> {
public PrintIndexMaxHeap(int capacity) {
super(capacity);
}
public PrintIndexMaxHeap(Comparable[] arr) {
super(arr);
}
// 测试 PrintableMaxHeap
public static void main(String[] args) {
testMaxHeap1() ;
//testMaxHeap2() ;
}
public static void testMaxHeap1() {
PrintIndexMaxHeap maxHeap = new PrintIndexMaxHeap(100);
int N = 12; // 堆中元素个数
int M = 90; // 堆中元素取值范围[0, M)
for( int i = 0 ; i < N ; i ++ )
maxHeap.insert(i, new Integer((int)(Math.random() * M)) );
maxHeap.treePrint();
Comparable i = maxHeap.extractMax();
System.out.println("max="+i);
maxHeap.treePrint();
maxHeap.heapSort();
}
public static void testMaxHeap2() {
PrintIndexMaxHeap maxHeap = new PrintIndexMaxHeap(randomArray(12));
maxHeap.treePrint();
Comparable i = maxHeap.extractMax();
System.out.println("max="+i);
maxHeap.treePrint();
}
// 以树状打印整个堆结构
public void treePrint(){
if( size() >= 100 ){
System.out.println("This print function can only work for less than 100 integer");
return;
}
System.out.println("The max heap size is: " + size());
System.out.println("Data in the max heap(注意这不是排序,只是数据在堆中的位置): ");
for( int i = 1 ; i <= size() ; i ++ ){
// 我们的print函数要求堆中的所有整数在[0, 100)的范围内
assert (Integer)data[indexes[i]] >= 0 && (Integer)data[indexes[i]] < 100;
System.out.print(data[i] + " ");
}
System.out.println();
System.out.println();
int n = size();
int maxLevel = 0;
int numberPerLevel = 1;
while( n > 0 ){
maxLevel += 1;
n -= numberPerLevel;
numberPerLevel *= 2;
}
int maxLevelNumber = (int)Math.pow(2, maxLevel-1);
int curTreeMaxLevelNumber = maxLevelNumber;
int index = 1;
for( int level = 0 ; level < maxLevel ; level ++ ){
String line1 = new String(new char[maxLevelNumber*3-1]).replace('\0', ' ');
int curLevelNumber = Math.min(count-(int)Math.pow(2,level)+1,(int)Math.pow(2,level));
boolean isLeft = true;
for( int indexCurLevel = 0 ; indexCurLevel < curLevelNumber ; index ++ , indexCurLevel ++ ){
line1 = putNumberInLine( (Integer)data[indexes[index]] , line1 , indexCurLevel , curTreeMaxLevelNumber*3-1 , isLeft );
isLeft = !isLeft;
}
System.out.println(line1);
if( level == maxLevel - 1 )
break;
String line2 = new String(new char[maxLevelNumber*3-1]).replace('\0', ' ');
for( int indexCurLevel = 0 ; indexCurLevel < curLevelNumber ; indexCurLevel ++ )
line2 = putBranchInLine( line2 , indexCurLevel , curTreeMaxLevelNumber*3-1 );
System.out.println(line2);
curTreeMaxLevelNumber /= 2;
}
}
private String putNumberInLine( Integer num, String line, int indexCurLevel, int curTreeWidth, boolean isLeft){
int subTreeWidth = (curTreeWidth - 1) / 2;
int offset = indexCurLevel * (curTreeWidth+1) + subTreeWidth;
assert offset + 1 < line.length();
if( num >= 10 )
line = line.substring(0, offset+0) + num.toString()
+ line.substring(offset+2);
else{
if( isLeft)
line = line.substring(0, offset+0) + num.toString()
+ line.substring(offset+1);
else
line = line.substring(0, offset+1) + num.toString()
+ line.substring(offset+2);
}
return line;
}
private String putBranchInLine( String line, int indexCurLevel, int curTreeWidth){
int subTreeWidth = (curTreeWidth - 1) / 2;
int subSubTreeWidth = (subTreeWidth - 1) / 2;
int offsetLeft = indexCurLevel * (curTreeWidth+1) + subSubTreeWidth;
assert offsetLeft + 1 < line.length();
int offsetRight = indexCurLevel * (curTreeWidth+1) + subTreeWidth + 1 + subSubTreeWidth;
assert offsetRight < line.length();
line = line.substring(0, offsetLeft+1) + "/" + line.substring(offsetLeft+2);
line = line.substring(0, offsetRight) + "\\" + line.substring(offsetRight+1);
return line;
}
public static Integer[] randomArray(int length) {
Integer[] arr = new Integer[length];
Random random = new Random();
for(int i=0; i<length; i++) {
arr[i] = random.nextInt(100);
}
//自动生成随机数组,先进行一次原始数据打印
System.out.println(Arrays.toString(arr));
return arr;
}
public void heapSort() {
System.out.print("排序结果: ");
while(size() > 0) {
System.out.print(extractMax());
System.out.print( " ");
}
System.out.println();
}
}