linkedIN的开源bobo项目中ListMerger类的分析:
实现的是一个多路归并的算法,构造函数的参数Iterator<T>[] iterators,是一个iterator的array,每个iterator就看作是一个list, 每一个list 都是排好序的,这里要找得分最高的结果,那么就是一个降序的list。
public class ListMerger
{
public static class MergedIterator<T> implements Iterator<T>
{
private class IteratorNode
{
public Iterator<T> _iterator;
public T _curVal;
public IteratorNode(Iterator<T> iterator)
{
_iterator = iterator;
_curVal = null;
}
public boolean fetch()
{//从iterator,也就是从list中取元素
if(_iterator.hasNext())
{
_curVal = _iterator.next();
return true;
}
_curVal = null;
return false;
}
}
private final PriorityQueue _queue;
private MergedIterator(final int length, final Comparator<T> comparator)
{
_queue = new PriorityQueue()
{
{
this.initialize(length);
}
@SuppressWarnings("unchecked")
@Override
protected boolean lessThan(Object o1, Object o2)
{//比较每一路的第一个元素
T v1 = ((IteratorNode)o1)._curVal;
T v2 = ((IteratorNode)o2)._curVal;
return (comparator.compare(v1, v2) < 0
);
}
};
}
public MergedIterator(final List<Iterator<T>> iterators, final Comparator<T> comparator)
{
//有几路,那么就将最大堆的容量设为多少
this(iterators.size(), comparator);
for(Iterator<T> iterator : iterators)
{
IteratorNode ctx = new IteratorNode(iterator);
if(ctx.fetch()) _queue.add(ctx);
}
}
public MergedIterator(final Iterator<T>[] iterators, final Comparator<T> comparator)
{
this(iterators.length, comparator);
for(Iterator<T> iterator : iterators)
{
IteratorNode ctx = new IteratorNode(iterator);
if(ctx.fetch()) _queue.add(ctx);
}
}
public boolean hasNext()
{
return _queue.size() > 0;
}
@SuppressWarnings("unchecked")
public T next()
{
IteratorNode ctx = (IteratorNode)_queue.top();
T val = ctx._curVal;
if (ctx.fetch())//从最大的那一路取出最上边的数据
{//这一路的list中还有数据,那么将最大堆重新排序
_queue.updateTop();
}
else
{//如果这一路已经没数据了,那么将这一路pop出堆
_queue.pop();
}
return val;
}
public void remove()
{
throw new UnsupportedOperationException();
}
}
private ListMerger() { }
public static <T> Iterator<T> mergeLists(final Iterator<T>[] iterators, final Comparator<T> comparator)
{
return new MergedIterator<T>(iterators, comparator);
}
public static <T> Iterator<T> mergeLists(final List<Iterator<T>> iterators, final Comparator<T> comparator)
{
return new MergedIterator<T>(iterators, comparator);
}
public static <T> ArrayList<T> mergeLists(int offset, int count, Iterator<T>[] iterators, Comparator<T> comparator)
{
return mergeLists(offset, count, new MergedIterator<T>(iterators, comparator));
}
public static <T> ArrayList<T> mergeLists(int offset, int count, List<Iterator<T>> iterators, Comparator<T> comparator)
{
return mergeLists(offset, count, new MergedIterator<T>(iterators, comparator));
}
//用于取出从offset开始的count个数据,对应于搜索结果的翻页结果
private static <T> ArrayList<T> mergeLists(int offset, int count, Iterator<T> mergedIter)
{
if (count == 0) return new ArrayList();
for (int c = 0; c < offset && mergedIter.hasNext(); c++)
{//首先刨除前offset个数据
mergedIter.next();
}
ArrayList<T> mergedList = new ArrayList<T>();
for (int c = 0; c < count && mergedIter.hasNext(); c++)
{
mergedList.add(mergedIter.next());
}
return mergedList;
}