tensorflow的模型保存与恢复

一、模型保存与恢复

1.模型保存

saver = tf.train.Saver()

2.模型恢复

restore(self, sess, save_path)

二、模型的训练

此次用比较简单的卷积网络训练cifar10,实现图像的分类,今天的重点不在训练的网络结构上,模型的保存和恢复不仅可以保留上次的训练数据继续训练,还可以快速呈现之前的训练结果,话不多说下面上代码。

import tensorflow as tf
import os

from CIFAR import load_CIFAR10


def weight(shape, stddev):
    init = tf.truncated_normal(shape=shape, stddev=stddev)
    return tf.Variable(init)


def bais(shape: object, value: object) -> object:
    init = tf.constant(value=value, dtype=tf.float32, shape=shape)
    return tf.Variable(init)


def conv(X, W):
    return tf.nn.conv2d(X, W, strides=[1, 1, 1, 1], padding="SAME")


def pool(X):
    return tf.nn.max_pool(X, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding="SAME")


def forward(X):
    # 卷积1
    W1 = weight([5, 5, 3, 64], 5e-2)
    b1 = bais([64], 0)
    tf.layers.batch_normalization(X, 1)
    con = tf.nn.relu(conv(X, W1) + b1)
    pool1 = pool(con)

    # 卷积2
    W2 = weight([5, 5, 64, 64], 5e-2)
    b2 = bais([64], 0.1)
    con2 = tf.nn.relu(conv(pool1, W2) + b2)

    pool2 = pool(con2)
    pool2 = tf.reshape(pool2, [-1, 8 * 8 * 64])
    # 第一全连接层
    wc1 = weight([8 * 8 * 64, 384], 0.04)
    bc1 = bais([384], 0.1)
    fc1 = tf.nn.relu(tf.matmul(pool2, wc1) + bc1)
    # 第二全连接层
    wc2 = weight([384, 192], 0.04)
    bc2 = bais([192], 0.1)
    fc2 = tf.nn.relu(tf.matmul(fc1, wc2) + bc2)
    # 第三全连接层
    wc3 = weight([192, 10], 1 / 192.0)
    bc3 = bais([10], 0)
    f_out = tf.nn.bias_add(tf.matmul(fc2, wc3), bc3)
    return f_out


def train(label, logits):
    # 交叉熵
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=tf.one_hot(label, 10), logits=logits))
    # 优化器选择
    global_steps = tf.Variable(tf.constant(0))
    optimizal = tf.train.AdamOptimizer(0.001).minimize(cross_entropy, global_step=global_steps)
    correct_predict = tf.equal(tf.argmax(logits, 1), label)
    accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32))
    return optimizal, cross_entropy, accuracy


def evaluate(X_test, Y_test):
    x = tf.placeholder(tf.float32, [None, 32, 32, 3])
    y = tf.placeholder(tf.int64, [None])
    logits = forward(x)
    optimizal, cross_entropy, accuracy = train(y, logits)
    with tf.Session() as sess:
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state('D:/Python/class_10/mode')
        if ckpt != None:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('no model!')
        acc = sess.run(accuracy, feed_dict={x: X_test, y: Y_test})
        print(acc)


def mian():
    data_dir = 'cifar-10-python'
    data_dir = os.path.join(data_dir, 'cifar-10-batches-py')
    x_train, y_train, x_test, y_test = load_CIFAR10(data_dir)
    X = tf.placeholder(tf.float32, [None, 32, 32, 3])
    Y = tf.placeholder(tf.int64, [None])
    logits = forward(X)
    optimizal, cross_entropy, accuracy = train(Y, logits)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state('D:/Python/class_10/mode')
        if ckpt != None:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('no model!')
        for i in range(4):
            for j in range(780):
                batch_x = x_train[64 * (j):64 * (j + 1), :]
                batch_y = y_train[64 * (j):64 * (j + 1)]
                _, cost, perdict = sess.run([optimizal, cross_entropy, accuracy], feed_dict={X: batch_x, Y: batch_y})
                if j % 100 == 0:
                    print("第", i * j, "次的loss:", cost, "准确率:", perdict)
        saver.save(sess, 'D:/Python/class_10/mode/model.ckpt')


if __name__ == "__main__":
    data_dir = 'cifar-10-python'
    data_dir = os.path.join(data_dir, 'cifar-10-batches-py')
    x_train, y_train, x_test, y_test = load_CIFAR10(data_dir)
    evaluate(x_test[0:64, :], y_test[0:64])

main函数时训练过程,最后一部分是恢复之前的训练网络,并给出测试集的准确率。下图是我电脑跑出来的结果:


今天内容比较少,但我感觉还是比较重要的,希望有更多的小伙伴能一起交流学习图像处理和深度学习方面的内容!


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值