PyTorch 1.x Visulalization -可视化
1. 简介
2. TorchSummary
- 安装
pip install torchsummary
- 定义
from torchsummary import summary
summary(your_model, input_size=(channels, H, W))
-
参数说明
- your_model:需要查看的model
- input_size:model输入tensor的尺寸
-
源码
-
示例代码
import torch
import torchvision.models as models
from torchsummary import summary
import tensorwatch as tw
print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
vgg = models.vgg19().to(device)
summary(vgg, (3, 224, 224))
- 输出
1.6.0+cu101
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1,792
ReLU-2 [-1, 64, 224, 224] 0
Conv2d-3 [-1, 64, 224, 224] 36,928
ReLU-4 [-1, 64, 224, 224] 0
MaxPool2d-5 [-1, 64, 112, 112] 0
Conv2d-6 [-1, 128, 112, 112] 73,856
ReLU-7 [-1, 128, 112, 112] 0
Conv2d-8 [-1, 128, 112, 112] 147,584
ReLU-9 [-1, 128, 112, 112] 0
MaxPool2d-10 [-1, 128, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 295,168
ReLU-12 [-1, 256, 56, 56] 0
Conv2d-13 [-1, 256, 56, 56] 590,080
ReLU-14 [-1, 256, 56, 56] 0
Conv2d-15 [-1, 256, 56, 56] 590,080
ReLU-16 [-1, 256, 56, 56] 0
Conv2d-17 [-1, 256, 56, 56] 590,080
ReLU-18 [-1, 256, 56, 56] 0
MaxPool2d-19 [-1, 256, 28, 28] 0
Conv2d-20 [-1, 512, 28, 28] 1,180,160
ReLU-21 [-1, 512, 28, 28] 0
Conv2d-22 [-1, 512, 28, 28] 2,359,808
ReLU-23 [-1, 512, 28, 28] 0
Conv2d-24 [-1, 512, 28, 28] 2,359,808
ReLU-25 [-1, 512, 28, 28] 0
Conv2d-26 [-1, 512, 28, 28] 2,359,808
ReLU-27 [-1, 512, 28, 28] 0
MaxPool2d-28 [-1, 512, 14, 14] 0
Conv2d-29 [-1, 512, 14, 14] 2,359,808
ReLU-30 [-1, 512, 14, 14] 0
Conv2d-31 [-1, 512, 14, 14] 2,359,808
ReLU-32 [-1, 512, 14, 14] 0
Conv2d-33 [-1, 512, 14, 14] 2,359,808
ReLU-34 [-1, 512, 14, 14] 0
Conv2d-35 [-1, 512, 14, 14] 2,359,808
ReLU-36 [-1, 512, 14, 14] 0
MaxPool2d-37 [-1, 512, 7, 7] 0
AdaptiveAvgPool2d-38 [-1, 512, 7, 7] 0
Linear-39 [-1, 4096] 102,764,544
ReLU-40 [-1, 4096] 0
Dropout-41 [-1, 4096] 0
Linear-42 [-1, 4096] 16,781,312
ReLU-43 [-1, 4096] 0
Dropout-44 [-1, 4096] 0
Linear-45 [-1, 1000] 4,097,000
================================================================
Total params: 143,667,240
Trainable params: 143,667,240
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 238.69
Params size (MB): 548.05
Estimated Total Size (MB): 787.31
----------------------------------------------------------------
3. TensorWatch
- 说明
- TensorWatch是微软为data science, deep learning and reinforcement learning设计的一款调试和可视化工具
- 源码
- 定义
def draw_model(model, input_shape=None, orientation='TB', png_filename=None): #orientation = 'LR' for landscpe
- 安装组件
pip install graphviz
pip install torchvision
pip install scikit-learn
pip install tensorwatch
3.1 保存模型图
- 示例代码
alexnet_model = torchvision.models.alexnet()
tw.draw_model(alexnet_model, [1, 3, 224, 224])
img = tw.draw_model(alexnet_model, [1, 3, 224, 224])
img.save(r'D:/alexnet.jpg')
3.2 查看层参数
- 示例代码
alexnet_model = torchvision.models.alexnet()
tw.model_stats(alexnet_model, [1, 3, 224, 224])
- 输出
3.3 错误及解决方案
3.3.1 AttributeError: module ‘torch.onnx’ has no attribute ‘set_training’
- 原因:pytorch版本太高,我是1.6,而1.6以下的版本torch.onnxhas 才有属性 ‘set_training’
- 办法:把pytorch的版本降低,可以直接使用
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch==1.2
3.3.2 AttributeError: ‘Dot’ object has no attribute ‘repr_svg’
- 原因:也是版本问题
- 办法:Anaconda\Lib\site-packages\tensorwatch\model_graph\hiddenlayer\pytorch_draw_model.py的第13行改为 return self.dot.create_svg().decode()
3.3.3 FileNotFoundError: [Error 2] “dot” not found in path
- 原因:Graphviz和pydot的问题
- 办法: 点击链接
- 装好Graphviz就没有问题了。
4. NetRon
- Netron是一款支持离线查看“各种”神经网络框架的模型可视化神器。
- 代码
- 可支持
- Netron supports ONNX (.onnx, .pb, .pbtxt)
- Keras (.h5, .keras)
- Core ML (.mlmodel)
- Caffe (.caffemodel, .prototxt)
- Caffe2 (predict_net.pb)
- Darknet (.cfg)
- MXNet (.model, -symbol.json)
- Barracuda (.nn)
- ncnn (.param)
- Tengine (.tmfile)
- TNN (.tnnproto)
- UFF (.uff)
- TensorFlow Lite (.tflite).
- 实验性地支持
- TorchScript (.pt, .pth),
- PyTorch (.pt, .pth)
- Torch (.t7),
- Arm NN (.armnn)
- BigDL (.bigdl, .model)
- Chainer (.npz, .h5)
- CNTK (.model, .cntk)
- Deeplearning4j (.zip)
- MediaPipe (.pbtxt)
- ML.NET (.zip)
- MNN (.mnn)
- PaddlePaddle (.zip, model)
- OpenVINO (.xml)
- scikit-learn (.pkl)
- TensorFlow.js (model.json, .pb)
- TensorFlow (.pb, .meta, .pbtxt, .ckpt, .index)