我是小白,想学一些技术,这里是学习记录
Tensorboard定义
"""Writes entries directly to event files in the log_dir to be consumed by TensorBoard.
The `SummaryWriter` class provides a high-level API to create an event file in a given directory and add summaries and events to it. The class updates the file contents asynchronously. This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training.
"""
第一行可知,tensorboard直接对log_dir路径下的事件文件进行解析。
常用函数介绍
1.writer = SummaryWriter()
建立一个对象,重要的选填参数为文件夹路径,也就是log_dir。其余参数可以自己摸索功能。
Examples::
from torch.utils.tensorboard import SummaryWriter
# create a summary writer with automatically generated folder name.
writer = SummaryWriter()
# folder location: runs/May04_22-14-54_s-MacBook-Pro.local/
# create a summary writer using the specified folder name.
writer = SummaryWriter("my_experiment")
# folder location: my_experiment
# create a summary writer with comment appended.
writer = SummaryWriter(comment="LR_0.1_BATCH_16")
# folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/
2.writer.add_scalar()
添加标量,画折线图用,里面有三个参数:
第一个参数可以简单理解为标题,画折线图时会根据标题从logs文件夹下寻找相关的event file
第二个参数可以简单理解为纵轴
第三个参数可以简单理解为横轴
writer = SummaryWriter("logs")
for i in range(100):
writer.add_scalar('y=x',i,i)
writer.close()
3.writer.add_image()
opencv读取图片一般是numpy型,用PIL读取的图片是PIL格式,需要通过numpy进行转换才可以在add_image()使用
add_image()函数需要两个主要参数,第一个是标题,第二个是传入的图片。传入的图片需要是torch.tensor格式或者numpy格式。
如果传入多张图片,那么第三个参数是传入图片的序号。
另外一个比较挑剔的是传入的图像通道数格式有要求。如果不是c,h,w顺序的话,需要添加额外第四个参数进行修改。
第四个参数使用方法如下:
Shape:
img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as
corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``.
以下代码可以验证:
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter("logs")
img_path = 'dataset/train/ants_image/0013035.jpg'
img_PIL = Image.open(img_path)
img_array = np.array(img_PIL)
# 打印图像发现图像是512,768,3格式,也就是h,w,c
print(img_array.shape)
# 但是输入必须是c,h,w格式,故需要一个转换
writer.add_image('test',img_array,0,dataformats='HWC')
writer.close()
还有一点要注意,同一标题名的图片会出现在同一个地方,不同标题名的图片出现在不同地方。
详见
https://www.bilibili.com/video/BV1hE411t7RN?p=9
tensorboard常用指令
注意要在相应环境下打开
(我的python文件和logs文件夹位于同一路径下)
1.tensorboard --logdir==logs
用于打开tensorboard
2.tensorboard --logdir==logs --port=6007
用于修改端口,免得一堆人用一个端口地址导致出现问题,我这里修改端口为6007,也可以是其他端口。
来源:
https://space.bilibili.com/203989554