官方连接:
Module — PyTorch 1.10.0 documentation
apply(fn)的官网介绍,该方法会将fn递归的应用于模块的每一个子模块(.children()的结果)及其自身。典型的用法是,对一个model的参数进行初始化。
示例:
import torch
import torch.nn as nn
@torch.no_grad()
def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
可以在11行net.apply(init_weights)进行断点调试,看下内部执行:
(Pdb) n
> /home/mi/anaconda3/envs/pt1.9/lib/python3.7/site-packages/torch/nn/modules/module.py(615)apply()
-> for module in self.children():
(Pdb) l
615 -> for module in self.children():
616 module.apply(fn)
617 fn(self)
618 return self
可以看到在apply内部,会遍历self.children()然后应用fn(在此处就是init_weights),然后调用fn(self),对自身应用fn。
nn.Module还有一个_apply(fn)方法,将模块转移到 CPU/ GPU上时,会调用_apply()方法,比如在执行net.cuda()时,会调用:
620 def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
621 r"""Moves all model parameters and buffers to the GPU.
622
623 This also makes associated parameters and buffers different objects. So
624 it should be called before constructing optimizer if the module will
625 live on GPU while being optimized.
626
627 .. note::
628 This method modifies the module in-place.
629
630 Args:
631 device (int, optional): if specified, all parameters will be
632 copied to that device
633
634 Returns:
635 Module: self
636 """
637 return self._apply(lambda t: t.cuda(device))
在_apply(fn)内部会执行3步:
对self.children() 进行递归的调用;
使用fn对 self._parameters 中的参数及其 gradient 进行处理;
使用fn对 self._buffers 中的 buffer 进行处理。
528 def _apply(self, fn):
529 for module in self.children():
530 module._apply(fn)
531
532 def compute_should_use_set_data(tensor, tensor_applied):
533 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
534 # If the new tensor has compatible tensor type as the existing tensor,
535 # the current behavior is to change the tensor in-place using `.data =`,
536 # and the future behavior is to overwrite the existing tensor. However,
537 # changing the current behavior is a BC-breaking change, and we want it
538 # to happen in future releases. So for now we introduce the
539 # `torch.__future__.get_overwrite_module_params_on_conversion()`
540 # global flag to let the user control whether they want the future
541 # behavior of overwriting the existing tensor or not.
542 return not torch.__future__.get_overwrite_module_params_on_conversion()
543 else:
544 return False
545
546 for key, param in self._parameters.items():
547 if param is not None:
548 # Tensors stored in modules are graph leaves, and we don't want to
549 # track autograd history of `param_applied`, so we have to use
550 # `with torch.no_grad():`
551 with torch.no_grad():
552 param_applied = fn(param)
553 should_use_set_data = compute_should_use_set_data(param, param_applied)
554 if should_use_set_data:
555 param.data = param_applied
556 else:
557 assert isinstance(param, Parameter)
558 assert param.is_leaf
559 self._parameters[key] = Parameter(param_applied, param.requires_grad)
560
561 if param.grad is not None:
562 with torch.no_grad():
563 grad_applied = fn(param.grad)
564 should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
565 if should_use_set_data:
566 param.grad.data = grad_applied
567 else:
568 assert param.grad.is_leaf
569 self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
570
571 for key, buf in self._buffers.items():
572 if buf is not None:
573 self._buffers[key] = fn(buf)
574
575 return self