当我开始训练我的估算器对象,
Tensorflow在打印时自动创建一个CheckpointSaverHook
INFO:tensorflow:Create CheckpointSaverHook.
这个自动创建的SaverHook将在培训开始和结束时保存我的模型。在
我想要的是每n个训练步骤创建一个检查点。为此,我创建了自己的储蓄挂钩,并在培训时将其传递给我的估计员。在saver_hook = tf.train.CheckpointSaverHook(
checkpoint_dir = model_dir,
save_steps = 100
)
model.train(input_fn,steps=1500,hooks=[saver_hook])
这在理论上是可行的,但是我自己的CheckpointSaverHook只保存*.meta文件,而自动创建的CheckpointSaverHook则保存*.meta、*.index和{}文件。在
我如何配置自己的SaverHook也能做到这一点呢?在
编辑:
添加了我的整个网络定义
网络.py在
^{pr2}$
在培训.py在from network import TFDotNet
from time import time
# settings
training_steps = 10000
mini_batch_size = 128
model_dir = 'neuralnet_data/02_networks/network01'
dataset_path = 'neuralnet_data/01_datasets/dataset.data'
# init dotnet
dotnet = TFDotNet(model_dir=model_dir)
# load dataset
print('loading dataset ...')
dataset = dotnet.load_dataset(dataset_path)
# split dataset
x_train, y_train, x_test, y_test = dotnet.split_dataset(dataset,0.1)
# train network
print('starting training ...')
t0 = time()
dotnet.train(x_train,y_train,steps=training_steps,batch_size=mini_batch_size)
print('Training took {}s'.format(time()-t0))