TensorFlow 常用代码

创建常量

state = tf.constant([1, 2]);

 

创建变量

state = tf.Variable([1, 1], name='counter')

 

变量赋值

state = tf.Variable(0, name='counter')
new_value = tf.add(state, 1)
#变量赋值
update = tf.assign(state, new_value)

 

初始化变量

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

 

同时执行多个操作

 

import tensorflow as tf

input1 = tf.constant(1.0)
input2 = tf.constant(2.0)
input3 = tf.constant(3.0)

add = tf.add(input1, input2)
mul = tf.multiply(input3, add)

with tf.Session() as sess:
    res = sess.run([mul, add])    #可传入多个操作, 进行执行.
    print(res)

 

变量占位符

 

import tensorflow as tf

input1 = tf.placeholder(tf.float32) #创建变量占位符
input2 = tf.placeholder(tf.float32) #创建变量占位符

add = tf.add(input1, input2)

with tf.Session() as sess:
    res = sess.run(add, feed_dict={ input1:[7.0], input2:[8.0] })   #执行时, 在传入数据.
    print(res)

 

读取, 保存模型

import tensorflow as tf

# 模型保存地址
save_dir = 'E:/TensorFlow/Models-Spaces/test_001/model_01.ckpt';
run_step = 1;   # 0.运行完成后保存模型, 1.读取上次保存的模型

a = tf.Variable(1)  # 声明变量
b = tf.placeholder(tf.int32)    # 使用占位符, 运行时传入.

# 结果记录
counter = tf.Variable(0, name='counter')

ab = tf.add(a, b)  # 加法计算
c = tf.add(counter, ab) # counter累加ab
update = tf.assign(counter, c)  # c赋值给counter

init_op = tf.global_variables_initializer();   #初始化变量

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)

    # 第一次运行, 保存模型
    if run_step == 0:
        for _ in range(3):
            sess.run(update, feed_dict={ b: 4 })    # 运行模型
            print(sess.run(counter))                # 查看counter的值

        # 保存模型
        saver.save(sess, save_dir)

    # 第二次运行, 读取模型
    if run_step == 1:
        # 读取模型
        saver.restore(sess, save_dir)
        for _ in range(3):
            sess.run(update, feed_dict={ b: 4 })    # 运行模型
            print(sess.run(counter))                # 查看counter的值

'''
    第一次运行, 把 run_step 设置为 0 (运行完成后保存模型) : 
        5
        10
        15
         
    第二次运行, 把 run_step 设置为 1 (读取上次保存的模型): 
        20
        25
        30
'''

 

打印模型文件里保存的内容:

import tensorflow as tf
# 引入必要包
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# 模型保存地址
save_dir = 'E:/TensorFlow/Models-Spaces/test_001/model_01.ckpt';
# 打印保存文件的内容
print_tensors_in_checkpoint_file(save_dir, None, True)

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值