TensorFlow保存,加载模型

1. 保存模型:

# -*- coding: UTF-8 -*-
import tensorflow as tf
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

W1 = tf.Variable(tf.random_normal([2]), name='w1')
W2 = tf.Variable(tf.constant(1, shape=[2, 2]), name='w2')
saver = tf.train.Saver()

with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    saver.save(session, './checkpoint_dir/mymodel')  # './checkpoint_dir/',模型保存的目录;'mymodel',模型的名字

2. 加载模型:

# -*- coding: UTF-8 -*-
import tensorflow as tf
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

saver=tf.train.import_meta_graph("./checkpoint_dir/mymodel.meta")

with tf.Session() as session:
    saver.restore(session,tf.train.latest_checkpoint('./checkpoint_dir/'))
    g=tf.get_default_graph()
    w1=g.get_tensor_by_name('w1:0')
    w2=g.get_tensor_by_name('w2:0')
    print(session.run(w1))
    print(session.run(w2))
    
# [0.01617053 1.2160776 ]
# [[1 1]
#  [1 1]]

3. practice-保存模型:

# -*- coding: UTF-8 -*-
import tensorflow as tf
import os
import numpy as np

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

w1 = tf.placeholder(dtype=np.float32, name='w1')
w2 = tf.placeholder(dtype=np.float32, name='w2')

b1 = tf.Variable(2.0, name='bias')

feed_dict = {w1: 4, w2: 8}

w3=tf.add(w1,w2)
w4=tf.multiply(w3,b1,name='multiply_op')
saver=tf.train.Saver()

with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    print(session.run(w4,feed_dict=feed_dict))
    saver.save(session,'./checkpoint_dir/mymodel')

# 24.0

4. practice-加载模型,最后一层添加自己的op

# -*- coding: UTF-8 -*-
import tensorflow as tf
import os
import numpy as np

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# w1 = tf.placeholder(dtype=np.float32, name='w1')
# w2 = tf.placeholder(dtype=np.float32, name='w2')
#
# b1 = tf.Variable(2.0, name='bias')
#
# feed_dict = {w1: 4, w2: 8}
#
# w3=tf.add(w1,w2)
# w4=tf.multiply(w3,b1,name='multiply_op')
# saver=tf.train.Saver()
#
# with tf.Session() as session:
#     session.run(tf.global_variables_initializer())
#     print(session.run(w4,feed_dict=feed_dict))
#     saver.save(session,'./checkpoint_dir/mymodel')

# 24.0

# checkpoint_dir
checkpoint_dir = "./checkpoint_dir"

# 1. 加载网络图
saver = tf.train.import_meta_graph(checkpoint_dir + '/mymodel.meta')

with tf.Session() as session:
    # 2. 加载值
    saver.restore(session, save_path=tf.train.latest_checkpoint(checkpoint_dir))
    # 3. Now, access the variables and op that you want to run.
    g = tf.get_default_graph()
    w1 = g.get_tensor_by_name('w1:0')
    w2 = g.get_tensor_by_name('w2:0')
    w4 = g.get_tensor_by_name('multiply_op:0')

    # 4. 添加自己的op

    w5 = tf.add(w4, 10)

    print(session.run([w4, w5], feed_dict={w1: 4, w2: 8}))

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值