tensorflow编程基础

模型构建中的几个概念。张量:数据,即某一类型的多维数组。变量:常用于定义模型中的参数,是通过不断训练得到的值。占位符:输入变量的载体,也可以理解为定义函数时的参数。图中的节点操作(OP):即一个OP获得0个或者多个tensor,执行计算输出额外的0个或多个tensor。在python中,返回的tensor是numpy.ndarray对象。

在具体的项目中,会有三种应用场景,分别是训练场景、测试场景和使用场景。训练场景是实现模型从无到有的过程,通过对样本的学习训练,调整学习参数,形成最终的模型。测试场景和使用场景:测试场景是利用图的正向运算得到的结果与真实值进行比较的差别;使用场景也是利用图的正向运算得到结果,并直接使用。二者的运算过程是一样的。这个过程特别像普通编程中使用函数的过程:实参就是输入的样本,形参就是占位符,运算过程就相当于函数体,得到的结果相当于返回值。

session与图的交互过程中还定义了以下两种数据的流向机制。注入机制(feed):通过占位符向模式中传入数据;取回机制(fetch):从模式中得到结果。

演示注入机制:需要注意的是,feed只在调用它的方法内有效,方法结束后feed就会消失。代码如下所示:

import tensorflow as tf
a=tf.placeholder(tf.int16)
b=tf.placeholder(tf.int16)
add=tf.add(a,b)
mul=tf.multiply(a,b)
with tf.Session() as sess:
    print("相加:%i" % sess.run(add,feed_dict={a:3,b:4}))
    print('相乘:%i' % sess.run(mul,feed_dict={a:3,b:5}))

保存和载入模型

saver=tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ....
    saver.save(sess,'save_patch+filename')

需要注意的是,save_patch的路径要在session的创建之前,模型保存在代码的同级目录下。而载入模型则通过在session中调用saver的restore()函数,从指定的路径找到对应名称的模型。除了在训练结束以后,在训练中也可以保存模型。这样当训练模型出现中断时,可以得到保存到的中间参数,习惯称之为保存检查点。

import tensorflow as tf
import numpy as np
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2'
train_X=np.linspace(-1,1,100)
train_Y=2*train_X+np.random.randn(100)*0.3
X=tf.placeholder('float')
Y=tf.placeholder('float')
W=tf.Variable(tf.random_normal([1]),name='weight')
b=tf.Variable(tf.zeros([1]),name='bias')
z=tf.multiply(X,W)+b
cost=tf.reduce_mean(tf.square(z-Y))
learning_rate=0.01
optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
initial=tf.global_variables_initializer()
training_epoch=20
display_step=2
saver=tf.train.Saver(max_to_keep=1)
savedir='log/'
with tf.Session() as sess:     
    sess.run(initial)
    for epoch in range(training_epoch):
        for (x,y) in zip(train_X,train_Y):
            sess.run(optimizer,feed_dict={X:x,Y:y})
        if epoch % display_step ==0:
            loss=sess.run(cost,feed_dict={X:x,Y:y})
            print('epoch:',epoch+1,'loss:',loss,'W:',sess.run(W),'b:',sess.run(b))
            saver.save(sess,savedir+'linermodule.cpkt',global_step=epoch)
    print('finished')
    print('loss:',sess.run(cost,feed_dict={X:x,Y:y}),'w:',sess.run(W),'b:',sess.run(b))
load_epoch=18
with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    saver.restore(sess2,savedir+'linermodule.cpkt-'+str(load_epoch))
    print('x=0.2,z=',sess2.run(z,feed_dict={X:0.2}))

保存的检查点文件如下所示:

因为设置max_to_keep=1,所以在迭代的过程中只保存一个文件。在训练的过程中,新生成的模型会覆盖以前的模型。

运行结果如下图所示:

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值