3-Tensorflow-demo_0801_简单的Saver.py



import tensorflow as tf
import os
"""
模型持久化:
    含义:将当前训练好的模型图 和 权重保存到本地磁盘中,方便后续的使用。
    1、服务器训练好了一个模型,迁移到 移动端使用;
    2、深度学习训练都很耗时,耗钱。 可以迁移学习。
"""

def train():
    with tf.Graph().as_default():
        v1 = tf.get_variable(
            'v1', dtype=tf.float32, shape=[1], initializer=tf.random_normal_initializer()
        )
        v2 = tf.get_variable(
            'v2', dtype=tf.float32, shape=[1], initializer=tf.random_normal_initializer()
        )
        rezult = v1 + v2
        # 构建持久化的对象
        saver = tf.train.Saver()

        # 创建持久化的文件路径
        save_file = './models/ai20/model.ckpt'
        dirpath = os.path.dirname(save_file)
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)
            print('成功创建文件夹:{}'.format(dirpath))

        # 二、构建会话
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            print(sess.run([v1, v2, rezult]))
            """
            [array([0.25989017], dtype=float32), 
            array([-0.12384889], dtype=float32), 
            array([0.13604128], dtype=float32)]
            """
            # 执行持久化
            saver.save(sess=sess, save_path=save_file)
            print('变量成功保存至:{}'.format(save_file))


def restore_var():
    with tf.Graph().as_default():
        v1 = tf.get_variable(
            'v1', dtype=tf.float32, shape=[1], initializer=tf.random_normal_initializer()
        )
        v2 = tf.get_variable(
            'v2', dtype=tf.float32, shape=[1], initializer=tf.random_normal_initializer()
        )
        rezult = v1 + v2

        # 构建持久化的对象
        saver = tf.train.Saver()

        # 创建持久化的文件路径
        save_file = './models/ai20/model.ckpt'
        dirpath = os.path.dirname(save_file)
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)
            print('成功创建文件夹:{}'.format(dirpath))

        # 二、构建会话
        with tf.Session() as sess:
            # 直接恢复变量,所以无需变量初始化
            saver.restore(sess, save_path=save_file)
            print(sess.run([v1, v2, rezult]))
            """
            [array([0.25989017], dtype=float32), 
            array([-0.12384889], dtype=float32), 
            array([0.13604128], dtype=float32)]
            """


if __name__ == '__main__':
    # train()
    restore_var()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值