tf.keras中如何为训练的模型创建checkpoint

深度学习模型的训练可能需要几个小时,几天甚至几周的时间来训练。

如果中途意外停止,那就白训练了。

本文教你如何检查你的深度学习模型以及如何建立Checkpoint

1、Checkpoint设置成当验证数据集的分类精度提高时保存网络权重(monitor=’val_acc’ and mode=’max’)。权重存储在一个包含评价的文件中(weights-improvement – { val_acc = .2f } .hdf5)。

	model = Sequential()
	model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
	model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
	model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
	# Compile model
	model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
	# checkpoint
	filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
	checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
	callbacks_list = [checkpoint]
	# Fit the model
	model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)

2、保存最佳神经网络模型

	model = Sequential()
	model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
	model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
	model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
	# Compile model
	model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
	# checkpoint
	filepath="weights.best.hdf5"
	checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
	callbacks_list = [checkpoint]
	# Fit the model
	model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)

3、加载Checkpoint神经网络模型

	model = Sequential()
	model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
	model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
	model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
	# load weights
	model.load_weights("weights.best.hdf5")
	# Compile model (required to make predictions)
	model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
	print("Created model and loaded weights from file")
	# load pima indians dataset
	dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
	# split into input (X) and output (Y) variables
	X = dataset[:,0:8]
	Y = dataset[:,8]
	# estimate accuracy on whole dataset using loaded weights
	scores = model.evaluate(X, Y, verbose=0)
	print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值