首先来到module模块中,然后进入base_module.py中,便可以看到fit()的原型。
class BaseModule(object):
################################################################################
# High Level API
################################################################################
def forward_backward(self, data_batch):
"""A convenient function that calls both ``forward`` and ``backward``."""
self.forward(data_batch, is_train=True)
self.backward()
# 验证集评测
def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
score_end_callback=None,
reset=True, epoch=0, sparse_row_id_fn=None):
"""Runs prediction on ``eval_data`` and evaluates the performance according to
the given ``eval_metric``.
Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
to see an end-to-end use-case.
Parameters
----------
eval_data : DataIter
Evaluation data to run prediction on.
eval_metric : EvalMetric or list of EvalMetrics
Evaluation metric to use.
num_batch : int
Number of batches to run. Defaults to ``None``, indicating run until the `DataIter`
finishes.
batch_end_callback : function
Could also be a list of functions.
reset : bool
Defaults to ``True``. Indicates whether we should reset `eval_data` before starting
evaluating.
epoch : int
Defaults to 0. For compatibility, this will be passed to callbacks (if any).
During training, this will correspond to the training epoch number.
sparse_row_id_fn : A callback function
The function takes `data_batch` as an input and returns a dict of
str -> NDArray. The resulting dict is used for pulling row_sparse
parameters from the kvstore, where the str key is the name of the param,
and the value is the row id of the param to pull.
Examples
--------
>>> # An example of using score for prediction.
>>> # Evaluate accuracy on val_dataiter
>>> metric = mx.metric.Accuracy()
>>> mod.score(val_dataiter, metric)
>>> mod.score(val_dataiter, ['mse', 'acc'])
"""
assert self.binded and self.params_initialized
# reset验证集
if reset:
eval_data.reset()
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
eval_metric.reset()
actual_num_batch = 0
# 验证集batch获取
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
# 模型加载数据集
self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
# 前向传播
self.forward(eval_batch, is_train=False)
# 调用metric列表update函数
if isinstance(eval_batch, list):
self.update_metric(eval_metric, [eb.label for eb in eval_batch], pre_sliced=True)
else:
self.update_metric(eval_metric, eval_batch.label)
# batch结束回调
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch,
nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
actual_num_batch += 1
# 验证集评测结束回调
if score_end_callback:
params = BatchEndParam(epoch=epoch,
nbatch=actual_num_batch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(score_end_callback):
callback(params)
# 返回metric列表结果name:value
return eval_metric.get_name_value()
def fit(self, train_data, eval_data=None, eval_metric='acc',
epoch_end_callback=None, batch_end_callback=None, kvstore='local',
optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
eval_end_callback=None,
eval_batch_end_callback=None, initializer=Uniform(0.01),
arg_params=None, aux_params=None, allow_missing=False,
force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
validation_metric=None, monitor=None, sparse_row_id_fn=None):
"""Trains the module parameters.
Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
to see an end-to-end use-case.
Parameters
----------
train_data : DataIter
训练集数据迭代器 Train DataIter.
eval_data : DataIter
如果不是'None',将用作验证集,并将评估每个epoch之后的性能。
eval_metric : str or EvalMetric
默认是'accuracy'.训练期间用来显示的绩效指标。
其他可能的预定义指标是:'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
epoch_end_callback : function or list of functions
每个epoch结束时,使用当前的 `epoch`, `symbol`, `arg_params`and `aux_params` 进行回调
Each callback will be called with the current `epoch`, `symbol`, `arg_params` and `aux_params`.
batch_end_callback : function or list of function
每个batch结束时,使用`BatchEndParam`进行回调
Each callback will be called with a `BatchEndParam`.
kvstore : str or KVStore
主要是解决你的梯度更新是在cpu进行还是gpu进行,默认值'local'.
"device",GPU计算梯度更新权重
"local",CPU更新
"dist_device_sync",分布式训练
optimizer : str or Optimizer
优化器,默认值'sgd'.
optimizer_params : dict
优化器参数,默认值(('learning_rate', 0.01),)。参数optimizer_params是 the optimizer 构造器的参数
eval_end_callback : function or list of function
这些函数将在每次全面的evaluation后调用,并用metrics在整个评估集进行评估
These will be called at the end of each full evaluation, with the metrics over the entire evaluation set.
eval_batch_end_callback : function or list of function
在evaluation期间,这些函数将在每个mini-batch后被调用
These will be called at the end of each mini-batch during evaluation.
initializer : Initializer
如果尚未初始化模块参数,则调用初始化程序来初始化它们
arg_params : dict
默认None, 值不为None,则arg_params应该是经过训练的模型或加载先前保存好的模型的checkpoint作为参数,替代initializer初始化参数。arg_params比initializer有更高的优先级
aux_params : dict
默认None, 值不为None,则替代initializer初始化参数
allow_missing : bool
默认False
表示当arg_params和aux_params不为None时,是否允许缺少参数。
allow_missing=True,那么缺少的参数将通过initializer进行初始化。
force_rebind : bool
默认False
如果已经绑定执行器,是否强制重新绑定执行器。
force_init : bool
默认False
即使参数已经初始化也是否强制初始化。
begin_epoch : int
默认值0
开始epoch,通常,如果从前一个训练阶段在Epoch[n]保存,重新训练则该值应为n+1
num_epoch : int
训练的epoch总的数量
sparse_row_id_fn : A callback function
The function takes `data_batch` as an input and returns a dict of
str -> NDArray. The resulting dict is used for pulling row_sparse
parameters from the kvstore, where the str key is the name of the param,
and the value is the row id of the param to pull.
Examples
--------
>>> # An example of using fit for training.
>>> # Assume training dataIter and validation dataIter are ready
>>> # Assume loading a previously checkpointed model
>>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
>>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
... optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
... arg_params=arg_params, aux_params=aux_params,
... eval_metric='acc', num_epoch=10, begin_epoch=3)
"""
assert num_epoch is not None, 'please specify number of epochs'
# 绑定训练集数据symbols name
self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
for_training=True, force_rebind=force_rebind)
if monitor is not None:
self.install_monitor(monitor)
# 初始化权重参数,初始化策略参考以上的参数说明
self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
allow_missing=allow_missing, force_init=force_init)
# 初始化优化器
self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
optimizer_params=optimizer_params)
# 验证评估
if validation_metric is None:
validation_metric = eval_metric
# str类型的eval_metric转metric.EvalMetric
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
################################################################################
# training loop
################################################################################
# for循环训练
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
# 每一轮的评估reset
eval_metric.reset()
# nbatch计数
nbatch = 0
data_iter = iter(train_data)
end_of_batch = False
next_data_batch = next(data_iter)
# 循环next()获取训练集一个batch数据
while not end_of_batch:
data_batch = next_data_batch
if monitor is not None:
monitor.tic()
# 前向传播 + 反向传播计算梯度
self.forward_backward(data_batch)
# 根据优化器梯度更新权重
self.update()
# 评估更新,调用metric的update
if isinstance(data_batch, list):
self.update_metric(eval_metric,
[db.label for db in data_batch],
pre_sliced=True)
else:
self.update_metric(eval_metric, data_batch.label)
# 获取下一个batch数据
try:
# pre fetch next batch
next_data_batch = next(data_iter)
self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
except StopIteration:
end_of_batch = True
if monitor is not None:
monitor.toc_print()
# 获取eval_metric列表的结果name:value
if end_of_batch:
eval_name_vals = eval_metric.get_global_name_value()
# batch结束回调
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
nbatch += 1
# one epoch of training is finished
# 每一个epoch结束,输出eval_metric评价列表结果, Train-xxx=xxx
for name, val in eval_name_vals:
self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
# 输出每一个epoch时间
toc = time.time()
self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
# 参数同步
# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)
# 每一个epoch结束回调
if epoch_end_callback is not None:
for callback in _as_list(epoch_end_callback):
callback(epoch, self.symbol, arg_params, aux_params)
#----------------------------------------
# evaluation on validation set
# 验证集评测,validation_metric为None时与训练集的metric列表一致
if eval_data:
res = self.score(eval_data, validation_metric,
score_end_callback=eval_end_callback,
batch_end_callback=eval_batch_end_callback, epoch=epoch)
#TODO: pull this into default
# 输出验证集评测log
for name, val in res:
self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
# end of 1 epoch, reset the data-iter for another epoch
# 复位训练集数据
train_data.reset()