关闭

Tensorflow自我训练进阶(代码+注解)【1】Tensor&Variables

标签: tensorflow
175人阅读 评论(0) 收藏 举报
分类:
import tensorflow as tf
import time
print '\nexample2: Tensor and Variables by ORCA\n'
#检查tensorflow版本与安装路径
print tf.__version__
print tf.__path__
#计时器
start = time.clock()

print 'before session and initialization\n'
#创建tensor
print 'tensors'
tensor1 = tf.random_normal([3, 9], stddev=0.35, name="weight") 
tensor2 = tf.zeros([10], name="bias")
tensor3 = tf.fill([2,4],39, name="skyfish")
print 'tensor1:', tensor1
print 'tensor2:', tensor2
#由于还没有sess.run(),所以这里不会正常输出
#创建vars
print 'variables'
weight = tf.Variable(tensor1)
bias = tf.Variable(tensor2)
print 'weight is:', weight
print 'bias is;', bias
weight_2 = tf.Variable(weight.initialized_value()*2, name="weight_2")
#由于没有sess.run(), 所以这里不会正常输出
#变量需要初始化,而tensor,var需要run之后才能正常赋值
print'\nInitialize and run session\n'
init_op = tf.initialize_all_variables() #vars must be initialized

saver = tf.train.Saver()

sess = tf.Session() #open a session
sess.run(init_op)
print sess.run(tensor1) #tensors should be run in session(as assign)
print sess.run(tensor2)
print 'weight is:', sess.run(weight) #vars should be run in session
print 'bias is:', sess.run(bias)
#用一个变量初始化另外一个变量
print '\nuse a var to initialize another var\n'
print 'weight_2 is;', sess.run(weight_2)
#保存模型
print '\nsave models\n'
save_path = saver.save(sess, "/tmp/model.ckpt")
print 'model saved in file:', save_path
sess.close()
#加载模型
print '\nrestore models\n'
with tf.Session() as sess:
  saver.restore(sess, "/tmp/model.ckpt")
print 'model is restored'
#重新保存,此时检查点文件用字典明显定义变量名称
print '\nsave with a new name for one/more vars...\n'
saver = tf.train.Saver({"new_weight":weight, "new_bias":bias}) #use dict
sess = tf.Session()
sess.run(init_op)
sess.run(weight)
sess.run(bias)
save_path = saver.save(sess, "/tmp/model2.ckpt")
print 'model is re-saved in file:', save_path
#重新加载
print '\nrestore models again\n'
sess = tf.Session() 
saver.restore(sess, "/tmp/model2.ckpt")
print 'model is restored, then output the vars:\n'
print 'new_weight', sess.run(new_weight)
print 'bias', sess.run(new_bias)
sess.close()

print 'total time:', time.clock()-start
有两个问题有待解决:
1.为什么保存模型时换其他路径不行?这里面有什么讲究吗?
2.字典传入新名字时,为什么重新加载模型时新名字显示undefined?
to be continued..
0
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:1861次
    • 积分:106
    • 等级:
    • 排名:千里之外
    • 原创:6篇
    • 转载:0篇
    • 译文:0篇
    • 评论:2条
    文章存档
    最新评论