前言
大多数神经网络训练时间比较长,为了避免意外导致训练结果丢失,我们要经常保存和载入训练过程的数据。
保存和载入训练过程
我们将从保存训练过程、载入保存的训练过程并继续训练、强制重新训练等方面来介绍
##保存训练过程
我们以上一篇的身份证男女识别模型为基础,进行演示,具体代码如下:
# Author:北京
# QQ:838262020
# time:2021/4/21
import tensorflow as tf
import random
random.seed()
x = tf.placeholder(tf.float32)
yTrain = tf.placeholder(tf.float32)
# w1 = tf.Variable(tf.random_normal([4, 8], mean=0.5, stddev=0.1), dtype=tf.float32)
w1 = tf.Variable(tf.random_normal([4, 32], mean=0.5, stddev=0.1), dtype=tf.float32)
b1 = tf.Variable(0, dtype=tf.float32)
xr = tf.reshape(x, [1, 4])
n1 = tf.nn.tanh(tf.matmul(xr, w1) + b1)
# w2 = tf.Variable(tf.random_normal([8, 2], mean=0.5, stddev=0.1), dtype=tf.float32)
w2 = tf.Variable(tf.random_normal([32, 16], mean=0.5, stddev=0.1), dtype=tf.float32)
b2 = tf.Variable(0, dtype=tf.float32)
n2 = tf.matmul(n1, w2) + b2
w3 = tf.Variable(tf.random_normal([16, 2], mean=0.5, stddev=0.1), dtype=tf.float32)
b3 = tf.Variable(0, dtype=tf.float32)
n3 = tf.matmul(n2, w3) + b3
y = tf.nn.softmax(tf.reshape(n3, [2]))
loss = tf.reduce_mean(tf.square(y - yTrain))
optimizer = tf.train.RMSPropOptimizer(0.01)
train = optimizer.minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
lossSum = 0.0
for i in range(5000):
# 随机产生[0,9]的4位整数,模拟身份证后4位
xDataRandom = [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10),
int(random.random() * 10)]
# 判断倒数第2位数字奇数或者偶数来模型对应的性别男女
if xDataRandom[2] % 2 == 0:
yTrainDataRandom = [0, 1]
else:
yTrainDataRandom = [1, 0]
result = sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})
lossSum = lossSum + float(result[len(result) - 1])
print("i:%d,loss:%10.10f,avgLoss:%10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))
trainResultPath="./save/idcard"
print("保存中.....")
tf.train.Saver().save(sess,save_path=trainResultPath)
print(("保存完成"))
其中保存训练过程代码如下:
trainResultPath="./save/idcard"
tf.train.Saver().save(sess,save_path=trainResultPath)
第一条语句指明保存的路径:当前项目下的save文件夹下,文件名为idcard。
第二条语句调用TensorFlow中的train包中的Saver函数返回的Saver对象的save成员函数来保存。
载入保存的训练过程并继续训练
对上面已经保存的训练过程进行载入和继续训练。具体代入如下:
# Author:北京
# QQ:838262020
# time:2021/4/21
import tensorflow as tf
import random
import os
trainResultPath = "./save/idcard"
random.seed()
x = tf.placeholder(tf.float32)
yTrain = tf.placeholder(tf.float32)
w1 = tf.Variable(tf.random_normal([4, 32], mean=0.5, stddev=0.1), dtype=tf.float32)
b1 = tf.Variable(0, dtype=tf.float32)
xr = tf.reshape(x, [1, 4])
n1 = tf.nn.tanh(tf.matmul(xr, w1) + b1)
w2 = tf.Variable(tf.random_normal([32, 16], mean=0.5, stddev=0.1), dtype=tf.float32)
b2 = tf.Variable(0, dtype=tf.float32)
n2 = tf.matmul(n1, w2) + b2
w3 = tf.Variable(tf.random_normal([16, 2], mean=0.5, stddev=0.1), dtype=tf.float32)
b3 = tf.Variable(0, dtype=tf.float32)
n3 = tf.matmul(n2, w3) + b3
y = tf.nn.softmax(tf.reshape(n3, [2]))
loss = tf.reduce_mean(tf.square(y - yTrain))
optimizer = tf.train.RMSPropOptimizer(0.01)
train = optimizer.minimize(loss)
sess = tf.Session()
# 载入保存的数据进行训练
if os.path.exists(trainResultPath + '.index'):
print("加载:%s" % trainResultPath)
tf.train.Saver().restore(sess, save_path=trainResultPath)
else:
print("加载路径不存在:%s" % trainResultPath)
sess.run(tf.global_variables_initializer())
# sess.run(tf.global_variables_initializer())
lossSum = 0.0
for i in range(5000):
# 随机产生[0,9]的4位整数,模拟身份证后4位
xDataRandom = [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10),
int(random.random() * 10)]
# 判断倒数第2位数字奇数或者偶数来模型对应的性别男女
if xDataRandom[2] % 2 == 0:
yTrainDataRandom = [0, 1]
else:
yTrainDataRandom = [1, 0]
result = sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})
lossSum = lossSum + float(result[len(result) - 1])
print("i:%d,loss:%10.10f,avgLoss:%10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))
print("保存中.....")
tf.train.Saver().save(sess, save_path=trainResultPath)
其中载入和进行训练过程代码如下:
# 载入保存的数据进行训练
if os.path.exists(trainResultPath + '.index'):
print("加载:%s" % trainResultPath)
tf.train.Saver().restore(sess, save_path=trainResultPath)
else:
print("加载路径不存在:%s" % trainResultPath)
sess.run(tf.global_variables_initializer())
判断以保存的训练过程文件是否存在,如果存在调用TensorFlow下的train包的Saver函数返回Saver对象的restore成员函数进行数据的载入和训练。如果不存在,初始化变量开始训练。
通过命令行参数控制强制重新训练
我们可以通过命令行来控制是否强制重新开始训练,放弃原本训练的模型。具体代码如下:
# Author:北京
# QQ:838262020
# time:2021/4/22
import tensorflow as tf
import random
import os
import sys
# 通过命令行参数控制是否强制重新开始训练
ifRestartT = False
argt = sys.argv[1:]
for v in argt:
if v == "-restart":
ifRestartT = True
trainResultPath = "./save/idcard"
random.seed()
x = tf.placeholder(tf.float32)
yTrain = tf.placeholder(tf.float32)
w1 = tf.Variable(tf.random_normal([4, 32], mean=0.5, stddev=0.1), dtype=tf.float32)
b1 = tf.Variable(0, dtype=tf.float32)
xr = tf.reshape(x, [1, 4])
n1 = tf.nn.tanh(tf.matmul(xr, w1) + b1)
w2 = tf.Variable(tf.random_normal([32, 16], mean=0.5, stddev=0.1), dtype=tf.float32)
b2 = tf.Variable(0, dtype=tf.float32)
n2 = tf.matmul(n1, w2) + b2
w3 = tf.Variable(tf.random_normal([16, 2], mean=0.5, stddev=0.1), dtype=tf.float32)
b3 = tf.Variable(0, dtype=tf.float32)
n3 = tf.matmul(n2, w3) + b3
y = tf.nn.softmax(tf.reshape(n3, [2]))
loss = tf.reduce_mean(tf.square(y - yTrain))
optimizer = tf.train.RMSPropOptimizer(0.01)
train = optimizer.minimize(loss)
sess = tf.Session()
if ifRestartT==True:
print("强制重新开始训练")
sess.run(tf.global_variables_initializer())
# 载入保存的数据进行训练
elif os.path.exists(trainResultPath + '.index'):
print("加载:%s" % trainResultPath)
tf.train.Saver().restore(sess, save_path=trainResultPath)
else:
print("加载路径不存在:%s" % trainResultPath)
sess.run(tf.global_variables_initializer())
# sess.run(tf.global_variables_initializer())
lossSum = 0.0
for i in range(5):
# 随机产生[0,9]的4位整数,模拟身份证后4位
xDataRandom = [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10),
int(random.random() * 10)]
# 判断倒数第2位数字奇数或者偶数来模型对应的性别男女
if xDataRandom[2] % 2 == 0:
yTrainDataRandom = [0, 1]
else:
yTrainDataRandom = [1, 0]
result = sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})
lossSum = lossSum + float(result[len(result) - 1])
print("i:%d,loss:%10.10f,avgLoss:%10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))
print("保存中.....")
tf.train.Saver().save(sess, save_path=trainResultPath)
其中制重新开始训练代码如下:
# 通过命令行参数控制是否强制重新开始训练
ifRestartT = False
argt = sys.argv[1:]
for v in argt:
if v == "-restart":
ifRestartT = True
我们可以通过命令行输入:
在训练模型文件后加上 -restart来判断是否要重新训练模型。