模型训练到一半有 bug 停了,可以 resume 继续炼,本篇给出 pytorch 在 resume 训练时续写 tensorboard 的简例,参考 [1-3],只要保证 writer 接收的 global step 是连着的就行。
Code
import numpy as np
from torch.utils.tensorboard import SummaryWriter
global_step = 0
log_p = '.'
losses = 1 / np.arange(1, 21)
# 第一次训练
with SummaryWriter(log_dir=log_p) as writer:
for loss in losses[:10]:
writer.add_scalar("loss", loss, global_step)
writer.add_scalar("loss_1", loss, global_step) # 前半段 loss,作为参考
global_step += 1
# (此处训练因模型有 bug 中断了)
# 重开,resume 训练。续写 tensorboard log 须:
# 1. 接着之前的 global_step
# 2. 同一个 log_dir
with SummaryWriter(log_dir=log_p) as writer:
for loss in losses[10:]:
writer.add_scalar("loss", loss, global_step)
writer.add_scalar("loss_2", loss, global_step) # 后半段 loss,作为参考
global_step += 1
这里加了 loss_1、loss_2 作为参考,会生成两个 events.out.tfevents 文件,一个是 loss 前半段和 loss_1,另一个是 loss 后半段和 loss_2。如果没有 loss_1、loss_2,就只有一个 events.out.tfevents 文件,显示一条连续的 loss 曲线。
如果 global step 不连续,但递增,则也能续写,不过 tensorboard 网页显示时中间会自动补一段曲线。即:
import numpy as np
from torch.utils.tensorboard import SummaryWriter
global_step = 0
log_p = '.'
losses = 1 / np.arange(1, 21)
# 第一次训练
with SummaryWriter(log_dir=log_p) as writer:
for loss in losses[:10]:
writer.add_scalar("loss", loss, global_step)
global_step += 1
# 因 bug 中断,且 global step 没接上之前的
# 但续写 tensorboard 时的开始 global step 大过中断时的 global step(即 global step 递增)
global_step += 5 # 从 9 直变 15
# resume 训练,续写 tensorboard
with SummaryWriter(log_dir=log_p) as writer:
for loss in losses[10:]:
writer.add_scalar("loss", loss, global_step)
global_step += 1
其中 9 < global step < 15 那段是没值的,不过网页显示是补了一段曲线。
Conclusion
所以写模型、存 checkpoint 时考虑记一个 global step,resume 训练时 global step 也 resume,就可以续写 tensorboard 了。