前面一篇博客:MXNet的训练基础脚本:base_module.py提到MXNet中很重要的一个脚本base_module.py,这里面定义了基本的BaseModule类,而本篇博客提到的module.py脚本中的Module类则是继承BaseModule类,然后进行了具体的实现,比如forward方法,update_metric方法等。
在创建模型的时候用的是mx.mod.Module这个类,比如:model = mx.mod.Module(symbol = sym),具体这个Module类的内容是什么?其实mx.mod.Module准确地写应该是mxnet.module.Module(),这个类的路径是~/mxnet/python/mxnet/module/module.py。这里主要看看各个方法是怎么实现的。
代码的git地址:
https://github.com/dmlc/mxnet/blob/master/python/mxnet/module/module.py
module.py脚本内容如下:
# pylint: disable=too-many-instance-attributes, too-many-arguments, protected-access, too-many-branches
# pylint: disable=too-many-public-methods
"""A `Module` implement the `BaseModule` API by wrapping a `Symbol` and one or
more `Executor` for data parallelization.
"""
import logging
import warnings
from .. import context as ctx
from .. import ndarray as nd
from .. import optimizer as opt
from .executor_group import DataParallelExecutorGroup
from ..model import _create_kvstore, _initialize_kvstore, _update_params, _update_params_on_kvstore
from ..model import load_checkpoint
from ..initializer import Uniform, InitDesc
from ..io import DataDesc
from .base_module import BaseModule, _check_input_names, _parse_data_desc
class Module(BaseModule):
"""Module is a basic module that wrap a `Symbol`. It is functionally the same
as the `FeedForward` model, except under the module API.
Parameters
----------
symbol : Symbol
data_names : list of str
Defaults to `('data')` for a typical model used in image classification.
label_names : list of str
Defaults to `('softmax_label')` for a typical model used in image
classification.
logger : Logger
Defaults to `logging`.
context : Context or list of Context
Defaults to ``mx.cpu()``.
work_load_list : list of number
Default ``None``, indicating uniform workload.
fixed_param_names: list of str
Default ``None``, indicating no network parameters are fixed.
state_names : list of str
states are similar to data and label, but not provided by data iterator.
Instead they are initialized to 0 and can be set by `set_states()`.
"""
# 初始化,当你通过mx.mod.Module创建一个对象的时候,就会进行这个初始化操作。
# 初始化主要是将配置参数包装在一个结构体self中(我理解的self是个结构体),这样不同函数之间传递参数就比较方便
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None):
super(Module, self).__init__(logger=logger)
# context初始化为cpu
if isinstance(context, ctx.Context):
context = [context]
self._context = context
if work_load_list is None:
work_load_list = [1] * len(self._context)
assert len(work_load_list) == len(self._context)
self._work_load_list = work_load_list
# 将symbol包含在self中
self._symbol = symbol
# 将这几个参数生成相应的列表
data_names = list(data_names) if data_names is not None else []
label_names = list(label_names) if label_names is not None else []
state_names = list(state_names) if state_names is not None else []
fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else []
# 这里是调用base_module.py中的_check_input_names函数来做check,要使得所有的输入名称都在symbol的arguments中。
# 另外因为这个函数并不在意BaseModule这个类中,所以需要在前面import这个函数才能调用
_check_input_names(symbol, data_names, "data", True)
_check_input_names(symbol, label_names, "label", False)
_check_input_names(symbol, state_names, "state", True)
_check_input_names(symbol, fixed_param_names, "fixed_param", True)
# 接下来就是将一些配置参数放进self,可以理解为做一些初始化
arg_names = symbol.list_arguments()
input_names = data_names + label_names + state_names
self._param_names = [x for x in arg_names if x not in input_names]
self._fixed_param_names = fixed_param_names
self._aux_names = symbol.list_auxiliary_states()
self._data_names = data_names
self._label_names = label_names
self._state_names = state_names
self._output_names = symbol.list_outputs()
self._arg_params = None
self._aux_params = None
self._params_dirty = False
self._optimizer = None
self._kvstore = None
self._update_on_kvstore = None
self._updater = None
self._preload_opt_states = None
self._grad_req = None
self._exec_group = None
self._data_shapes = None
self._label_shapes = None
@staticmethod
def load(prefix, epoch, load_optimizer_states=False, **kwargs):
"""Creates a model from previously saved checkpoint.
Parameters
----------
prefix : str
path prefix of saved model files. You should have
"prefix-symbol.json", "prefix-xxxx.params", and
optionally "prefix-xxxx.states", where xxxx is the
epoch number.
epoch : int
epoch to load.
load_optimizer_states : bool
whether to load optimizer states. Checkpoint needs
to have been made with save_optimizer_states=True.
data_names : list of str
Default is `('data')` for a typical model used in image classification.
label_names : list of str
Default is `('softmax_label')` for a typical model used in image
classification.
logger : Logger
Default is `logging`.
context : Context or list of Context
Default is ``cpu()``.
work_load_list : list of number
Default ``None``, indicating uniform workload.
fixed_param_names: list of str
Default ``None``, indicating no network parameters are fixed.
"""
sym, args, auxs = load_checkpoint(prefix, epoch)
mod = Module(symbol=sym, **kwargs)
mod._arg_params = args
mod._aux_params = auxs
mod.params_initialized = True
if load_optimizer_states:
mod._preload_opt_states = '%s-%04d.states'%(prefix, epoch)
return mod
def save_checkpoint(self, prefix, epoch, save_optimizer_states=False):
"""Saves current progress to checkpoint.