【keras原理解析】Keras神经网络运行源码深入解析

model.fit(X_train,y_train,batch_size=BATCH_SIZE,nb_epoch=1,validation_data=(X_val,y_val))

以上是keras进行model训练的fit代码,它真正的实现流程是怎样的呢?

以上最终调用的是training.Model.fit()方法,在fit方法主要进行步骤如下:

  1. 模型参数的处理,验证数据的合法性相关的准备工作
  2. 准备好模型的输入数据和训练相关的函数

以上准备工作最好后,将后续的工作delegate委托给training_arrays.fit_loop()方法,撇开数据的处理、准备,训练的主要代码是这段循环,非常关键:

callbacks.set_model(callback_model)
    callbacks.set_params({
        'batch_size': batch_size,
        'epochs': epochs,
        'steps': steps_per_epoch,
        'samples': num_train_samples,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics or [],
    })
callbacks.on_train_begin() 
for epoch in range(initial_epoch, nb_epoch):
        # 记录本回epoch的历史信息
        callbacks.on_epoch_begin(epoch)
        # 按照batch批次打混索引
        if shuffle == 'batch':
            index_array = batch_shuffle(index_array, batch_size)
        elif shuffle:
            np.random.shuffle(index_array)
        # 得到一个批次的索引
        batches = make_batches(nb_train_sample, batch_size)
        epoch_logs = {}
        #........
        #省略逻辑见下 部分 
        #........
        callbacks.on_epoch_end(epoch, epoch_logs)
        if callback_model.stop_training:
            break

callbacks.on_train_end()

以上{for epoch in }代码逻辑主要是对每个epoch进行循环,其中核心针对每个batch的处理见下代码:=

            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_ids = index_array[batch_start:batch_end]
                try:
                    if isinstance(ins[-1], float):
                        # Do not slice the training phase flag.
                        ins_batch = slice_arrays(
                            ins[:-1], batch_ids) + [ins[-1]]
                    else:
                        ins_batch = slice_arrays(ins, batch_ids)
                except TypeError:
                    raise TypeError('TypeError while preparing batch. '
                                    'If using HDF5 input data, '
                                    'pass shuffle="batch".')
                batch_logs = {}
                batch_logs['batch'] = batch_index
                batch_logs['size'] = len(batch_ids)
                
                #回调:每个batch的开始处:logs包含size,即当前batch的样本数
                callbacks.on_batch_begin(batch_index, batch_logs)
                for i in indices_for_conversion_to_dense:
                    ins_batch[i] = ins_batch[i].toarray()

                outs = f(ins_batch)
                outs = to_list(outs)
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                #回调:batch结束:logs包含loss,若启用accuracy则还包含acc
                callbacks.on_batch_end(batch_index, batch_logs)
                if callback_model.stop_training:
                    break

                if batch_index == len(batches) - 1:  # Last batch.
                    if do_validation:
                        val_outs = test_loop(model, val_f, val_ins,
                                             batch_size=batch_size,
                                             verbose=0)
                        val_outs = to_list(val_outs)
                        # Same labels assumed.
                        for l, o in zip(out_labels, val_outs):
                            epoch_logs['val_' + l] = o

【1、回调函数callback】

以上就是整个fit_loop()函数的调用代码,其中代码关键点都存在回调函数:

  1. on_epoch_begin: 在每个epoch开始时调用
  2. on_epoch_end: 在每个epoch结束时调用
  3. on_batch_begin: 在每个batch开始时调用
  4. on_batch_end: 在每个batch结束时调用
  5. on_train_begin: 在训练开始时调用
  6. on_train_end: 在训练结束时调用

其中回调函数on_batch_end是主要的回调函数,e.gkeras.callbacks.BaseLogger

统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        batch_size = logs.get('size', 0)
        self.seen += batch_size

        for k, v in logs.items():
            if k in self.stateful_metrics:
                self.totals[k] = v
            else:
                if k in self.totals:
                    self.totals[k] += v * batch_size
                else:
                    self.totals[k] = v * batch_size

其中回调函数on_epoch_end,e.gkeras.callbacks.BaseLogger

这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。

    def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            for k in self.params['metrics']:
                if k in self.totals:
                    # Make value available to next callbacks.
                    if k in self.stateful_metrics:
                        logs[k] = self.totals[k]
                    else:
                        logs[k] = self.totals[k] / self.seen

补充:

keras.callbacks.ModelCheckpoint

在on_epoch_end时会保存模型数据进入文件

keras.callbacks.History

主要记录每一次epoch训练的结果,结果包含loss以及acc的值

keras.callbacks.ProgbarLogger

这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。

【2、outs = f(ins_batch)】

其中函数f()是作为参数传递进入,通过debug我们进行调试,发现直接是进入了Keras后端,进行处理,这样符合keras是基于tf进行的二次封装这前提,这而就是调用不同后端引擎的函数。

经过部分数据检验后,进入到tensorflow_backend.Function._call进行真正的tf操作,其中Function类就是提供众多Tensorflow中运算图的工具。

    def _call(self, inputs):
        if not isinstance(inputs, (list, tuple)):
            raise TypeError('`inputs` should be a list or tuple.')

        session = get_session()
        feed_arrays = []
        array_vals = []
        feed_symbols = []
        symbol_vals = []
        #数据处理转换
        for tensor, value in zip(self.inputs, inputs):
            if value is None:
                continue
            if is_tensor(value):
                # Case: feeding symbolic tensor.
                feed_symbols.append(tensor)
                symbol_vals.append(value)
            else:
                feed_arrays.append(tensor)
                # We need to do array conversion and type casting
                # at this level, since
                # `callable_fn` only supports exact matches.
                array_vals.append(
                    np.asarray(value,
                               dtype=tf.as_dtype(tensor.dtype).as_numpy_dtype))
        if self.feed_dict:
            for key in sorted(self.feed_dict.keys()):
                array_vals.append(
                    np.asarray(self.feed_dict[key],
                               dtype=tf.as_dtype(key.dtype).as_numpy_dtype))

        # Refresh callable if anything has changed.
        if (self._callable_fn is None or
                feed_arrays != self._feed_arrays or
                symbol_vals != self._symbol_vals or
                feed_symbols != self._feed_symbols or
                session != self._session):
            #生成一个可以调用的graph
            self._make_callable(feed_arrays,
                                feed_symbols,
                                symbol_vals,
                                session)
        #运行graph
        if self.run_metadata:
            fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
        else:
            fetched = self._callable_fn(*array_vals)
        #返回结果
        return fetched[:len(self.outputs)]

总结:

1、Keras调用tf进行计算,是分batch进行操作,每个batch结束keras可以对返回进行相应的存储等操作。

 

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值