【PyTorch】torch.nn.Module 源码分析(转)

torch.nn.Module 这个类的内部有多达 48 个函数,这个类是 PyTorch 中所有 neural network module 的基类,自己创建的网络模型都是这个类的子类,下边是一个示例。这篇文章就和大家一起来阅读一下这个 base class 。

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))


首先是__init__forward这两个函数。__init__中主要是初始化一些内部需要用到的stateforward在这里没有具体实现,是需要在各个子类中实现的,如果子类中没有实现就会报错raise NotImplementedError

函数cudacpu比较简单。函数cuda的作用是Moves all model parameters and buffers to the GPU.;函数cpu的作用是Moves all model parameters and buffers to the CPU.。两者返回的都是Module本身且都调用了_apply函数。

 def cuda(self, device=None):
    return self._apply(lambda t: t.cuda(device))

def cpu(self):
    return self._apply(lambda t: t.cpu())

接下来看一下函数_apply。首先通过循环来实现对所有子模型都遍历一遍该函数内的操作。接下来的这个循环是遍历self._parameters,然后函数compute_should_use_set_data用来决定是否change the tensor in-place,即原地修改tensor。如果是原地修改,将原来的用新的代替就好;否则就在字典self._parameters中把新的tensor注册。如果参数值param有梯度param.grad,那么对param.grad也要做相同的操作。最后一个循环就是对字典self._buffers中的tensor做一个CPUGPU之间的迁移,并将修改后的tensor重新存放到self._buffers中。最后将Module本身返回。

def _apply(self, fn):
    for module in self.children():
        module._apply(fn)

    def compute_should_use_set_data(tensor, tensor_applied):
        # ...

    for key, param in self._parameters.items():
        if param is not None:
            # Tensors stored in modules are graph leaves, and we don't want to
            # track autograd history of `param_applied`, so we have to use
            # `with torch.no_grad():`
            with torch.no_grad():
                param_applied = fn(param)
            should_use_set_data = compute_should_use_set_data(param, param_applied)
            if should_use_set_data:
                param.data = param_applied
            else:
                assert isinstance(param, Parameter)
                assert param.is_leaf
                self._parameters[key] = Parameter(param_applied, param.requires_grad)

            if param.grad is not None:
                with torch.no_grad():
                    grad_applied = fn(param.grad)
                should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
                if should_use_set_data:
                    param.grad.data = grad_applied
                else:
                    assert param.grad.is_leaf
                    self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)

    for key, buf in self._buffers.items():
        if buf is not None:
            self._buffers[key] = fn(buf)

    return self

有了_apply这个函数,就可以很方便地做一些操作,比如函数share_memory就调用了函数_apply。作用就是将所有tensor进行一遍share_memory_操作,即Moves the underlying storage to shared memory. This is a no-op if the underlying storage is already in shared memory and for CUDA tensors. Tensors in shared memory cannot be resized.,简而言之就是将tensor转移到共享内存shared memory中去。

def share_memory(self):
    return self._apply(lambda t: t.share_memory_())

现在来看一下apply函数(注意和上边的_apply函数区分)。这个函数很简单就是将Module及其所有的SubModule传进给定的fn函数操作一遍。举个例子,我们可以用这个函数来对Module的网络模型参数用指定的方法初始化。

def apply(self, fn):
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

下边这个例子就是将网络模型net中的子模型Linear的参数全部赋值为 1 。

Example::
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.data.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)

下边看下type函数、float函数、double函数以及half函数。type函数是将所有parametersbuffers都转成指定的目标类型dst_typefloatdoublehalf这三个函数是将所有floating point parameters分别转成float datatypedouble datatypehalf datatypetorch.Tensor.floattorch.float32torch.Tensor.doubletorch.float64torch.Tensor.halftorch.float16

 def type(self, dst_type):
    return self._apply(lambda t: t.type(dst_type))

def float(self):
    return self._apply(lambda t: t.float() if t.is_floating_point() else t)

def double(self):
    return self._apply(lambda t: t.double() if t.is_floating_point() else t)

def half(self):
    return self._apply(lambda t: t.half() if t.is_floating_point() else t)

函数to的作用是原地 ( in-place ) 修改Module,它可以当成三种函数来使用:function:: to(device=None, dtype=None, non_blocking=False); function:: to(dtype, non_blocking=False); function:: to(tensor, non_blocking=False)。下边展示的是使用方法。

 >>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

到这里就已经介绍了 [公式] 个函数了。

函数state_dict的作用是返回一个包含module的所有statedictionary,而这个字典的Keys对应的就是parameterbuffer的名字names。该函数的源码部分有一个循环可以递归遍历Module中所有的SubModule

>>> net = torch.nn.Linear(2, 2)
>>> net.state_dict()
OrderedDict([('weight', tensor([[-0.3558,  0.2153],
        [-0.2785,  0.6982]])), ('bias', tensor([ 0.5771, -0.6232]))])
>>> net.state_dict().keys()
odict_keys(['weight', 'bias'])

>>> net = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 2))
>>> net.state_dict()
OrderedDict([('0.weight', tensor([[ 0.4792,  0.5772], [ 0.1039, -0.0552]])), 
        ('0.bias', tensor([-0.5175, -0.6469])), 
        ('1.weight', tensor([[-0.5346, -0.0173], [-0.2092,  0.0794]])), 
        ('1.bias', tensor([-0.2150,  0.2323]))])
>>> net.state_dict().keys()
odict_keys(['0.weight', '0.bias', '1.weight', '1.bias'])

函数load_state_dict的作用和上边介绍的state_dict的作用刚好相反,是将parameterbuffer加载到Module及其SubModule中去。

对于函数parameters,我们可以使用for param in model.parameters()来遍历网络模型中的参数,因为该函数返回的是一个迭代器iterator。我们在使用优化算法的时候就是将model.parameters()传给优化器Optimizer。与之类似的还有函数buffers、函数children和函数modules

def parameters(self, recurse=True):
    for name, param in self.named_parameters(recurse=recurse):
        yield param

def buffers(self, recurse=True):
    for name, buf in self.named_buffers(recurse=recurse):
        yield buf

def children(self):
    for name, module in self.named_children():
        yield module

def modules(self):
    for name, module in self.named_modules():
        yield module

与之相对应的,也有四个函数:named_parametersnamed_buffersnamed_childrennamed_modules。函数返回一个迭代器,包括namesmembers

def _named_members(self, get_members_fn, prefix='', recurse=True):
    r"""Helper method for yielding various names + members of modules."""
    memo = set()
    modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
    for module_prefix, module in modules:
        members = get_members_fn(module)
        for k, v in members:
            if v is None or v in memo:
                continue
            memo.add(v)
            name = module_prefix + ('.' if module_prefix else '') + k
            yield name, v

def named_parameters(self, prefix='', recurse=True):
    gen = self._named_members(
        lambda module: module._parameters.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem

def named_buffers(self, prefix='', recurse=True):
    gen = self._named_members(
        lambda module: module._buffers.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem

def named_children(self):
    memo = set()
    for name, module in self._modules.items():
        if module is not None and module not in memo:
            memo.add(module)
            yield name, module

def named_modules(self, memo=None, prefix=''):
    if memo is None:
        memo = set()
    if self not in memo:
        memo.add(self)
        yield prefix, self
        for name, module in self._modules.items():
            if module is None:
                continue
            submodule_prefix = prefix + ('.' if prefix else '') + name
            for m in module.named_modules(memo, submodule_prefix):
                yield m

至此,又介绍了 [公式] 个函数。

函数train和函数eval的作用是将Module及其SubModule分别设置为training modeevaluation mode。这两个函数只对特定的Module有影响,例如Class DropoutClass BatchNorm

def train(self, mode=True): 
    self.training = mode
    for module in self.children():
        module.train(mode)
    return self

def eval(self):
    return self.train(False)

函数requires_grad_用于设置self.parameters()是否需要record梯度,默认情况下是True。函数zero_grad 用于设置self.parameters()gradients为零。

 def requires_grad_(self, requires_grad=True):
    for p in self.parameters():
        p.requires_grad_(requires_grad)
    return self

def zero_grad(self):
    for p in self.parameters():
        if p.grad is not None:
            p.grad.detach_()
            p.grad.zero_()

函数_get_nameextra_repr__repr__以及__dir__都是用于输出Module的相关信息的。_get_name返回的是Module类的名字;extra_repr是用于torch.nn.Module的子类来具体实现,用于输出module的信息,可以输出一行或者多行的字符串信息,具体示例如下所示;__repr__用于输出该Module中所有SubModule的信息并且one item per line__dir__用于输出该Module中包含的所有self.__class__self.__dict__.keys()self._parameters.keys()self._modules.keys()以及self._buffers.keys(),并且会通过key for key in keys if not key[0].isdigit()来消除不合法的Python变量名称的属性。

def _get_name(self):
    return self.__class__.__name__

def extra_repr(self):
    return ''

def __repr__(self):
    extra_lines = []
    extra_repr = self.extra_repr()
    if extra_repr:  # empty string will be split into list ['']
        extra_lines = extra_repr.split('\n')
    child_lines = []
    for key, module in self._modules.items():
        mod_str = repr(module)
        mod_str = _addindent(mod_str, 2)
        child_lines.append('(' + key + '): ' + mod_str)
    lines = extra_lines + child_lines

    main_str = self._get_name() + '('
    if lines:
        # simple one-liner info, which most builtin Modules will use
        if len(extra_lines) == 1 and not child_lines:
            main_str += extra_lines[0]
        else:
            main_str += '\n  ' + '\n  '.join(lines) + '\n'
    main_str += ')'
    return main_str 

def __dir__(self):
    module_attrs = dir(self.__class__)
    attrs = list(self.__dict__.keys())
    parameters = list(self._parameters.keys())
    modules = list(self._modules.keys())
    buffers = list(self._buffers.keys())
    keys = module_attrs + attrs + parameters + modules + buffers

    # Eliminate attrs that are not legal Python variable names
    keys = [key for key in keys if not key[0].isdigit()]
    return sorted(keys)

# --------------------------

# torch.nn.Linear -- class Linear(Module)
def extra_repr(self):
    return 'in_features={}, out_features={}, bias={}'.format(
        self.in_features, self.out_features, self.bias is not None
    )

# Example
>>> l = torch.nn.Linear(2, 2)
>>> l.extra_repr()
'in_features=2, out_features=2, bias=True'

至此又介绍了 [公式] 个函数。

__setstate__设置state,如果self.__dict__中找不到_forward_pre_hooks_state_dict_hooks_load_state_dict_pre_hooks,那么就在self中定义这三个变量为OrderedDict

def __setstate__(self, state):
    self.__dict__.update(state)
    # Support loading old checkpoints that don't have the following attrs:
    if '_forward_pre_hooks' not in self.__dict__:
        self._forward_pre_hooks = OrderedDict()
    if '_state_dict_hooks' not in self.__dict__:
        self._state_dict_hooks = OrderedDict()
    if '_load_state_dict_pre_hooks' not in self.__dict__:
        self._load_state_dict_pre_hooks = OrderedDict()

__getattr__用于获取给定nameModule类中的成员。首先从self.__dict__['_parameters']self.__dict__['_buffers']以及self.__dict__['_modules']中查找,找到后将其return;若找不到,则调用raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))报错。

def __getattr__(self, name):
    if '_parameters' in self.__dict__:
        _parameters = self.__dict__['_parameters']
        if name in _parameters:
            return _parameters[name]
    if '_buffers' in self.__dict__:
        _buffers = self.__dict__['_buffers']
        if name in _buffers:
            return _buffers[name]
    if '_modules' in self.__dict__:
        modules = self.__dict__['_modules']
        if name in modules:
            return modules[name]
    raise AttributeError("'{}' object has no attribute '{}'".format(
        type(self).__name__, name))

__setattr__(self, name, value)用于设置属性,即首先从self.__dict__.get('_parameters')self.__dict__.get('_buffers')以及self.__dict__.get('_modules')中查找,找到后则将该Key-Value删除,按照给定的namevalue重新register

__delattr__用于删除给定nameModule类中的成员。首先从self._parametersself._buffers以及self._modules中查找,找到后使用del将其删除;若找不到,则调用object.__delattr__(self, name)进行删除。

def __delattr__(self, name):
    if name in self._parameters:
        del self._parameters[name]
    elif name in self._buffers:
        del self._buffers[name]
    elif name in self._modules:
        del self._modules[name]
    else:
        object.__delattr__(self, name)

_save_to_state_dict(self, destination, prefix, keep_vars)的作用是将module state储存到destination,并且只针对该module,所以这个函数一般是被module中的所有SubModule调用。This is called on every submodule in method ~ torch.nn.Module.state_dict_load_from_state_dict的作用与之相反,是用来加载module的,相同的是也只针对该module,所以这个函数通常是被module中的所有SubModule调用。This is called on every submodule in method ~ torch.nn.Module.load_state_dict。参数prefix表示的是该Moduleparametersbuffers的前缀。

def _save_to_state_dict(self, destination, prefix, keep_vars):
    for name, param in self._parameters.items():
        if param is not None:
            destination[prefix + name] = param if keep_vars else param.data
    for name, buf in self._buffers.items():
        if buf is not None:
            destination[prefix + name] = buf if keep_vars else buf.data

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    for hook in self._load_state_dict_pre_hooks.values():
        hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

    local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
    local_state = {k: v.data for k, v in local_name_params if v is not None}

    for name, param in local_state.items():
        key = prefix + name
        if key in state_dict:
            # ...
        elif strict:
            missing_keys.append(key)

    if strict:
        for key in state_dict.keys():
            if key.startswith(prefix):
                input_name = key[len(prefix):]
                input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
                if input_name not in self._modules and input_name not in local_state:
                    unexpected_keys.append(key)

函数_register_state_dict_hook和函数_register_load_state_dict_pre_hook的作用也很简单,直接看代码注释就知道了。hooks即文件 torch.utils.hooks ,文件中的类Class RemovableHandle的作用是A handle which provides the capability to remove a hook

def _register_state_dict_hook(self, hook):
    r"""These hooks will be called with arguments: `self`, `state_dict`,
    `prefix`, `local_metadata`, after the `state_dict` of `self` is set.
    Note that only parameters and buffers of `self` or its children are
    guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
    inplace or return a new one.
    """
    handle = hooks.RemovableHandle(self._state_dict_hooks)
    self._state_dict_hooks[handle.id] = hook
    return handle

def _register_load_state_dict_pre_hook(self, hook):
    r"""These hooks will be called with arguments: `state_dict`, `prefix`,
    `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,
    `error_msgs`, before loading `state_dict` into `self`. These arguments
    are exactly the same as those of `_load_from_state_dict`.
    """
    handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
    self._load_state_dict_pre_hooks[handle.id] = hook
    return handle

函数_tracing_name主要是被函数_slow_forward调用。函数_slow_forward和函数__call__作用相似,都是在利用函数forward做计算。

def _tracing_name(self, tracing_state):
    if not tracing_state._traced_module_stack:
        return None
    module = tracing_state._traced_module_stack[-1]
    for name, child in module.named_children():
        if child is self:
            return name
    return None

def _slow_forward(self, *input, **kwargs):
    tracing_state = torch._C._get_tracing_state()
    if not tracing_state:
        return self.forward(*input, **kwargs)
    if not hasattr(tracing_state, '_traced_module_stack'):
        tracing_state._traced_module_stack = []
    name = self._tracing_name(tracing_state)
    if name:
        tracing_state.push_scope('%s[%s]' % (self._get_name(), name))
    else:
        tracing_state.push_scope(self._get_name())
    tracing_state._traced_module_stack.append(self)
    try:
        result = self.forward(*input, **kwargs)
    finally:
        tracing_state.pop_scope()
        tracing_state._traced_module_stack.pop()
    return result

def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        result = hook(self, input)
        if result is not None:
            if not isinstance(result, tuple):
                result = (result,)
            input = result
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            result = hook_result
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result

register_parameterregister_bufferadd_module这三个函数可以放一起看。函数register_parameter的作用就是将给定的name - param加入到字典self._parameters中去。函数register_buffer通常用于register那些不属于model parameters的属性,例如BatchNormrunning_mean就不是parameter。函数add_module的作用是给当前Module按照传递进来的参数对name - module添加子模块SubModule

def register_parameter(self, name, param):
    # Check AttributeError TypeError KeyError ...
    if param is None:
        self._parameters[name] = None
    elif not isinstance(param, Parameter):
        # raise TypeError ...
    elif param.grad_fn:
        # raise ValueError ... Cannot assign non-leaf Tensor to parameter
    else:
        self._parameters[name] = param

def register_buffer(self, name, tensor):
    r"""Example:
        >>> self.register_buffer('running_mean', torch.zeros(num_features))
    """
    if '_buffers' not in self.__dict__:
        # raise AttributeError ... Cannot assign buffer before Module.__init__() call
    elif not isinstance(name, torch._six.string_classes):
        # raise TypeError ... Buffer name should be a string
    elif '.' in name:
        # raise KeyError ... Buffer name can't contain '.'
    elif name == '':
        # raise KeyError ... Buffer name can't be empty string 
    elif hasattr(self, name) and name not in self._buffers:
        # raise KeyError ... Attribute already exists
    elif tensor is not None and not isinstance(tensor, torch.Tensor):
        # raise TypeError
    else:
        self._buffers[name] = tensor

def add_module(self, name, module):
    # Check TypeError KeyError ...
    self._modules[name] = module

函数register_backward_hookregister_forward_pre_hook以及register_forward_hook的作用与前边介绍的函数_register_state_dict_hook和函数_register_load_state_dict_pre_hook的作用类似,也是在该Moduleregister一个hook。官方给出的说明是:【register_backward_hook will be called every time the gradients with respect to module inputs are computed. register_forward_pre_hook will be called every time before :func: forward is invoked. register_forward_hook will be called every time after :func: forward has computed an output.】。

def register_backward_hook(self, hook):
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle

def register_forward_pre_hook(self, hook):
    handle = hooks.RemovableHandle(self._forward_pre_hooks)
    self._forward_pre_hooks[handle.id] = hook
    return handle

def register_forward_hook(self, hook):
    handle = hooks.RemovableHandle(self._forward_hooks)
    self._forward_hooks[handle.id] = hook
    return handle

以上就是 torch.nn.Module 中的全部函数介绍(这个类的代码竟然长达一千多行)。

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值