一、工具介绍
wandb
全称Weights & Biases
,用来帮助开发者跟踪机器学习的项目。通过wandb
可以记录模型训练过程中指标的变化情况以及超参的设置,然后对输出的结果进行可视化的对比,帮助我们更好的分析模型在训练过程中的问题,并快速与同事进行团队协作。
wandb
可以将训练过程中的参数,上传到服务器上,通过登录wandb
来进行实施过程模型训练过程中参数和指标的变化。
wandb
的主体组成主要有以下四大模块,分别是:
- 仪表盘:跟踪实验、可视化结果
- 报告:保存和分享可复制的成果/结论
- Sweeps:通过调节超参数来优化模型
- Artifacts:数据集和模型版本化,流水线跟踪
二、工具使用
2.1 工具安装
pip install wandb
2.2 创建账号
在shell命令行中或官网注册页面注册免费账号
wandb login
在官网登录账号,点击
Settings
,滚动到下面,寻找API Keys
进行复制,在命令行中粘贴即可。
2.3 编写脚本
2.3.1 初始化W&B
在开始记录之前,在脚本开始处初始化wandb
# Inside my model training code
import wandb
wandb.init(project="my-project")
如果项目不存在,则自动创建项目。脚本的运行会同步到一个名称为「my-project」的项目中
2.3.2 声明超参数
通过对象wandb.config
保存超参数。
# Save model inputs and hyperparameters
config = wandb.config
config.dropout = 0.2
config.hidden_layer_size = 128
2.3.3 记录指标
在训练模型过程中记录指标(Metric),如损失(Loss)和准确率(ACC)。通过wandb.log
记录更为复杂的输出和结果,如直方图、图形和图像。
def my_train_loop():
wandb.watch(model) # Log gradients and model parameters
for epoch in range(10):
loss = 0 # change as appropriate :)
wandb.log({'epoch': epoch, 'loss': loss}) # Log metrics to visualize performance
2.3.4 保存文件
保存在路径wandb.run.d
中的全部内容都会被上传到W&B
,并在运行结束后与运行项保存在一起。
# by default, this will save to a new subfolder for files associated
# with your run, created in wandb.run.dir (which is ./wandb by default)
wandb.save("mymodel.h5")
# you can pass the full path to the Keras model API
model.save(os.path.join(wandb.run.dir, "mymodel.h5"))
当正常运行脚本时,后台进程会同步相关记录(终端输出、指标和文件)至云端。
若脚本从git库中运行,后台进程会同步记录git状态信息。
三、工具示例
wandb操作简单,是一款性能强大的深度学习网络可视化插件,支持多个深度学习网络框架,相关功能如下:
- 调参:存储训练参数,自动分析参数关联,助力快速调参
- 看板:同步本地实验结果,并进行实时可视化展示
- 报告:将结果自动同步到云端,永久记录实验过程,可用于模型改进探索及团队分享
- 自动记录CPU、GPU等系统硬件使用情况
- 自动记录控制台打印的训练日志,以及环境依赖等参数项
3.1 Log Metrics
3.1.1 简单示例
相关数据文件:apple.csv
import os
import pandas as pd
import wandb
apple = pd.read_csv('./data/apple.csv')
apple = apple[-1000:]
# Init
wandb.init(project='visualize-models', name='apple_metric')
# Log the metrics
for price in apple['close']:
wandb.log({"Stock Price": price})
3.1.2 自定义格式
可以通过matplotlib
进行图表对象的绘制,然后传递至wandb.log
import os
import wandb
import matplotlib.pyplot as plt
fibonacci = [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
plt.plot(fibonacci)
plt.ylabel('some interestering numbers')
wandb.init(project='visualize-models', name='matplotlib to wandb')
wandb.log({'plot': plt})
3.1.3 图像格式
如果图像是Numpy
数组,且数组包含浮点数,我们将其转换为0-255
之间的整数,则可手动指定模式或只提供PIL.Image
import os
import pandas as pd
import wandb
import matplotlib.pyplot as plt
img = './data/ladybug.jpg'
im = plt.imread(img)
wandb.init(project='visualize-models', name='image to wandb')
wandb.log({'img': [wandb.Image(im, caption='Cafe')]})
3.2 保存和恢复模型
3.2.1 保存模型
#"model.h5" is saved in wandb.run.dir & will be uploaded at the end of training
model.save(os.path.join(wandb.run.dir, "model.h5"))
# Save a model file manually from the current directory:
wandb.save('model.h5')
# Save all files that currently exist containing the substring "ckpt":
wandb.save('../logs/*ckpt*')
# Save any files starting with "checkpoint" as they're written to:
wandb.save(os.path.join(wandb.run.dir, "checkpoint*"))
3.2.2 恢复模型
# restore the model file "model.h5" from a specific run by user "lavanyashukla"
# in project "save_and_restore" from run "10pr4joa"
best_model = wandb.restore('model.h5', run_path="lavanyashukla/save_and_restore/10pr4joa")
# use the "name" attribute of the returned object if your framework expects a filename, e.g. as in Keras
model.load_weights(best_model.name)
四、工具总结
wnadb简而言之就是颜值高、气质好,可视化交互性强,支持所有深度学习框架,包括TF、PyTorch、Keras、sci-learn、HF和XGboost等。最最最关键的是其内嵌在计算机中,不需要写可视化代码。总而言之:没用过的赶快用,用过的都说好!