tensorflow自定义网络模型

Slim

TF-Slim 模块是 TensorFlow 中最好用的 API 之一。尤其是里面引入的 arg_scope、model_variables、repeat、stack。
TF-Slim 是 TensorFlow 中一个用来构建、训练、评估复杂模型的轻量化库。TF-Slim 模块可以和 TensorFlow 中其它API混合使用。

Slim模块的导入

1
import tensorflow.contrib.slim as slim

Slim 构建模型

可以用 slim、variables、layers 和 scopes 来十分简洁地定义模型。下面对各个部分进行了详细描述:

Slim变量(Variables)
1
2
3
4
5
6
weights = slim.variable('weights',
                        shape=[10, 10, 3 , 3],
                        initializer=tf.truncated_normal_initializer(stddev=0.1),
                        regularizer=slim.l2_regularizer(0.05),
                        device='/CPU:0')
~
Slim 层(Layers)

使用基础(plain)的 TensorFlow 代码:

1
2
3
4
5
6
7
8
9
input = ...
with tf.name_scope('conv1_1') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32,
                                           stddev=1e-1), name='weights')
  conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
  biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32),
                       trainable=True, name='biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name=scope)

为了避免代码的重复。Slim 提供了很多方便的神经网络 layers 的高层 op。例如:与上面的代码对应的 Slim 版的代码:

1
2
input = ...
net = slim.conv2d(input, 128, [3, 3], scope='conv1_1')

slim.arg_scope() 函数的使用

这个函数的作用是给list_ops中的内容设置默认值。但是每个list_ops中的每个成员需要用@add_arg_scope修饰才行。所以使用slim.arg_scope()有两个步骤:

  • 使用@slim.add_arg_scope修饰目标函数
  • 用 slim.arg_scope()为目标函数设置默认参数.
    例如如下代码;首先用@slim.add_arg_scope修饰目标函数fun1(),然后利用slim.arg_scope()为它设置默认参数。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    
    import tensorflow as tf
    slim =tf.contrib.slim
     
    @slim.add_arg_scope
    def fun1(a=0,b=0):
        return (a+b)
     
    with slim.arg_scope([fun1],a=10):
        x=fun1(b=30)
        print(x)
    

运行结果:40
参考链接:
https://blog.csdn.net/u013921430/article/details/80915696

其他用法见参考链接

https://blog.csdn.net/wanttifa/article/details/90208398

查看ckpt中变量的几种方法

查看ckpt中变量的方法有三种:

  • 在有model的情况下,使用tf.train.Saver进行restore
  • 使用tf.train.NewCheckpointReader直接读取ckpt文件,这种方法不需要model。
  • 使用tools里的freeze_graph来读取ckpt
    Tips:
  • 如果模型保存为.ckpt的文件,则使用该文件就可以查看.ckpt文件里的变量。ckpt路径为 model.ckpt
  • 如果模型保存为.ckpt-xxx-data (图结构)、.ckpt-xxx.index (参数名)、.ckpt-xxx-meta (参数值)文件,则需要同时拥有这三个文件才行。并且ckpt的路径为 model.ckpt-xxx

    1.基于model来读取ckpt文件里的变量

    1.首先建立起model
    2.从ckpt中恢复变量
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    
    with tf.Graph().as_default() as g: 
      #建立model
      images, labels = cifar10.inputs(eval_data=eval_data) 
      logits = cifar10.inference(images) 
      top_k_op = tf.nn.in_top_k(logits, labels, 1) 
      #从ckpt中恢复变量
      sess = tf.Session()
      saver = tf.train.Saver() #saver = tf.train.Saver(...variables...) # 恢复部分变量时,只需要在Saver里指定要恢复的变量
      save_path = 'ckpt的路径'
      saver.restore(sess, save_path) # 从ckpt中恢复变量
    

注意:基于model来读取ckpt中变量时,model和ckpt必须匹配。

2.使用tf.train.NewCheckpointReader直接读取ckpt文件里的变量,使用tools.inspect_checkpoint里的print_tensors_in_checkpoint_file函数打印ckpt里的东西

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#使用NewCheckpointReader来读取ckpt里的变量
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReader
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
  print("tensor_name: ", key)
  #print(reader.get_tensor(key))
#使用print_tensors_in_checkpoint_file打印ckpt里的内容
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

print_tensors_in_checkpoint_file(file_name, #ckpt文件名字
                 tensor_name, # 如果为None,则默认为ckpt里的所有变量
                 all_tensors, # bool 是否打印所有的tensor,这里打印出的是tensor的值,一般不推荐这里设置为False
                 all_tensor_names) # bool 是否
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值