本文,就来给大家介绍一款新型的机器学习可视化工具,能够让人工智能研发过程变得更加简单明了。
引言
人工智能、深度学习方向的项目,和数据可视化是紧密相连的。
模型训练过程中梯度下降过程是什么样的?损失函数的走向如何?训练模型的准确度怎么变化的?
清楚这些数据,对我们模型的优化至关重要。
由于人工智能项目往往伴随着巨大数据量,用肉眼去逐个数据查看、分析是不现实的。这时候就需要用到数据可视化和日志分析报告。
TensorFlow自带的Tensorboard在模型和训练过程可视化方面做得越来越好。但是,也越来越臃肿,对于初入人工智能的同学来说有一定的门槛。
人工智能方面的项目变得越来越规范化,以模型训练、数据集准备为例,目前很多大公司已经发布了各自的自动机器学习平台,让工程师把更多精力放在优化策略上,而不是在准备数据、数据可视化方面。
SwanLab
训练过程社区:https://swanlab.cn/benchmarks
swanlab,来自中国团队情感机器,这款工具能够帮助跟踪你的机器学习项目。它能够自动记录模型训练过程中的超参数和输出指标,然后可视化和比较结果,并快速与同事共享结果。
通过swanlab,能够给你的机器学习项目带来强大的交互式可视化调试体验,能够自动化记录Python脚本中的图标,并且实时在网页仪表盘展示它的结果,例如,损失函数、准确率、召回率,它能够让你在最短的时间内完成机器学习项目可视化图片的制作。
总而言之,swanlab有这些特点:
- 整体UI交互体验拉满,很好看
- 有云端版,所以手机上也能看实验
- 适配框架很多,有接近30个,基本上主流的不主流的都覆盖了;因为是中国团队的关系,也适配了很多国产框架(LLaMA Factory、XTuner、ModelScope Swift等等)
- 实验全流程记录,超参数记录,日志记录,硬件环境记录,GPU实时监控,Python库记录,一体化表格对比
- 支持华为昇腾显卡,应该是这类工具里唯一一款能记录昇腾NPU的显存变化的
- 支持多人团队使用
也就是说,swanlab并不单纯的是一款数据可视化工具,还有非常丰富的团队协作功能。
swanlab另外一大亮点的就是强大的兼容性,它能够和Jupyter、Pytorch、TensorFlow、Keras、HuggingFace Transformer、sfast.ai、LightGBM、XGBoost一起结合使用。
因此,它不仅可以给你带来时间和精力上的节省,还能够给你的结果带来质的改变。
训练过程例子
- MNIST:https://swanlab.cn/@ZeyiLin/MNIST-example/runs/4plp6w0qehoqpt0uq2tcy/chart
- YOLO:h==ttps://swanlab.cn/@ZeyiLin/ultratest/runs/==yux7vclmsmmsar9ear7u5/chart
- BERT:https://swanlab.cn/@ZeyiLin/BERT/runs/b1bf2m5ituh0nw2cijiia/chart
极简教程
1. 安装库
pip install swanlab
2. 创建账户
swanlab login
3. 初始化
import swanlab
swanlab.init(project="my-project")
4. 声明超参数
swanlab.config.dropout=0.2
swanlab.config.hidden_layer_size=128
5. 记录日志
for epoch in range(10):
loss = 1 - 1/epoch
swanlab.log({'epoch': epoch, 'loss': loss})
6.测试demo
import swanlab
import random
# 创建一个SwanLab项目
swanlab.init(
# 设置项目名
project="my-awesome-project",
# 设置超参数
config={
"learning_rate": 0.02,
"architecture": "CNN",
"dataset": "CIFAR-100",
"epochs": 10
}
)
# 模拟一次训练
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
loss = 2 ** -epoch + random.random() / epoch + offset
# 记录训练指标
swanlab.log({"acc": acc, "loss": loss})
# [可选] 完成训练,这在notebook环境中是必要的
swanlab.finish()
使用swanlab以后,log、日志和环境信息将会同步到cloud。
PyTorch应用SwanLab
我们以一个MNIST手写体识别任务为例,展示swanlab的用法。
安装必要的库:
torch
torchvision
swanlab
快速安装命令:
pip install torch torchvision swanlab
完整代码:
import os
import torch
from torch import nn, optim, utils
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import swanlab
# CNN网络构建
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
# 1,28x28
self.conv1 = nn.Conv2d(1, 10, 5) # 10, 24x24
self.conv2 = nn.Conv2d(10, 20, 3) # 128, 10x10
self.fc1 = nn.Linear(20 * 10 * 10, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
in_size = x.size(0)
out = self.conv1(x) # 24
out = F.relu(out)
out = F.max_pool2d(out, 2, 2) # 12
out = self.conv2(out) # 10
out = F.relu(out)
out = out.view(in_size, -1)
out = self.fc1(out)
out = F.relu(out)
out = self.fc2(out)
out = F.log_softmax(out, dim=1)
return out
# 捕获并可视化前20张图像
def log_images(loader, num_images=16):
images_logged = 0
logged_images = []
for images, labels in loader:
# images: batch of images, labels: batch of labels
for i in range(images.shape[0]):
if images_logged < num_images:
# 使用swanlab.Image将图像转换为wandb可视化格式
logged_images.append(swanlab.Image(images[i], caption=f"Label: {labels[i]}"))
images_logged += 1
else:
break
if images_logged >= num_images:
break
swanlab.log({"MNIST-Preview": logged_images})
def train(model, device, train_dataloader, optimizer, criterion, epoch, num_epochs):
model.train()
# 1. 循环调用train_dataloader,每次取出1个batch_size的图像和标签
for iter, (inputs, labels) in enumerate(train_dataloader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# 2. 传入到resnet18模型中得到预测结果
outputs = model(inputs)
# 3. 将结果和标签传入损失函数中计算交叉熵损失
loss = criterion(outputs, labels)
# 4. 根据损失计算反向传播
loss.backward()
# 5. 优化器执行模型参数更新
optimizer.step()
print('Epoch [{}/{}], Iteration [{}/{}], Loss: {:.4f}'.format(epoch, num_epochs, iter + 1, len(train_dataloader),
loss.item()))
# 6. 每20次迭代,用SwanLab记录一下loss的变化
if iter % 20 == 0:
swanlab.log({"train/loss": loss.item()})
def test(model, device, val_dataloader, epoch):
model.eval()
correct = 0
total = 0
with torch.no_grad():
# 1. 循环调用val_dataloader,每次取出1个batch_size的图像和标签
for inputs, labels in val_dataloader:
inputs, labels = inputs.to(device), labels.to(device)
# 2. 传入到resnet18模型中得到预测结果
outputs = model(inputs)
# 3. 获得预测的数字
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
# 4. 计算与标签一致的预测结果的数量
correct += (predicted == labels).sum().item()
# 5. 得到最终的测试准确率
accuracy = correct / total
# 6. 用SwanLab记录一下准确率的变化
swanlab.log({"val/accuracy": accuracy}, step=epoch)
if __name__ == "__main__":
#检测是否支持mps
try:
use_mps = torch.backends.mps.is_available()
except AttributeError:
use_mps = False
#检测是否支持cuda
if torch.cuda.is_available():
device = "cuda"
elif use_mps:
device = "mps"
else:
device = "cpu"
# 初始化swanlab
run = swanlab.init(
project="MNIST-example",
experiment_name="PlainCNN",
config={
"model": "ResNet18",
"optim": "Adam",
"lr": 1e-4,
"batch_size": 256,
"num_epochs": 10,
"device": device,
},
)
# 设置MNIST训练集和验证集
dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])
train_dataloader = utils.data.DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
val_dataloader = utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)
# (可选)看一下数据集的前16张图像
log_images(train_dataloader, 16)
# 初始化模型
model = ConvNet()
model.to(torch.device(device))
# 打印模型
print(model)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=run.config.lr)
# 开始训练和测试循环
for epoch in range(1, run.config.num_epochs+1):
swanlab.log({"train/epoch": epoch}, step=epoch)
train(model, device, train_dataloader, optimizer, criterion, epoch, run.config.num_epochs)
if epoch % 2 == 0:
test(model, device, val_dataloader, epoch)
# 保存模型
# 如果不存在checkpoint文件夹,则自动创建一个
if not os.path.exists("checkpoint"):
os.makedirs("checkpoint")
torch.save(model.state_dict(), 'checkpoint/latest_checkpoint.pth')