上一篇博客我们分析了Layer
类的源码,本文将会分析下Input
的源码。
所有的Function api 都需要定义一个Input
,Input
是InputLayer
的实例化对象,InputLayer
继承于Layer
层,并重写了init
方法和get_config
方法,InputLayer
多了一个额外的参数sparse
,该参数的意思是输入值placeholder是否是稀疏的,我们先详细看下InputLayer
的init
方法
def __init__(self, input_shape=None, batch_size=None,
batch_input_shape=None,
dtype=None, input_tensor=None, sparse=False, name=None):
if not name:
prefix = 'input'
name = prefix + '_' + str(K.get_uid(prefix))
super(InputLayer, self).__init__(dtype=dtype, name=name)
self.trainable = False
self.built = True
self.sparse = sparse
self.supports_masking = True
if input_shape and batch_input_shape:
raise ValueError('Only provide the input_shape OR '
'batch_input_shape argument to '
'InputLayer, not both at the same time.')
if input_tensor is not None and batch_input_shape is None:
# If input_tensor is set, and batch_input_shape is not set:
# Attempt automatic input shape inference.
try:
batch_input_shape = K.int_shape(input_tensor)
except TypeError:
if not input_shape and not batch_input_shape:
raise ValueError('InputLayer was provided '
'an input_tensor argument, '
'but its input shape cannot be '
'automatically inferred. '
'You should pass an input_shape or '
'batch_input_shape argument.')
if not batch_input_shape:
if not input_shape:
raise ValueError('An Input layer should be passed either '
'a `batch_input_shape` or an `input_shape`.')
else:
batch_input_shape = (batch_size,) + tuple(input_shape)
else:
batch_input_shape = tuple(batch_input_shape)
if not dtype:
if input_tensor is None:
dtype = K.floatx()
else:
dtype = K.dtype(input_tensor)
self.batch_input_shape = batch_input_shape
self.dtype = dtype
赋值几个参数,确定batch_size,input_shape,dtype
if input_tensor is None:
self.is_placeholder = True
input_tensor = K.placeholder(shape=batch_input_shape,
dtype=dtype,
sparse=self.sparse,
name=self.name)
else:
self.is_placeholder = False
input_tensor._keras_shape = batch_input_shape
# Create an input node to add to self.outbound_node
# and set output_tensors' _keras_history.
input_tensor._uses_learning_phase = False
input_tensor._keras_history = (self, 0, 0)
Node(self,
inbound_layers=[],
node_indices=[],
tensor_indices=[],
input_tensors=[input_tensor],
output_tensors=[input_tensor],
input_masks=[None],
output_masks=[None],
input_shapes=[batch_input_shape],
output_shapes=[batch_input_shape])
然后判断下input_tensor是否是空的,如果是空的就定义一个placeholder,然后实例化Node,注意Node
对象的第一个参数是self,也就是说outbound_layer
是自身,output_tensors
传递的是input_tensor
,同理output_shapes
传递的参数也是batch_input_shape
,也就是说输入层是第一层,Node没有需要连接的前一层。
接下来我们来看Input
方法
def Input(shape=None, batch_shape=None,
name=None, dtype=None, sparse=False,
tensor=None):
if not batch_shape and tensor is None:
assert shape is not None, ('Please provide to Input either a `shape`'
' or a `batch_shape` argument. Note that '
'`shape` does not include the batch '
'dimension.')
if shape is not None and not batch_shape:
batch_shape = (None,) + tuple(shape)
if not dtype:
dtype = K.floatx()
input_layer = InputLayer(batch_input_shape=batch_shape,
name=name, dtype=dtype,
sparse=sparse,
input_tensor=tensor)
# Return tensor including _keras_shape and _keras_history.
# Note that in this case train_output and test_output are the same pointer.
outputs = input_layer._inbound_nodes[0].output_tensors
return unpack_singleton(outputs)
代码不多,首先判断下是否有shape参数,没有就抛出异常,如果没有定义batch参数,就会自动拼接一下None表示可变维度,dtype参数如果没有则定义为float类型, 然后实例化InputLayer
对象,从InputLayer
对象的inbound_nodes
中找打第一个node,然后从该node中取出output_tensors
并返回
最后再以一个例子来回顾下整个过程
inputs = Input(shape=(100))
只有shape参数,则batch_shape = (None,100),dtype设置为float,实例化一个InputLayer
对象,因为input_tensor参数是None,所以创建一个placeholder,Tensor("input_1:0", shape=(?, 100), dtype=float32)
。然后实例化Node
对象,把placeholder
作为input_tensor
与output_tensors
,最后从InputLayer
对象的_inbound_nodes
的第一个Node中返回output_tensors