Pytorch tensorboard 使用

该博客介绍了如何使用PyTorch的`torch.utils.tensorboard`模块进行多种数据的可视化,包括添加参数、标量、直方图、图像、PR曲线等。通过`SummaryWriter`创建日志并用`add_*`系列函数记录训练过程中的关键指标,以便于在TensorBoard中查看和分析。此外,还展示了如何添加文本、嵌入、网格图、多线图和边缘图等高级特性。
摘要由CSDN通过智能技术生成

所有内容来自于 torch.utils.tensorboard.writer.py 中的注释。

启动方式:

tensorboard --logdirs=logs

想要在另一台机器访问对应网页可以选择添加 --bind_all

比较常用的几个函数:

  1. add_scalar,add_scalars
  2. add_histogram
  3. add_image,add_images
  4. add_pr_curve
from torch.utils import tensorboard
import numpy as np
import keyword
import torch
import shutil


def main():
    shutil.rmtree('./saved')
    writer = tensorboard.SummaryWriter('./saved')

    # add_hparams
    for i in range(5):
        writer.add_hparams({'lr': 0.1*i, 'bsize': i},
                        {'hparam/accuracy': 10*i, 'hparam/loss': 10*i})
    
    # add_scalar
    x = range(100)
    for i in x:
        writer.add_scalar('y=2x', i * 2, i)

    # add_scalars
    r = 5
    for i in range(100):
        writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
                                        'xcosx':i*np.cos(i/r),
                                        'tanx': np.tan(i/r)}, i)

    # add_histogram
    for i in range(10):
        x = np.random.random(1000)
        writer.add_histogram('distribution centers', x + i, i)

    # add_histogram_raw
    dummy_data = []
    for idx, value in enumerate(range(50)):
        dummy_data += [idx + 0.001] * value

    bins = list(range(50+2))
    bins = np.array(bins)
    values = np.array(dummy_data).astype(float).reshape(-1)
    counts, limits = np.histogram(values, bins=bins)
    sum_sq = values.dot(values)
    writer.add_histogram_raw(
        tag='histogram_with_raw_data',
        min=values.min(),
        max=values.max(),
        num=len(values),
        sum=values.sum(),
        sum_squares=sum_sq,
        bucket_limits=limits[1:].tolist(),
        bucket_counts=counts.tolist(),
        global_step=0)
    
    # add_image
    img = np.zeros((3, 100, 100))
    img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
    img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000

    img_HWC = np.zeros((100, 100, 3))
    img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
    img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000

    writer.add_image('my_image', img, 0)

    # If you have non-default dimension setting, set the dataformats argument.
    writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')

    # add_images
    img_batch = np.zeros((16, 3, 100, 100))
    for i in range(16):
        img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i
        img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i

    writer.add_images('my_image_batch', img_batch, 0)

    # add_text
    writer.add_text('lstm', 'This is an lstm', 0)
    writer.add_text('rnn', 'This is an rnn', 10)

    # add_embedding
    meta = []
    while len(meta)<100:
        meta = meta+keyword.kwlist # get some strings
    meta = meta[:100]

    for i, v in enumerate(meta):
        meta[i] = v+str(i)

    label_img = torch.rand(100, 3, 10, 32)
    for i in range(100):
        label_img[i]*=i/100.0

    writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
    writer.add_embedding(torch.randn(100, 5), label_img=label_img)
    writer.add_embedding(torch.randn(100, 5), metadata=meta)

    # add_pr_curve
    labels = np.random.randint(2, size=100)  # binary label
    predictions = np.random.rand(100)
    writer.add_pr_curve('pr_curve', labels, predictions, 0)

    # add_custom_scalars_multilinechart
    writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330'])

    # add_custom_scalars_marginchart
    writer.add_custom_scalars_marginchart(['twse/0050', 'twse/2330', 'twse/2006'])

    # add_custom_scalars
    layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]},
                         'USA':{ 'dow':['Margin',   ['dow/aaa', 'dow/bbb', 'dow/ccc']],
                              'nasdaq':['Margin',   ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}}

    writer.add_custom_scalars(layout)

    # add_mesh
    vertices_tensor = torch.as_tensor([
        [1, 1, 1],
        [-1, -1, 1],
        [1, -1, -1],
        [-1, 1, -1],
    ], dtype=torch.float).unsqueeze(0)
    colors_tensor = torch.as_tensor([
        [255, 0, 0],
        [0, 255, 0],
        [0, 0, 255],
        [255, 0, 255],
    ], dtype=torch.int).unsqueeze(0)
    faces_tensor = torch.as_tensor([
        [0, 2, 3],
        [0, 3, 1],
        [0, 1, 2],
        [1, 3, 2],
    ], dtype=torch.int).unsqueeze(0)

    writer.add_mesh('my_mesh', vertices=vertices_tensor, colors=colors_tensor, faces=faces_tensor)

    writer.close()


if __name__ == '__main__':
    main()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值