PyTorchViz
包(https://github.com/szagoruyko/pytorchviz)可以用来方便地绘制PyTorch正向网络计算图和后传计算路径,本文讲解如何使用PyTorchViz
包,主要参考网站包括:
- https://towardsdatascience.com/understanding-pytorch-with-an-example-a-step-by-step-tutorial-81fc5f8c4e8e
- https://stackoverflow.com/questions/52468956/how-do-i-visualize-a-net-in-pytorch
1. 安装PyTorchViz
pip install torchviz
安装中间如果出现任何问题,请自行google或者留言,我安装过程中出现的错误已经忘记是什么了……不过问题不大。
2. 导入绘制函数
from torchviz import make_dot
3. 定义网络并前向计算输入
4. 绘制计算图和后传路径
make_dot(yhat)
可以看到make_dot
函数不但画出了计算路径,连网络的各层weight,bias以及后传时用的什么操作都显示出来了。这对于我们debug,以及理解网络的计算流程非常有用。