前言
keras最优雅的地方还是在于其fit函数,自动验证,灵活的callback,batch_size、epochs的简单设置等,相比于tensorflow需要自己编写验证代码,自己编写循环模块来实现多个epoch的训练,可谓是简单了太多。那么fit函数如此强大的功能到底是怎么实现的呢,本文将会带领大家一起探讨其中的原理。
代码分析
首先,fit函数会对batch_size进行一个验证,这里调用了另外一个函数
batch_size = self._validate_or_infer_batch_size(
batch_size, steps_per_epoch, x)
这个函数会去取第一层的layer的shape,如果第一层layer的shape不是空的,就把shape[0]取出来,在之前的源码分析中我们有说过,在keras中定义Input层的shape的时候默认是不需要定义第一个维度的,实际上第一个维度表示的是batch_size,所以此处就是在取默认的batch_size,除非设置了batch_shape
参数,不然该参数通常都是None,也就是说下面代码中的static_batch_size
通常是None,如果不是None,就把batch_size的值设置为batch_shape
的值,是None就返回用户调用fit函数设置的batch_size,如果这个参数也没有设置就默认返回32.
def _validate_or_infer_batch_size(self, batch_size, steps, x):
if batch_size is not None and is_generator_or_sequence(x):
raise ValueError('The `batch_size` argument must not be specified when'
' using a generator or Sequence as an input.')
layers = super(Model, self).layers # Avoids the override in Sequential.
if layers:
first_layer = layers[0]
static_batch_size = get_static_batch_size(first_layer)
if static_batch_size is not None:
# Check `batch_size` argument is consistent with InputLayer.
if batch_size is not None and batch_size != static_batch_size:
raise ValueError('The `batch_size` argument value {} is '
'incompatible with the specified batch '
'size of your Input Layer: {}'
.format(batch_size, static_batch_size))
# Set inferred batch size from the InputLayer.
if steps is None:
batch_size = static_batch_size
if batch_size is None and steps is None:
# Backwards compatibility
batch_size = 32
return batch_size
接下来分了两个部分,首选判断一下输入的x是不是生成器类型的,如果是的话则转而调用fit_generator
,fit_generator
函数和fit
函数差不多,主要是多了一步从生成器中取数据的过程,下面的代码分析主要以fit
函数为主。
首先,对输入的数据进行标准化的处理,获得训练阶段用到的标准化的输入数据,主要的内容包括一些输入值的类型的判断,验证是否有optimizer ,是否compile过,等操作。
x, y, sample_weights = self._standardize_user_data(
x, y,
sample_weight=sample_weight,
class_weight=class_weight,
batch_size=batch_size)
接下来,会判断一下是否需要进行数据验证,这里分了三种情况。
- 如果
validation_data
参数有值,说明需要做数据验证,同样调用_standardize_user_data
方法对验证输入做过标准化的处理,然后判断是否有learning_phase,就是说是否要在训练和测试阶段做区分,类似dropout,如果有就给验证数据加一个[0]
,为什么要加这个我们后面说。
# Prepare validation data.
do_validation = False
if validation_data:
do_validation = True
if len(validation_data) == 2:
val_x, val_y = validation_data
val_sample_weight = None
elif len(validation_data) == 3:
val_x, val_y, val_sample_weight = validation_data
else:
raise ValueError('When passing validation_data, '
'it must contain 2 (x_val, y_val) '
'or 3 (x_val, y_val, val_sample_weights) '
'items, however it contains %d items' %
len(validation_data))
val_x, val_y, val_sample_weights = self._standardize_user_data(
val_x, val_y,
sample_weight=val_sample_weight,
batch_size=batch_size)
if self._uses_dynamic_learning_phase():
val_inputs = val_x + val_y + val_sample_weights + [0.]
else:
val_inputs = val_x + val_y + val_sample_weights
- 第二种情况,是用户没有输入验证集,而是给了一个验证集的百分比,即
validation_split
,该参数的意思是在传入的x与y中取validation_split
大小的数据做验证集,该参数是一个[0,1]区间内的值,然后按照该参数对x于y进行切割,最后依旧是判断一下learning_phase。
elif validation_split and 0. < validation_split < 1.:
if any(K.is_tensor(t) for t in x):
raise ValueError(
'If your data is in the form of symbolic tensors, '
'you cannot use `validation_split`.')
do_validation = True
if hasattr(x[0], 'shape'):
split_at = int(int(x[0].shape[0]) * (1. - validation_split))
else:
split_at = int(len(x[0]) * (1. - validation_split))
x, val_x = (slice_arrays(x, 0, split_at),
slice_arrays(x, split_at))
y, val_y = (slice_arrays(y, 0, split_at),
slice_arrays(y, split_at))
sample_weights, val_sample_weights = (
slice_arrays(sample_weights, 0, split_at),
slice_arrays(sample_weights, split_at))
if self._uses_dynamic_learning_phase():
val_inputs = val_x + val_y + val_sample_weights + [0.]
else:
val_inputs = val_x + val_y + val_sample_weights
- 第三种情况就简单很多了,如果只设置了
validation_steps
参数则验证输入只有一个[0]
elif validation_steps:
do_validation = True
if self._uses_dynamic_learning_phase():
val_inputs = [0.]
接下来是训练数据的准备,依旧是learning_phase
的判断,然后给fit_function
赋值。
# Prepare input arrays and training function.
if self._uses_dynamic_learning_phase():
fit_inputs = x + y + sample_weights + [1.]
else:
fit_inputs = x + y + sample_weights
self._make_train_function()
fit_function = self.train_function
这里我们详细看下_make_train_function
这个函数,这个函数很重要。首先判断self.train_function
是否是None,如果是第一次,肯定是None值,然后准备好inputs的值,接下来定义了一个tensorflow的作用域,并调用self.optimizer.get_updates()
函数,该函数需要传入权重和损失值,最终得到的是梯度信息。最后调用K.function
函数,session的定义,feed_dict的参数都包含在这个函数中。
def _make_train_function(self):
if not hasattr(self, 'train_function'):
raise RuntimeError('You must compile your model before using it.')
self._check_trainable_weights_consistency()
if self.train_function is None:
inputs = (self._feed_inputs +
self._feed_targets +
self._feed_sample_weights)
if self._uses_dynamic_learning_phase():
inputs += [K.learning_phase()]
with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
training_updates = self.optimizer.get_updates(
params=self._collected_trainable_weights,
loss=self.total_loss)
updates = (self.updates +
training_updates +
self.metrics_updates)
# Gets loss and metrics. Updates weights at each call.
self.train_function = K.function(
inputs,
[self.total_loss] + self.metrics_tensors,
updates=updates,
name='train_function',
**self._function_kwargs)
我们再详细看下K.function
函数,该函数内部生成了一个类Function
def function(inputs, outputs, updates=None, **kwargs):
if kwargs:
for key in kwargs:
session_has_key = has_arg(tf.Session.run, key, True)
function_has_key = has_arg(Function.__init__, key, True)
if not (session_has_key or function_has_key):
raise ValueError('Invalid argument "%s" passed to K.function '
'with TensorFlow backend' % key)
return Function(inputs, outputs, updates=updates, **kwargs)
这个类的构造方法会从session_kwargs中获取到feed_dict参数,run_option参数(GPU的配置信息)等信息。
self.feed_dict = session_kwargs.pop('feed_dict', {})
# additionaloperations
self.fetches = session_kwargs.pop('fetches', [])
if not isinstance(self.fetches, list):
self.fetches = [self.fetches]
# The main use case of `fetches` being passed to a model is the ability
# to run custom updates
# (since the outputs of fetches are never returned).
# This requires us to wrap fetches in `identity` ops.
self.fetches = [tf.identity(x) for x in self.fetches]
# self.session_kwargs is used for _legacy_call
self.session_kwargs = session_kwargs.copy()
self.run_options = session_kwargs.pop('options', None)
self.run_metadata = session_kwargs.pop('run_metadata', None)
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not '
'supported at this '
'time: %s', session_kwargs.keys())
其次Function
类还有一个call方法,没错keras框架中用的最多的call方法,该方法最终调用了_call()
方法,而_call()
方法中,定义了session,并获取到了之前的feed_dict的值,并在_make_callable
方法中执行了run
操作。
def _call(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
session = get_session()
feed_arrays = []
array_vals = []
feed_symbols = []
symbol_vals = []
if self.feed_dict:
for key in sorted(self.feed_dict.keys()):
array_vals.append(
np.asarray(self.feed_dict[key],
dtype=tf.as_dtype(key.dtype).as_numpy_dtype))
# Refresh callable if anything has changed.
if (self._callable_fn is None or
feed_arrays != self._feed_arrays or
symbol_vals != self._symbol_vals or
feed_symbols != self._feed_symbols or
session != self._session):
self._make_callable(feed_arrays,
feed_symbols,
symbol_vals,
session)
if self.run_metadata:
fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
else:
fetched = self._callable_fn(*array_vals)
return fetched[:len(self.outputs)]
那么这个call
方法到底是在哪里调用的呢,我们回到fit函数的最外层。fit_function赋值完成后当然到了val_function的赋值。
# Prepare display labels.
out_labels = self.metrics_names
if do_validation:
self._make_test_function()
val_function = self.test_function
callback_metrics = copy.copy(out_labels) + [
'val_' + n for n in out_labels]
else:
callback_metrics = copy.copy(out_labels)
val_function = None
val_inputs = []
最后返回fit_loop()
的值,而call()
也是在这里调用的。
return training_arrays.fit_loop(self, fit_function, fit_inputs,
out_labels=out_labels,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
val_function=val_function,
val_inputs=val_inputs,
shuffle=shuffle,
callback_metrics=callback_metrics,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq)
接下来我们来看fit_loop
方法,首先对验证集做一些验证,确保之后能使用
do_validation = False
if val_function and val_inputs:
do_validation = True
if (verbose and fit_inputs and
hasattr(fit_inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
print('Train on %d samples, validate on %d samples' %
(fit_inputs[0].shape[0], val_inputs[0].shape[0]))
if validation_steps:
do_validation = True
if steps_per_epoch is None:
raise ValueError('Can only use `validation_steps` '
'when doing step-wise '
'training, i.e. `steps_per_epoch` '
'must be set.')
elif do_validation:
if steps_per_epoch:
raise ValueError('Must specify `validation_steps` '
'to perform validation '
'when doing step-wise training.')
然后计算训练样本的数量,并转换为index
num_train_samples = check_num_samples(fit_inputs,
batch_size=batch_size,
steps=steps_per_epoch,
steps_name='steps_per_epoch')
if num_train_samples is not None:
index_array = np.arange(num_train_samples)
接下来是callback的一些处理,包括初始化history对象、log对象,然后把这些callback和用户自定的callback组合成一个callback list,这里有一行代码很重要callbacks._call_begin_hook('train')
,这里会统一代用所有callback的on_train_begin
方法,也就是说执行callback的回调函数
model.history = cbks.History()
_callbacks = [cbks.BaseLogger(
stateful_metrics=model.stateful_metric_names)]
if verbose:
if steps_per_epoch is not None:
count_mode = 'steps'
else:
count_mode = 'samples'
_callbacks.append(
cbks.ProgbarLogger(
count_mode,
stateful_metrics=model.stateful_metric_names))
_callbacks += (callbacks or []) + [model.history]
callbacks = cbks.CallbackList(_callbacks)
out_labels = out_labels or []
# it's possible to callback a different model than itself
# (used by Sequential models)
callback_model = model._get_callback_model()
callbacks.set_model(callback_model)
callbacks.set_params({
'batch_size': batch_size,
'epochs': epochs,
'steps': steps_per_epoch,
'samples': num_train_samples,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
})
callbacks._call_begin_hook('train')
callbacks.model.stop_training = False
for cbk in callbacks:
cbk.validation_data = val_inputs
以上代码都是一些训练前的准备,接下来就正式开始train了,这里我把代码分细一点说明。
首先是循环epochs,并执行所有callback的on_epoch_begin
方法,
for epoch in range(initial_epoch, epochs):
# Reset stateful metrics
for m in model.stateful_metric_functions:
m.reset_states()
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
接下来是一个判断,分是否传递了steps_per_epoch
参数,如果这个参数不是空的,那么直接遍历这个steps,并回调所有的callback的on_train_batch begin
方法,接下来调用fit_function
,也就是执行Function
的call()
方法,这样咋们之前定义的session就正式开始run起来了,然后把loss值保存在log参数中,log作为参数执行on_train_batch_end
方法,执行完成后可以得到训练一个batch的时间,如果这个时间过长就会打印警告日志,告诉用户callback中做了太耗时的操作。最后调用test_loop
方法对验证集进行验证,这里就不展开说明了。
if steps_per_epoch is not None:
for step_index in range(steps_per_epoch):
batch_logs = {'batch': step_index, 'size': 1}
callbacks._call_batch_hook('train', 'begin', step_index, batch_logs)
outs = fit_function(fit_inputs)
outs = to_list(outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks._call_batch_hook('train', 'end', step_index, batch_logs)
if callback_model.stop_training:
break
if do_validation and should_run_validation(validation_freq, epoch):
val_outs = test_loop(model, val_function, val_inputs,
steps=validation_steps,
callbacks=callbacks,
verbose=0)
val_outs = to_list(val_outs)
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
另一个分支即没有定义step,总体代码上上面的分支类似,主要区别在于取数据的部分,因为的分支因为定义了step,那直接根据step就能取,而下面的分支会先对数据做一个shuffle,然后调用make_batches
方法,根据数据量与batch_size计算出每个batch的起始step与终止step,最后根据这两个值取出训练数据,并执行on_train_batch begin
->fit_function
->on_train_batch end
。当训练完最后一个batch后,会对验证集进行一次验证。
else:
if shuffle == 'batch':
index_array = batch_shuffle(index_array, batch_size)
elif shuffle:
np.random.shuffle(index_array)
batches = make_batches(num_train_samples, batch_size)
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
try:
if isinstance(fit_inputs[-1], float):
# Do not slice the training phase flag.
ins_batch = slice_arrays(
fit_inputs[:-1], batch_ids) + [fit_inputs[-1]]
else:
ins_batch = slice_arrays(fit_inputs, batch_ids)
except TypeError:
raise TypeError('TypeError while preparing batch. '
'If using HDF5 input data, '
'pass shuffle="batch".')
batch_logs = {'batch': batch_index, 'size': len(batch_ids)}
callbacks._call_batch_hook('train', 'begin', batch_index, batch_logs)
for i in indices_for_conversion_to_dense:
ins_batch[i] = ins_batch[i].toarray()
outs = fit_function(ins_batch)
outs = to_list(outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
callbacks._call_batch_hook('train', 'end', batch_index, batch_logs)
if callbacks.model.stop_training:
break
if batch_index == len(batches) - 1: # Last batch.
if do_validation and should_run_validation(validation_freq, epoch):
val_outs = test_loop(model, val_function, val_inputs,
batch_size=batch_size,
callbacks=callbacks,
verbose=0)
val_outs = to_list(val_outs)
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
最后执行on_epoch_end
、on_train_end
方法,并把history返回。
callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
break
callbacks._call_end_hook('train')
return model.history
总结
要看懂keras的fit函数我认为有两个关键点,首先keras把所有的包括日志,操作等都以callback对象的形式回调执行,保证了代码的优雅性。其次,fit_function以call()方法的形式来run session,从而巧妙的隐藏了session,让人完全感觉不到其存在,总的训练流程也不存在冗余的代码。看似简单的fit函数其实包含了如此之多的智慧,在我看来编写keras的大神们真的已经达到了登峰造极的地步,让人不得不佩服的五体投地,在此也再次感谢这些大神们能为我们提供出如此简单高效的API,respect。