一、tesorflow基本概念
二、计算图
三、操作
3.1 运算操作定义类操作的类型,以及参与运算的数据的类型
3.2 tensorflow中的运算符
变量运算指的是element wise
运算
四、变量
4.1 创建变量
通过tf.Variables()
函数创建
4.2变量初始化方式
- 一次性全部初始化
sess.run(tf.global_variables_initializer())
- 自定义初始化
- 由另一个变量初始化。通过变量的
initial_value
实行,对新的变量初始化
b = tf.Variable(tf.zeros([1]))
c = tf.Variable(b.initial_value,name='ll')
4.3变量的保存与恢复
saver.save(sess, 'my-model', global_step=step)
训练过程,根据训练的步数保存模型
- 保存变量
tf.train.Saver()
import tensorflow as tf
import numpy as np
# 训练数据
x_data = np.float32(np.random.rand(2,100))
y_data = np.dot([0.1,0.2],x_data)+0.3
# 定义模型
b = tf.Variable(tf.zeros([1]))
w = tf.Variable(tf.random_uniform([1,2],-1,1))
y = tf.matmul(w,x_data)+b
# 定义损失函数
loss = tf.reduce_mean(tf.square(y-y_data))
optimizer = tf.train.AdamOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
saver = tf.train.Saver()# 保存变量
sess = tf.Session()
sess.run(init)
sess.run(train)
path = saver.save(sess,'./var.ckpt')# 返回保存路径
print(path)
- 恢复变量
saver.restore(sess,'./var.ckpt')
取变量值时,在变量名后面加:0
import tensorflow as tf
import numpy as np
b = tf.Variable(tf.zeros([1]))
w = tf.Variable(tf.random_uniform([1,2],-1,1),name='w')
init = tf.global_variables_initializer()
saver = tf.train.Saver()# 保存变量
sess = tf.Session()
sess.run(init)
print('w_current:',sess.run(w))
# path = saver.save(sess,'./var.ckpt')# 返回保存路径
# print(path)
saver.restore(sess,'./var.ckpt')
print('w_save:',sess.run('w:0'))# 取变量值时,后面加:0
- 保存部分变量
保存部分变量。需要保存的变量,以字典的形式传入tf.train.Saver() - 恢复部分变量
saver.restore(sess,'./part_var.ckpt')
取变量值,通过变量名取出,不是别名sess.run(b1)
import tensorflow as tf
import numpy as np
b1 = tf.Variable(tf.random_uniform([1]))
b2 = tf.Variable(tf.random_uniform([1]))
w = tf.Variable(tf.random_uniform([1,2],-1,1),name='w')
init = tf.global_variables_initializer()
# 保存部分变量。需要保存的变量,以字典的形式传入tf.train.Saver()
saver = tf.train.Saver({'bf':b1,'bh':b2})
sess = tf.Session()
sess.run(init)
# print('w_current:',sess.run(w))
print('b1_current:',sess.run(b1))
print('b2_current:',sess.run(b2))
# path = saver.save(sess,'./part_var.ckpt')# 返回保存路径
# print(path)
s = saver.restore(sess,'./part_var.ckpt')
print('b1_save:',sess.run(b1))# 取变量值时
print('b2_save:',sess.run(b2))# 取变量值时
5 会话
5.1 会话的创建和运行