深度学习就像炼丹,一次模型的训练需要很长时间,任何人都无法做到一直盯着模型的训练。通常都是开启模型训练之后,只要不报错并且随着迭代次数的增加模型的loss在下降,这时我都会去干别的了让它跑去吧,第二天再来看我的丹炼的怎么样。这时怎么在一堆仙丹(经过一天的训练会保存许多模型的权重,每迭代多少个batch或每个epoch保存一次权重)中找到最好的那一个。如果盲目的取最后几次保存的,容易取到过拟合的模型;相反如果随机在前中期取一个,则容易取到欠拟合的模型。因此,我们需要可以直观的看到模型训练过程中(训练阶段、验证阶段)loss值与accuracy值的变化曲线,选取一个在训练阶段与验证阶段均达到收敛的权重。
通过查找资料,发现可以借助tensorflow的tensorboard工具实现训练过程的可视化。由于是Caffe起家后来转用Pytorch,期间也学过一段时间的Tensorflow,由于不是一个好的TFBOY对TF的一些函数也不是太清楚,有幸在github上找到一位大神封装好的代码(https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/04-utils/tensorboard),所以我就无耻的拿了过来。
首先需要再已有Pytorch环境的基础上安装tensorflow,安一个cpu版本的tensorflow即可,因为只需要tensorflow带的tensorboard工具箱。(强烈建议安装低版本的tensorflow,在实验中我电脑安装的就是1.4版本的即可满足我可视化的需求,pip install tensorflow==1.4.0)。
1、项目代码组织结构
2、logger.py
大佬封装的工具可以实现,loss、accuracy、weight、grad、image在训练过程中变化记录,个人认为loss与accuracy是模型训练时比较关注的主要指标。
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
import tensorflow as tf
import numpy as np
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
class Logger(object):
def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
def image_summary(self, tag, images, step):
"""Log a list of images."""
img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")
# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""
# C