缩放器
一个用于管理深度学习模型、超参数和数据集的小型库,旨在使培训深度学习模型变得容易和可复制。
开始
zookeeper允许您使用click和TensorFlow Datasets构建命令行界面,用很少的锅炉板来训练深度学习模型。它可以帮助您以一种框架不可知和有效的方式构建您的机器学习项目。
ZooKER深受Tensor2Tensor和Fairseq启发,但被设计成一个库,使其轻量级和非常灵活。
安装pip install zookeeperpip install colorama # optional for colored console output
注册表
ZooKeeper跟踪数据预处理、模型和超参数,以便您可以从命令行按名称引用它们。
数据集和预处理
tensorflow数据集提供可以自动下载的many popular datasets。
在下面我们将使用MNIST,并为图像定义一个default预处理,该预处理将图像缩放到[0, 1],并对类标签使用一个热编码:importtensorflowastffromzookeeperimportcli,build_train,HParams,registry,PreprocessingclassImageClassification(Preprocessing):@propertydefkwargs(self):return{"input_shape":self.features["image"].shape,"num_classes":self.features["label"].num_classes,}definputs(self,data):returntf.cast(data["image"],tf.float32)defoutputs(self,data):returntf.one_hot(data["label"],self.features["label"].num_classes)@registry.register_preprocess("mnist")classdefault(ImageClassification):definputs(self,data):returnsuper().inputs(data)/255
型号
接下来我们将注册一个名为cnn的模型。我们将使用Keras API进行此操作:@registry.register_modeldefcnn(hp,input_shape,num_classes):returntf.keras.models.Sequential([tf.keras.layers.Conv2D(hp.filters[0],(3,3),activation=hp.activation,input_shape=input_shape),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Conv2D(hp.filters[1],(3,3),activation=hp.activation),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Conv2D(hp.filters[2],(3,3),activation=hp.activation),tf.keras.layers.Flatten(),tf.keras.layers.Dense(hp.filters[3],activation=hp.activation),tf.keras.layers.Dense(num_classes,activation="softmax"),])
超参数
对于每个模型,我们可以注册一个或多个超参数集,这些超参数集将在调用时传递给模型函数:@registry.register_hparams(cnn)classbasic(HParams):activation="relu"batch_size=32filters=[64,64,64,64]learning_rate=1e-3@propertydefoptimizer(self):returntf.keras.optimizers.Adam(self.learning_rate)
训练循环
为了训练上面注册的模型,我们需要编写一个自定义的训练循环。动物园管理员将把所有的东西都绑在一起:@cli.command()@build_train()deftrain(build_model,dataset,hparams,output_dir):"""Start model training."""model=build_model(hparams,**dataset.preprocessing.kwargs)model.compile(optimizer=hparams.optimizer,loss="categorical_crossentropy",metrics=["categorical_accuracy","top_k_categorical_accuracy"],)model.fit(dataset.train_data(hparams.batch_size),steps_per_epoch=dataset.train_examples//hparams.batch_size,validation_data=dataset.validation_data(hparams.batch_size),validation_steps=dataset.validation_examples//hparams.batch_size,)
这将注册名为train的click命令,该命令可以从命令行执行。
命令行界面
要使刚创建的文件可执行,我们将在底部添加以下行:if__name__=="__main__":cli()
如果要在单独的文件中注册模型,请确保在调用cli之前导入它们,以允许zookeeper正确注册它们。要将cli作为可执行命令安装,请签出click的^{} integration。
用法
ZooKeeper已经提供了prepare、plot和tensorboard命令,但现在还包括我们在上面创建的train命令:python examples/train.py --helpUsage: train.py [OPTIONS] COMMAND [ARGS]...Options:--help Show this message and exit.Commands:install-completion Install shell completion.plot Plot data examples.prepare Downloads and prepares datasets for reading.tensorboard Start TensorBoard to monitor model training.train Start model training.
为了训练我们刚刚注册的跑步模式:python examples/train.py train cnn --dataset mnist --hparams-set basic --hparams batch_size=64
多个参数用逗号分隔,字符串应不带引号:python examples/train.py train cnn --dataset mnist --hparams-set basic --hparams batch_size=32,actvation=relu
欢迎加入QQ群-->: 979659372
推荐PyPI第三方库