为什么需要模型的前向传播路径
无论是可视化网络结构,还是计算网络参数个数与浮点运算次数,还是算法中需要网络整体结构信息,都需要获取网络模型的前向传播路径。从pytorch代码角度考虑,就是依照次序获得并记录各个Module中的forward()过程。
难点是什么
对pytorch熟悉的都知道,网络模型是通过Module的sub_modules()组成的。通过迭代访问Module的子Module可以访问到模型中的所有参数。然而,这样只能访问到所有的网络参数,并无法获取网络前向传播的路径。
怎么获取前向传播路径
将nn.Conv2d与nn.Linear等模块视为最小颗粒模块,我们想获得的是最小颗粒模块组成的有向图,表示前向传播时,数据的传播通路。获取的方法是重载torch.Tensor类与nn.Module.register_forward_hook()
pytorch中,Tensor是一个非常常见的。大部分的运算都是以Tensor为输入输出的,如torch.mul, torch.add等。如果Tensor能够有个成员函数保持记录上一个或多个经过的最小颗粒模块,那么在到达下一个最小颗粒模块时,就可以告诉这个模块他上面的模块有哪些。
以上是大致的思路。但是这个方法需要重载所有用到的tensor运算函数,包括view add等等。