Pytorch中的Hook(常用于网络特征可视化等)

5 篇文章 3 订阅
4 篇文章 1 订阅

Pytorch中的Hook

使用Hook函数获取网络中间变量.
Hook函数机制是不改变函数主体,实现额外功能,像一个挂件,挂钩。正是因为PyTorch计算图动态图的机制,所以才会有Hook函数。在动态图机制的运算,当运算结束后,一些中间变量就会被释放掉,例如,特征图,非leaf节点的梯度。但是有时候,我们需要这些中间变量,所以就出现了Hook函数。
torch提供了四种hook方法,分别用于

  • 获取各个参数的梯度值。tensor.register_hook(hook)
  • 获取各个层前向传播的输入输出值。Module.register_forward_hook(hook)
  • 获取各个层前向传播的输入值。Module.register_forward_pre_hook(hook)
  • 获取各个层反向传播的梯度值。Module.register_backward_hook(hook)

Torch.tensor.register_hook(hook)

记录中间各参数的梯度值

import torch
x=torch.tensor([3.],requires_grad=True)
y=torch.tensor([5.],requires_grad=True)
print('x:',x)
print('y:',y)

a=x+y # a=3+5=8
b=x*y # b=3*5=15

c=a*b # c=8*15=120

# 存储a的梯度
a_grads=[]
def hook_grad(grad):
    a_grads.append(grad)

handle=a.register_hook(hook_grad)
c.backward()
x: tensor([3.], requires_grad=True)
y: tensor([5.], requires_grad=True)
# c=x^2*y+x*y^2。c对x求导:2xy+y^2,c对y求导:x^2+2xy。x=3,y=5
print(x.grad,y.grad,a.grad,b.grad,c.grad)
tensor([55.]) tensor([39.]) None None None
# c=a*x*y,对a求导为:xy.(x=3,y=5)
print(a_grads)
# c=(x+y)*b,对b求导为:x+y
print(b.grad)
[tensor([15.])]
None
# 移除hook函数
handle.remove()

torch.nn.Module.register_forward_hook

注册module的前向传播Hook函数
记录前向传播的中生成的中间输入输出结果。

  • module:当前网络层
  • input:本层输入数据。与register_forward_hook中得到的输入数据相同
  • output:本层输出数据
import torch
import torch.nn as nn
# 创建一个网络
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,2,3) # in_c,out_c,f_size
        self.pool1=nn.MaxPool2d(2)
    def forward(self,x):
        x=self.conv1(x)
        x=self.pool1(x)
        return x
    

下面需要的方法:

  • detach() 返回一个新的Variable,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个Variable永远不需要计算其梯度,不具有grad。
  • fill_() 填充(覆盖)新数据。()没有fill方法,只有fill_方法
# 初始化网络
net=Net()
print(net)
# detach将张量分离,fill_填充
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.detach().zero_()
print(net.conv1.weight)

Net(
  (conv1): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
Parameter containing:
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],
        [[[2., 2., 2.],
          [2., 2., 2.],
          [2., 2., 2.]]]], requires_grad=True)
# 保存信息
output_list=[]
input_list=[]

# 定义方法
def forward_hook(module,data_input,data_output):
    input_list.append(data_input)
    output_list.append(data_output)

# 注册hook
net.conv1.register_forward_hook(forward_hook)
    

卷积层Conv2d的输入维度为(n,c,h,w),输出为(n,c,h,w)

# 输入数据
fake_data=torch.ones(16).view((1,1,4,4)) #维度:(n,c,h,w)
print(fake_data)
result=net(fake_data)
print('RESULT:',result.shape,'\n',result)

tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])
RESULT: torch.Size([1, 2, 1, 1]) 
 tensor([[[[ 9.]],

         [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)
print(input_list)
print('='*20)
print(output_list)

[(tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)]
====================
[tensor([[[[ 9.,  9.],
          [ 9.,  9.]],

         [[18., 18.],
          [18., 18.]]]], grad_fn=<ThnnConv2DBackward>)]

torch.nn.module.register_forward_pre_hook

记录前向传播前的值

  • module:当前网络层
  • input:本层输入数据。与register_forward_hook中得到的输入数据相同

torch.nn.module.register_backward_hook

记录反向更新的值

  • module:当前网络层
  • grad_input:当前网络层输入数据
  • grad_output:当前网络层输出数据

示例

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import time
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.conv1=nn.Conv2d(3,6,5) # in,out,f
        self.pool1=nn.MaxPool2d(2,2)
        self.conv2=nn.Conv2d(6,16,5)
        self.pool2=nn.MaxPool2d(2,2)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
    def forward(self,x):
        x=F.relu(self.conv1(x))
        x=self.pool1(x)
        x=F.relu(self.conv2(x))
        x=self.pool2(x)
        x=x.view(-1,16*5*5)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x
img_path='/Users/liuyanzhe/Downloads/1.jpg'
transform=transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
img=Image.open(img_path)
plt.imshow(img)
img=transform(img)
img.unsqueeze_(dim=0)
net=LeNet()

# 保存信息
output_list=[]
input_list=[]
# 定义register_forward_hook方法。得到前向传播中的输入输出数据
def forward_hook(module,data_input,data_output): 
    #############存储数据
    input_list.append(data_input)
    output_list.append(data_output)
    #############或者直接可视化数据(第一种可视化方法,黑白)
    img = data_input[0].clone().cpu()  # input[0]维度为:[1,通道数,图片大小,图片大小]
    img = np.transpose(img, (1, 0, 2, 3))
    img = img[:16]  # 只保存前16个图片
    print(img.shape)
    tick = time.time()
    save_image(img, '/Users/liuyanzhe/Documents/测试暂存/1-' + str(tick) + '.png')
    #############或者直接可视化数据(第二种可视化方法,彩色)
    x = data_input[0][0]
    #最多显示2张图
    min_num = np.minimum(2, x.size()[0])
    plt.figure()
    for i in range(min_num):
        plt.subplot(1, 4, i + 1)
        # plt.imshow(x[i].cpu(), cmap='gray')
        plt.imshow(x[i].cpu())
        plt.axis('off')
    tick = time.time()
    plt.savefig('/Users/liuyanzhe/Documents/测试暂存/1-'+ str(tick) + '.png', dpi=100, bbox_inches='tight')
    print('已保存图片')
    plt.close()

pre_list=[]
# 定义register_forward_pre_hook方法。得到前向传播中的输入数据
def forward_pre_hook(module,input_data):
    '''
    input_data:tuple类型,只有一条数据,即input_data[0]为1个tensor,形状为[1,通道数,图片大小,图片大小],该tensor即为所有内容(在该例中)。
    '''
    #############存储数据
    pre_list.append(input_data)
    #############或者直接可视化数据(第一种可视化方法,黑白)
    img = input_data[0].clone().cpu()  # input[0]维度为:[1,通道数,图片大小,图片大小]
    img = np.transpose(img, (1, 0, 2, 3))
    img = img[:16]  # 只保存前16个图片
    print(img.shape)
    tick = time.time()
    save_image(img, '/Users/liuyanzhe/Documents/测试暂存/2-' + str(tick) + '.png')
    #############或者直接可视化数据(第二种可视化方法,彩色)
    x = input_data[0][0]
    #最多显示2张图
    min_num = np.minimum(2, x.size()[0])
    plt.figure()
    for i in range(min_num):
        plt.subplot(1, 4, i + 1)
        # plt.imshow(x[i].cpu(), cmap='gray')
        plt.imshow(x[i].cpu())
        plt.axis('off')
    tick = time.time()
    plt.savefig('/Users/liuyanzhe/Documents/测试暂存/2-'+ str(tick) + '.png', dpi=100, bbox_inches='tight')
    print('已保存图片')
    plt.close()

grad_input_list=[]
grad_output_list=[]
# 定义backward方法,得到反向传播时的输入输出梯度数据
def backward_hook(module,grad_input,grad_output):
    grad_input_list.append(grad_input)
    grad_output_list.append(grad_output)

# 得到前向传播中的输入输出数据
net.conv1.register_forward_hook(forward_hook)
net.conv2.register_forward_hook(forward_hook)
# 得到前向传播中输入数据
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv2.register_forward_pre_hook(forward_pre_hook)
# 得到梯度数据
net.conv1.register_backward_hook(backward_hook)
net.conv2.register_backward_hook(backward_hook)



with torch.no_grad():
    out=net(img)

#     print(output_list)
    print(input_list)
    print('='*20)
    print(pre_list)
    
  • 5
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
很好,下面是 PyTorch 的 UNet 模型特征图可代码: ```python import torch from torch.autograd import Variable import numpy as np import cv2 import matplotlib.pyplot as plt def hook_fn(m, i, o): print(m) print("------------Input Grad------------") print(i) print("------------Output Grad------------") print(o) class Unet(torch.nn.Module): def __init__(self, n_channels, n_classes, bilinear=True): super(Unet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits class DoubleConv(torch.nn.Module): def __init__(self, in_ch, out_ch): super(DoubleConv, self).__init__() self.conv = torch.nn.Sequential( torch.nn.Conv2d(in_ch, out_ch, 3, padding=1), torch.nn.BatchNorm2d(out_ch), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(out_ch, out_ch, 3, padding=1), torch.nn.BatchNorm2d(out_ch), torch.nn.ReLU(inplace=True) ) def forward(self, x): x = self.conv(x) return x class Down(torch.nn.Module): def __init__(self, in_ch, out_ch): super(Down, self).__init__() self.mpconv = torch.nn.Sequential( torch.nn.MaxPool2d(2), DoubleConv(in_ch, out_ch) ) def forward(self, x): x = self.mpconv(x) return x class Up(torch.nn.Module): def __init__(self, in_ch, out_ch, bilinear=True): super(Up, self).__init__() if bilinear: self.up = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) else: self.up = torch.nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.conv = DoubleConv(in_ch, out_ch) def forward(self, x1, x2): x1 = self.up(x1) diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = torch.nn.functional.pad(x1, (diffX // 2, diffX - diffX//2, diffY // 2, diffY - diffY//2)) x = torch.cat([x2, x1], dim=1) x = self.conv(x) return x class OutConv(torch.nn.Module): def __init__(self, in_ch, out_ch): super(OutConv, self).__init__() self.conv = torch.nn.Conv2d(in_ch, out_ch, 1) def forward(self, x): x = self.conv(x) return x # 加载已经训练好的UNet模型 model = Unet(n_channels = 3, n_classes = 1) model.load_state_dict(torch.load("unet.pth")) model.eval() # 图像预处理 img = cv2.imread("example.jpg") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (512, 512)) img = np.transpose(img, (2, 0, 1)) img = img.astype(np.float32) / 255. img = Variable(torch.from_numpy(img).unsqueeze(0)) # 注册钩子,获取特征图 features_blobs = [] def hook_feature(module, input, output): features_blobs.append(output.data.cpu().numpy()) model.conv1.register_forward_hook(hook_feature) # 获取并绘制特征图 output = model(img) fea = features_blobs[0] plt.figure(figsize=(10, 10)) plt.subplots_adjust(wspace=0, hspace=0) for idx in range(64): plt.subplot(8, 8, idx + 1) plt.axis('off') plt.imshow(fea[0][idx], cmap='jet') plt.show() ``` 希望这个代码可以帮到你。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值