Tensorflow学习笔记:模型训练数据的保存和恢复的简单实例

#! /usr/bin/env python2
# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import argparse

'''
保存模型训练后参数的简单实例
'''
print('保存和恢复模型训练后参数的简单实例:')

#创建一个图
my_graph = tf.Graph()

with my_graph.as_default():
	var = tf.Variable(0, name='counter') #一个变量,初始值设置为0,但是要会话执行run才会被赋值
	#创建一个op, 实现var + 2
	step = tf.constant(2)
	newVar = tf.add(var, step)
	update = tf.assign(var, newVar)
	# 启动图后, 变量必须先经过`初始化` (init) op 初始化,
	# 首先必须增加一个`初始化` op 到图中.
	init_op = tf.initialize_all_variables()

	#创建saver来保存模型数据
	saver = tf.train.Saver()

#在会话中运行或测试图
def train_or_test(is_test):
	#创建会话,启动图
	with tf.Session(graph = my_graph) as sess:
		if is_test == False: #如果是训练模型
			print('Train begin...')
			sess.run(init_op) #先运行初始化操作
			print('var = %d' % sess.run(var)) #打印初始值

			#更新var,并打印
			for i in range(5):
				sess.run(update)
				print('[%d] var = %d' % (i, sess.run(var)))
				#保存每次迭代的结果,保存的文件从val_iter-1开始而不是0,这个有点搞不明白,有知道原因的麻烦给个留言 哈哈
				saver.save(sess, './model_data1/val_iter', global_step = i)

			saver.save(sess, './model_data1/val_final')

		else: #如果是测试模型
			print('Test begin...')
			for i in range(5)[1:]:
				iter_data_file = './model_data1/val_iter-' + str(i)
				#恢复每次迭代的结果
				saver.restore(sess, iter_data_file)
				print('[%d] var = %d' % (i, sess.run(var)))

			model_data = tf.train.latest_checkpoint('./model_data1/')
			print(model_data) # ./model_data1/val_final
			saver.restore(sess, model_data)
			print('read final var = %d' % sess.run(var))

#必须定义这个main入口
def main(_):
	train_or_test(ARGS.test)

if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument(
		'-t',
		'--test',
		#type = int,
		default = False,
		action = 'store_true', # 运行 ./model_train1.py -t或--test 则ARGS.test被置为True
		help = 'train: True, test: False.'
		)
  
	#ARGS, unparsed = parser.parse_known_args()
	#tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
	ARGS = parser.parse_args()
	print(ARGS)
	tf.app.run()

'''
命令:(1) ./model_train1.py  训练模型
	  (2) ./model_train1.py -t[--test] 测试模型
'''

'''
保存完 model_data1目录下出现:
	checkpoint      (具有最近检查点列表的协议缓冲区)
	val_final       (包含变量的值)
	val_final.meta  (包含图形结构)
	val_iter-1
	val_iter-1.meta
	...
	val_iter-4
	val_iter-4.meta
'''

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值