一、模型创建与nn.Module
1. 模型创建步骤
torch.nn | |
---|---|
nn.Parameter | 张量子类,表示可学习参数,如weight, bias |
nn.Module | 所有网络层基类,管理网络属性 |
nn.functional | 函数具体实现,如卷积,池化,激活函数等 |
nn.init | 参数初始化方法 |
2. nn.model
属性
- parameters : 存储管理nn.Parameter类
- modules : 存储管理nn.Module类
- buffers:存储管理缓冲属性,如BN层中的running_mean
- ***_hooks :存储管理钩子函数
调用步骤:
采用步进(Step into)的调试方法从创建网络模型开始(net =LeNet(classes=2)
)进入到每一个被调用函数,观察net的_modules字段何时被构建并且赋值,记录其中所有进入的类与函数
-
net = LeNet(classes=2)
-
LeNet类
__init__(),super(LeNet, self).__init__()
def __init__(self, classes): super(LeNet, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, classes)
-
Module类
__init__(), self._construct()
,构造8个有序字典def _construct(self): """ Initializes internal Module state, shared by both nn.Module and ScriptModule. """ torch._C._log_api_usage_once("python.nn_module") self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() self._load_state_dict_pre_hooks = OrderedDict() self._modules = OrderedDict()
-
LeNet类:构造卷积层
nn.Conv2d(3, 6, 5)
-
Conv2d类:
__init()__
,继承自_ConvNd类,调用父类构造def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super(Conv2d, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode)
-
_ConvNd类:
__init__()
,继承自Module,调用父类构造,同二三步,再进行变量初始化 -
LeNet类:返回至
self.conv1 = nn.Conv2d(3, 6, 5)
,被父类(nn.Model)__setattr__()
函数拦截# name = 'conv1' # value = Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) modules = self.__dict__.get('_modules') if isinstance(value, Module): if modules is None: raise AttributeError( "cannot assign module before Module.__init__() call") remove_from(self.__dict__, self._parameters, self._buffers) modules[name] = value
因而被记录到LeNet类的_modules中
-
继续构建其他网络层,最后得到的net如下:
总结
- 一个module可以包含多个子module
- 一个module相当于一个运算,必须实现forward()函数
- 每个module都有8个字典管理它的属性
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2