TensorboardX
前言
对于深度学习任务,可视化训练过程能够更为直观地反应网络学习的好坏,便于进一步的网络调参。目前,可视化工具主要有以下几大类:
- visidom
- tensorboard
- tensorboardX
- other package
本文介绍这四者中相对易于上手的,兼容不同深度框架的tensorboradX。介绍的内容包括:
- 使用tensorboardX的代码通用框架
- 常用的API
- 可视化GAN实例
代码通用框架
- 导入SummaryWriter
- 创建SummaryWriter实例
- 调用相应的API进行可视化
from tensorboardX import SummaryWriter
writer = SummaryWriter(logDirPath)
writer.add_something(tag name, object, iteration number)
常见的API
在处理深度学习任务时,我们可能经常会遇到可视化如下参数:
- 神经网络的结构
- Loss曲线
- 展示image
- 数据的直方图
接下来逐一介绍相应的API。
-
add_graph(model, input_to_model, verbose=False)
- model (torch.nn.Module), 待绘制的网络
- input_to_model (torch.autograd.Variable),网络的输入
-
add_scalar(tag, scalar_value, global_step=None)
添加单个标量- tag (string), 数据的id
- scalar_value (float) ,数据的值
- global_step (int) ,步数
-
add_scalars(main_tag,tag_scalar_dict,global_step=None)
添加多个标量- tag (string),数据的id
- main_tag (string) , 标签的前缀
- tag_scalar_dict (dict) , {tag: value},字典,键:名称,值:数值
- global_step (int) ,步数
-
add_image(tag, img_tensor, global_step=None)
添加图数据,要求安装了pillow包- tag (string), 数据的名称
- img_tensor (torch.Tensor) ,图像数据,shape(3,H,W) 配合
torchvision.utils.make_grid
使用 - global_step (int), 步数
-
add_histogram(tag,values,global_step=None,bins='tensorflow')
添加直方图- tag (string) ,数据id
- values (numpy.array),数据
- global_step (int),横坐标的刻度值
- bins (string) , 可选 {‘tensorflow’,’ auto’, ‘fd’, …} ,区间模式选择。
可视化GAN实例
源码:GAN 实例
注意:
-
图片由ScreenToGif制作而成。
-
使用Tensorboard 观看:
tensorboard --logdir logPath --port xxxx