目录
什么是TensorboardX
Tensorboard 是 TensorFlow 的一个附加工具,可以记录训练过程的数字、图像等内容,以方便研究人员观察神经网络训练过程。可是对于 PyTorch 等其他神经网络训练框架并没有功能像 Tensorboard 一样全面的类似工具,一些已有的工具功能有限或使用起来比较困难 (tensorboard_logger, visdom等) 。TensorboardX 这个工具使得 TensorFlow 外的其他神经网络框架也可以使用到 Tensorboard 的便捷功能。TensorboardX 的 github仓库在这里。
TensorboardX 的文档相对详细,但大部分缺少相应的示例。本文是对TensorboardX 各项功能的完整介绍,每项都包含了示例,给出了可视化效果,希望可以方便大家的使用。笔者水平有限,还请读者们斧正,相关问题可以在留言区提出,我尽量解答。
配置TensorboardX
环境要求
- 操作系统:MacOS / Ubuntu (Windows未测试)
- Python2/3
- PyTorch >= 1.0.0 && torchvision >= 0.2.1 && tensorboard >= 1.12.0 1
以上版本要求你对应TensorboardX@1.6版本。为保证版本时效性,建议大家按照 TensorboardX github仓库中README 的要求进行环境配置。
安装
可以直接使用 pip 进行安装,或者从源码进行安装。
使用 pip 安装
pip install tensorboardX
从源码安装
git clone https://github.com/lanpa/tensorboardX && cd tensorboardX && python setup.py install
使用TensorboardX
首先,需要创建一个 SummaryWriter 的示例:
from tensorboardX import SummaryWriter
# Creates writer1 object.
# The log will be saved in 'runs/exp'
writer1 = SummaryWriter('runs/exp')
# Creates writer2 object with auto generated file name
# The log directory will be something like 'runs/Aug20-17-20-33'
writer2 = SummaryWriter()
# Creates writer3 object with auto generated file name, the comment will be appended to the filename.
# The log directory will be something like 'runs/Aug20-17-20-33-resnet'
writer3 = SummaryWriter(comment='resnet')
以上展示了三种初始化 SummaryWriter 的方法:
- 提供一个路径,将使用该路径来保存日志
- 无参数,默认将使用
runs/日期时间
路径来保存日志 - 提供一个 comment 参数,将使用
runs/日期时间-comment
路径来保存日志
一般来讲,我们对于每次实验新建一个路径不同的 SummaryWriter,也叫一个 run,如 runs/exp1
、runs/exp2
。
接下来,我们就可以调用 SummaryWriter 实例的各种 add_something
方法向日志中写入不同类型的数据了。想要在浏览器中查看可视化这些数据,只要在命令行中开启 tensorboard 即可:
tensorboard --logdir=<your_log_dir>
其中的 <your_log_dir>
既可以是单个 run 的路径,如上面 writer1 生成的 runs/exp
;也可以是多个 run 的父目录,如 runs/
下面可能会有很多的子文件夹,每个文件夹都代表了一次实验,我们令 --logdir=runs/
就可以在 tensorboard 可视化界面中方便地横向比较 runs/
下不同次实验所得数据的差异。
使用各种 add 方法记录数据
下面详细介绍 SummaryWriter 实例的各种数据记录方法,并提供相应的示例供参考。
数字 (scalar)
使用 add_scalar
方法来记录数字常量。
add_scalar(tag, scalar_value, global_step=None, walltime=None)
参数
- tag (string): 数据名称,不同名称的数据使用不同曲线展示
- scalar_value (float): 数字常量值
- global_step (int, optional): 训练的 step
- walltime (float, optional): 记录发生的时间,默认为
time.time()
需要注意,这里的 scalar_value
一定是 float 类型,如果是 PyTorch scalar tensor,则需要调用 .item()
方法获取其数值。我们一般会使用 add_scalar
方法来记录训练过程的 loss、accuracy、learning rate 等数值的变化,直观地监控训练过程。
Example
from tensorboardX import SummaryWriter
writer = SummaryWriter