tensowflow 训练 远程提交_Tensorflow 保存和载入训练过程

本节涉及点:

保存训练过程

载入保存的训练过程并继续训练

通过命令行参数控制是否强制重新开始训练

训练过程中的手动保存

保存训练过程前,程序征得同意

一、保存训练过程

以下方代码为例:

importtensorflow as tfimportrandom

random.seed()

x=tf.placeholder(tf.float32)

yTrain=tf.placeholder(tf.float32)

w1= tf.Variable(tf.random_normal([4, 8], mean=0.5, stddev=0.1), dtype=tf.float32)

b1= tf.Variable(0, dtype=tf.float32)

xr= tf.reshape(x, [1, 4])

n1= tf.nn.tanh(tf.matmul(xr, w1) +b1)

w2= tf.Variable(tf.random_normal([8, 2], mean=0.5, stddev=0.1), dtype=tf.float32)

b2= tf.Variable(0, dtype=tf.float32)

n2= tf.matmul(n1, w2) +b2

y= tf.nn.softmax(tf.reshape(n2, [2]))

loss= tf.reduce_mean(tf.square(y -yTrain))

optimizer= tf.train.RMSPropOptimizer(0.01)

train=optimizer.minimize(loss)

sess=tf.Session()

sess.run(tf.global_variables_initializer())

lossSum= 0.0

for i in range(5):

xDataRandom= [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10), int(random.random() * 10)]if xDataRandom[2] % 2 ==0:

yTrainDataRandom= [0, 1]else:

yTrainDataRandom= [1, 0]

result= sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})

lossSum= lossSum + float(result[len(result) - 1])print("i: %d, loss: %10.10f, avgLoss: %10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))

trainResultPath= "./save/idcard2"

print("saving...")

tf.train.Saver().save(sess, save_path=trainResultPath)

i: 0, loss: 0.2790884972, avgLoss: 0.2790884972

i: 1, loss: 0.2675500214, avgLoss: 0.2733192593

i: 2, loss: 0.2441657931, avgLoss: 0.2636014372

i: 3, loss: 0.2675784826, avgLoss: 0.2645956986

i: 4, loss: 0.2452606559, avgLoss: 0.2607286900

saving...

'./save/idcard2'

解析:

首先用一个变量 trainResultPath 来指定保存训练过程数据的目录

这是一个字符串类型的变量,其中的小数点 “ . ” 表示 Python 程序执行的当前目录, “ / ” 用于分隔目录和子目录(windows 中一般用反斜杠 " \ " 来分隔 ),一般采用Linux 目录中的写法,兼容性更好

./save/idcard2 表示 保存的位置是 执行程序 idcard2.py 的目录的 save 的子目录下以 idcard2 为基本名称的一系列文件

下方图片中,以 idcard2 开头的文件分别保存了 模型和可变参数的信息,checkpoint 文件保存了一些基础信息

“.meta”文件:包含图形结构。

“.data”文件:包含变量的值。

“.index”文件:标识检查点。

“checkpoint”文件:具有最近检查点列表的协议缓冲区。

tf.train.Saver().save(sess, save_path=trainResultPath)

调用 tensorflow 下的train 包中的 saver 对象的 save 成员函数进行保存,第一个参数 纯如当前的会话对象(本程序中 是 sess),第二个参数 save_path 传入保存位置

二、载入保存的训练过程并继续训练

如果已经保存了训练数据,就可以用下面的代码 载入训练数据并继续训练

注意:如果使用的是 jupyter ,请再运行完毕 上方的代码 并保存结果到 ./save/idcard2 之后

重启服务再运行

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值