Ray Tune 模块
Tune
Tune是一个超参数整定模块,他以’trials’来构建起每一次尝试。为’trials’利用Scheduler作为调度器。可以使用包括PBT,AsyncHyperBand在内的多种超参数整定方法。
如何使用?
根据上述所述,分为以下几步:
- 根据自己的需求构建一个trials,可理解为一个训练epoch,该trials需继承Tune.Trainable类
- 选择合适的Schedulers
- 调用ray.tune.run(),其中trials作为run的run_or_experiment 传入
例子:mnist-pytorch模型训练超参数整定
借助mnist-pytorch官方例子进行解释
1.构建一个trial
在本例中,一个trial具有以下几步
- 训练一轮
- 评估此轮模型效果
- 返回评估指标
class TrainMNIST(tune.Trainable):
def _setup(self, config):
# 类似于__init__函数,用于初始化相关配置
# 1.读数据:self.data_loader = ...
# 2.构建模型 : self.model = ...
# 3.优化器: self.optimizer = ...
#... 具体源码见官方教程
def _train(self):
#训练模型
train(
self.model,