【Flink源码】StreamGraph构建过程

StreamGraph 生成过程

前面我们说到 StreamGraph 最终由 StreamGraphGenerator 类生成
其构造函数只是实例化了一些属性,之后调用了一些 setter 方法,设置了如状态后端、保存点路径等的配置,真正生成 StreamGraph 的是之后的 generate 方法(getStreamGraphGenerator(transformations).generate())
下面我们一起来看一下 StreamGraphGenerator.generate() 源码,不过在此之前,我相信你一定有一个一直没搞懂的问题,那就是 DataStream 里的算子是怎么变成 List 的?我们首先解决这个疑问。
以 flatmap 算子为例,一起来看 DataStream 源码

DataStream.java

public <R> SingleOutputStreamOperator<R> flatMap(
            FlatMapFunction<T, R> flatMapper, TypeInformation<R> outputType) {
    return transform("Flat Map", outputType, new StreamFlatMap<>(clean(flatMapper)));
}

public <R> SingleOutputStreamOperator<R> transform(
            String operatorName,
            TypeInformation<R> outTypeInfo,
            OneInputStreamOperatorFactory<T, R> operatorFactory) {

    return doTransform(operatorName, outTypeInfo, operatorFactory);
}

protected <R> SingleOutputStreamOperator<R> doTransform(
            String operatorName,
            TypeInformation<R> outTypeInfo,
            StreamOperatorFactory<R> operatorFactory) {

    // read the output type of the input Transform to coax out errors about MissingTypeInfo
    transformation.getOutputType();

    OneInputTransformation<T, R> resultTransform =
            new OneInputTransformation<>(
                    this.transformation,
                    operatorName,
                    operatorFactory,
                    outTypeInfo,
                    environment.getParallelism());

    @SuppressWarnings({"unchecked", "rawtypes"})
    SingleOutputStreamOperator<R> returnStream =
            new SingleOutputStreamOperator(environment, resultTransform);

    getExecutionEnvironment().addOperator(resultTransform);

    return returnStream;
}

这次一镜到底终于在最后一行找到了 getExecutionEnvironment().addOperator(resultTransform),而 getExecutionEnvironment 是获取当前的执行环境,返回 StreamExecutionEnvironment 对象。
好,我们再回到 StreamExecutionEnvironment 类寻找 addOperator 方法

StreamExecutionEnvironment.java

public void addOperator(Transformation<?> transformation) {
    Preconditions.checkNotNull(transformation, "transformation must not be null.");
    this.transformations.add(transformation);
}

这次全明白了吧!!

总结:

  • flatmap 转换将用户自定义的 FlatMapFunction 包装到 StreamFlatMap 这个 Operator 中
  • 再将 StreamFlatMap 包装到 OneInputTransformation
  • 最后该 transformation 存到 env 中
  • 当调用 env.execute 时,遍历其中的 transformation 列表构造出 StreamGraph

分层示意图如下:
1662428639617

下面我们回到主题,继续看 generate() 方法

public StreamGraph generate() {
    // 1. 初始化并配置 streamGraph 的信息
    streamGraph = new StreamGraph(executionConfig, checkpointConfig, savepointRestoreSettings);
    streamGraph.setEnableCheckpointsAfterTasksFinish(
            configuration.get(
                    ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH));
    shouldExecuteInBatchMode = shouldExecuteInBatchMode();
    configureStreamGraph(streamGraph);

    // 用户保存已经转换的 Transformation
    alreadyTransformed = new IdentityHashMap<>();

    // 2. 遍历所有 transformation 并转换为计算图
    for (Transformation<?> transformation : transformations) {
        transform(transformation);
    }

    streamGraph.setSlotSharingGroupResource(slotSharingGroupResources);

    setFineGrainedGlobalStreamExchangeMode(streamGraph);

    for (StreamNode node : streamGraph.getStreamNodes()) {
        if (node.getInEdges().stream().anyMatch(this::shouldDisableUnalignedCheckpointing)) {
            for (StreamEdge edge : node.getInEdges()) {
                edge.setSupportsUnalignedCheckpoints(false);
            }
        }
    }

    final StreamGraph builtStreamGraph = streamGraph;

    alreadyTransformed.clear();
    alreadyTransformed = null;
    streamGraph = null;

    return builtStreamGraph;
}

主要包含几个步骤:

  1. 初始化并配置 streamGraph 的信息
  2. 遍历所有的 Transformation,并对 transformation 进行转换

核心代码都在 transform(transformation)

private Collection<Integer> transform(Transformation<?> transform) {
    // 1. 如何某个 transformation 已经转换过,直接返回 transformedId,这里要判断,因为 graph 可能会出现 loop
    if (alreadyTransformed.containsKey(transform)) {
        return alreadyTransformed.get(transform);
    }

    LOG.debug("Transforming " + transform);

    // 2. 设置并行度
    if (transform.getMaxParallelism() <= 0) {

        // if the max parallelism hasn't been set, then first use the job wide max parallelism
        // from the ExecutionConfig.
        int globalMaxParallelismFromConfig = executionConfig.getMaxParallelism();
        if (globalMaxParallelismFromConfig > 0) {
            transform.setMaxParallelism(globalMaxParallelismFromConfig);
        }
    }

    // 3. 设置 slot 共享组
    transform
            .getSlotSharingGroup()
            .ifPresent(
                    slotSharingGroup -> {
                        final ResourceSpec resourceSpec =
                                SlotSharingGroupUtils.extractResourceSpec(slotSharingGroup);
                        if (!resourceSpec.equals(ResourceSpec.UNKNOWN)) {
                            slotSharingGroupResources.compute(
                                    slotSharingGroup.getName(),
                                    (name, profile) -> {
                                        if (profile == null) {
                                            return ResourceProfile.fromResourceSpec(
                                                    resourceSpec, MemorySize.ZERO);
                                        } else if (!ResourceProfile.fromResourceSpec(
                                                        resourceSpec, MemorySize.ZERO)
                                                .equals(profile)) {
                                            throw new IllegalArgumentException(
                                                    "The slot sharing group "
                                                            + slotSharingGroup.getName()
                                                            + " has been configured with two different resource spec.");
                                        } else {
                                            return profile;
                                        }
                                    });
                        }
                    });

    // call at least once to trigger exceptions about MissingTypeInfo
    // 4. 调用判断是否有推断出 outputType,有则抛出异常
    transform.getOutputType();

    // 5. 调用具体的 translator
    @SuppressWarnings("unchecked")
    final TransformationTranslator<?, Transformation<?>> translator =
            (TransformationTranslator<?, Transformation<?>>)
                    translatorMap.get(transform.getClass());

    Collection<Integer> transformedIds;
    if (translator != null) {
        transformedIds = translate(translator, transform);
    } else {
        transformedIds = legacyTransform(transform);
    }

    // need this check because the iterate transformation adds itself before
    // transforming the feedback edges
    // 6. 将转换过的 transform 添加到 alreadyTransformed 中
    if (!alreadyTransformed.containsKey(transform)) {
        alreadyTransformed.put(transform, transformedIds);
    }

    return transformedIds;
}

步骤如下:

  1. 如果某个 transformation 已经转换过,直接返回 transformedId,这里要判断,因为 graph 可能会出现 loop
  2. 设置并行度
  3. 设置 SlotSharingGroup
  4. 调用判断是否有推断出 outputType,有则抛出异常
  5. 调用具体的 translator,内置的 translator 都被保留在 translatorMap 中
  6. 将转换过的 transformation 添加到 alreadyTransformed 中

translatorMap实例化内容在静态代码块中被构建,具体如下:

static {
    @SuppressWarnings("rawtypes")
    Map<Class<? extends Transformation>, TransformationTranslator<?, ? extends Transformation>>
            tmp = new HashMap<>();
    tmp.put(OneInputTransformation.class, new OneInputTransformationTranslator<>());
    tmp.put(TwoInputTransformation.class, new TwoInputTransformationTranslator<>());
    tmp.put(MultipleInputTransformation.class, new MultiInputTransformationTranslator<>());
    tmp.put(KeyedMultipleInputTransformation.class, new MultiInputTransformationTranslator<>());
    tmp.put(SourceTransformation.class, new SourceTransformationTranslator<>());
    tmp.put(SinkTransformation.class, new SinkTransformationTranslator<>());
    tmp.put(LegacySinkTransformation.class, new LegacySinkTransformationTranslator<>());
    tmp.put(LegacySourceTransformation.class, new LegacySourceTransformationTranslator<>());
    tmp.put(UnionTransformation.class, new UnionTransformationTranslator<>());
    tmp.put(PartitionTransformation.class, new PartitionTransformationTranslator<>());
    tmp.put(SideOutputTransformation.class, new SideOutputTransformationTranslator<>());
    tmp.put(ReduceTransformation.class, new ReduceTransformationTranslator<>());
    tmp.put(
            TimestampsAndWatermarksTransformation.class,
            new TimestampsAndWatermarksTransformationTranslator<>());
    tmp.put(BroadcastStateTransformation.class, new BroadcastStateTransformationTranslator<>());
    tmp.put(
            KeyedBroadcastStateTransformation.class,
            new KeyedBroadcastStateTransformationTranslator<>());
    tmp.put(CacheTransformation.class, new CacheTransformationTranslator<>());
    translatorMap = Collections.unmodifiableMap(tmp);
}

调用具体的 translator 翻译的代码逻辑如下所示:

private Collection<Integer> translate(
            final TransformationTranslator<?, Transformation<?>> translator,
            final Transformation<?> transform) {
    checkNotNull(translator);
    checkNotNull(transform);

    final List<Collection<Integer>> allInputIds = getParentInputIds(transform.getInputs());

    // the recursive call might have already transformed this
    if (alreadyTransformed.containsKey(transform)) {
        return alreadyTransformed.get(transform);
    }

    final String slotSharingGroup =
            determineSlotSharingGroup(
                    transform.getSlotSharingGroup().isPresent()
                            ? transform.getSlotSharingGroup().get().getName()
                            : null,
                    allInputIds.stream()
                            .flatMap(Collection::stream)
                            .collect(Collectors.toList()));

    final TransformationTranslator.Context context =
            new ContextImpl(this, streamGraph, slotSharingGroup, configuration);

    return shouldExecuteInBatchMode
        ? translator.translateForBatch(transform, context)
            : translator.translateForStreaming(transform, context);
}

这里最核心的一个方法是 getParentInputIds(transform.getInputs()),当我们查看该方法源码时,神奇的事发生了

private List<Collection<Integer>> getParentInputIds(
            @Nullable final Collection<Transformation<?>> parentTransformations) {
    final List<Collection<Integer>> allInputIds = new ArrayList<>();
    if (parentTransformations == null) {
        return allInputIds;
    }

    for (Transformation<?> transformation : parentTransformations) {
        allInputIds.add(transform(transformation));
    }
    return allInputIds;
}

可以看到,通过 for 循环遍历 input,同时调用 transform 方法进行递归,终止条件为 input 为空。这就意味着从上游有数据过来开始就会一直递归构建 StreamGraph,直到数据流尽。

接下来我们以前面的 flatmap 算子为例,从前面的源码可以知道,flatmap 是由 OneInputTransformationTranslator 转换器转换来的。
以 flatmap 算子作为入参
其调用链为:transform -> translate -> getParentInputs -> 遍历 flatmap 的 inputs,然后调用 transform 方法,直到无 inputs
接下来开始 Collection Source。由于 Flink 新版本已经实现了流批一体,因此此处是可以分为两个版本处理。
我们此处仅针对流处理环境,看一下 translateForStreaming 方法

SimpleTransformationTranslator.java

public final Collection<Integer> translateForStreaming(
            final T transformation, final Context context) {
    checkNotNull(transformation);
    checkNotNull(context);

    final Collection<Integer> transformedIds =
            translateForStreamingInternal(transformation, context);
    configure(transformation, context);

    return transformedIds;
}

此处用到了 LegacySourceTransformationTranslator 的 translateForStreamingInternal 方法

protected Collection<Integer> translateForStreamingInternal(
            final LegacySourceTransformation<OUT> transformation, final Context context) {
    return translateInternal(transformation, context);
}

private Collection<Integer> translateInternal(
        final LegacySourceTransformation<OUT> transformation, final Context context) {
    checkNotNull(transformation);
    checkNotNull(context);

    final StreamGraph streamGraph = context.getStreamGraph();
    final String slotSharingGroup = context.getSlotSharingGroup();
    final int transformationId = transformation.getId();
    final ExecutionConfig executionConfig = streamGraph.getExecutionConfig();
    
    // 添加 source 算子
    streamGraph.addLegacySource(
            transformationId,
            slotSharingGroup,
            transformation.getCoLocationGroupKey(),
            transformation.getOperatorFactory(),
            null,
            transformation.getOutputType(),
            "Source: " + transformation.getName());

    if (transformation.getOperatorFactory() instanceof InputFormatOperatorFactory) {
        streamGraph.setInputFormat(
                transformationId,
                ((InputFormatOperatorFactory<OUT>) transformation.getOperatorFactory())
                        .getInputFormat());
    }
    
    // 设置并行度
    final int parallelism =
            transformation.getParallelism() != ExecutionConfig.PARALLELISM_DEFAULT
                    ? transformation.getParallelism()
                    : executionConfig.getParallelism();
    streamGraph.setParallelism(transformationId, parallelism);
    streamGraph.setMaxParallelism(transformationId, transformation.getMaxParallelism());

    return Collections.singleton(transformationId);
}

该方法执行过程:

  1. 为 StreamGraph 添加 source 算子
  2. 为 StreamGraph 设置并行度
public <IN, OUT> void addLegacySource(
        Integer vertexID,
        @Nullable String slotSharingGroup,
        @Nullable String coLocationGroup,
        StreamOperatorFactory<OUT> operatorFactory,
        TypeInformation<IN> inTypeInfo,
        TypeInformation<OUT> outTypeInfo,
        String operatorName) {
    addOperator(
            vertexID,
            slotSharingGroup,
            coLocationGroup,
            operatorFactory,
            inTypeInfo,
            outTypeInfo,
            operatorName);
    sources.add(vertexID);
}

public <IN, OUT> void addOperator(
        Integer vertexID,
        @Nullable String slotSharingGroup,
        @Nullable String coLocationGroup,
        StreamOperatorFactory<OUT> operatorFactory,
        TypeInformation<IN> inTypeInfo,
        TypeInformation<OUT> outTypeInfo,
        String operatorName) {
    Class<? extends TaskInvokable> invokableClass =
            operatorFactory.isStreamSource()
                    ? SourceStreamTask.class
                    : OneInputStreamTask.class;
    addOperator(
            vertexID,
            slotSharingGroup,
            coLocationGroup,
            operatorFactory,
            inTypeInfo,
            outTypeInfo,
            operatorName,
            invokableClass);
}

private <IN, OUT> void addOperator(
        Integer vertexID,
        @Nullable String slotSharingGroup,
        @Nullable String coLocationGroup,
        StreamOperatorFactory<OUT> operatorFactory,
        TypeInformation<IN> inTypeInfo,
        TypeInformation<OUT> outTypeInfo,
        String operatorName,
        Class<? extends TaskInvokable> invokableClass) {

    // 添加 StreamNode,生成 streamNode 并添加进 Map<Integer, StreamNode> 里
    addNode(
            vertexID,
            slotSharingGroup,
            coLocationGroup,
            invokableClass,
            operatorFactory,
            operatorName);
    // 设置该 transformation 输入和输出的序列化方法
    setSerializers(vertexID, createSerializer(inTypeInfo), null, createSerializer(outTypeInfo));

    // 设置 outputType
    if (operatorFactory.isOutputTypeConfigurable() && outTypeInfo != null) {
        // sets the output type which must be know at StreamGraph creation time
        operatorFactory.setOutputType(outTypeInfo, executionConfig);
    }

    // 设置 inputType
    if (operatorFactory.isInputTypeConfigurable()) {
        operatorFactory.setInputType(inTypeInfo, executionConfig);
    }

    if (LOG.isDebugEnabled()) {
        LOG.debug("Vertex: {}", vertexID);
    }
}

回到 SimpleTransformationTranslator.translate 方法,下一步是 configure(transformation, context)

private void configure(final T transformation, final Context context) {
    final StreamGraph streamGraph = context.getStreamGraph();
    final int transformationId = transformation.getId();

    StreamGraphUtils.configureBufferTimeout(
            streamGraph, transformationId, transformation, context.getDefaultBufferTimeout());
    
    // 设置算子 uid
    if (transformation.getUid() != null) {
        streamGraph.setTransformationUID(transformationId, transformation.getUid());
    }
    if (transformation.getUserProvidedNodeHash() != null) {
        streamGraph.setTransformationUserHash(
                transformationId, transformation.getUserProvidedNodeHash());
    }

    StreamGraphUtils.validateTransformationUid(streamGraph, transformation);
    
    // 设置资源和验证
    if (transformation.getMinResources() != null
            && transformation.getPreferredResources() != null) {
        streamGraph.setResources(
                transformationId,
                transformation.getMinResources(),
                transformation.getPreferredResources());
    }

    final StreamNode streamNode = streamGraph.getStreamNode(transformationId);
    if (streamNode != null) {
        validateUseCaseWeightsNotConflict(
                streamNode.getManagedMemoryOperatorScopeUseCaseWeights(),
                transformation.getManagedMemoryOperatorScopeUseCaseWeights());
        streamNode.setManagedMemoryUseCaseWeights(
                transformation.getManagedMemoryOperatorScopeUseCaseWeights(),
                transformation.getManagedMemorySlotScopeUseCases());
        if (null != transformation.getDescription()) {
            streamNode.setOperatorDescription(transformation.getDescription());
        }
    }
}

设置 uid,用户提供的节点 hash 函数,资源等
至此完成了 Collection Source 过程,不过我们还遗留了一个很重要的问题,前面在讲到 addOperator 方法时里面有一个 addNode 的添加 StreamNode 方法,该方法源码如下:

protected StreamNode addNode(
            Integer vertexID,
            @Nullable String slotSharingGroup,
            @Nullable String coLocationGroup,
            Class<? extends TaskInvokable> vertexClass,
            StreamOperatorFactory<?> operatorFactory,
            String operatorName) {

    if (streamNodes.containsKey(vertexID)) {
        throw new RuntimeException("Duplicate vertexID " + vertexID);
    }

    StreamNode vertex =
            new StreamNode(
                    vertexID,
                    slotSharingGroup,
                    coLocationGroup,
                    operatorFactory,
                    operatorName,
                    vertexClass);

    streamNodes.put(vertexID, vertex);

    return vertex;
}

我们知道一张图由节点和边组成,前面的 translate 方法已经找到了节点和边之间的对应关系,此处是真正的实体化节点以及与边的关系,并将最终结果写入 Map<Integer, StreamNode>。 StreamNode 就是一个普通的 entity 就不放源码了。
最后,我们还要搞清楚 StreamNode 在前面是如何与边相连继而组成一张图的
要搞懂这个问题,我们还要回到 flatmap 的实现函数
我们知道,在 DataStream 中,Flink 将将 StreamFlatMap 包装到 OneInputTransformation,在 OneInputTransformation 中完成了从 transformation 到 StreamGraph 的转换,接下来我们就来看 OneInputTransformation 的源码

AbstractOneInputTransformationTranslator.java

protected Collection<Integer> translateInternal(
        final Transformation<OUT> transformation,
        final StreamOperatorFactory<OUT> operatorFactory,
        final TypeInformation<IN> inputType,
        @Nullable final KeySelector<IN, ?> stateKeySelector,
        @Nullable final TypeInformation<?> stateKeyType,
        final Context context) {
    checkNotNull(transformation);
    checkNotNull(operatorFactory);
    checkNotNull(inputType);
    checkNotNull(context);

    final StreamGraph streamGraph = context.getStreamGraph();
    final String slotSharingGroup = context.getSlotSharingGroup();
    final int transformationId = transformation.getId();
    final ExecutionConfig executionConfig = streamGraph.getExecutionConfig();

    // 添加 StreamNode
    streamGraph.addOperator(
            transformationId,
            slotSharingGroup,
            transformation.getCoLocationGroupKey(),
            operatorFactory,
            inputType,
            transformation.getOutputType(),
            transformation.getName());

    if (stateKeySelector != null) {
        TypeSerializer<?> keySerializer = stateKeyType.createSerializer(executionConfig);
        streamGraph.setOneInputStateKey(transformationId, stateKeySelector, keySerializer);
    }

    int parallelism =
            transformation.getParallelism() != ExecutionConfig.PARALLELISM_DEFAULT
                    ? transformation.getParallelism()
                    : executionConfig.getParallelism();
    streamGraph.setParallelism(transformationId, parallelism);
    streamGraph.setMaxParallelism(transformationId, transformation.getMaxParallelism());

    final List<Transformation<?>> parentTransformations = transformation.getInputs();
    checkState(
            parentTransformations.size() == 1,
            "Expected exactly one input transformation but found "
                    + parentTransformations.size());

    // 添加 StreamEdge
    for (Integer inputId : context.getStreamNodeIds(parentTransformations.get(0))) {
        streamGraph.addEdge(inputId, transformationId, 0);
    }

    return Collections.singleton(transformationId);
}

和前面的 SimpleTransformationTranslator 一样,也是先通过 addOperator 添加了 StreamNode,后面又通过 streamGraph.addEdge(inputId, transformationId, 0) 将节点连接到 StreamGraph 上
这下我们明白了!DataStream 面对每一个 Operator 算子都会将其底层的 transformation 转换会一个 StreamNode,然后将其连接到 StreamGraph 上,当所有算子处理完毕这张 StreamGraph 也就构建完毕。

总结:

  • Flink 通过 StreamGraphExecutor 生成器的 generate 方法生成 StreamGraph
  • generate 方法通过遍历 Transformation 列表并递归 translate,找到算子之间的父子关系
  • 然后在 Collection Source 的过程中向 DAG 中加入数据源、并行度以及 StreamNode 实体,实现 DAG 的构建

实例分析

如下程序,是一个从 Source 中按行切分成单词并过滤输出的简单流程序,其中包含了逻辑转换:随机分区 shuffle。接下来分析该程序是如何生成 StreamGraph 的。

DataStream<String> text = env.socketTextStream(hostName, port);
text.flatMap(new LineSplitter()).shuffle().filter(new HelloFilter()).print();

首先会在 env 中生成一棵 transformation 树,用 List<Transformation<?>> 保存。结构图如下:

1662429146839

其中符号 * 为 input 指针,指向上游的 transformation,从而形成了一棵 transformation 树。然后,通过调用 StreamGraphGenerator.generate(env, transformation) 来生成 StreamGraph。自底向上递归调用每一个 transformation,也就是说处理顺序是 Source -> FlatMap -> Shuffle -> Filter -> Sink

1662429385153

处理流程:

  1. 首先处理的 Source,生成了 Source 的 StreamNode
  2. 然后处理的 FlatMap,生成了 FlatMap 的 StreamNode,并生成 StreamEdge 连接上游的 Source 和 FlatMap。由于上下游的并发度不一样,所以此处是 Rebalance 分区
  3. 然后处理的 Shuffle,由于是逻辑转换,并不会生成实际的节点。将 partitioner 信息暂存在 virtualPartitionNodes 中
  4. 在处理 Filter 时,生成了 Filter 的 StreamNode。发现上游是 shuffle,找到 shuffle 的上游 FlatMap,创建 StreamEdge 与 Filter 相连。并把 ShufflePartitioner 的信息写到 StreamEdge 中
  5. 最后处理 Sink,创建 Sink 的 StreamNode,并生成 StreamEdge 与上游 Filter 相连。由于上下游并发度相同,所以此处选择 Forward 分区

下面通过 UI 可视化观察得到的 StreamGraph
1662429783468

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值