PyTorch nn.Module 实例的8个属性字典(OrderedDict)

为方便介绍,这里先通过继承 nn.Module 来定义一个Net网络。

import torch
from torch import nn
from torch.nn import functional as F


class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(1024, 512)
        self.linear2 = nn.Linear(512, 128)
        self.linear3 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x


net = Model()

1. Net 的实例化

在实例化自定义的网络时,会执行构造函数 __init__(),为了是后续的操作正确执行,需要在定义类时在构造函数 __init__() 中首先调用执行父类的 __init__() 函数:super(Model, self).__init__(),这样才会创造出Net必须有的字典属性:

_parameters
_modules
_buffers
_backward_hooks
_forward_hooks
_forward_pre_hooks
_state_dict_hooks
_load_state_dict_pre_hooks

2. _parameters, _modules, _buffers

nn.Module 使用了 Python 的 __setattr__ 机制,当在类中定义成员时,__setattr__ 会检测成员的 type 派生于哪些类型。如果派生于 Parameter 类,则被归于 _parameters ;如果派生于 Module ,则划归于 _modules。因此,如果类中定义的成员被封装到Python的普通数据类型中,则不会自动归类,比如:self.layers = [nn.Linear(1024, 80), nn.Linear(80, 10],检测到是list类型,则会视为普通属性。

2.1 _parameters

当直接调用 net._parameters 时,会发现,字典为空。因为在定义的网络的成员没有直接派生于 Parameter 类的,所以该方法返回空字典。这时可以使用 net.parameters() 方法,该方法返回一个迭代器,递归获取每层的参数。

>>> net._parameters
OrderedDict()

>>> for i in net.parameters():
    	print(i.__class__)
    	break
<class 'torch.nn.parameter.Parameter'>

2.2 _modules

_modules 包含了类所有的派生于 Module 的成员,前面说了,如果成员被封装到列表中,并不会被添加到 _modules 中,这时,如有必要,可以使用 ModuleList 替代列表来使用,ModuleList 继承了 类,实现了list的功能。

>>> net._modules
OrderedDict([('conv1', Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))),
             ('pool',
              MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)),
             ('conv2', Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))),
             ('flatten', Flatten(start_dim=1, end_dim=-1)),
             ('linear1', Linear(in_features=1024, out_features=80, bias=True)),
             ('linear2', Linear(in_features=80, out_features=10, bias=True))])

2.3 _buffers

该成员值的填充是通过register_buffer API来完成的,通常用来将一些需要持久化的状态(但又不是网络的参数)放到_buffer里;一些极其个别的操作,比如BN,会将running_mean的值放入进来。

3. _backward_hooks, _forward_hooks, _forward_pre_hooks

hook 函数可以在不改变网络的主体的情况下,实现一些额外的功能。因为 Pytorch 中动态运算图的机制,网络计算的中间变量会在计算结束后释放以节省性能。此时,可以通过在某一层网络上挂上一些 hook 函数来获取该层的中间变量。
hook 函数只需要是一个可调用对象就可以(实现了 __call__)。PyTorch 为网络提供了三种 hook:

  1. forward_pre_hooksnn.Module 提供了 net.register_forward_pre_hook(hook) 方法来注册该 hook:hook(module, input)。该 hook 函数用于获取网络层的 input。

    • module: 网络层
    • input: 网路层输入数据
  2. forward_hooksnn.Module 提供了 net.register_forward_hook(hook) 方法来注册该 hook:hook(module, input, output)。该 hook 函数用于获取 module 的 input 和 output。

    • module: 网络层
    • input: 网络层输入数据
    • output: 网络层输出数据
  3. backward_hooksnn.Module 提供了 net.register_backward_hook(hook) 方法来注册该hook:hook(module, grad_input, grad_output)。该 hook 函数用于获取反向传播中 module 的grad_in,grad_out。

    • module: 网络层
    • grad_input: 网络层输入梯度
    • grad_output: 网络层输出梯度
net = Model()
x = torch.rand(10, 1, 28, 28)
y = net(x)
x = x + 0.1*torch.rand(x.shape)
loss = torch.nn.L1Loss()

def forward_pre_hooks(module, input):
    r"""前向传播前hook函数"""
    print("forward_pre_hooks 的输出:")
    print("module: ", module, "\ninput shape: ", input[0].shape)
    
def forward_hooks(module, input, output):
    r"""前向传播hook函数"""
    print("\n\nforward_hooks 的输出:")
    print("module: ", module, "\ninput shape: ", input[0].shape, "\noutput shape: ", output[0].shape)

def backward_hooks(module, grad_input, grad_output):
    r"""反向传播hook函数"""
    print("\n\nbackward_hooks 的输出:")
    print("module: ", module, "\ngrad_input shape: ", grad_input[0].shape, "\ngrad_output shape: ", grad_output[0].shape)
    
net.conv2.register_forward_pre_hook(forward_pre_hooks)
net.conv2.register_forward_hook(forward_hooks)
net.conv2.register_backward_hook(backward_hooks)

y_hat = net(x)
l = loss(y_hat, y)
l.sum().backward()

输出:

forward_pre_hooks 的输出:
module:  Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) 
input shape:  torch.Size([10, 6, 12, 12])


forward_hooks 的输出:
module:  Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) 
input shape:  torch.Size([10, 6, 12, 12]) 
output shape:  torch.Size([16, 8, 8])


backward_hooks 的输出:
module:  Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) 
grad_input shape:  torch.Size([10, 6, 12, 12]) 
grad_output shape:  torch.Size([10, 16, 8, 8])
​

_backward_hooks, _forward_hooks, _forward_pre_hooks 中存放了对应的hook。在网络进行前向传播时:首先会执行 _forward_pre_hooks 中的 hooks,然后执行网络的 forward 函数;再然后执行 _forward_hooks 中的 hooks 函数。当发生反向传播时,会依次执行 _backward_hooks 中的 hooks。

>>> net.conv2._forward_pre_hooks
OrderedDict([(43, <function __main__.forward_pre_hooks(module, input)>)])
>>> net.conv2._forward_hooks
OrderedDict([(44, <function __main__.forward_hooks(module, input, output)>)])
>>> net.conv2._backward_hooks
OrderedDict([(45,
              <function __main__.backward_hooks(module, grad_input, grad_output)>)])

4. _state_dict_hooks, _load_state_dict_pre_hooks

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值