keras源码分析之Input

上一篇博客我们分析了Layer类的源码,本文将会分析下Input的源码。

所有的Function api 都需要定义一个InputInputInputLayer的实例化对象,InputLayer继承于Layer层,并重写了init方法和get_config方法,InputLayer多了一个额外的参数sparse,该参数的意思是输入值placeholder是否是稀疏的,我们先详细看下InputLayerinit方法

def __init__(self, input_shape=None, batch_size=None,
          batch_input_shape=None,
          dtype=None, input_tensor=None, sparse=False, name=None):
	 if not name:
	     prefix = 'input'
	     name = prefix + '_' + str(K.get_uid(prefix))
	 super(InputLayer, self).__init__(dtype=dtype, name=name)
	
	 self.trainable = False
	 self.built = True
	 self.sparse = sparse
	 self.supports_masking = True
	
	 if input_shape and batch_input_shape:
	     raise ValueError('Only provide the input_shape OR '
	                      'batch_input_shape argument to '
	                      'InputLayer, not both at the same time.')
	 if input_tensor is not None and batch_input_shape is None:
	     # If input_tensor is set, and batch_input_shape is not set:
	     # Attempt automatic input shape inference.
	     try:
	         batch_input_shape = K.int_shape(input_tensor)
	     except TypeError:
	         if not input_shape and not batch_input_shape:
	             raise ValueError('InputLayer was provided '
	                              'an input_tensor argument, '
	                              'but its input shape cannot be '
	                              'automatically inferred. '
	                              'You should pass an input_shape or '
	                              'batch_input_shape argument.')
	 if not batch_input_shape:
	     if not input_shape:
	         raise ValueError('An Input layer should be passed either '
	                          'a `batch_input_shape` or an `input_shape`.')
	     else:
	         batch_input_shape = (batch_size,) + tuple(input_shape)
	 else:
	     batch_input_shape = tuple(batch_input_shape)
	
	 if not dtype:
	     if input_tensor is None:
	         dtype = K.floatx()
	     else:
	         dtype = K.dtype(input_tensor)
	
	 self.batch_input_shape = batch_input_shape
	 self.dtype = dtype

赋值几个参数,确定batch_size,input_shape,dtype

	if input_tensor is None:
		self.is_placeholder = True
		input_tensor = K.placeholder(shape=batch_input_shape,
		                            dtype=dtype,
		                            sparse=self.sparse,
		                            name=self.name)
	else:
		self.is_placeholder = False
		input_tensor._keras_shape = batch_input_shape
	# Create an input node to add to self.outbound_node
	# and set output_tensors' _keras_history.
	input_tensor._uses_learning_phase = False
	input_tensor._keras_history = (self, 0, 0)
	Node(self,
		inbound_layers=[],
		node_indices=[],
		tensor_indices=[],
		input_tensors=[input_tensor],
		output_tensors=[input_tensor],
		input_masks=[None],
		output_masks=[None],
		input_shapes=[batch_input_shape],
		output_shapes=[batch_input_shape])

然后判断下input_tensor是否是空的,如果是空的就定义一个placeholder,然后实例化Node,注意Node对象的第一个参数是self,也就是说outbound_layer是自身,output_tensors传递的是input_tensor,同理output_shapes传递的参数也是batch_input_shape,也就是说输入层是第一层,Node没有需要连接的前一层。

接下来我们来看Input方法

def Input(shape=None, batch_shape=None,
          name=None, dtype=None, sparse=False,
          tensor=None):

    if not batch_shape and tensor is None:
        assert shape is not None, ('Please provide to Input either a `shape`'
                                   ' or a `batch_shape` argument. Note that '
                                   '`shape` does not include the batch '
                                   'dimension.')
    if shape is not None and not batch_shape:
        batch_shape = (None,) + tuple(shape)
    if not dtype:
        dtype = K.floatx()
    input_layer = InputLayer(batch_input_shape=batch_shape,
                             name=name, dtype=dtype,
                             sparse=sparse,
                             input_tensor=tensor)
    # Return tensor including _keras_shape and _keras_history.
    # Note that in this case train_output and test_output are the same pointer.
    outputs = input_layer._inbound_nodes[0].output_tensors
    return unpack_singleton(outputs)

代码不多,首先判断下是否有shape参数,没有就抛出异常,如果没有定义batch参数,就会自动拼接一下None表示可变维度,dtype参数如果没有则定义为float类型, 然后实例化InputLayer对象,从InputLayer对象的inbound_nodes中找打第一个node,然后从该node中取出output_tensors并返回

最后再以一个例子来回顾下整个过程

inputs = Input(shape=(100))

只有shape参数,则batch_shape = (None,100),dtype设置为float,实例化一个InputLayer对象,因为input_tensor参数是None,所以创建一个placeholder,Tensor("input_1:0", shape=(?, 100), dtype=float32)。然后实例化Node对象,把placeholder作为input_tensoroutput_tensors,最后从InputLayer对象的_inbound_nodes的第一个Node中返回output_tensors

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值