PyTorch - 32 - CNN训练循环重构 - 同时进行超参数测试
- Cleaning Up The Training Loop And Extracting Classes
- The Two Classes We Will Build
- Building The RunManger For Training Loop Runs
- What Are Code Smells?
- Refactoring By Extracting A Class
- Extracting Classes Creates Layers Of Abstraction
- Beginning A Training Loop Run
- Tracking Our Training Loop Performance
- What Does It Feel Like To Be Wrong?
Cleaning Up The Training Loop And Extracting Classes
当我们在训练循环中退出几集时,我们建立了很多功能,使我们可以尝试许多不同的参数和值,并且还使训练循环中的调用需求可以得到结果进入TensorBoard。
所有这些工作都有所帮助,但是我们的培训循环现在非常拥挤。在本集中,我们将使用上次构建的RunBuilder类并构建一个名为RunManager的新类,清理训练循环,并为进行更多的实验打下基础。
我们的目标是能够在顶部添加参数和值,并在多次训练中测试或尝试所有值。
例如,在这种情况下,我们要使用两个参数lr和batch_size,对于batch_size,我们要尝试两个不同的值。这总共给了我们两次训练。批次大小不同时,两次运行的学习率相同。
params = OrderedDict(
lr = [.01]
,batch_size = [1000, 2000]
)
对于结果,我们希望看到并能够比较两次运行。
run | epoch | loss | accuracy | epoch duration | run duration | lr | batch_size |
---|---|---|---|---|---|---|---|
1 | 1 | 0.983 | 0.618 | 48.697 | 50.563 | 0.01 | 1000 |
1 | 2 | 0.572 | 0.777 | 19.165 | 69.794 | 0.01 | 1000 |
1 | 3 | 0.468 | 0.827 | 19.366 | 89.252 | 0.01 | 1000 |
1 | 4 | 0.428 | 0.843 | 18.840 | 108.176 | 0.01 | 1000 |
1 | 5 | 0.389 | 0.857 | 19.082 | 127.320 | 0.01 | 1000 |
2 | 1 | 1.271 | 0.528 | 18.558 | 19.627 | 0.01 | 2000 |
2 | 2 | 0.623 | 0.757 | 19.822 | 39.520 | 0.01 | 2000 |
2 | 3 | 0.526 | 0.791 | 21.101 | 60.694 | 0.01 | 2000 |
2 | 4 | 0.478 | 0.814 | 20.332 | 81.110 | 0.01 | 2000 |
2 | 5 | 0.440 | 0.835 | 20.413 | 101.600 | 0.01 | 2000 |
The Two Classes We Will Build
为此,我们需要建立两个新类。 在上一集中,我们构建了名为RunBuilder的第一堂课。 它被称为顶部。
for run in RunBuilder.get_runs(params):
现在,我们需要构建此RunManager类,该类将使我们能够管理运行循环中的每个运行。 RunManager实例将使我们能够抽出许多繁琐的TensorBoard调用,并允许我们添加其他功能。
我们将看到,随着参数和运行次数的增加,TensorBoard将开始崩溃,作为一种可行的解决方案,可用于检查结果。
在我们每次运行的不同阶段,将调用RunManager。 在运行阶段和纪元阶段的开始和结束时,我们都会有呼叫。 我们还将致电跟踪每个时期内的损失和正确预测的数量。 最后,最后,我们将运行结果保存到磁盘。
让我们看看如何构建此RunManager类。
Building The RunManger For Training Loop Runs
让我们从import开始吧:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from IPython.display import display, clear_output
import pandas as pd
import time
import json
from itertools import product
from collections import namedtuple
from collections import OrderedDict
首先,我们使用class关键字声明该类。
class RunManager():
接下来,我们将定义类构造函数。
def __init__(self):
self.epoch_count = 0
self