wandb学习笔记

一、工具介绍

wandb全称Weights & Biases,用来帮助开发者跟踪机器学习的项目。通过wandb可以记录模型训练过程中指标的变化情况以及超参的设置,然后对输出的结果进行可视化的对比,帮助我们更好的分析模型在训练过程中的问题,并快速与同事进行团队协作。
wandb可以将训练过程中的参数,上传到服务器上,通过登录wandb来进行实施过程模型训练过程中参数和指标的变化。
image.png
wandb的主体组成主要有以下四大模块,分别是:

  • 仪表盘:跟踪实验、可视化结果
  • 报告:保存和分享可复制的成果/结论
  • Sweeps:通过调节超参数来优化模型
  • Artifacts:数据集和模型版本化,流水线跟踪

二、工具使用

2.1 工具安装

pip install wandb

2.2 创建账号

在shell命令行中或官网注册页面注册免费账号

wandb login

在官网登录账号,点击Settings,滚动到下面,寻找API Keys进行复制,在命令行中粘贴即可。

2.3 编写脚本

2.3.1 初始化W&B

在开始记录之前,在脚本开始处初始化wandb

# Inside my model training code
import wandb
wandb.init(project="my-project")

如果项目不存在,则自动创建项目。脚本的运行会同步到一个名称为「my-project」的项目中

2.3.2 声明超参数

通过对象wandb.config保存超参数。

# Save model inputs and hyperparameters
config = wandb.config
config.dropout = 0.2
config.hidden_layer_size = 128

2.3.3 记录指标

在训练模型过程中记录指标(Metric),如损失(Loss)和准确率(ACC)。通过wandb.log记录更为复杂的输出和结果,如直方图、图形和图像。

def my_train_loop():
    wandb.watch(model)    # Log gradients and model parameters
    for epoch in range(10):
        loss = 0 # change as appropriate :)
        wandb.log({'epoch': epoch, 'loss': loss})    # Log metrics to visualize performance

2.3.4 保存文件

保存在路径wandb.run.d中的全部内容都会被上传到W&B,并在运行结束后与运行项保存在一起。

# by default, this will save to a new subfolder for files associated
# with your run, created in wandb.run.dir (which is ./wandb by default)
wandb.save("mymodel.h5")

# you can pass the full path to the Keras model API
model.save(os.path.join(wandb.run.dir, "mymodel.h5"))

当正常运行脚本时,后台进程会同步相关记录(终端输出、指标和文件)至云端。
若脚本从git库中运行,后台进程会同步记录git状态信息。

三、工具示例

wandb操作简单,是一款性能强大的深度学习网络可视化插件,支持多个深度学习网络框架,相关功能如下:

  • 调参:存储训练参数,自动分析参数关联,助力快速调参
  • 看板:同步本地实验结果,并进行实时可视化展示
  • 报告:将结果自动同步到云端,永久记录实验过程,可用于模型改进探索及团队分享
  • 自动记录CPU、GPU等系统硬件使用情况
  • 自动记录控制台打印的训练日志,以及环境依赖等参数项

3.1 Log Metrics

3.1.1 简单示例

相关数据文件:apple.csv

import os
import pandas as pd
import wandb

apple = pd.read_csv('./data/apple.csv')
apple = apple[-1000:]

# Init

wandb.init(project='visualize-models', name='apple_metric')

# Log the metrics

for price in apple['close']:
 wandb.log({"Stock Price": price})

3.1.2 自定义格式

可以通过matplotlib进行图表对象的绘制,然后传递至wandb.log

import os
import wandb
import matplotlib.pyplot as plt

fibonacci = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]

plt.plot(fibonacci)
plt.ylabel('some interestering numbers')

wandb.init(project='visualize-models', name='matplotlib to wandb')

wandb.log({'plot': plt})

3.1.3 图像格式

如果图像是Numpy数组,且数组包含浮点数,我们将其转换为0-255之间的整数,则可手动指定模式或只提供PIL.Image

import os
import pandas as pd
import wandb
import matplotlib.pyplot as plt

img = './data/ladybug.jpg'
im = plt.imread(img)

wandb.init(project='visualize-models', name='image to wandb')

wandb.log({'img': [wandb.Image(im, caption='Cafe')]})

3.2 保存和恢复模型

3.2.1 保存模型

#"model.h5" is saved in wandb.run.dir & will be uploaded at the end of training
model.save(os.path.join(wandb.run.dir, "model.h5"))

# Save a model file manually from the current directory:
wandb.save('model.h5')

# Save all files that currently exist containing the substring "ckpt":
wandb.save('../logs/*ckpt*')

# Save any files starting with "checkpoint" as they're written to:
wandb.save(os.path.join(wandb.run.dir, "checkpoint*"))

3.2.2 恢复模型

# restore the model file "model.h5" from a specific run by user "lavanyashukla"
# in project "save_and_restore" from run "10pr4joa"
best_model = wandb.restore('model.h5', run_path="lavanyashukla/save_and_restore/10pr4joa")

# use the "name" attribute of the returned object if your framework expects a filename, e.g. as in Keras
model.load_weights(best_model.name)

四、工具总结

wnadb简而言之就是颜值高、气质好,可视化交互性强,支持所有深度学习框架,包括TF、PyTorch、Keras、sci-learn、HF和XGboost等。最最最关键的是其内嵌在计算机中,不需要写可视化代码。总而言之:没用过的赶快用,用过的都说好!

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值