[深度学习-实践]tensorflow_hub简单理解模型的生成与加载

0. 前言

Tensorflow于1.7之后推出了tensorflow hub,其是一个适合于迁移学习的部分,主要通过将tensorflow的训练好的模型进行模块划分,并可以再次加以利用。不过介于推出不久,目前只有图像的分类和文本的分类以及少量其他模型
这里先通过几个简单的例子,来展示该hub的使用流程。

1. 一个超简单例子

1.1 创建一个Module

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import tensorflow as tf
import tensorflow_hub as hub


def half_plus_two():
  '''该函数主要是创建一个简单的模型,其网络结构就是y = a*x + b  '''
  # 创建两个变量,a和b,如网络的权重和偏置
  a = tf.get_variable('a', shape=[])
  b = tf.get_variable('b', shape=[])
  # 创建一个占位变量,为后面graph的输入提供准备
  x = tf.placeholder(tf.float32)
  # 创建一个完整的graph
  y = a*x + b
  # 通过hub的add_signature,建立hub需要的网络
  hub.add_signature('function1',inputs=x, outputs=y)

  y = a * x
  hub.add_signature('function2', inputs=x, outputs=y)


def export_module(path):
  '''该函数用于调用创建api进行module创建,然后进行网络的权重赋值,最后通过session进行运行权重初始化,并最后输出该module'''
  # 通过hub的create_module_spec,接收函数建立一个Module
  spec = hub.create_module_spec(half_plus_two)
  # 防止串graph,将当期的操作放入同一个graph中
  with tf.Graph().as_default():
    # 通过hub的Module读取一个模块,该模块可以是url链接,表示从tensorflow hub去拉取,
    # 或者接收上述创建好的module
    module = hub.Module(spec)
    # 这里演示如何将权重值赋予到graph中的变量,如从checkpoint中进行变量恢复等
    init_a = tf.assign(module.variable_map['a'], 0.5)
    init_b = tf.assign(module.variable_map['b'], 2.0)
    init_vars = tf.group([init_a, init_b])

    with tf.Session() as sess:
      # 运行初始化,为了将其中变量的值设置为赋予的值
      sess.run(init_vars)
      # 将模型导出到指定路径
      module.export(path,sess)


if __name__ == '__main__':
  export_module("./module")

运行上述代码,可得
在这里插入图片描述
可以看出,该例子中,生成一个Module是

  • 1 - 先通过自定义网络,然后通过 hub.add_signature(inputs=x, outputs=y) 进行类似注册的操作
  • 2 - 再通过hub.create_module_spec(half_plus_two)进行生成ModuleSpec对象
  • 3 - 创建一个独立的tf.Graph(),通过module =
    hub.Module(spec)进行装载该Module,然后进行权重赋值,初始化等操作
  • 4 - 最后通过module.export(path,sess)导出该Module

1.2 调用一个存在的Module

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tensorflow as tf
import tensorflow_hub as hub


def testExportTool1(self):

  # 指定module的文件夹位置,这里是export
  module_path = os.path.join('.','module')

  with tf.Graph().as_default():
    # 读取当前存在的一个module
    m = hub.Module(module_path)
    print('signature',m.get_signature_names())
    # 如直接采用y=f(x) 一样进行调用,
    output1=  m([10,3,4], signature='function1', as_dict=True)
    output2 = m([10, 3, 4], signature='function2')

    with tf.Session() as sess:
      # 惯例进行全局变量初始化
      sess.run(tf.initializers.global_variables())
      # 观察生成的值是否与预定义值一致,即prediction是否与label一致
      print(sess.run(output1)['default'])
      print(sess.run(output2))
      self.assertAllEqual(sess.run(output1)['default'], [7, 3.5, 4])
      self.assertAllEqual(sess.run(output2), [5, 1.5, 2])


if __name__ == '__main__':
  testExportTool1()

对于调用来说,就十分简单了

  • 1 - 创建一个tf.Graph(),然后通过m = hub.Module(module_path)进行装载已存在的Module
  • 2 - 如y=f(x)一样进行调用
  • 3 - sess.run一下即可。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

茫茫人海一粒沙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值