Task的输入和输出
Task的输出
Task产出的每一个ResultPartition都有一个关联的ResultPartitionWriter,同时也都有一个独立的LocalBufferPool负责提供写入数据所需的buffer。ResultPartion实现了ResultPartitionWriter接口:
public class ResultPartition implements ResultPartitionWriter, BufferPoolOwner {
/** Type of this partition. Defines the concrete subpartition implementation to use. */
private final ResultPartitionType partitionType;
// ResultPartition由ResultSubpartition构成,
// ResultSubpartition的数量由下游消费Task数和DistributionPattern来决定。
// 例如,如果是 FORWARD,则下游只有一个消费者;如果是SHUFFLE,则下游消费者的数量和下游算子的并行度一样
/** The subpartitions of this partition. At least one. */
private final ResultSubpartition[] subpartitions;
// ResultPartitionManager管理当前TaskManager所有的ResultPartition
private final ResultPartitionManager partitionManager;
// 通知当前ResultPartition有数据可供消费的回调函数回调
private final ResultPartitionConsumableNotifier partitionConsumableNotifier;
private BufferPool bufferPool;
// 在有数据产出时,是否需要发送消息来调度或更新消费者(Stream模式下调度模式为 ScheduleMode.EAGER,无需发通知)
private final boolean sendScheduleOrUpdateConsumersMessage;
// 是否已经通知了消费者
private boolean hasNotifiedPipelinedConsumers;
public ResultPartition(
String owningTaskName,
TaskActions taskActions, // actions on the owning task
JobID jobId,
ResultPartitionID partitionId,
ResultPartitionType partitionType,
int numberOfSubpartitions,
int numTargetKeyGroups,
ResultPartitionManager partitionManager,
ResultPartitionConsumableNotifier partitionConsumableNotifier,
IOManager ioManager,
boolean sendScheduleOrUpdateConsumersMessage) {
this.owningTaskName = checkNotNull(owningTaskName);
this.taskActions = checkNotNull(taskActions);
this.jobId = checkNotNull(jobId);
this.partitionId = checkNotNull(partitionId);
this.partitionType = checkNotNull(partitionType);
this.subpartitions = new ResultSubpartition[numberOfSubpartitions];
this.numTargetKeyGroups = numTargetKeyGroups;
this.partitionManager = checkNotNull(partitionManager);
this.partitionConsumableNotifier = checkNotNull(partitionConsumableNotifier);
this.sendScheduleOrUpdateConsumersMessage = sendScheduleOrUpdateConsumersMessage;
// Create the subpartitions.
switch (partitionType) {
case BLOCKING: // Batch模式,SpillableSubpartition,在Buffer不充足时将结果写入磁盘
for (int i = 0; i < subpartitions.length; i++) {
subpartitions[i] = new SpillableSubpartition(i, this, ioManager);
}
break;
case PIPELINED: // Streaming模式,PipelinedSubpartition
case PIPELINED_BOUNDED:
for (int i = 0; i < subpartitions.length; i++) {
subpartitions[i] = new PipelinedSubpartition(i, this);
}
break;
default:
throw new IllegalArgumentException("Unsupported result partition type.");
}
// Initially, partitions should be consumed once before release.
pin();
LOG.debug("{}: Initialized {}", owningTaskName, this);
}
}
Task在启动的时候会向NetworkEnvironment进行注册,这里会为每一个ResultPartition分配LocalBufferPool:
class NetworkEnvironment {
private final NetworkBufferPool networkBufferPool;
private final ConnectionManager connectionManager;
private final ResultPartitionManager resultPartitionManager;
private final TaskEventDispatcher taskEventDispatcher;
// 注册一个Task,要给这个Task的输入和输出分配buffer pool
public void registerTask(Task task) throws IOException {
final ResultPartition[] producedPartitions = task.getProducedPartitions();
synchronized (lock) {
if (isShutdown) {
throw new IllegalStateException("NetworkEnvironment is shut down");
}
for (final ResultPartition partition : producedPartitions) {
setupPartition(partition); // 输出
}
// Setup the buffer pool for each buffer reader // 分配对应的buffer pool
final SingleInputGate[] inputGates = task.getAllInputGates();
for (SingleInputGate gate : inputGates) {
setupInputGate(gate);
}
}
}
public void setupPartition(ResultPartition partition) throws IOException {
BufferPool bufferPool = null;
try {
// 如果PartitionType 是 unbounded,则不限制buffer pool 的最大大小,否则为 sub-partition * taskmanager.network.memory.buffers-per-channel
int maxNumberOfMemorySegments = partition.getPartitionType().isBounded() ?
partition.getNumberOfSubpartitions() * networkBuffersPerChannel +
extraNetworkBuffersPerGate : Integer.MAX_VALUE;
// If the partition type is back pressure-free, we register with the buffer pool for
// callbacks to release memory.
// 创建一个LocalBufferPool,请求的最少的MemeorySegment数量和sub-partition一致
// 如果没有反压,则需要自己处理buffer的回收(主要是在batch模式)
bufferPool = networkBufferPool.createBufferPool(partition.getNumberOfSubpartitions(),
maxNumberOfMemorySegments,
partition.getPartitionType().hasBackPressure() ? Optional.empty() : Optional.of(partition));
// 给这个ResultPartition分配LocalBufferPool
partition.registerBufferPool(bufferPool);
// 向ResultPartitionManager注册
resultPartitionManager.registerResultPartition(partition);
} catch (Throwable t) {
// ......
}
taskEventDispatcher.registerPartition(partition.getPartitionId());
}
}
Task通过RecordWriter将结果写入ResultPartition中。RecordWriter是对ResultPartitionWriter的一层封装,并负责将记录对象序列化到buffer中。先来看一下RecordWriter的成员变量和构造函数:
class RecordWriter {
// 底层的 ResultPartition
private final ResultPartitionWriter targetPartition;
// 决定一条记录应该写入哪一个channel, 即 sub-partition
private final ChannelSelector<T> channelSelector;
// channel的数量,即 sub-partition的数量
private final int numberOfChannels;
// broadcast记录
private final int[] broadcastChannels;
// 序列化
private final RecordSerializer<T> serializer;
// 供每一个 channel 写入数据使用
private final Optional<BufferBuilder>[] bufferBuilders;
// 定时强制 flush 输出buffer
private final Optional<OutputFlusher> outputFlusher;
RecordWriter(ResultPartitionWriter writer, ChannelSelector<T> channelSelector, long timeout, String taskName) {
this.targetPartition = writer;
this.channelSelector = channelSelector;
this.numberOfChannels = writer.getNumberOfSubpartitions();
this.channelSelector.setup(numberOfChannels);
//序列化器,用于指的一提将一条记录序列化到多个buffer中
this.serializer = new SpanningRecordSerializer<T>();
this.bufferBuilders = new Optional[numberOfChannels];
this.broadcastChannels = new int[numberOfChannels];
for (int i = 0; i < numberOfChannels; i++) {
broadcastChannels[i] = i;
bufferBuilders[i] = Optional.empty();
}
checkArgument(timeout >= -1);
this.flushAlways = (timeout == 0);
if (timeout == -1 || timeout == 0) {
outputFlusher = Optional.empty();
} else {
//根据超时时间创建一个定时 flush 输出 buffer 的线程
String threadName = taskName == null ?
DEFAULT_OUTPUT_FLUSH_THREAD_NAME :
DEFAULT_OUTPUT_FLUSH_THREAD_NAME + " for " + taskName;
outputFlusher = Optional.of(new OutputFlusher(threadName, timeout));
outputFlusher.get().start();
}
}
}
当Task通过RecordWriter输出一条记录时,主要流程为:
- 通过ChannelSelector确定写入的目标channel
- 使用RecordSerializer对记录进行序列化
- 向ResultPartition请求BufferBuilder,用于写入序列化结果(在networkBufferPool中申请新的本地MemorySegment)
- 向ResultPartition添加BufferConsumer,其主要用于下游Task任务对上诉写入本地Buffer中的数据进行读取操作
BufferConsumer主要是对memorySegment数据的一层封装,其操作实质还是对memorySegment的操作;BufferConsumer可以用来被PipelinedSubpartitionView来触发其内部的BufferAvailabilityListener的消费通知,其可以通知下游Task任务对该ResultPartition写入Buffer的结果读取:
- 在local模式下,BufferAvailabilityListener对应于LocalInputChannel,其会通知对应的LocalInputChannel该buffer数据准备就绪,可以进行本地local读取;
- 在Remote模式下,BufferAvailabilityListener对应于CreditBasedSequenceNumberingViewReader,其会向Netty ChannelPipeline中发送ViewReader事件去通知下游的RemoteInputChannel去读取该Buffer数据;
代码如下:
class RecordWriter {
public void emit(T record) throws IOException, InterruptedException {
// channelSelector确定目标channel
emit(record, channelSelector.selectChannels(record, numChannels));
}
private void emit(T record, int[] targetChannels) throws IOException, InterruptedException {
serializer.serializeRecord(record); //序列化
boolean pruneAfterCopying = false;
for (int channel : targetChannels) {
if (copyFromSerializerToTargetChannel(channel)) { //将序列化结果写入buffer
pruneAfterCopying = true;
}
}
// Make sure we don't hold onto the large intermediate serialization buffer for too long
if (pruneAfterCopying) {
serializer.prune(); //清除序列化使用的buffer(这个是序列化时临时写入的byte[]),减少内存占用
}
}
//将序列化结果写入buffer
private boolean copyFromSerializerToTargetChannel(int targetChannel) throws IOException, InterruptedException {
// We should reset the initial position of the intermediate serialization buffer before
// copying, so the serialization results can be copied to multiple target buffers.
serializer.reset();
boolean pruneTriggered = false;
BufferBuilder bufferBuilder = getBufferBuilder(targetChannel);
SerializationResult result = serializer.copyToBufferBuilder(bufferBuilder);
while (result.isFullBuffer()) { //buffer 写满了,调用 bufferBuilder.finish 方法
numBytesOut.inc(bufferBuilder.finish());
numBuffersOut.inc();
// If this was a full record, we are done. Not breaking out of the loop at this point
// will lead to another buffer request before breaking out (that would not be a
// problem per se, but it can lead to stalls in the pipeline).
if (result.isFullRecord()) { //当前这条记录也完整输出了
pruneTriggered = true;
bufferBuilders[targetChannel] = Optional.empty();
break;
}
// 当前这条记录没有写完,申请新的 buffer 写入
bufferBuilder = requestNewBufferBuilder(targetChannel);
result = serializer.copyToBufferBuilder(bufferBuilder);
}
checkState(!serializer.hasSerializedData(), "All data should be written at once");
if (flushAlways) { //强制刷新结果
targetPartition.flush(targetChannel);
}
return pruneTriggered;
}
private BufferBuilder getBufferBuilder(int targetChannel) throws IOException, InterruptedException {
if (bufferBuilders[targetChannel].isPresent()) {
return bufferBuilders[targetChannel].get();
} else {
return requestNewBufferBuilder(targetChannel);
}
}
//请求新的 BufferBuilder,用于写入数据 如果当前没有可用的 buffer,会阻塞
private BufferBuilder requestNewBufferBuilder(int targetChannel) throws IOException, InterruptedException {
checkState(!bufferBuilders[targetChannel].isPresent() || bufferBuilders[targetChannel].get().isFinished());
//从 LocalBufferPool 中请求 BufferBuilder
BufferBuilder bufferBuilder = targetPartition.getBufferProvider().requestBufferBuilderBlocking();
bufferBuilders[targetChannel] = Optional.of(bufferBuilder);
//添加一个BufferConsumer,用于读取写入到 MemorySegment 的数据
targetPartition.addBufferConsumer(bufferBuilder.createBufferConsumer(), targetChannel);
return bufferBuilder;
}
}
向ResultPartition添加一个BufferConsumer,ResultPartition会将其转交给对应的ResultSubpartition:
class ResultPartition implement ResultPartitionWriter {
//向指定的 subpartition 添加一个 buffer
public void addBufferConsumer(BufferConsumer bufferConsumer, int subpartitionIndex) throws IOException {
checkNotNull(bufferConsumer);
ResultSubpartition subpartition;
try {
checkInProduceState();
subpartition = subpartitions[subpartitionIndex];
}
catch (Exception ex) {
bufferConsumer.close();
throw ex;
}
//添加 BufferConsumer,说明已经有数据生成了
if (subpartition.add(bufferConsumer)) {
notifyPipelinedConsumers();
}
}
/**
* Notifies pipelined consumers of this result partition once.
*/
private void notifyPipelinedConsumers() {
//对于 Streaming 模式的任务,由于调度模式为 EAGER,所有的 task 都已经部署了,下面的通知不会触发 (flink默认调度模式为EAGER)
if (sendScheduleOrUpdateConsumersMessage && !hasNotifiedPipelinedConsumers && partitionType.isPipelined()) {
//对于 PIPELINE 类型的 ResultPartition,在第一条记录产生时,
//会告知 JobMaster 当前 ResultPartition 可被消费,这会触发下游消费者 Task 的部署
partitionConsumableNotifier.notifyPartitionConsumable(jobId, partitionId, taskActions);
hasNotifiedPipelinedConsumers = true;
}
}
}
前面已经看到,根据ResultPartitionType的不同,ResultSubpartition的实现类也不同。对于Streaming模式,使用的是PipelinedSubpartition:
public abstract class ResultSubpartition {
/** The index of the subpartition at the parent partition. */
protected final int index;
/** The parent partition this subpartition belongs to. */
protected final ResultPartition parent;
/** All buffers of this subpartition. Access to the buffers is synchronized on this object. */
//当前 subpartiion 堆积的所有的 Buffer 的队列
protected final ArrayDeque<BufferConsumer> buffers = new ArrayDeque<>();
/** The number of non-event buffers currently in this subpartition. */
//当前 subpartiion 中堆积的 buffer 的数量
@GuardedBy("buffers")
private int buffersInBacklog;
}
class PipelinedSubpartition extends ResultSubpartition {
//用于消费写入的 Buffer
private PipelinedSubpartitionView readView;
//index 是当前 sub-paritition 的索引
PipelinedSubpartition(int index, ResultPartition parent) {
super(index, parent);
}
@Override
public boolean add(BufferConsumer bufferConsumer) {
return add(bufferConsumer, false);
}
//添加一个新的BufferConsumer
//这个参数里的 finish 指的是整个 subpartition 都完成了
private boolean add(BufferConsumer bufferConsumer, boolean finish) {
checkNotNull(bufferConsumer);
final boolean notifyDataAvailable;
synchronized (buffers) {
if (isFinished || isReleased) {
bufferConsumer.close();
return false;
}
// Add the bufferConsumer and update the stats
buffers.add(bufferConsumer);
updateStatistics(bufferConsumer);
//更新 backlog 的数量,只有 buffer 才会使得 buffersInBacklog + 1,事件不会增加 buffersInBacklog
increaseBuffersInBacklog(bufferConsumer);
notifyDataAvailable = shouldNotifyDataAvailable() || finish;
isFinished |= finish;
}
if (notifyDataAvailable) {
//通知数据可以被消费
notifyDataAvailable();
}
return true;
}
//只在第一个 buffer 为 finish 的时候才通知
private boolean shouldNotifyDataAvailable() {
// Notify only when we added first finished buffer.
return readView != null && !flushRequested && getNumberOfFinishedBuffers() == 1;
}
//通知readView,有数据可用了
private void notifyDataAvailable() {
if (readView != null) {
readView.notifyDataAvailable();
}
}
@Override
public void flush() {
final boolean notifyDataAvailable;
synchronized (buffers) {
if (buffers.isEmpty()) {
return;
}
// if there is more then 1 buffer, we already notified the reader
// (at the latest when adding the second buffer)
notifyDataAvailable = !flushRequested && buffers.size() == 1;
flushRequested = true;
}
if (notifyDataAvailable) {
notifyDataAvailable();
}
}
}
在强制进行flush的时候,也会发出数据可用的通知。这是因为,假如产出的数据记录较少无法完整地填充一个MemorySegment,那么ResultSubpartition可能会一直处于不可被消费的状态。而为了保证产出的记录能够及时被消费,就需要及时进行flush,从而确保下游能更及时地处理数据。在RecordWriter中有一个OutputFlusher会定时触发flush,间隔可以通过DataStream.setBufferTimeout()来控制。
写入的Buffer最终被保存在ResultSubpartition中维护的一个队列中,如果需要消费这些Buffer,就需要依赖ResultSubpartitionView。当需要消费一个ResultSubpartition的结果时,需要创建一个ResultSubpartitionView对象,并关联到ResultSubpartition中;当数据可以被消费时,会通过对应的回调接口告知ResultSubpartitionView:
/**
* A view to consume a {@link ResultSubpartition} instance.
*/
public interface ResultSubpartitionView {
/**
* Returns the next {@link Buffer} instance of this queue iterator.
*
* <p>If there is currently no instance available, it will return <code>null</code>.
* This might happen for example when a pipelined queue producer is slower
* than the consumer or a spilled queue needs to read in more data.
*
* <p><strong>Important</strong>: The consumer has to make sure that each
* buffer instance will eventually be recycled with {@link Buffer#recycleBuffer()}
* after it has been consumed.
*/
@Nullable
BufferAndBacklog getNextBuffer() throws IOException, InterruptedException;
// 通知下游input Channel 该ResultSubpartition的数据可供消费
void notifyDataAvailable();
// 已经完成对ResultSubpartition的消费
void notifySubpartitionConsumed() throws IOException;
boolean nextBufferIsEvent();
//........
}
class PipelinedSubpartitionView implements ResultSubpartitionView {
/** The subpartition this view belongs to. */
private final PipelinedSubpartition parent;
private final BufferAvailabilityListener availabilityListener;
/** Flag indicating whether this view has been released. */
private final AtomicBoolean isReleased;
PipelinedSubpartitionView(PipelinedSubpartition parent, BufferAvailabilityListener listener) {
this.parent = checkNotNull(parent);
this.availabilityListener = checkNotNull(listener);
this.isReleased = new AtomicBoolean();
}
@Nullable
@Override
public BufferAndBacklog getNextBuffer() {
return parent.pollBuffer();
}
@Override
public void notifyDataAvailable() {
// 回调接口 通知下游input channel该ResultSubpartition的数据可供消费
// 其实现如下:本地:LocalInputChannel
// remote:CreditBasedSequenceNumberingViewReader
availabilityListener.notifyDataAvailable();
}
@Override
public void notifySubpartitionConsumed() {
releaseAllResources();
}
@Override
public void releaseAllResources() {
if (isReleased.compareAndSet(false, true)) {
// The view doesn't hold any resources and the parent cannot be restarted. Therefore,
// it's OK to notify about consumption as well.
parent.onConsumedSubpartition();
}
}
@Override
public boolean isReleased() {
return isReleased.get() || parent.isReleased();
}
@Override
public boolean nextBufferIsEvent() {
return parent.nextBufferIsEvent();
}
@Override
public boolean isAvailable() {
return parent.isAvailable();
}
}
当需要创建一个ResultSubpartition的消费者时,需要通过ResultPartitionManager来创建。ResultPartitionManager会管理当前Task的所有ResultPartition。
class ResultPartitionManager implements ResultPartitionProvider {
// 管理所有的ResultPartition,使用的是Guava提供的支持多级映射的哈希表
public final Table<ExecutionAttemptID, IntermediateResultPartitionID, ResultPartition>
registeredPartitions = HashBasedTable.create();
//一个Task在向NetworkEnvironment注册的时候就会逐一注册所有的ResultPartition
public void registerResultPartition(ResultPartition partition) throws IOException {
synchronized (registeredPartitions) {
checkState(!isShutdown, "Result partition manager already shut down.");
ResultPartitionID partitionId = partition.getPartitionId();
ResultPartition previous = registeredPartitions.put(
partitionId.getProducerId(), partitionId.getPartitionId(), partition);
if (previous != null) {
throw new IllegalStateException("Result partition already registered.");
}
LOG.debug("Registered {}.", partition);
}
}
// 在指定的ResultSubpartition中创建一个ResultSubpartitionView,用于消费数据
@Override
public ResultSubpartitionView createSubpartitionView(
ResultPartitionID partitionId,
int subpartitionIndex,
BufferAvailabilityListener availabilityListener) throws IOException {
synchronized (registeredPartitions) {
final ResultPartition partition = registeredPartitions.get(partitionId.getProducerId(),
partitionId.getPartitionId());
//创建 ResultSubpartitionView,可以看作是 ResultSubpartition 的消费者
return partition.createSubpartitionView(subpartitionIndex, availabilityListener);
}
}
}
Task的输入
前面已经介绍过,Task的输入被抽象为InputGate,而InputGate则由InputChannel组成,InputChannel和该Task需要消费的ResultSubpartition是一一对应的。
public interface InputGate extends AutoCloseable {
int getNumberOfInputChannels();
String getOwningTaskName();
boolean isFinished();
//请求消费 ResultPartition
void requestPartitions() throws IOException, InterruptedException;
/**
* Blocking call waiting for next {@link BufferOrEvent}.
* 阻塞调用
* @return {@code Optional.empty()} if {@link #isFinished()} returns true.
*/
Optional<BufferOrEvent> getNextBufferOrEvent() throws IOException, InterruptedException;
/**
* Poll the {@link BufferOrEvent}.
* 非阻塞调用
* @return {@code Optional.empty()} if there is no data to return or if {@link #isFinished()} returns true.
*/
Optional<BufferOrEvent> pollNextBufferOrEvent() throws IOException, InterruptedException;
void sendTaskEvent(TaskEvent event) throws IOException;
void registerListener(InputGateListener listener);
int getPageSize();
}
Task通过循环调用InputGate.getNextBufferOrEvent方法阻塞的从Channel中获取输入数据,并将获取的数据交给它所封装的算子进行处理,这构成了一个Task的基本运行逻辑。InputGate有两个具体的实现,分别为SingleInputGate和UnionInputGate,UnionInputGate有多个SingleInputGate联合构成。
class SingleInputGate {
//该 InputGate 包含的所有 InputChannel
private final Map<IntermediateResultPartitionID, InputChannel> inputChannels;
/** Channels, which notified this input gate about available data. */
//InputChannel 构成的队列,这些 InputChannel 中都有有可供消费的数据
private final ArrayDeque<InputChannel> inputChannelsWithData = new ArrayDeque<>();
/**
* Buffer pool for incoming buffers. Incoming data from remote channels is copied to buffers
* from this pool.
*/
//用于接收输入的缓冲池
private BufferPool bufferPool;
/** Global network buffer pool to request and recycle exclusive buffers (only for credit-based). */
//全局网络缓冲池
private NetworkBufferPool networkBufferPool;
/** Registered listener to forward buffer notifications to. */
private volatile InputGateListener inputGateListener;
private Optional<BufferOrEvent> getNextBufferOrEvent(boolean blocking) throws IOException, InterruptedException {
if (hasReceivedAllEndOfPartitionEvents) {
return Optional.empty();
}
if (isReleased) {
throw new IllegalStateException("Released");
}
//首先尝试请求分区inputChannel.requestSubpartition(),实际的请求只会执行一次 (分别对应localInputChannel和RemoteInputChannel)
requestPartitions();
InputChannel currentChannel;
boolean moreAvailable;
Optional<BufferAndAvailability> result = Optional.empty();
do {
synchronized (inputChannelsWithData) {
// 从inputChannelsWithData队列中获取有数据的channel,经典的生产者-消费者模式
while (inputChannelsWithData.size() == 0) {
if (isReleased) {
throw new IllegalStateException("Released");
}
if (blocking) {
inputChannelsWithData.wait(); // 阻塞等待
}
else {
return Optional.empty();
}
}
currentChannel = inputChannelsWithData.remove();
enqueuedInputChannelsWithData.clear(currentChannel.getChannelIndex());
// 是否还有更多的数据
moreAvailable = !inputChannelsWithData.isEmpty();
}
result = currentChannel.getNextBuffer();
} while (!result.isPresent());
// this channel was now removed from the non-empty channels queue
// we re-add it in case it has more data, because in that case no "non-empty" notification
// will come for that channel
if (result.get().moreAvailable()) {
// 如果这个channel还有更多的数据,继续加入到队列中
queueChannel(currentChannel);
moreAvailable = true;
}
final Buffer buffer = result.get().buffer();
if (buffer.isBuffer()) {
return Optional.of(new BufferOrEvent(buffer, currentChannel.getChannelIndex(), moreAvailable));
}
else {
final AbstractEvent event = EventSerializer.fromBuffer(buffer, getClass().getClassLoader());
//如果是 EndOfPartitionEvent 事件,那么如果所有的 InputChannel 都接收到这个事件了
//将 hasReceivedAllEndOfPartitionEvents 标记为 true,此后不再能获取到数据
if (event.getClass() == EndOfPartitionEvent.class) {
channelsWithEndOfPartitionEvents.set(currentChannel.getChannelIndex());
if (channelsWithEndOfPartitionEvents.cardinality() == numberOfInputChannels) {
// Because of race condition between:
// 1. releasing inputChannelsWithData lock in this method and reaching this place
// 2. empty data notification that re-enqueues a channel
// we can end up with moreAvailable flag set to true, while we expect no more data.
checkState(!moreAvailable || !pollNextBufferOrEvent().isPresent());
moreAvailable = false;
hasReceivedAllEndOfPartitionEvents = true;
}
currentChannel.notifySubpartitionConsumed();
currentChannel.releaseAllResources();
}
return Optional.of(new BufferOrEvent(event, currentChannel.getChannelIndex(), moreAvailable));
}
}
// 当一个InputChannel有数据时的回调
void notifyChannelNonEmpty(InputChannel channel) {
queueChannel(checkNotNull(channel));
}
//将新的channel加入队列
private void queueChannel(InputChannel channel) {
int availableChannels;
synchronized (inputChannelsWithData) {
//判断这个channel是否已经在队列中
if (enqueuedInputChannelsWithData.get(channel.getChannelIndex())) {
return;
}
availableChannels = inputChannelsWithData.size();
//加入队列
inputChannelsWithData.add(channel);
enqueuedInputChannelsWithData.set(channel.getChannelIndex());
if (availableChannels == 0) {
//如果之前队列中没有channel,这个channel加入后,通知等待的线程
inputChannelsWithData.notifyAll();
}
}
if (availableChannels == 0) {
//如果之前队列中没有channel,这个channel加入后,通知InputGateListener
//表明这个 InputGate 中有数据了
InputGateListener listener = inputGateListener;
if (listener != null) {
listener.notifyInputGateNonEmpty(this);
}
}
}
//请求分区
@Override
public void requestPartitions() throws IOException, InterruptedException {
synchronized (requestLock) {
//只请求一次
if (!requestedPartitionsFlag) {
if (isReleased) {
throw new IllegalStateException("Already released.");
}
// Sanity checks
if (numberOfInputChannels != inputChannels.size()) {
throw new IllegalStateException("Bug in input gate setup logic: mismatch between " +
"number of total input channels and the currently set number of input " +
"channels.");
}
for (InputChannel inputChannel : inputChannels.values()) {
//每一个channel都请求对应的子分区
inputChannel.requestSubpartition(consumedSubpartitionIndex);
}
}
requestedPartitionsFlag = true;
}
}
}
SingleInputGate的逻辑还比较清晰,它通过内部维护的一个队列形成一个生产者-消费者的模型,当InputChannel中有数据时就加入到队列中,在需要获取数据时从队列中取出一个channel,获取channel中的数据。
UnionInputGate是多个SingleInputGate联合组成,它的内部有一个inputGatesWithData队列:
public class UnionInputGate implements InputGate, InputGateListener {
/** The input gates to union. */
private final InputGate[] inputGates;
/** Gates, which notified this input gate about available data. */
private final ArrayDeque<InputGate> inputGatesWithData = new ArrayDeque<>();
@Override
public Optional<BufferOrEvent> getNextBufferOrEvent() throws IOException, InterruptedException {
if (inputGatesWithRemainingData.isEmpty()) {
return Optional.empty();
}
// Make sure to request the partitions, if they have not been requested before.
requestPartitions();
InputGateWithData inputGateWithData = waitAndGetNextInputGate();
InputGate inputGate = inputGateWithData.inputGate;
BufferOrEvent bufferOrEvent = inputGateWithData.bufferOrEvent;
if (bufferOrEvent.moreAvailable()) {
//这个 InputGate 中还有更多的数据,继续加入队列
// this buffer or event was now removed from the non-empty gates queue
// we re-add it in case it has more data, because in that case no "non-empty" notification
// will come for that gate
queueInputGate(inputGate);
}
if (bufferOrEvent.isEvent()
&& bufferOrEvent.getEvent().getClass() == EndOfPartitionEvent.class
&& inputGate.isFinished()) {
checkState(!bufferOrEvent.moreAvailable());
if (!inputGatesWithRemainingData.remove(inputGate)) {
throw new IllegalStateException("Couldn't find input gate in set of remaining " +
"input gates.");
}
}
// Set the channel index to identify the input channel (across all unioned input gates)
final int channelIndexOffset = inputGateToIndexOffsetMap.get(inputGate);
bufferOrEvent.setChannelIndex(channelIndexOffset + bufferOrEvent.getChannelIndex());
bufferOrEvent.setMoreAvailable(bufferOrEvent.moreAvailable() || inputGateWithData.moreInputGatesAvailable);
return Optional.of(bufferOrEvent);
}
private InputGateWithData waitAndGetNextInputGate() throws IOException, InterruptedException {
while (true) {
InputGate inputGate;
boolean moreInputGatesAvailable;
synchronized (inputGatesWithData) {
//等待 inputGatesWithData 队列,经典的生产者-消费者模型
while (inputGatesWithData.size() == 0) {
inputGatesWithData.wait();
}
inputGate = inputGatesWithData.remove();
enqueuedInputGatesWithData.remove(inputGate);
moreInputGatesAvailable = enqueuedInputGatesWithData.size() > 0;
}
// In case of inputGatesWithData being inaccurate do not block on an empty inputGate, but just poll the data.
Optional<BufferOrEvent> bufferOrEvent = inputGate.pollNextBufferOrEvent();
if (bufferOrEvent.isPresent()) {
return new InputGateWithData(inputGate, bufferOrEvent.get(), moreInputGatesAvailable);
}
}
}
@Override
public void notifyInputGateNonEmpty(InputGate inputGate) {
queueInputGate(checkNotNull(inputGate));
}
}
InputGate相当于是对InputChannel的一层封装,实际数据的获取还是要依赖于InputChannel。
public abstract class InputChannel {
protected final int channelIndex;
//消费的目标 ResultPartitionID
protected final ResultPartitionID partitionId;
protected final SingleInputGate inputGate;
//回调函数,告知 InputGate 当前 channel 有数据
protected void notifyChannelNonEmpty() {
inputGate.notifyChannelNonEmpty(this);
}
//请求ResultSubpartition
abstract void requestSubpartition(int subpartitionIndex) throws IOException, InterruptedException;
abstract Optional<BufferAndAvailability> getNextBuffer() throws IOException, InterruptedException;
abstract void sendTaskEvent(TaskEvent event) throws IOException;
abstract void notifySubpartitionConsumed() throws IOException;
abstract void releaseAllResources() throws IOException;
}
InputChannel的基本逻辑也比较简单,它的生命周期按照requestSubpartition(intsubpartitionIndex),getNextBuffer()和releaseAllResources()这样的顺序进行。
根据InputChannel消费的ResultPartition的位置,InputChannel有LocalInputChannel和RemoteInputChannel两中不同的实现,分别对应本地和远程数据交换。InputChannel还有一个实现类是UnknownInputChannel,相当于是还未确定ResultPartition位置的情况下的占位符,最终还是会更新为LocalInputChannel或是RemoteInputChannel。