model.fit(X_train,y_train,batch_size=BATCH_SIZE,nb_epoch=1,validation_data=(X_val,y_val))
以上是keras进行model训练的fit代码,它真正的实现流程是怎样的呢?
以上最终调用的是training.Model.fit()方法,在fit方法主要进行步骤如下:
- 模型参数的处理,验证数据的合法性相关的准备工作
- 准备好模型的输入数据和训练相关的函数
以上准备工作最好后,将后续的工作delegate委托给training_arrays.fit_loop()方法,撇开数据的处理、准备,训练的主要代码是这段循环,非常关键:
callbacks.set_model(callback_model)
callbacks.set_params({
'batch_size': batch_size,
'epochs': epochs,
'steps': steps_per_epoch,
'samples': num_train_samples,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
})
callbacks.on_train_begin()
for epoch in range(initial_epoch, nb_epoch):
# 记录本回epoch的历史信息
callbacks.on_epoch_begin(epoch)
# 按照batch批次打混索引
if shuffle == 'batch':
index_array = batch_shuffle(index_array, batch_size)
elif shuffle:
np.random.shuffle(index_array)
# 得到一个批次的索引
batches = make_batches(nb_train_sample, batch_size)
epoch_logs = {}
#........
#省略逻辑见下 部分
#........
callbacks.on_epoch_end(epoch, epoch_logs)
if callback_model.stop_training:
break
callbacks.on_train_end()
以上{for epoch in }代码逻辑主要是对每个epoch进行循环,其中核心针对每个batch的处理见下代码:=
for batch_index, (batch_start, batch_end) in enumerate(batches):
batch_ids = index_array[batch_start:batch_end]
try:
if isinstance(ins[-1], float):
# Do not slice the training phase flag.
ins_batch = slice_arrays(
ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = slice_arrays(ins, batch_ids)
except TypeError:
raise TypeError('TypeError while preparing batch. '
'If using HDF5 input data, '
'pass shuffle="batch".')
batch_logs = {}
batch_logs['batch'] = batch_index
batch_logs['size'] = len(batch_ids)
#回调:每个batch的开始处:logs包含size,即当前batch的样本数
callbacks.on_batch_begin(batch_index, batch_logs)
for i in indices_for_conversion_to_dense:
ins_batch[i] = ins_batch[i].toarray()
outs = f(ins_batch)
outs = to_list(outs)
for l, o in zip(out_labels, outs):
batch_logs[l] = o
#回调:batch结束:logs包含loss,若启用accuracy则还包含acc
callbacks.on_batch_end(batch_index, batch_logs)
if callback_model.stop_training:
break
if batch_index == len(batches) - 1: # Last batch.
if do_validation:
val_outs = test_loop(model, val_f, val_ins,
batch_size=batch_size,
verbose=0)
val_outs = to_list(val_outs)
# Same labels assumed.
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
【1、回调函数callback】
以上就是整个fit_loop()函数的调用代码,其中代码关键点都存在回调函数:
- on_epoch_begin: 在每个epoch开始时调用
- on_epoch_end: 在每个epoch结束时调用
- on_batch_begin: 在每个batch开始时调用
- on_batch_end: 在每个batch结束时调用
- on_train_begin: 在训练开始时调用
- on_train_end: 在训练结束时调用
其中回调函数on_batch_end是主要的回调函数,e.gkeras.callbacks.BaseLogger
统计该batch里面训练的loss以及acc的值,计入totals,乘以batch_size后。
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
self.seen += batch_size
for k, v in logs.items():
if k in self.stateful_metrics:
self.totals[k] = v
else:
if k in self.totals:
self.totals[k] += v * batch_size
else:
self.totals[k] = v * batch_size
其中回调函数on_epoch_end,e.gkeras.callbacks.BaseLogger
这个类的on_epoch_end函数里,执行对这个epoch训练数据的loss以及acc求平均值。
def on_epoch_end(self, epoch, logs=None):
if logs is not None:
for k in self.params['metrics']:
if k in self.totals:
# Make value available to next callbacks.
if k in self.stateful_metrics:
logs[k] = self.totals[k]
else:
logs[k] = self.totals[k] / self.seen
补充:
keras.callbacks.ModelCheckpoint
在on_epoch_end时会保存模型数据进入文件
keras.callbacks.History
主要记录每一次epoch训练的结果,结果包含loss以及acc的值
keras.callbacks.ProgbarLogger
这个函数里面实现训练中间状态数据信息的输出,主要涉及进度相关信息。
【2、outs = f(ins_batch)】
其中函数f()是作为参数传递进入,通过debug我们进行调试,发现直接是进入了Keras后端,进行处理,这样符合keras是基于tf进行的二次封装这前提,这而就是调用不同后端引擎的函数。
经过部分数据检验后,进入到tensorflow_backend.Function._call进行真正的tf操作,其中Function类就是提供众多Tensorflow中运算图的工具。
def _call(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
session = get_session()
feed_arrays = []
array_vals = []
feed_symbols = []
symbol_vals = []
#数据处理转换
for tensor, value in zip(self.inputs, inputs):
if value is None:
continue
if is_tensor(value):
# Case: feeding symbolic tensor.
feed_symbols.append(tensor)
symbol_vals.append(value)
else:
feed_arrays.append(tensor)
# We need to do array conversion and type casting
# at this level, since
# `callable_fn` only supports exact matches.
array_vals.append(
np.asarray(value,
dtype=tf.as_dtype(tensor.dtype).as_numpy_dtype))
if self.feed_dict:
for key in sorted(self.feed_dict.keys()):
array_vals.append(
np.asarray(self.feed_dict[key],
dtype=tf.as_dtype(key.dtype).as_numpy_dtype))
# Refresh callable if anything has changed.
if (self._callable_fn is None or
feed_arrays != self._feed_arrays or
symbol_vals != self._symbol_vals or
feed_symbols != self._feed_symbols or
session != self._session):
#生成一个可以调用的graph
self._make_callable(feed_arrays,
feed_symbols,
symbol_vals,
session)
#运行graph
if self.run_metadata:
fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
else:
fetched = self._callable_fn(*array_vals)
#返回结果
return fetched[:len(self.outputs)]
总结:
1、Keras调用tf进行计算,是分batch进行操作,每个batch结束keras可以对返回进行相应的存储等操作。