keras源码分析之fit

前言

keras最优雅的地方还是在于其fit函数,自动验证,灵活的callback,batch_size、epochs的简单设置等,相比于tensorflow需要自己编写验证代码,自己编写循环模块来实现多个epoch的训练,可谓是简单了太多。那么fit函数如此强大的功能到底是怎么实现的呢,本文将会带领大家一起探讨其中的原理。

代码分析

首先,fit函数会对batch_size进行一个验证,这里调用了另外一个函数

batch_size = self._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x)

这个函数会去取第一层的layer的shape,如果第一层layer的shape不是空的,就把shape[0]取出来,在之前的源码分析中我们有说过,在keras中定义Input层的shape的时候默认是不需要定义第一个维度的,实际上第一个维度表示的是batch_size,所以此处就是在取默认的batch_size,除非设置了batch_shape参数,不然该参数通常都是None,也就是说下面代码中的static_batch_size通常是None,如果不是None,就把batch_size的值设置为batch_shape的值,是None就返回用户调用fit函数设置的batch_size,如果这个参数也没有设置就默认返回32.

def _validate_or_infer_batch_size(self, batch_size, steps, x):
 
    if batch_size is not None and is_generator_or_sequence(x):
        raise ValueError('The `batch_size` argument must not be specified when'
                         ' using a generator or Sequence as an input.')

    layers = super(Model, self).layers  # Avoids the override in Sequential.
    if layers:
        first_layer = layers[0]
        static_batch_size = get_static_batch_size(first_layer)
        if static_batch_size is not None:

            # Check `batch_size` argument is consistent with InputLayer.
            if batch_size is not None and batch_size != static_batch_size:
                raise ValueError('The `batch_size` argument value {} is '
                                 'incompatible with the specified batch '
                                 'size of your Input Layer: {}'
                                 .format(batch_size, static_batch_size))

            # Set inferred batch size from the InputLayer.
            if steps is None:
                batch_size = static_batch_size

    if batch_size is None and steps is None:
        # Backwards compatibility
        batch_size = 32
    return batch_size

接下来分了两个部分,首选判断一下输入的x是不是生成器类型的,如果是的话则转而调用fit_generatorfit_generator函数和fit函数差不多,主要是多了一步从生成器中取数据的过程,下面的代码分析主要以fit函数为主。

首先,对输入的数据进行标准化的处理,获得训练阶段用到的标准化的输入数据,主要的内容包括一些输入值的类型的判断,验证是否有optimizer ,是否compile过,等操作。

x, y, sample_weights = self._standardize_user_data(
            x, y,
            sample_weight=sample_weight,
            class_weight=class_weight,
            batch_size=batch_size)

接下来,会判断一下是否需要进行数据验证,这里分了三种情况。

  • 如果validation_data参数有值,说明需要做数据验证,同样调用_standardize_user_data方法对验证输入做过标准化的处理,然后判断是否有learning_phase,就是说是否要在训练和测试阶段做区分,类似dropout,如果有就给验证数据加一个[0],为什么要加这个我们后面说。
# Prepare validation data.
do_validation = False
if validation_data:
    do_validation = True
    if len(validation_data) == 2:
        val_x, val_y = validation_data
        val_sample_weight = None
    elif len(validation_data) == 3:
        val_x, val_y, val_sample_weight = validation_data
    else:
        raise ValueError('When passing validation_data, '
                         'it must contain 2 (x_val, y_val) '
                         'or 3 (x_val, y_val, val_sample_weights) '
                         'items, however it contains %d items' %
                         len(validation_data))

    val_x, val_y, val_sample_weights = self._standardize_user_data(
        val_x, val_y,
        sample_weight=val_sample_weight,
        batch_size=batch_size)
    if self._uses_dynamic_learning_phase():
        val_inputs = val_x + val_y + val_sample_weights + [0.]
    else:
        val_inputs = val_x + val_y + val_sample_weights
  • 第二种情况,是用户没有输入验证集,而是给了一个验证集的百分比,即validation_split,该参数的意思是在传入的x与y中取validation_split大小的数据做验证集,该参数是一个[0,1]区间内的值,然后按照该参数对x于y进行切割,最后依旧是判断一下learning_phase。
elif validation_split and 0. < validation_split < 1.:
    if any(K.is_tensor(t) for t in x):
        raise ValueError(
            'If your data is in the form of symbolic tensors, '
            'you cannot use `validation_split`.')
    do_validation = True
    if hasattr(x[0], 'shape'):
        split_at = int(int(x[0].shape[0]) * (1. - validation_split))
    else:
        split_at = int(len(x[0]) * (1. - validation_split))
    x, val_x = (slice_arrays(x, 0, split_at),
                slice_arrays(x, split_at))
    y, val_y = (slice_arrays(y, 0, split_at),
                slice_arrays(y, split_at))
    sample_weights, val_sample_weights = (
        slice_arrays(sample_weights, 0, split_at),
        slice_arrays(sample_weights, split_at))
    if self._uses_dynamic_learning_phase():
        val_inputs = val_x + val_y + val_sample_weights + [0.]
    else:
        val_inputs = val_x + val_y + val_sample_weights
  • 第三种情况就简单很多了,如果只设置了validation_steps参数则验证输入只有一个[0]
elif validation_steps:
    do_validation = True
    if self._uses_dynamic_learning_phase():
        val_inputs = [0.]

接下来是训练数据的准备,依旧是learning_phase的判断,然后给fit_function赋值。

# Prepare input arrays and training function.
 if self._uses_dynamic_learning_phase():
     fit_inputs = x + y + sample_weights + [1.]
 else:
     fit_inputs = x + y + sample_weights
 self._make_train_function()
 fit_function = self.train_function

这里我们详细看下_make_train_function这个函数,这个函数很重要。首先判断self.train_function是否是None,如果是第一次,肯定是None值,然后准备好inputs的值,接下来定义了一个tensorflow的作用域,并调用self.optimizer.get_updates()函数,该函数需要传入权重和损失值,最终得到的是梯度信息。最后调用K.function函数,session的定义,feed_dict的参数都包含在这个函数中。

def _make_train_function(self):
    if not hasattr(self, 'train_function'):
        raise RuntimeError('You must compile your model before using it.')
    self._check_trainable_weights_consistency()
    if self.train_function is None:
        inputs = (self._feed_inputs +
                  self._feed_targets +
                  self._feed_sample_weights)
        if self._uses_dynamic_learning_phase():
            inputs += [K.learning_phase()]

        with K.name_scope('training'):
            with K.name_scope(self.optimizer.__class__.__name__):
                training_updates = self.optimizer.get_updates(
                    params=self._collected_trainable_weights,
                    loss=self.total_loss)
            updates = (self.updates +
                       training_updates +
                       self.metrics_updates)
            # Gets loss and metrics. Updates weights at each call.
            self.train_function = K.function(
                inputs,
                [self.total_loss] + self.metrics_tensors,
                updates=updates,
                name='train_function',
                **self._function_kwargs)

我们再详细看下K.function函数,该函数内部生成了一个类Function

def function(inputs, outputs, updates=None, **kwargs):
    if kwargs:
        for key in kwargs:
            session_has_key = has_arg(tf.Session.run, key, True)
            function_has_key = has_arg(Function.__init__, key, True)
            if not (session_has_key or function_has_key):
                raise ValueError('Invalid argument "%s" passed to K.function '
                                 'with TensorFlow backend' % key)
    return Function(inputs, outputs, updates=updates, **kwargs)

这个类的构造方法会从session_kwargs中获取到feed_dict参数,run_option参数(GPU的配置信息)等信息。

self.feed_dict = session_kwargs.pop('feed_dict', {})
# additionaloperations
self.fetches = session_kwargs.pop('fetches', [])
if not isinstance(self.fetches, list):
    self.fetches = [self.fetches]
# The main use case of `fetches` being passed to a model is the ability
# to run custom updates
# (since the outputs of fetches are never returned).
# This requires us to wrap fetches in `identity` ops.
self.fetches = [tf.identity(x) for x in self.fetches]
# self.session_kwargs is used for _legacy_call
self.session_kwargs = session_kwargs.copy()
self.run_options = session_kwargs.pop('options', None)
self.run_metadata = session_kwargs.pop('run_metadata', None)
if session_kwargs:
    raise ValueError('Some keys in session_kwargs are not '
                     'supported at this '
                     'time: %s', session_kwargs.keys())

其次Function类还有一个call方法,没错keras框架中用的最多的call方法,该方法最终调用了_call()方法,而_call()方法中,定义了session,并获取到了之前的feed_dict的值,并在_make_callable方法中执行了run操作。

def _call(self, inputs):
    if not isinstance(inputs, (list, tuple)):
        raise TypeError('`inputs` should be a list or tuple.')

    session = get_session()
    feed_arrays = []
    array_vals = []
    feed_symbols = []
    symbol_vals = []
    
    if self.feed_dict:
        for key in sorted(self.feed_dict.keys()):
            array_vals.append(
                np.asarray(self.feed_dict[key],
                           dtype=tf.as_dtype(key.dtype).as_numpy_dtype))

    # Refresh callable if anything has changed.
    if (self._callable_fn is None or
            feed_arrays != self._feed_arrays or
            symbol_vals != self._symbol_vals or
            feed_symbols != self._feed_symbols or
            session != self._session):
        self._make_callable(feed_arrays,
                            feed_symbols,
                            symbol_vals,
                            session)
    if self.run_metadata:
        fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
    else:
        fetched = self._callable_fn(*array_vals)
    return fetched[:len(self.outputs)]

那么这个call方法到底是在哪里调用的呢,我们回到fit函数的最外层。fit_function赋值完成后当然到了val_function的赋值。

# Prepare display labels.
out_labels = self.metrics_names

if do_validation:
    self._make_test_function()
    val_function = self.test_function
    callback_metrics = copy.copy(out_labels) + [
        'val_' + n for n in out_labels]
else:
    callback_metrics = copy.copy(out_labels)
    val_function = None
    val_inputs = []

最后返回fit_loop()的值,而call()也是在这里调用的。

return training_arrays.fit_loop(self, fit_function, fit_inputs,
                                out_labels=out_labels,
                                batch_size=batch_size,
                                epochs=epochs,
                                verbose=verbose,
                                callbacks=callbacks,
                                val_function=val_function,
                                val_inputs=val_inputs,
                                shuffle=shuffle,
                                callback_metrics=callback_metrics,
                                initial_epoch=initial_epoch,
                                steps_per_epoch=steps_per_epoch,
                                validation_steps=validation_steps,
                                validation_freq=validation_freq)

接下来我们来看fit_loop方法,首先对验证集做一些验证,确保之后能使用

    do_validation = False
    if val_function and val_inputs:
        do_validation = True
        if (verbose and fit_inputs and
           hasattr(fit_inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
            print('Train on %d samples, validate on %d samples' %
                  (fit_inputs[0].shape[0], val_inputs[0].shape[0]))
    if validation_steps:
        do_validation = True
        if steps_per_epoch is None:
            raise ValueError('Can only use `validation_steps` '
                             'when doing step-wise '
                             'training, i.e. `steps_per_epoch` '
                             'must be set.')
    elif do_validation:
        if steps_per_epoch:
            raise ValueError('Must specify `validation_steps` '
                             'to perform validation '
                             'when doing step-wise training.')

然后计算训练样本的数量,并转换为index

    num_train_samples = check_num_samples(fit_inputs,
                                          batch_size=batch_size,
                                          steps=steps_per_epoch,
                                          steps_name='steps_per_epoch')
    if num_train_samples is not None:
        index_array = np.arange(num_train_samples)

接下来是callback的一些处理,包括初始化history对象、log对象,然后把这些callback和用户自定的callback组合成一个callback list,这里有一行代码很重要callbacks._call_begin_hook('train'),这里会统一代用所有callback的on_train_begin方法,也就是说执行callback的回调函数

    model.history = cbks.History()
    _callbacks = [cbks.BaseLogger(
        stateful_metrics=model.stateful_metric_names)]
    if verbose:
        if steps_per_epoch is not None:
            count_mode = 'steps'
        else:
            count_mode = 'samples'
        _callbacks.append(
            cbks.ProgbarLogger(
                count_mode,
                stateful_metrics=model.stateful_metric_names))
    _callbacks += (callbacks or []) + [model.history]
    callbacks = cbks.CallbackList(_callbacks)
    out_labels = out_labels or []

    # it's possible to callback a different model than itself
    # (used by Sequential models)
    callback_model = model._get_callback_model()

    callbacks.set_model(callback_model)
    callbacks.set_params({
        'batch_size': batch_size,
        'epochs': epochs,
        'steps': steps_per_epoch,
        'samples': num_train_samples,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics or [],
    })
    callbacks._call_begin_hook('train')
    callbacks.model.stop_training = False
    for cbk in callbacks:
        cbk.validation_data = val_inputs

以上代码都是一些训练前的准备,接下来就正式开始train了,这里我把代码分细一点说明。

首先是循环epochs,并执行所有callback的on_epoch_begin方法,

    for epoch in range(initial_epoch, epochs):
        # Reset stateful metrics
        for m in model.stateful_metric_functions:
            m.reset_states()
        callbacks.on_epoch_begin(epoch)
        epoch_logs = {}

接下来是一个判断,分是否传递了steps_per_epoch参数,如果这个参数不是空的,那么直接遍历这个steps,并回调所有的callback的on_train_batch begin方法,接下来调用fit_function也就是执行Functioncall()方法,这样咋们之前定义的session就正式开始run起来了,然后把loss值保存在log参数中,log作为参数执行on_train_batch_end方法,执行完成后可以得到训练一个batch的时间,如果这个时间过长就会打印警告日志,告诉用户callback中做了太耗时的操作。最后调用test_loop方法对验证集进行验证,这里就不展开说明了。

        if steps_per_epoch is not None:
            for step_index in range(steps_per_epoch):
                batch_logs = {'batch': step_index, 'size': 1}
                callbacks._call_batch_hook('train', 'begin', step_index, batch_logs)
                outs = fit_function(fit_inputs)

                outs = to_list(outs)
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks._call_batch_hook('train', 'end', step_index, batch_logs)
                if callback_model.stop_training:
                    break

            if do_validation and should_run_validation(validation_freq, epoch):
                val_outs = test_loop(model, val_function, val_inputs,
                                     steps=validation_steps,
                                     callbacks=callbacks,
                                     verbose=0)
                val_outs = to_list(val_outs)
                # Same labels assumed.
                for l, o in zip(out_labels, val_outs):
                    epoch_logs['val_' + l] = o

另一个分支即没有定义step,总体代码上上面的分支类似,主要区别在于取数据的部分,因为的分支因为定义了step,那直接根据step就能取,而下面的分支会先对数据做一个shuffle,然后调用make_batches方法,根据数据量与batch_size计算出每个batch的起始step与终止step,最后根据这两个值取出训练数据,并执行on_train_batch begin->fit_function->on_train_batch end。当训练完最后一个batch后,会对验证集进行一次验证。

        else:
            if shuffle == 'batch':
                index_array = batch_shuffle(index_array, batch_size)
            elif shuffle:
                np.random.shuffle(index_array)

            batches = make_batches(num_train_samples, batch_size)
            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_ids = index_array[batch_start:batch_end]
                try:
                    if isinstance(fit_inputs[-1], float):
                        # Do not slice the training phase flag.
                        ins_batch = slice_arrays(
                            fit_inputs[:-1], batch_ids) + [fit_inputs[-1]]
                    else:
                        ins_batch = slice_arrays(fit_inputs, batch_ids)
                except TypeError:
                    raise TypeError('TypeError while preparing batch. '
                                    'If using HDF5 input data, '
                                    'pass shuffle="batch".')
                batch_logs = {'batch': batch_index, 'size': len(batch_ids)}
                callbacks._call_batch_hook('train', 'begin', batch_index, batch_logs)
                for i in indices_for_conversion_to_dense:
                    ins_batch[i] = ins_batch[i].toarray()

                outs = fit_function(ins_batch)
                outs = to_list(outs)
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o

                callbacks._call_batch_hook('train', 'end', batch_index, batch_logs)
                if callbacks.model.stop_training:
                    break

            if batch_index == len(batches) - 1:  # Last batch.
                if do_validation and should_run_validation(validation_freq, epoch):
                    val_outs = test_loop(model, val_function, val_inputs,
                                         batch_size=batch_size,
                                         callbacks=callbacks,
                                         verbose=0)
                    val_outs = to_list(val_outs)
                    # Same labels assumed.
                    for l, o in zip(out_labels, val_outs):
                        epoch_logs['val_' + l] = o

最后执行on_epoch_endon_train_end方法,并把history返回。

        callbacks.on_epoch_end(epoch, epoch_logs)
        if callbacks.model.stop_training:
            break
    callbacks._call_end_hook('train')
    return model.history

总结

要看懂keras的fit函数我认为有两个关键点,首先keras把所有的包括日志,操作等都以callback对象的形式回调执行,保证了代码的优雅性。其次,fit_function以call()方法的形式来run session,从而巧妙的隐藏了session,让人完全感觉不到其存在,总的训练流程也不存在冗余的代码。看似简单的fit函数其实包含了如此之多的智慧,在我看来编写keras的大神们真的已经达到了登峰造极的地步,让人不得不佩服的五体投地,在此也再次感谢这些大神们能为我们提供出如此简单高效的API,respect。

  • 11
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值