前言
Argparse是机器学习训练中非常常用的第三方库,它的主要作用是用来解析命令行参数。
通过定义所需的参数,可以从程序外部向程序内部传递各种参数。这样,用户在启动训练模型时,可以灵活地调整模型的参数,如学习率、训练轮数、数据集路径等,从而使得模型训练更加灵活和方便。
SwanLab 是一款面向研究人员的开源机器学习训练管理工具,类似Tensorboard或Wandb。
那么,如何将Argparse传入SwanLab的配置信息当中呢?
SwanLab:https://github.com/SwanHubX/SwanLab
安装方式pip install -U swanlab
实战
首先我们写一段Argparse代码,包含了几个超参数:
import argparse
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--epochs', type=int, default=20,
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0,
help='learning rate (default: 1.0)')
args = parser.parse_args()
我们将args包一层vars(),传入swanlab.init的config参数:
import swanlab
swanlab.init(
experiment_name="mnist_example",
description="A plain neural network basic on MNist",
config=vars(args)
)
这样当我们运行训练脚本时,Argparse的内容将会显示在实验看板的配置信息中:
完整代码:
import argparse
import swanlab
import random
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--epochs', type=int, default=20,
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0,
help='learning rate (default: 1.0)')
args = parser.parse_args()
swanlab.init(
experiment_name="mnist_example",
description="A plain neural network basic on MNist",
config=vars(args)
)
# 下面是模拟训练代码
offset = random.random() / 5
for epoch in range(1, args.epochs+1):
acc = 1 - 2**-epoch - random.random() / epoch - offset
loss = 2**-epoch + random.random() / epoch + offset
print(f"epoch={epoch}, accuracy={acc}, loss={loss}")
# Tracking index: 'loss' and 'accuracy'
swanlab.log({"loss": loss, "accuracy": acc})