文章目录
Pytorch框架学习 -2 torch.nn.modules.Module(nn.Module)理解
最简单的例子
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
分析
- 一个Pytroch模型应该以类的形式出现
- Pytorch训练模型应该是nn.Module的子类
- 一个训练模型包含经过初始化和前向传播两个过程
初始化模型是为了注册参数,保证模型能够正常处理这些重要参数,显然是必要
不同神经网络的前向传播过程肯定要自己定义,否则这个模型就失去了独特性
部分源码:
基本参数
class Module:
dump_patches: bool = False
_version: int = 1
training: bool
dump_patches
当调用.to()|.cuda()的时候,将参数也将转化为gpu类型
_version
用于之后函数比较版本
training
使用train(mode)方法时修改,默认在init时变为True
主要影响bn和dropout等在网络训练和评估时使用方法不一样的功能
初始化函数
def __init__(self):
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._non_persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
_parameters
保存当前module的训练参数
_buffers
保存当前moduile的非训练参数