功能
- 版本管理:wandb.Artifact(),可以上传一些文件/文件夹到云端
- 监测,profile
- 支持输出图片、视频、表格、html等多媒体格式
- 超参数搜索,等…
注册,安装,登录
- 进入官网,注册一个账号
- 命令行安装
pip install wandb
- 终端登录:输入如下命令后回车
wandb login
- 根据终端提示,打开网页复制API key,粘贴到终端
使用(可参考官方文档,源码)
- wandb.init() 初始化,有几个常用参数:
project:指定项目名称,每次运行是一个run
entity:团队名称/用户名;没指定entity,则记录默认发送到用户名下
name:指定本次run的名称
config:是一个字典,可以记录本次训练的配置和超参数在config里。
要往config里头追加参数:wandb.config.updata({“para-name”: para-value}) - wandb.log()记录相关指标,传入字典
- wandb.finish()
示例:
import wandb # 导包
wandb.init(project="this project",
entity="xxx", # 用户名/团队名
name="first run",
config={
"lr": 1e-4,
"epoch": 1,
"data": 'cifar10'
})
for i in range(1,11):
loss = 20-i
wandb.log({"loss": loss})
wandb.finish()
运行完后,观察终端的返回结果。点进链接,查看可视化界面。可以查看系统的使用情况(监测服务器运行时环境),可以查看不同run(运行一次是一个run)的对比结果。
观察图片的输出
test_dataset = datasets.CIFAR10(root="./cifar10", train=False, download=False, transform=torchvision.transforms.ToTensor())
wandb.Image()的输入参数:numpy,pil,tensor
import wandb
from PIL import Image
wandb.init(project="this project",
config={
"lr": 1e-3,
"epoch": 3,
"data": 'cifar10'
})
for i in range(3):
img = test_dataset[i][0] # pil or tensor
wandb.log({"img": wandb.Image(img)})
# img.save("./img.jpg")
wandb.finish()
跟踪模型(参数,梯度等)
wandb.watch(),常用参数如下:
model:要监控的模型
log:要记录的指标,该参数有4个可选值:gradients(默认),parameters,all,None
可以通过观察参数值是否趋于稳定来判断模型的学习情况,通过看梯度是否趋于0来观察模型收敛情况。
log_freq:记录指标的频率,默认值1000个step
log_graph:是否可视化模型结构,默认False