Pytorch学习笔记:hook操作——提取特征、梯度等信息

介绍

功能:主要用于提取中间变量,同时也可以做修改等操作。

常用的相关函数方法

# 提取数据的梯度
torch.Tensor.register_hook(hook)
# 提取模型的中间特征数据
torch.nn.Module.register_forward_hook(hook)
# 提取网络层中的梯度
torch.nn.Module.register_full_backward_hook(hook)

  PyTorch在每一次运算结束后都会释放中间变量,从而节省内存空间,例如释放模型中间得到的特征数据、反向传播过程中的梯度等等,因此就有了hook方法,可以操作中间变量,如保存梯度、保存中间特征数据,也可以对中间变量做修改,如增大梯度、限制梯度范围等等,核心在于hook函数的定义。

定义hook:

# register_hook
hook(grad) -> Tensor or None
# register_forward_hook
hook(module, input, output) -> None or modified output
# register_full_backward_hook
# 一般只利用grad_output,提取模块输出元素的梯度
hook(module, grad_input, grad_output) -> tuple(Tensor) or None

数据梯度

  利用torch.Tensor.register_hook(hook)方法实现,计算数据在做反向传播时的梯度

代码案例

以下面的公式为例:
z = 1 4 ∑ i = 1 4 y i , y i = x i 2 z=\frac14\sum_{i=1}^4y_i,\quad y_i=x_i^2 z=41i=14yi,yi=xi2

import torch


def grad_hook_x(grad):
    # 只传入梯度这一个变量
    x_grad.append(grad)


def grad_hook_y(grad):
    y_grad.append(grad)


torch.manual_seed(0)
y_grad = []
x_grad = []
x = torch.rand(4, requires_grad=True)
y = torch.pow(x, 2)
z = torch.mean(y)
y.register_hook(grad_hook_y)
x.register_hook(grad_hook_x)
z.backward()
print(x)
print("x grad: ", x_grad[0])
print("y grad: ", y_grad[0])

输出,相当于对x和y上的梯度做了保存

# 输入x
tensor([0.4963, 0.7682, 0.0885, 0.1320], requires_grad=True)
# x上的梯度
x grad:  tensor([0.2481, 0.3841, 0.0442, 0.0660])
# y上的梯度
y grad:  tensor([0.2500, 0.2500, 0.2500, 0.2500])

注:将z.backward()改为z.backward(retain_graph=True)也可以实现储存梯度的功能

修改梯度

  如果想要修改梯度,则只需要修改hook函数,如下面案例,此时y上的梯度是原来的两倍,将会影响x的参数更新(更新幅度变大)

def grad_hook_y(grad):
    return grad * 2

网络中间特征

  利用torch.nn.Module.register_forward_hook(hook)方法实现,实现提取特征数据的功能。

注:尽量不要在这里修改特征数据,容易出问题,最好直接去网络结构里面改。

代码案例

网络结构

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self, input_size, out_size, middle_size=None):
        super().__init__()
        if not middle_size:
            middle_size = input_size // 2
        self.conv1 = nn.Conv2d(input_size, middle_size, 3)
        self.conv2 = nn.Conv2d(middle_size, middle_size, 3)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(middle_size, out_size)
        self.middle_size = middle_size

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.pool(x2).view(-1, self.middle_size)
        x4 = self.fc(x3)
        return x4

调用方法

def forward_hook(module, inputs, outputs):
    # 传入模块、模块输入、模块输出三种参数
    feature_map_inputs.append(inputs)
    feature_map_outputs.append(outputs)


torch.manual_seed(0)
feature_map_inputs = []
feature_map_outputs = []

net = Net(4, 2, 3)
net.conv1.register_forward_hook(forward_hook)
data = torch.rand((1, 4, 6, 6), dtype=torch.float32)

out = net(data)
out1 = out[:, 0]
net.zero_grad()
out1.backward(retain_graph=True)

输出
在这里插入图片描述

网络梯度

  利用torch.nn.Module.register_full_backward_hook(hook)方法实现,实现提取特征数据的梯度功能、也可以修改梯度。

注:

  • 在提取梯度时,最好加一个.detach()方法,切断梯度,防止后续操作对网络反向传播有影响;
  • 模块存在多个输入输出时,backward_hook()中的inputsoutputs均为元组类型。

代码案例

卷积模块

网络结构还是之前定义的结构

import torch
from torch import nn


def backward_hook(module, inputs, outputs):
    # 元组类型,常利用[0]提取梯度数据
    grad_inputs.append(inputs[0].detach())
    grad_outputs.append(outputs[0].detach())


torch.manual_seed(0)
grad_inputs = []
grad_outputs = []

net = Net(4, 2, 3)
net.conv2.register_backward_hook(backward_hook)
data = torch.rand((1, 4, 6, 6), dtype=torch.float32)

out = net(data)
out1 = out[:, 0]
net.zero_grad()
# retain_graph设为True表明在反向传播时保存梯度
out1.backward(retain_graph=True)

输出

在这里插入图片描述

注:

  • grad_inputs表示模块输入参数的梯度,梯度尺寸和输入特征图的尺寸相同;
  • grad_outputs表示模块输出参数的梯度,同上,梯度尺寸和输出特征图的尺寸相同。

全连接模块

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self, input_size, out_size, middle_size=None):
        super().__init__()
        if not middle_size:
            middle_size = input_size // 2
        self.fc1 = nn.Linear(input_size, middle_size)
        self.fc2 = nn.Linear(middle_size, out_size)

    def forward(self, x):
        x1 = self.fc1(x)
        x2 = self.fc2(x1)

        return x2

    
def backward_hook(module, inputs, outputs):
    grad_inputs.append(inputs[0].detach())
    grad_outputs.append(outputs[0].detach())


torch.manual_seed(0)
grad_inputs = []
grad_outputs = []
net = Net(6, 2, 3)
net.fc2.register_backward_hook(backward_hook)
data = torch.rand((1, 6), dtype=torch.float32)

out = net(data)
out1 = out[:, 0]
net.zero_grad()
# retain_graph设为True,目的保留梯度
out1.backward(retain_graph=True)
print(grad_inputs, grad_outputs)

输出

在这里插入图片描述

注:

  • grad_inputs表示模块输入参数的梯度,梯度尺寸和输入特征的尺寸相同;
  • grad_outputs表示模块输出参数的梯度,梯度尺寸和输出特征的尺寸相同。

官方文档

register_hook:https://pytorch.org/docs/1.2.0/tensors.html#torch.Tensor.register_hook

register_forward_hook:https://pytorch.org/docs/1.2.0/nn.html#torch.nn.Module.register_forward_hook

register_full_backward_hook:https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=register_full_backward_hook#torch.nn.Module.register_full_backward_hook

注:以上内容仅是笔者个人见解,若有错误,欢迎指正。

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值