1. ShuffleMapTask的runTask()方法
override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val deserializeStartTime = System.currentTimeMillis()
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
metrics = Some(context.taskMetrics)
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
return writer.stop(success = true).get
} catch {
case e: Exception =>
try {
if (writer != null) {
writer.stop(success = false)
}
} catch {
case e: Exception =>
log.debug("Could not stop writer", e)
}
throw e
}
}
首先得到shuffleManager,shuffleManager分为三种SortShuffleManager,HashshuffleManager,UnsafeShuffleManager。这里我们focus on UnsafeShuffleManager。得到shuffleManager后,再拿到UnsafeShuffleWriter。在调用UnsafeShuffleWriter的write()方法将数据写入shuffle文件。
2. UnsafeShuffleWriter的write()方法
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
boolean success = false;
try {
while (records.hasNext()) {
insertRecordIntoSorter(records.next());
}
closeAndWriteOutput();
success = true;
} finally {
if (!success) {
sorter.cleanupAfterError();
}
}
}
write()方法调用insertRecordIntoSorter()方法。
void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serBuffer.reset();
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
serOutputStream.flush();
final int serializedRecordSize = serBuffer.size();
assert (serializedRecordSize > 0);
sorter.insertRecord(
serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
}
先将数据序列化,insertRecord()方法将其插入到UnsafeShuffleExternalSorter中。
3. UnsafeShuffleExternalSorter 的insertRecord()方法
public void insertRecord(
Object recordBaseObject,
long recordBaseOffset,
int lengthInBytes,
int partitionId) throws IOException {
// Need 4 bytes to store the record length.
final int totalSpaceRequired = lengthInBytes + 4;
if (!haveSpaceForRecord(totalSpaceRequired)) {
allocateSpaceForRecord(totalSpaceRequired);
}
final long recordAddress =
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
final Object dataPageBaseObject = currentPage.getBaseObject();
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
currentPagePosition += 4;
freeSpaceInCurrentPage -= 4;
PlatformDependent.copyMemory(
recordBaseObject,
recordBaseOffset,
dataPageBaseObject,
currentPagePosition,
lengthInBytes);
currentPagePosition += lengthInBytes;
freeSpaceInCurrentPage -= lengthInBytes;
sorter.insertRecord(recordAddress, partitionId);
}
4.UnsafeShuffleExternalSorter的insertRecord()方法
public void insertRecord(long recordPointer, int partitionId) {
if (!hasSpaceForAnotherRecord()) {
if (pointerArray.length == Integer.MAX_VALUE) {
throw new IllegalStateException("Sort pointer array has reached maximum size");
} else {
expandPointerArray();
}
}
pointerArray[pointerArrayInsertPosition] =
PackedRecordPointer.packPointer(recordPointer, partitionId);
pointerArrayInsertPosition++;
}
PackedRecordPointerPackedRecordPointer对象用一个64bit的long型变量来记录数据信息:
[24 bit partition number][13 bit memory page number][27 bit offset in page]。这些信息用来数据排序。
private void allocateSpaceForRecord(int requiredSpace) throws IOException {
if (!sorter.hasSpaceForAnotherRecord()) {
logger.debug("Attempting to expand sort pointer array");
final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
if (memoryAcquired < memoryToGrowPointerArray) {
shuffleMemoryManager.release(memoryAcquired);
spill();
} else {
sorter.expandPointerArray();
shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
}
}
if (requiredSpace > freeSpaceInCurrentPage) {
logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
freeSpaceInCurrentPage);
// TODO: we should track metrics on the amount of space wasted when we roll over to a new page
// without using the free space at the end of the current page. We should also do this for
// BytesToBytesMap.
if (requiredSpace > PAGE_SIZE) {
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
PAGE_SIZE + ")");
} else {
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
if (memoryAcquired < PAGE_SIZE) {
shuffleMemoryManager.release(memoryAcquired);
spill();
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
}
}
currentPage = memoryManager.allocatePage(PAGE_SIZE);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = PAGE_SIZE;
allocatedPages.add(currentPage);
}
}
}