一、模型保存与恢复
1.模型保存
saver = tf.train.Saver()
2.模型恢复
restore(self, sess, save_path)
二、模型的训练
此次用比较简单的卷积网络训练cifar10,实现图像的分类,今天的重点不在训练的网络结构上,模型的保存和恢复不仅可以保留上次的训练数据继续训练,还可以快速呈现之前的训练结果,话不多说下面上代码。
import tensorflow as tf
import os
from CIFAR import load_CIFAR10
def weight(shape, stddev):
init = tf.truncated_normal(shape=shape, stddev=stddev)
return tf.Variable(init)
def bais(shape: object, value: object) -> object:
init = tf.constant(value=value, dtype=tf.float32, shape=shape)
return tf.Variable(init)
def conv(X, W):
return tf.nn.conv2d(X, W, strides=[1, 1, 1, 1], padding="SAME")
def pool(X):
return tf.nn.max_pool(X, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding="SAME")
def forward(X):
# 卷积1
W1 = weight([5, 5, 3, 64], 5e-2)
b1 = bais([64], 0)
tf.layers.batch_normalization(X, 1)
con = tf.nn.relu(conv(X, W1) + b1)
pool1 = pool(con)
# 卷积2
W2 = weight([5, 5, 64, 64], 5e-2)
b2 = bais([64], 0.1)
con2 = tf.nn.relu(conv(pool1, W2) + b2)
pool2 = pool(con2)
pool2 = tf.reshape(pool2, [-1, 8 * 8 * 64])
# 第一全连接层
wc1 = weight([8 * 8 * 64, 384], 0.04)
bc1 = bais([384], 0.1)
fc1 = tf.nn.relu(tf.matmul(pool2, wc1) + bc1)
# 第二全连接层
wc2 = weight([384, 192], 0.04)
bc2 = bais([192], 0.1)
fc2 = tf.nn.relu(tf.matmul(fc1, wc2) + bc2)
# 第三全连接层
wc3 = weight([192, 10], 1 / 192.0)
bc3 = bais([10], 0)
f_out = tf.nn.bias_add(tf.matmul(fc2, wc3), bc3)
return f_out
def train(label, logits):
# 交叉熵
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=tf.one_hot(label, 10), logits=logits))
# 优化器选择
global_steps = tf.Variable(tf.constant(0))
optimizal = tf.train.AdamOptimizer(0.001).minimize(cross_entropy, global_step=global_steps)
correct_predict = tf.equal(tf.argmax(logits, 1), label)
accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32))
return optimizal, cross_entropy, accuracy
def evaluate(X_test, Y_test):
x = tf.placeholder(tf.float32, [None, 32, 32, 3])
y = tf.placeholder(tf.int64, [None])
logits = forward(x)
optimizal, cross_entropy, accuracy = train(y, logits)
with tf.Session() as sess:
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('D:/Python/class_10/mode')
if ckpt != None:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('no model!')
acc = sess.run(accuracy, feed_dict={x: X_test, y: Y_test})
print(acc)
def mian():
data_dir = 'cifar-10-python'
data_dir = os.path.join(data_dir, 'cifar-10-batches-py')
x_train, y_train, x_test, y_test = load_CIFAR10(data_dir)
X = tf.placeholder(tf.float32, [None, 32, 32, 3])
Y = tf.placeholder(tf.int64, [None])
logits = forward(X)
optimizal, cross_entropy, accuracy = train(Y, logits)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('D:/Python/class_10/mode')
if ckpt != None:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('no model!')
for i in range(4):
for j in range(780):
batch_x = x_train[64 * (j):64 * (j + 1), :]
batch_y = y_train[64 * (j):64 * (j + 1)]
_, cost, perdict = sess.run([optimizal, cross_entropy, accuracy], feed_dict={X: batch_x, Y: batch_y})
if j % 100 == 0:
print("第", i * j, "次的loss:", cost, "准确率:", perdict)
saver.save(sess, 'D:/Python/class_10/mode/model.ckpt')
if __name__ == "__main__":
data_dir = 'cifar-10-python'
data_dir = os.path.join(data_dir, 'cifar-10-batches-py')
x_train, y_train, x_test, y_test = load_CIFAR10(data_dir)
evaluate(x_test[0:64, :], y_test[0:64])
main函数时训练过程,最后一部分是恢复之前的训练网络,并给出测试集的准确率。下图是我电脑跑出来的结果:
今天内容比较少,但我感觉还是比较重要的,希望有更多的小伙伴能一起交流学习图像处理和深度学习方面的内容!