TensorFlow 模型的运行机制

了解 TensorFlow 模型的运行机制

tensorflow 的运行机制属于“定义”于“运行”相互分离。从操作层面可以认为是 模型构建 和 模型运行。

在模型构件中,需要了解几个概念:

张量(tensor):数据,N维向量

变量(variable):模型参数,通过不断训练得到的值

占位符(placeholder):输入变量的载体

图中的节点操作(operation,op):执行相关计算,从而获取新的 tensor

上述定义的内容都是在一个叫做“图”的容器中完成的。关于“图”,有以下几点需要注意理解:

  • 一个图代表一个计算任务
  • 在模型运行的环节,“图”在会话(session)里被启动
  • session 将图的 op 分发到如 CPU/GPU 等计算设备上,同时提供执行 op 的方法。在这些方法执行后,返回相应的 tensor (python 里返回的是 numpy ndarray;C/C++ 里返回的是 TensorFlow::Tensor 实例)

session 与图交互的过程成通过 feed (注入机制)和 fetch (取回机制)来进行数据流动。i.e.

  • feed:通过占位符向模式中传入数据。
  • fetch:从模式中得到结果
import tensorflow as tf
mul = tf.multiply(a,b)
add = tf.add(a,b)
with tf.Session() as sess:
    print(sess.run(add, feed_dict={a:3, b:4}))
    print(sess.run(mul, feed_dict={a:3, b:4}))
    print(sess.run([mul, add], feed_dict={a:3, b:4}))

保存和载入模型方法

  • 模型保存
    首先建立一个 saver,然后在 session 中通过 saver 的 save 即可将模型保存起来

    # 构建模型 graph 的代码(矩阵相乘,relu等)
    saver = tf.train.Saver()
    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      # 将数据丢入模型进行训练 。。。
      # 训练完成后,使用 saver.save 进行保存
      saver.save(sess, 'save_path/file_name') # file_name 会自动创建
    
  • 模型载入
    在 session 中通过调用 saver 的 restore() 函数,会从指定的路径找到模型文件,并覆盖到相关的参数中。

    saver = tf.train.Saver()
    with tf.Session() as sess:
      # 参数可以进行初始化,也可以不进行初始化。反正初始化了还会被载入的模型覆盖
      sess.run(tf.global_variables_initializer())
      saver.restore(sess, 'save_path/file_name')
    
  • 通过 print_tensor_in_checkpoint_file() 函数可以将保存模型里的内容打印出来,进行具体的查看。

  • 前面的例子中 Saver 的创建比较简单。其实 tf.train.Saver 函数里面还可以放参数来实现更高级的功能,可以指定存储变量名字与变量的对应关系。可以写成这样

    saver = tf.train.Saver({'weight': W, 'bias': b})
    

    类似的写法还有两种:

    saver = tf.train.Saver([W, b]) # 放到一个 list 里
    saver = tf.train.Saver({v.op.name: v for v in [W, b]}) # 将 op的名字当作 key
    
  • 检查点(checkpoint)
    保存模型并不限于在训练之后,在训练之中也需要保存,因为 tensorflow 训练模型是难免会出现中断的情况。我们自然希望将辛苦得到的中间参数保留下来,否则下次又要重新开始。

    这种在训练中保存模型,习惯上称之为保存检查点。

    # saver 的另一个参数——max_to_keep = 1, 表明最多只保存一个检查点文件
    saver = tf.train.Saver(max_to_keep = 1)
    # balabala
    saver.save(sess, savedir + 'file_name.cpke', global_step = epoch)
    saver.restore(sess2, savedir + 'file_name.ckpt' + str(load_epoch)
    
  • 更简便的保存检查点

    tf.train.MonitoredTrainingSession() 函数。该函数可以直接实现保存及载入检查点模型的文件。与前面方式不同,本例中并不是按照循环步数来保存,而是按照训练时间来保存的。通过指定 save_checkpoint_secs 参数的具体秒数,来设置每训练多久保存一次检查点。

    import tensorflow as tf
    tf.reset_default_graph()
    global_step = tf.train.get_or_create_global_step()
    step = tf.assign_add(global_step, 1)
    # 设置检查点路径维 log/checkpoints
    with tf.train.MonitoredTrainingSession(
      checkpoint_dir='log/checkpoints', save_checkpoint_secs=2) as sess:
      print(sess.run([global_step]))
      while not sess.should_stop():
        i = sess.run(step)
        print(i)
    

TensorBoard 可视化介绍

TensorFlow 还提供了一个可视化工具 TensorBoard。它可以将训练过程中的各种绘制数据展示出来。可以通过网页来观察模型的结构和训练过程中各个参数的变化。

当然,tensorboard 不会自动吧代码展示出来,其实它是一个日志展示系统,需要在 session 中运算图时,将各种类型的数据汇总并输出到日志文件中。然后启动 tensorboard 服务,tensorboard 读取这些日志文件,并开启 6006 端口提供 web 服务,让用户可以在浏览器中查看数据。

TensorFlow 提供了一系列 API 来生成这些数据。

i.e. tf.summary.scalar(tags, values, collections=None, name=None)

Class Summary Writer: add_summary(), add sessionlog(), add_event(), and add_graph()

tf.summary.scalar('loss_function', cost)
tf.summary.histogram('z',z)
# balabala  启动 session
with tf.Session() as sess:
  sess.run(init)
  merged_summary_op = tf.summary.merge_all() # 合并所有的 summary
  # 创建 summary_writer, 用于写入文件
  summary_writer = tf.summary.FileWriter('log/mnist_with_summaries', sess.graph)
  # 训练模型 balabala
  # 生成 summary
  summary_str = sess.run(merged_summary_op, feed_dict={X:x, Y:y})
  summary_writer.add_summary(summary_str, epoch) # 将 summary 写入文件
  ....

TensorFlow 基础类型定义及操作函数介绍

推荐在编译器里导入 tensorflow 后,通过 help() 函数进行查看

import tensorflow as tf
# 想要查看一个 tensorflow 操作函数 
help(tf.function_name)
# ex. help(tf.segment_sum)

共享变量

共享变量的用途:在某种请款下,一个模型需要使用其他模型创建的变量,两个模型一起训练。比如,对抗网络中的生成器模型与判别器模型。如果使用 tf.Variable,将会生成一个新的变量,而我们需要的是原来那个 bias 变量。这时怎么办呢?

这时,通过引入 get_variables 方法,利用共享变量来解决这个问题。这个方法可以使用多套网络模型来训练一套权重。

使用 get_variable 获取变量

get_variable 一般会配合 variable_scope 一起使用,实现共享变量。variable_scope 表示变量作用域。在某一作用域中的变量可以被设置成共享的方式,被其他网络模型使用。

get_variable() 函数的定义如下:

tf.get_variable(<name>, <shape>, <initializer>)

图的基本操作

前面接触了一些图的基本概念,这里来系统的了解一下 tensorflow 中的图可以做的事情。

可以在一个 tensorflow 中手动建立其他的图,也可以根据图里的变量获取当前的图。

动态图(Eager)

动态图是相对静态图而言的。所谓的动态图是指在 python 中的代码被调用后,其操作立即被执行的计算。其与静态图最大的区别是不需要使用 session 来建立会话了。即,在静态图中,需要在会话中调用 run 方法才可以获得某个张量的具体值;而在动态图里,直接运行就可以得到具体值了。

动态图是在 tensorflow 1.3 版本之后才出现的。它使得 tensorflow 入门更加简单,也是研发更加直观。

在动态图的创建过程中,也是默认建立了一个 session。所有的代码都在该 session 中运行,而且该 session 具有进程相同的生命周期。这表明,一旦使用动态图就无法实现静态图中关闭 session 的功能,也无法实现多 session 的操作。

而在动态图中,如果调用 tf.matmul 时,将会立即计算两个矩阵的相乘的值,而不是一个 op。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值