本节涉及点:
保存训练过程
载入保存的训练过程并继续训练
通过命令行参数控制是否强制重新开始训练
训练过程中的手动保存
保存训练过程前,程序征得同意
一、保存训练过程
以下方代码为例:
importtensorflow as tfimportrandom
random.seed()
x=tf.placeholder(tf.float32)
yTrain=tf.placeholder(tf.float32)
w1= tf.Variable(tf.random_normal([4, 8], mean=0.5, stddev=0.1), dtype=tf.float32)
b1= tf.Variable(0, dtype=tf.float32)
xr= tf.reshape(x, [1, 4])
n1= tf.nn.tanh(tf.matmul(xr, w1) +b1)
w2= tf.Variable(tf.random_normal([8, 2], mean=0.5, stddev=0.1), dtype=tf.float32)
b2= tf.Variable(0, dtype=tf.float32)
n2= tf.matmul(n1, w2) +b2
y= tf.nn.softmax(tf.reshape(n2, [2]))
loss= tf.reduce_mean(tf.square(y -yTrain))
optimizer= tf.train.RMSPropOptimizer(0.01)
train=optimizer.minimize(loss)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
lossSum= 0.0
for i in range(5):
xDataRandom= [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10), int(random.random() * 10)]if xDataRandom[2] % 2 ==0:
yTrainDataRandom= [0, 1]else:
yTrainDataRandom= [1, 0]
result= sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})
lossSum= lossSum + float(result[len(result) - 1])print("i: %d, loss: %10.10f, avgLoss: %10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))
trainResultPath= "./save/idcard2"
print("saving...")
tf.train.Saver().save(sess, save_path=trainResultPath)
i: 0, loss: 0.2790884972, avgLoss: 0.2790884972
i: 1, loss: 0.2675500214, avgLoss: 0.2733192593
i: 2, loss: 0.2441657931, avgLoss: 0.2636014372
i: 3, loss: 0.2675784826, avgLoss: 0.2645956986
i: 4, loss: 0.2452606559, avgLoss: 0.2607286900
saving...
'./save/idcard2'
解析:
首先用一个变量 trainResultPath 来指定保存训练过程数据的目录
这是一个字符串类型的变量,其中的小数点 “ . ” 表示 Python 程序执行的当前目录, “ / ” 用于分隔目录和子目录(windows 中一般用反斜杠 " \ " 来分隔 ),一般采用Linux 目录中的写法,兼容性更好
./save/idcard2 表示 保存的位置是 执行程序 idcard2.py 的目录的 save 的子目录下以 idcard2 为基本名称的一系列文件
下方图片中,以 idcard2 开头的文件分别保存了 模型和可变参数的信息,checkpoint 文件保存了一些基础信息
“.meta”文件:包含图形结构。
“.data”文件:包含变量的值。
“.index”文件:标识检查点。
“checkpoint”文件:具有最近检查点列表的协议缓冲区。
tf.train.Saver().save(sess, save_path=trainResultPath)
调用 tensorflow 下的train 包中的 saver 对象的 save 成员函数进行保存,第一个参数 纯如当前的会话对象(本程序中 是 sess),第二个参数 save_path 传入保存位置
二、载入保存的训练过程并继续训练
如果已经保存了训练数据,就可以用下面的代码 载入训练数据并继续训练
注意:如果使用的是 jupyter ,请再运行完毕 上方的代码 并保存结果到 ./save/idcard2 之后
重启服务再运行