PyTorch使用tensorboard可以显示网络运行情况,和TensorFlow的tensorboard使用很类似。均需要安装tensorboard包。
pip install tensorboard==1.15.0
导入:
from torch.utils.tensorboard import SummaryWriter
指定log路径:
log_dir = './run_logs'
self.writer = SummaryWriter(log_dir)
写入到log:
self.board_write_step = 50#每50个step写一次board_log
if (global_steps+1) % self.board_write_step == 0:
self.writer.add_image("result", result_img, global_step=global_steps)
self.writer.add_scalar("acc", acc, global_step=global_steps)
在shell中开启tensorboard:
tensorboard --logdir=./run_logs
浏览器中打开localhost:6006
:
注:在新版tensorboard(2.3.0)中,如果需要远程访问,需要在命令中加 --bind_all参数。例如
tensorboard --bind_all --logdir=./run_logs
否则只能本机访问,tensorboard启动时会提示这条信息:
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all