Pytorch训练可视化(TensorboardX)
PyTorch 番外篇:Pytorch 中的 TensorBoard(TensorBoard in PyTorch)
TensorBoard 相关资料
TensorBoard 是 Tensorflow 官方推出的可视化工具。
官方介绍
TensorBoard: Visualizing Learning
TensorBoard 实践介绍(2017 年 TensorFlow 开发大会)
相关博客
Tensorflow 的可视化工具 Tensorboard 的初步使用
TensorFlow 教程 4 Tensorboard 可视化好帮手
PyTorch 实现
在这次的代码里,是通过简单的神经网络实现一个 MINIST 的分类器,并且通过 TensorBoard 实现训练过程的可视化。
在训练阶段,通过 scalar_summary 画出损失和精确率,通过 image_summary 可视化训练的图像。
另外,使用 histogram_summary 可视化神经网络的参数的权重和梯度值。
需要安装的 package
tensorflow
torch
torchvision
scipy
numpy
LOG 功能实现(Logger 类)
基于 TensorBoard,给 Pytorch 的训练提供保存训练信息的接口。
Tensorboard 可以记录与展示以下数据形式:
标量 Scalars
图片 Images
音频 Audio
计算图 Graph
数据分布 Distribution
直方图 Histograms
嵌入向量 Embeddings
代码中实现了标量 Scalar、图片 Image、直方图 Histogram 的保存。
1
2
3
4
5
6
7
8
# 包
import tensorflow as tf
import numpy as np
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Logger(object):
def __init__(self, log_dir):
"