1 定义用于TensorBoard日志的目录
import os
root_logdir = os.path.join(os.curdir,"my_logs")
def get_run_logdir():
import time
run_id = time.strftime("run_%Y_%m_%d-%H_%M_%S")
return os.path.join(root_logdir,run_id)
run_logdir = get_run_logdir() # 例:'./my_logs/run_2022_03_09-15_53_22'
2 类与函数
2.1 SummaryWriter类
torch.utils.tensorboard.writer.
SummaryWriter
(log_dir=无,注释='',purge_step=无,max_queue=10,flush_secs=120,filename_suffix='')
功能:创建一个摘要编写器,该脚本将事件和摘要写出到事件文件中。
from torch.utils.tensorboard import SummaryWriter
# create a summary writer with automatically generated folder name.
writer = SummaryWriter()
# folder location: runs/May04_22-14-54_s-MacBook-Pro.local/
# create a summary writer using the specified folder name.
writer = SummaryWriter("my_experiment")
# folder location: my_experiment
# create a summary writer with comment appended.
writer = SummaryWriter(comment="LR_0.1_BATCH_16")
# folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/
writer = SummaryWriter(run_logdir)
2.2 add_scalar()函数
add_scalar
(标记中, scalar_value, global_step=无, walltime=无, new_style=假, double_precision=假)
功能:将标量数据添加到摘要。
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
x = range(100)
for i in x:
writer.add_scalar('y=2x', i * 2, i)
writer.close()
2.3 add_scalars()函数
add_scalars
(main_tag,tag_scalar_dict,global_step=无,挂页时间=无)
功能:将许多标量数据添加到摘要中。
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
r = 5
for i in range(100):
writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
'xcosx':i*np.cos(i/r),
'tanx': np.tan(i/r)}, i)
writer.close()
# This call adds three values to the same scalar plot with the tag
# 'run_14h' in TensorBoard's scalar section.
2.4 分组画图
按层次结构命名来对图进行分组。
from torch.utils.tensorboard import SummaryWriter
import numpy as np
writer = SummaryWriter()
for n_iter in range(100):
writer.add_scalar('Loss/train', np.random.random(), n_iter)
writer.add_scalar('Loss/test', np.random.random(), n_iter)
writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
3 启动TensorBoard服务器
$ tensorboard --logdir=./my_logs --port=6006