PyTorch - 31 - Training Loop Run Builder - 神经网络实验
Using The RunBuilder
Class
本集以及本系列最后几集中的目的是使自己处于能够有效地尝试我们构建的训练过程的位置。因此,我们将扩展在超参数实验中该情节涉及的内容。我们将使那里看到的更加干净。
我们将构建一个名为RunBuilder的类。但是,在我们看如何构建类之前。让我们看看它将允许我们做什么。我们将从进口开始。
from collections import OrderedDict
from collections import namedtuple
from itertools import product
我们正在从集合中导入OrderedDict和namedtuple,并从itertools中导入了一个名为product的函数。这个product()函数是我们上次看到的函数,它在给定多个列表输入的情况下计算笛卡尔乘积。
好的。这是RunBuilder类,它将构建用于定义运行的参数集。看到如何使用后,我们将看到它的工作原理。
class RunBuilder():
@staticmethod
def get_runs(params):
Run = namedtuple('Run', params.keys())
runs = []
for v in product(*params.values()):
runs.append(Run(*v))
return runs
关于使用此类的主要注意事项是它具有一个称为get_runs()的静态方法。该方法将为我们提供基于传入参数构建的运行结果。
现在定义一些参数。
params = OrderedDict(
lr = [.01, .001]
,batch_size = [1000, 10000]
)
在这里,我们在字典中定义了一组参数和值。我们有一组学习率和一组批次大小,我们想尝试一下。当我们说“尝试”时,是指我们要针对字典中的每个学习率和每个批次大小进行一次训练。
要获得这些运行,我们只需调用RunBuilder类的get_runs()函数,并传入我们要使用的参数即可。
> runs = RunBuilder.get_runs(params)
> runs
[
Run(lr=0.01, batch_size=1000),
Run(lr=0.01, batch_size=10000),
Run(lr=0.001, batch_size=1000),
Run(lr=0.001, batch_size=10000)
]
很好,我们可以看到RunBuilder类已经构建并返回了四个运行的列表。这些运行中的每一个都有学习率和定义运行的批处理大小。
我们可以通过索引到列表来访问单个运行,如下所示:
> run = runs[0]
> run
Run(lr=<