每天学点pytorch--torch.nn.Module的apply()方法

官方连接:

Module — PyTorch 1.10.0 documentation

apply(fn)的官网介绍,该方法会将fn递归的应用于模块的每一个子模块(.children()的结果)及其自身。典型的用法是,对一个model的参数进行初始化。

示例:

import torch
import torch.nn as nn

@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)

可以在11行net.apply(init_weights)进行断点调试,看下内部执行:

(Pdb) n
> /home/mi/anaconda3/envs/pt1.9/lib/python3.7/site-packages/torch/nn/modules/module.py(615)apply()
-> for module in self.children():
(Pdb) l
615  ->	        for module in self.children():
616  	            module.apply(fn)
617  	        fn(self)
618  	        return self

可以看到在apply内部,会遍历self.children()然后应用fn(在此处就是init_weights),然后调用fn(self),对自身应用fn。

nn.Module还有一个_apply(fn)方法,将模块转移到 CPU/ GPU上时,会调用_apply()方法,比如在执行net.cuda()时,会调用:

620  	    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
621  	        r"""Moves all model parameters and buffers to the GPU.
622  	
623  	        This also makes associated parameters and buffers different objects. So
624  	        it should be called before constructing optimizer if the module will
625  	        live on GPU while being optimized.
626  	
627  	        .. note::
628  	            This method modifies the module in-place.
629  	
630  	        Args:
631  	            device (int, optional): if specified, all parameters will be
632  	                copied to that device
633  	
634  	        Returns:
635  	            Module: self
636  	        """
637  	        return self._apply(lambda t: t.cuda(device))

在_apply(fn)内部会执行3步:

对self.children() 进行递归的调用;

使用fn对 self._parameters 中的参数及其 gradient 进行处理;

使用fn对 self._buffers 中的 buffer 进行处理。

528	        def _apply(self, fn):
529  	        for module in self.children():
530  	            module._apply(fn)
531  	
532  	        def compute_should_use_set_data(tensor, tensor_applied):
533  	            if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
534  	                # If the new tensor has compatible tensor type as the existing tensor,
535  	                # the current behavior is to change the tensor in-place using `.data =`,
536  	                # and the future behavior is to overwrite the existing tensor. However,
537  	                # changing the current behavior is a BC-breaking change, and we want it
538  	                # to happen in future releases. So for now we introduce the
539  	                # `torch.__future__.get_overwrite_module_params_on_conversion()`
540  	                # global flag to let the user control whether they want the future
541  	                # behavior of overwriting the existing tensor or not.
542  	                return not torch.__future__.get_overwrite_module_params_on_conversion()
543  	            else:
544  	                return False
545  	
546  	        for key, param in self._parameters.items():
547  	            if param is not None:
548  	                # Tensors stored in modules are graph leaves, and we don't want to
549  	                # track autograd history of `param_applied`, so we have to use
550  	                # `with torch.no_grad():`
551  	                with torch.no_grad():
552  	                    param_applied = fn(param)
553  	                should_use_set_data = compute_should_use_set_data(param, param_applied)
554  	                if should_use_set_data:
555  	                    param.data = param_applied
556  	                else:
557  	                    assert isinstance(param, Parameter)
558  	                    assert param.is_leaf
559  	                    self._parameters[key] = Parameter(param_applied, param.requires_grad)
560  	
561  	                if param.grad is not None:
562  	                    with torch.no_grad():
563  	                        grad_applied = fn(param.grad)
564  	                    should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
565  	                    if should_use_set_data:
566  	                        param.grad.data = grad_applied
567  	                    else:
568  	                        assert param.grad.is_leaf
569  	                        self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
570  	
571  	        for key, buf in self._buffers.items():
572  	            if buf is not None:
573  	                self._buffers[key] = fn(buf)
574  	
575  	        return self

  • 9
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值