[Pytorch进阶技巧(二)]如何获取网络模型的前向传播路径

为什么需要模型的前向传播路径

无论是可视化网络结构,还是计算网络参数个数与浮点运算次数,还是算法中需要网络整体结构信息,都需要获取网络模型的前向传播路径。从pytorch代码角度考虑,就是依照次序获得并记录各个Module中的forward()过程。

难点是什么

对pytorch熟悉的都知道,网络模型是通过Module的sub_modules()组成的。通过迭代访问Module的子Module可以访问到模型中的所有参数。然而,这样只能访问到所有的网络参数,并无法获取网络前向传播的路径。

怎么获取前向传播路径

将nn.Conv2d与nn.Linear等模块视为最小颗粒模块,我们想获得的是最小颗粒模块组成的有向图,表示前向传播时,数据的传播通路。获取的方法是重载torch.Tensor类与nn.Module.register_forward_hook()

pytorch中,Tensor是一个非常常见的。大部分的运算都是以Tensor为输入输出的,如torch.mul, torch.add等。如果Tensor能够有个成员函数保持记录上一个或多个经过的最小颗粒模块,那么在到达下一个最小颗粒模块时,就可以告诉这个模块他上面的模块有哪些。

以上是大致的思路。但是这个方法需要重载所有用到的tensor运算函数,包括view add等等。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值