Pytorch查看网络各层的输入维度[调试技巧]
情景:python的函数经过多层封装,比如pytorch的函数,我们需要理解他的调用过程,这时候可以上网查一下
例:获取Pytorch::nn.Module的输入的维度信息(即神经网络每层输入的维度关系)
forward方法的具体流程:
这里参考了这篇博文https://www.cnblogs.com/llfctt/p/10967651.html
以一个Module为例:
- 调用module的call方法
- module的call里面调用module的forward方法
- forward里面如果碰到Module的子类,回到第1步,如果碰到的是Function的子类,继续往下
- 调用Function的call方法
- Function的call方法调用了Function的forward方法。
- Function的forward返回值
- module的forward返回值
- 在module的call进行forward_hook操作,然后返回值。
具体操作
本图调用model,进行前向传播,"单步进入"该函数
然后我们进入了module的call方法,然后我们在下图光标指示位置再次"单步进入"这个函数