深入剖析Pytorch的nn.Module源码(常用方法)

深入剖析Pytorch的nn.Module源码

本文是对nn.Module中的常用函数源码进行剖析

1.__init__函数

包含很多成员变量,一般是字典格式,默认情况下shuffle、dropout都是遵循training=true设置的

    def __init__(self) -> None:
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters: Dict[str, Optional[Parameter]] = OrderedDict()
        self._buffers: Dict[str, Optional[Tensor]] = OrderedDict()
        self._non_persistent_buffers_set: Set[str] = set()
        self._backward_hooks: Dict[int, Callable] = OrderedDict()
        self._is_full_backward_hook = None
        self._forward_hooks: Dict[int, Callable] = OrderedDict()
        self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
        self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
        self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
        self._load_state_dict_post_hooks: Dict[int, Callable] = OrderedDict()
        self._modules: Dict[str, Optional['Module']] = OrderedDict()
2.register_buffer

作用:往当前模型中添加buffer。一般我们不能将buffer视为模型的参数,默认情况下buffers是持久的,可以和parameters一起保存,当然也可以设置False,就不会被保存了。参数说明:

  • name:buffer名称
  • tensor:注册的buffer张量的值
  • persistent:是否作为张量保存下来
def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
	pass # 具体的逻辑判断
3.register_parameter

作用:往模型中添加参数,使用频率较高。参数说明:

name:字符串形式,添加参数的名称

parameter:是tensor形式的继承,但必须写成Parameter(一个类)的实例形式,而不是简单的一个tensor

def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
	pass # 具体的逻辑判断

使用举例:https://www.codenong.com/cs106951116/

class Example(nn.Module):
    def __init__(self):
        super(Example, self).__init__()
        print('看看我们的模型有哪些parameter:\t', self._parameters, end='\n')
        self.W1_params = nn.Parameter(torch.rand(2,3))
        print('增加W1后看看:',self._parameters, end='\n')
       
        self.register_parameter('W2_params' , nn.Parameter(torch.rand(2,3))) # register parameter
        print('增加W2后看看:',self._parameters, end='\n')
    def forward(self, x):
        return x
4.add_module

作用:往当前module中添加子模块

主要参考:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值