神经网络--保存和载入训练数据

前言

  大多数神经网络训练时间比较长,为了避免意外导致训练结果丢失,我们要经常保存和载入训练过程的数据。

保存和载入训练过程

  我们将从保存训练过程、载入保存的训练过程并继续训练、强制重新训练等方面来介绍
##保存训练过程
我们以上一篇的身份证男女识别模型为基础,进行演示,具体代码如下:

# Author:北京
# QQ:838262020
# time:2021/4/21
import tensorflow as tf
import random

random.seed()

x = tf.placeholder(tf.float32)
yTrain = tf.placeholder(tf.float32)

# w1 = tf.Variable(tf.random_normal([4, 8], mean=0.5, stddev=0.1), dtype=tf.float32)
w1 = tf.Variable(tf.random_normal([4, 32], mean=0.5, stddev=0.1), dtype=tf.float32)
b1 = tf.Variable(0, dtype=tf.float32)
xr = tf.reshape(x, [1, 4])
n1 = tf.nn.tanh(tf.matmul(xr, w1) + b1)

# w2 = tf.Variable(tf.random_normal([8, 2], mean=0.5, stddev=0.1), dtype=tf.float32)
w2 = tf.Variable(tf.random_normal([32, 16], mean=0.5, stddev=0.1), dtype=tf.float32)
b2 = tf.Variable(0, dtype=tf.float32)

n2 = tf.matmul(n1, w2) + b2

w3 = tf.Variable(tf.random_normal([16, 2], mean=0.5, stddev=0.1), dtype=tf.float32)
b3 = tf.Variable(0, dtype=tf.float32)

n3 = tf.matmul(n2, w3) + b3



y = tf.nn.softmax(tf.reshape(n3, [2]))

loss = tf.reduce_mean(tf.square(y - yTrain))

optimizer = tf.train.RMSPropOptimizer(0.01)
train = optimizer.minimize(loss)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

lossSum = 0.0
for i in range(5000):
    # 随机产生[0,9]的4位整数,模拟身份证后4位
    xDataRandom = [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10),
                   int(random.random() * 10)]
    # 判断倒数第2位数字奇数或者偶数来模型对应的性别男女
    if xDataRandom[2] % 2 == 0:
        yTrainDataRandom = [0, 1]
    else:
        yTrainDataRandom = [1, 0]

    result = sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})
    lossSum = lossSum + float(result[len(result) - 1])
    print("i:%d,loss:%10.10f,avgLoss:%10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))
trainResultPath="./save/idcard"
print("保存中.....")
tf.train.Saver().save(sess,save_path=trainResultPath)
print(("保存完成"))

  其中保存训练过程代码如下:

trainResultPath="./save/idcard"
tf.train.Saver().save(sess,save_path=trainResultPath)

  第一条语句指明保存的路径:当前项目下的save文件夹下,文件名为idcard。
  第二条语句调用TensorFlow中的train包中的Saver函数返回的Saver对象的save成员函数来保存。

载入保存的训练过程并继续训练

  对上面已经保存的训练过程进行载入和继续训练。具体代入如下:

# Author:北京
# QQ:838262020
# time:2021/4/21
import tensorflow as tf
import random
import os

trainResultPath = "./save/idcard"

random.seed()

x = tf.placeholder(tf.float32)
yTrain = tf.placeholder(tf.float32)

w1 = tf.Variable(tf.random_normal([4, 32], mean=0.5, stddev=0.1), dtype=tf.float32)
b1 = tf.Variable(0, dtype=tf.float32)
xr = tf.reshape(x, [1, 4])

n1 = tf.nn.tanh(tf.matmul(xr, w1) + b1)

w2 = tf.Variable(tf.random_normal([32, 16], mean=0.5, stddev=0.1), dtype=tf.float32)
b2 = tf.Variable(0, dtype=tf.float32)

n2 = tf.matmul(n1, w2) + b2

w3 = tf.Variable(tf.random_normal([16, 2], mean=0.5, stddev=0.1), dtype=tf.float32)
b3 = tf.Variable(0, dtype=tf.float32)

n3 = tf.matmul(n2, w3) + b3

y = tf.nn.softmax(tf.reshape(n3, [2]))

loss = tf.reduce_mean(tf.square(y - yTrain))

optimizer = tf.train.RMSPropOptimizer(0.01)

train = optimizer.minimize(loss)

sess = tf.Session()

# 载入保存的数据进行训练
if os.path.exists(trainResultPath + '.index'):
    print("加载:%s" % trainResultPath)
    tf.train.Saver().restore(sess, save_path=trainResultPath)
else:
    print("加载路径不存在:%s" % trainResultPath)
    sess.run(tf.global_variables_initializer())

# sess.run(tf.global_variables_initializer())

lossSum = 0.0
for i in range(5000):
    # 随机产生[0,9]的4位整数,模拟身份证后4位
    xDataRandom = [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10),
                   int(random.random() * 10)]
    # 判断倒数第2位数字奇数或者偶数来模型对应的性别男女
    if xDataRandom[2] % 2 == 0:
        yTrainDataRandom = [0, 1]
    else:
        yTrainDataRandom = [1, 0]

    result = sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})
    lossSum = lossSum + float(result[len(result) - 1])
    print("i:%d,loss:%10.10f,avgLoss:%10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))

print("保存中.....")
tf.train.Saver().save(sess, save_path=trainResultPath)

  其中载入和进行训练过程代码如下:

# 载入保存的数据进行训练
if os.path.exists(trainResultPath + '.index'):
    print("加载:%s" % trainResultPath)
    tf.train.Saver().restore(sess, save_path=trainResultPath)
else:
    print("加载路径不存在:%s" % trainResultPath)
    sess.run(tf.global_variables_initializer())

  判断以保存的训练过程文件是否存在,如果存在调用TensorFlow下的train包的Saver函数返回Saver对象的restore成员函数进行数据的载入和训练。如果不存在,初始化变量开始训练。

通过命令行参数控制强制重新训练

  我们可以通过命令行来控制是否强制重新开始训练,放弃原本训练的模型。具体代码如下:

# Author:北京
# QQ:838262020
# time:2021/4/22

import tensorflow as tf
import random
import os
import sys

# 通过命令行参数控制是否强制重新开始训练
ifRestartT = False
argt = sys.argv[1:]

for v in argt:
    if v == "-restart":
        ifRestartT = True

trainResultPath = "./save/idcard"

random.seed()

x = tf.placeholder(tf.float32)
yTrain = tf.placeholder(tf.float32)

w1 = tf.Variable(tf.random_normal([4, 32], mean=0.5, stddev=0.1), dtype=tf.float32)
b1 = tf.Variable(0, dtype=tf.float32)
xr = tf.reshape(x, [1, 4])

n1 = tf.nn.tanh(tf.matmul(xr, w1) + b1)

w2 = tf.Variable(tf.random_normal([32, 16], mean=0.5, stddev=0.1), dtype=tf.float32)
b2 = tf.Variable(0, dtype=tf.float32)

n2 = tf.matmul(n1, w2) + b2

w3 = tf.Variable(tf.random_normal([16, 2], mean=0.5, stddev=0.1), dtype=tf.float32)
b3 = tf.Variable(0, dtype=tf.float32)

n3 = tf.matmul(n2, w3) + b3

y = tf.nn.softmax(tf.reshape(n3, [2]))

loss = tf.reduce_mean(tf.square(y - yTrain))

optimizer = tf.train.RMSPropOptimizer(0.01)

train = optimizer.minimize(loss)

sess = tf.Session()

if ifRestartT==True:
    print("强制重新开始训练")
    sess.run(tf.global_variables_initializer())
# 载入保存的数据进行训练
elif os.path.exists(trainResultPath + '.index'):
    print("加载:%s" % trainResultPath)
    tf.train.Saver().restore(sess, save_path=trainResultPath)
else:
    print("加载路径不存在:%s" % trainResultPath)
    sess.run(tf.global_variables_initializer())

# sess.run(tf.global_variables_initializer())

lossSum = 0.0
for i in range(5):
    # 随机产生[0,9]的4位整数,模拟身份证后4位
    xDataRandom = [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10),
                   int(random.random() * 10)]
    # 判断倒数第2位数字奇数或者偶数来模型对应的性别男女
    if xDataRandom[2] % 2 == 0:
        yTrainDataRandom = [0, 1]
    else:
        yTrainDataRandom = [1, 0]

    result = sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})
    lossSum = lossSum + float(result[len(result) - 1])
    print("i:%d,loss:%10.10f,avgLoss:%10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))

print("保存中.....")
tf.train.Saver().save(sess, save_path=trainResultPath)

  其中制重新开始训练代码如下:

# 通过命令行参数控制是否强制重新开始训练
ifRestartT = False
argt = sys.argv[1:]

for v in argt:
    if v == "-restart":
        ifRestartT = True

  我们可以通过命令行输入:
在这里插入图片描述在训练模型文件后加上 -restart来判断是否要重新训练模型。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是基于TensorFlow神经网络训练代码,并加了注释,可以保存模型、加载模型并载入新的数据进行继续训练: ```python import tensorflow as tf # 载入数据集 mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() # 数据预处理 x_train, x_test = x_train / 255.0, x_test / 255.0 # 定义模型 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10) ]) # 定义损失函数 loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) # 编译模型 model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy']) # 加载或预训练模型 try: model = tf.keras.models.load_model('my_model.h5') print('已加载模型,继续训练...') except: print('未找到模型,重新预训练...') # 预训练模型 model.fit(x_train, y_train, epochs=5) # 保存模型 model.save('my_model.h5') # 载入新的数据并继续训练 new_x_train, new_y_train = mnist.load_data()[0] new_x_train, new_x_test = new_x_train / 255.0, x_test / 255.0 model.fit(new_x_train, new_y_train, epochs=5) # 保存模型 model.save('my_model.h5') # 评估模型 model.evaluate(x_test, y_test, verbose=2) ``` 这段代码在预训练模型时,如果已经存在名为'my_model.h5'的模型文件,就会加载该模型并继续训练;如果不存在,则重新预训练。预训练完成后,将模型保存到'my_model.h5'文件中。接着,加载新的数据并对模型进行继续训练。最后,对模型进行了一次评估。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值