Keras源码解析 _map_graph_network

_map_graph_network

验证网络拓扑并收集其层、节点、层depth、节点depth信息。

  • 该函数包含子函数build_map来真正构建层图的map,通过递归调用逐步构建layer_indicesnodes_in_decreasing_depthnetwork_nodes
  • network_nodes是“图可见的节点”,是指只跟当前network相关的节点
  • finished_nodes保存遍历过的节点,nodes_in_progress保存正在遍历中的节点,防止出现环
  • nodes_in_decreasing_depth自输入到输出保存节点的顺序,列表越往后对应的depth越小
  • depth生成时,从后面的节点开始遍历,即对应上面的decreasing,当前node与outbound_layer的depth相同

具体的部分可以看注释,depth(深度)用于对Node和Layer进行描述。按照depth的顺序,获得经过排序的self.layers_by_depth和self.nodes_by_depth,防止有向图中出现环。depth主要应用在network网络的run_internal_graph,该函数构建模型图,并根据输入返回输出。

def _map_graph_network(inputs, outputs):
    
    # 层图包含的节点的集合
    # 层中的节点信息并不一定都是当前图的,这一点可以从上述Node源码分析可以得出
    network_nodes = set()  # ids of all nodes relevant to the Network
    nodes_depths = {}  # dict {node: depth value}
    layers_depths = {}  # dict {layer: depth value}
    layer_indices = {}  # dict {layer: index in traversal}
    nodes_in_decreasing_depth = []
	# 生成层图的映射
    def build_map(tensor,
                  finished_nodes,
                  nodes_in_progress,
                  layer,
                  node_index,
                  tensor_index):
        # 获取层对应的节点,前面提到一个层可能有多个节点
        node = layer._inbound_nodes[node_index]

        # 说明是环 报错
        if node in nodes_in_progress:
            raise ValueError('The tensor ' + str(tensor) + ' at layer "' +
                             layer.name + '" is part of a cycle.')

        # Don't repeat work for shared subgraphs
        if node in finished_nodes:
            return
		# 拼接层和node_index作为唯一标识
        node_key = _make_node_key(layer.name, node_index)
        network_nodes.add(node_key)

        # 存储遍历的layer顺序
        if layer not in layer_indices:
            layer_indices[layer] = len(layer_indices)
		# 表明正在遍历,防止后面遍历为一个环,无限递归
        nodes_in_progress.add(node)
		# 可能多个输入层
        for i in range(len(node.inbound_layers)):
            x = node.input_tensors[i]
            layer = node.inbound_layers[i]
            node_index = node.node_indices[i]
            tensor_index = node.tensor_indices[i]
            build_map(x, finished_nodes, nodes_in_progress, layer,
                      node_index, tensor_index)

        finished_nodes.add(node)
        nodes_in_progress.remove(node)
        nodes_in_decreasing_depth.append(node)

    finished_nodes = set()
    nodes_in_progress = set()
    for x in outputs:
        layer, node_index, tensor_index = x._keras_history
        build_map(x, finished_nodes, nodes_in_progress,
                  layer=layer,
                  node_index=node_index,
                  tensor_index=tensor_index)
	# 注意nodes_in_decreasing_depth添加node的顺序是从input开始的,当然需要考虑到不同的输出和某一层不同的输入
	# 这里reverse,反转了列表,所以变成从后面的节点到前面的节点
    for node in reversed(nodes_in_decreasing_depth):
        # 默认如果没有outbound nodes则depth为0
        depth = nodes_depths.setdefault(node, 0)

        # Update the depth of the corresponding layer
        previous_depth = layers_depths.get(node.outbound_layer, 0)
        # 如果outbound_layer有一个更高的depth,则使用这个更高的depth
        # 这点用于有不同输入depth等级的共享层
        depth = max(depth, previous_depth)
        layers_depths[node.outbound_layer] = depth
        nodes_depths[node] = depth

        # 更新inbound的节点depth,node的depth是它连接的所有层的最大depth
        for i in range(len(node.inbound_layers)):
            inbound_layer = node.inbound_layers[i]
            node_index = node.node_indices[i]
            inbound_node = inbound_layer._inbound_nodes[node_index]
            previous_depth = nodes_depths.get(inbound_node, 0)
            nodes_depths[inbound_node] = max(depth + 1, previous_depth)

    # 构建一个字典 {depth: list of nodes with this depth}
    nodes_by_depth = {}
    for node, depth in nodes_depths.items():
        if depth not in nodes_by_depth:
            nodes_by_depth[depth] = []
        nodes_by_depth[depth].append(node)

    #构建一个字典 {depth: list of layers with this depth}
    layers_by_depth = {}
    for layer, depth in layers_depths.items():
        if depth not in layers_by_depth:
            layers_by_depth[depth] = []
        layers_by_depth[depth].append(layer)

    # Get sorted list of layer depths.
    depth_keys = list(layers_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Set self.layers and self._layers_by_depth.
    layers = []
    for depth in depth_keys:
        layers_for_depth = layers_by_depth[depth]
        # Network.layers needs to have a deterministic order:
        # here we order them by traversal order.
        layers_for_depth.sort(key=lambda x: layer_indices[x])
        layers.extend(layers_for_depth)

    # Get sorted list of node depths.
    depth_keys = list(nodes_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Check that all tensors required are computable.
    # computable_tensors: all tensors in the graph
    # that can be computed from the inputs provided.
    computable_tensors = []
    for x in inputs:
        computable_tensors.append(x)

    layers_with_complete_input = []  # To provide a better error msg.
    for depth in depth_keys:
        for node in nodes_by_depth[depth]:
            layer = node.outbound_layer
            if layer:
                for x in node.input_tensors:
                    if id(x) not in [id(ct) for ct in computable_tensors]:
                        raise ValueError('Graph disconnected: '
                                         'cannot obtain value for tensor ' +
                                         str(x) + ' at layer "' +
                                         layer.name + '". '
                                         'The following previous layers '
                                         'were accessed without issue: ' +
                                         str(layers_with_complete_input))
                for x in node.output_tensors:
                    computable_tensors.append(x)
                layers_with_complete_input.append(layer.name)

    # Ensure name unicity, which will be crucial for serialization
    # (since serialized nodes refer to layers by their name).
    all_names = [layer.name for layer in layers]
    for name in all_names:
        if all_names.count(name) != 1:
            raise ValueError('The name "' + name + '" is used ' +
                             str(all_names.count(name)) +
                             ' times in the model. '
                             'All layer names should be unique.')
    return network_nodes, nodes_by_depth, layers, layers_by_depth

参考
Keras部分源码赏析

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值