使用Ray Tune自动调参

本文介绍了如何利用RayTune库在PyTorch中进行自动超参数调优。首先,解释了RayTune的功能,包括集成的搜索算法和优化工具。接着,详细展示了使用步骤,包括安装、导入库、数据加载、模型构建、训练和测试、构建Trainable以及定义超参搜索空间。最后,通过示例代码演示了如何执行超参搜索,并解析最佳试验结果。整个过程帮助开发者高效地优化模型性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


前言

本文以PyTorch框架构建的卷积网络模型做分类任务为例介绍如何使用Ray Tune进行自动调参,相关代码引自官网文档


一、Ray Tune是什么?

Ray Tune是一个用来实验执行和超参数调优的Python包,其中集成了网格搜索、随机搜索、贝叶斯优化搜索(BayesOptSearch)等搜索算法以及Optuna, Hyperopt等优化工具。Ray Tune调参的模型可以是基于PyTorch, XGBoost, TensorFlow或Keras等框架构建的模型。

二、使用步骤

1.安装包

可以只安装 ray 下的 tune 包:

$ pip install -U "ray[tune]"

或安装整个 ray 包:

$ pip install ray

2.引入库

代码如下(示例):

import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler

3.读入数据(与Ray Tune无关)

示例代码如下:

def load_data(data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/.data.lock")):
        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset

4.构建神经网络模型(与Ray Tune无关)

示例代码如下:

class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

5.模型的训练和测试(与Ray Tune无关)

示例代码如下:

EPOCH_SIZE = 512
TEST_SIZE = 256

def train(model, optimizer, train_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # We set this just for the example to run quickly.
        if batch_idx * len(data) > EPOCH_SIZE:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()


def test(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            # We set this just for the example to run quickly.
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

6.构建“Trainable”

“Trainable”是一个需要在Tune.run()函数(见下文)运行时输入的参数,表示每一次的训练、调参及模型保存过程,可以用一个函数构建或用一个类(必须继承自tune.Trainable类)来构建。Trainable接受传入的参数config表示超参搜索空间,每次迭代tune.report()函数返回当前训练的结果,其余部分与正常的模型训练相同。本文以函数方式构建Trainable。

代码如下(示例):

def train_mnist(config):
    # 加载数据
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    train_loader = DataLoader(
        datasets.MNIST("~/data", train=True, download=True, transform=mnist_transforms),
        batch_size=64,
        shuffle=True)
    test_loader = DataLoader(
        datasets.MNIST("~/data", train=False, transform=mnist_transforms),
        batch_size=64,
        shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	
	# 定义模型
    model = ConvNet()
    model.to(device)

	# 优化器
    optimizer = optim.SGD(
        model.parameters(), lr=config["lr"], momentum=config["momentum"])
    
    # 训练过程
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)

        # 返回当前step验证结果到Ray Tune, 以决定是否停止等后续操作.
        tune.report(mean_accuracy=acc)

        if i % 5 == 0:
            # 保存当前模型文件.
            torch.save(model.state_dict(), "./model.pth")

在上面的代码中,每训练一个epoch后,tune.report()函数会返回当前的实验结果,以确定调参方向或者是否提前终止实验。

7.超参搜索

首先用户自定义超参搜索空间,示例代码如下:

# 定义超参搜索空间
search_space = {
    "lr": tune.sample_from(lambda spec: 10 ** (-10 * np.random.rand())),
    "momentum": tune.uniform(0.1, 0.9),
}

开始执行搜索过程,首先初始化ray,然后执行tune.run()函数即可,示例代码如下:

# 初始化,设置分配给ray的资源数目,默认会使用当前设备的所有资源
ray.init(num_cpus=4, num_gpus=4)

#开始执行搜索
analysis = tune.run(
    train_mnist,
    num_samples=20,
    scheduler=ASHAScheduler(metric="mean_accuracy", mode="max"),
    config=search_space,
)

由于在上文中的train_mnist方法中,除config之外没有别的参数传入,因此在run()函数中直接传入train_mnist即可,若在train_mnist()函数中有用户自定义需要传入的其他参数,则使用tune.with_parameters()函数传入参数,示例代码如下:

analysis = tune.run(
    tune.with_parameters(
            train_mnist,
            parameter1,
            parameter2,
            ...
        ),
    num_samples=20,  # 不同的超参实验次数
    scheduler=ASHAScheduler(metric="mean_accuracy", mode="max"),
    config=search_space,  # 超参搜索空间
)

实验结果以及各类配置参数等都可以通过analysis获取,示例如下:

best_trial = analysis.best_trial  # Get best trial
best_config = analysis.best_config  # Get best trial's hyperparameters
best_logdir = analysis.best_logdir  # Get best trial's logdir
best_checkpoint = analysis.best_checkpoint  # Get best trial's best checkpoint
best_result = analysis.best_result  # Get best trial's last results

# 实验结果输出
print("Best trial is:", best_trial)
print("Best config is:", best_config)
print("Best logdir is:", best_logdir)
print("Best checkpoint is:", best_checkpoint)
print("Best result is:", best_result)

总结

本文借助官方文档的例子,结合自己的使用,简单介绍了Ray Tune的用法,相关更详细的介绍和用法(比如更多搜索算法的选择等)见官网文档

最近发现PyTorch官网也提供了相关的教程,感兴趣的朋友移步到这里

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值