深度学习中基于tensorflow_slim进行复杂模型训练二之tensorflow_slim的使用

上篇博客主要介绍了tensorflow_slim的基本模块,本篇主要介绍一下如何使用该模块训练自己的模型。主要分为数据转化,数据读取,数据预处理,模型选择,训练参数设定,构建pb文件,固化pb文件中的参数几部分。

一、数据转化:

主要目的是将图片转化为TFrecords文件,该部分属于数据的预处理阶段,可以参考datasets中的download_and_conver_flower中的run函数实现。具体关于如何使用将会在后续介绍。

二、数据读取

该部分主要是在datasets中新建一个文件并将其命名为自己的名字,例如命名为emotion.py,然后将flowers.py中的内容复制到新建的文件中,并对以下部分进行修改:

1. _FILE_PATTERN = 'emotion_%s_*.tfrecord' 表示tfrecord文件名的格式

2. SPLITS_TO_SIZES = {'train': 18534, 'validation': 8331}表示用于训练和测试的数据个数

3. _NUM_CLASSES = 5,训练数据的类数,涉及到网络模型最后一层的输出。

最后需要在dataset_factory中增加自己新建的数据映射。

datasets_map = {
    'emotion': emotion,
}

三、数据增强

该过程主要是对读取的数据进行数据增强,可以有两种方式:1. 采用现有的增强模式(因为数据增强的大部分操作都是一样的),2. 构建自己的增强方式(可以使模型训练的时候传入的参数较统一)。

对于第二种方式依然需要构建新的文件夹,然后复制一个内容进行修改或者完全自己书写。本次采用的是复制cifarnet_preprocessing.py的内容进行修改得到的。具体修改的地方如下:将

distorted_image = tf.random_crop(image, [output_height, output_width, 3])    改为

distorted_image = tf.image.resize_images(image, [output_height, output_width], method=1)

主要是为了避免需要的图片比还未裁剪的小导致无法进行裁剪的错误

然后在preprocessing.py中增加新的映射:
  preprocessing_fn_map = {
      'emotion': emotion_preprocessing,
  }
  
四、模型选择

在nets中选择出自己需要使用的模型,并下载对应训练好的模型.ckpt文件,具体的下载地址可以参考README.md文件(以inception_v3模型为例)[inception_v3_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) 

五、训练参数设定

准备工作完成后就是使用train_image_classifier.py对自己的数据进行训练,该部分涉及到较多的参数,具体设定如下:

1. tf.app.flags.DEFINE_string( 'dataset_name', 'emotion', 'The name of the dataset to load.')
    
表示数据的名字,在读取数据和数据增强的时候该值相当于是map中的key,根据该值找到对应的读取和增强的脚本。

tf.app.flags.DEFINE_string('dataset_split_name', 'train', 'The name of the train/test split.')

表示数据的作用,用来train还是validation,该值主要是产生emotion_train*.tfrecord的形式存于data_sources中,便于在后面的读取数据时使用

  if '*' in data_sources or '?' in data_sources or '[' in data_sources:  data_files = gfile.Glob(data_sources)

的方式对数据进行读取。
      

tf.app.flags.DEFINE_string( 'dataset_dir', "data_to_tfrecord", 'The directory where the dataset files are stored.')

表示tfrecord数据的存储路径,构建data结构读取数据时使用
      
tf.app.flags.DEFINE_integer( 'labels_offset',  0,  'An offset for the labels in the dataset. This flag is primarily used to  evaluate the VGG and ResNet architectures which do not use a background  class for the ImageNet dataset.') 表示标签的偏移量,即默认标签是从0开始,假如偏移2,那么标签就会从2开始,一般情况下选择默认的值即可。


tf.app.flags.DEFINE_string('model_name', 'inception_v3', 'The name of the architecture to train.')用于进行二次训练的模型名字,主要依赖于net_factory中的map是如何写的。

tf.app.flags.DEFINE_string( 'preprocessing_name', 'emotion', 'The name of the preprocessing to use. If left  as `None`, then the model_name flag is used.')  表示采用的预处理方式,主要依赖于preprocessing_factory中的map
                                     
tf.app.flags.DEFINE_integer('batch_size', 32, 'The number of samples in each batch.') 在将数据进行批处理时每批数据的多少
    
tf.app.flags.DEFINE_integer( 'train_image_size', 299, 'Train image size') 模型输入的图片大小

tf.app.flags.DEFINE_integer('max_number_of_steps', 50000, 'The maximum number of training steps.')   表示训练的步数
                            
tf.app.flags.DEFINE_integer( 'log_every_n_steps', 10,  'The frequency with which logs are print.') log的输出频率,即每运行多少步输出一个log

tf.app.flags.DEFINE_integer( 'save_summaries_secs', 100,  'The frequency with which summaries are saved, in seconds.')    
表示存储summaries的频率

tf.app.flags.DEFINE_integer('save_interval_secs', 600, 'The frequency with which the model is saved, in seconds.')  表示存储模型的频率

tf.app.flags.DEFINE_float( 'weight_decay', 0.00004, 'The weight decay on the model weights.') 表示为了避免过拟合采用正则化的系数


tf.app.flags.DEFINE_string( 'train_dir', 'train_result', 'Directory where checkpoints and event logs are written to.')表示训练参数存储的地方    
    

tf.app.flags.DEFINE_string(  'checkpoint_path', "pre_trained_check/inception_v3_2016_08_28/inception_v3.ckpt", 'The path to a checkpoint from which to fine-tune.')  表示提前处理好的模型参数存储的地方
    

tf.app.flags.DEFINE_string('checkpoint_exclude_scopes', "InceptionV3/Logits,InceptionV3/AuxLogits",'Comma-separated list of scopes of variables to exclude when restoring  from a checkpoint.')模型中不用恢复的节点,一般均为模型的输出层,因为输出层需要结合自己实际的类进行训练确定数据的输出大小,当该值为空时,则表示所有的变量均恢复。


tf.app.flags.DEFINE_string('trainable_scopes', None, 'Comma-separated list of scopes to filter the set of variables to train. By default, None would train all the variables.')  表示再次训练的节点,None表示所有的都参与训练。
    
tf.app.flags.DEFINE_string( 'learning_rate_decay_type',  'exponential',  'Specifies how the learning rate is decayed. One of "fixed", "exponential", or "polynomial"') 表示学习率衰减的方式。


对于该模块在使用中涉及到的其他参数均使用默认的即可。

在使用脚本时有时候会报出部分操作无法在GPU上运行的错误,此时train的上面增加config = tf.ConfigProto(allow_soft_placement=True)表示当无法采用GPU计算时使用cpu进行。并将该参数传递给train。

五、构建pb文件

此时直接使用export_interence_graph.py可以将模型结构变成.pb的,涉及的参数如下:

tf.app.flags.DEFINE_string( 'model_name', 'inception_v3', 'The name of the architecture to save.') 表示要调用的模型结构

tf.app.flags.DEFINE_boolean( 'is_training', False, 'Whether to save out a training-focused version of the model.') 表示在模型中的参数是否用来进行训练

tf.app.flags.DEFINE_integer( 'image_size', 299, 'The image size to use, otherwise use the model default_image_size.')定义一个输入占位符的二三维大小

tf.app.flags.DEFINE_string('dataset_name', 'emotion',  'The name of the dataset to use with the model.') 主要根据传入的名字确定一个对应的数据集,确定其num_class的值用于构建模型结构
    

tf.app.flags.DEFINE_integer( 'labels_offset', 0, 'An offset for the labels in the dataset. This flag is primarily used to  evaluate the VGG and ResNet architectures which do not use a background  class for the ImageNet dataset.')偏移量,用来构建模型的时候会用到

tf.app.flags.DEFINE_string(  'output_file', 'train_pb/motion_inception_v3_graph.pb', 'Where to save the resulting file to.')输出的.pb文件名字和存储地方

tf.app.flags.DEFINE_integer( 'batch_size', None,'Batch size for the exported model. Defaulted to "None" so batch size can ')定义输入占位符的第一维度大小。


六、在模型结构中放入自己训练的结果并固化

该过程实现的原理是先读入一个结构图,然后在使用saver.restore()恢复图中对应
参数的值,最后再存储。

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.training import saver as saver_lib

# 定义一些参数
input_graph = 'train_pb\\test_1.pb'
output_graph = 'train_pb\\test_2.pb'
input_checkpoint = 'train_result\\model.ckpt-20'
output_node_names = 'InceptionV3/Predictions/Reshape_1'

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(input_graph, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

with detection_graph.as_default():
    with tf.Session(graph=detection_graph) as sess:
        var_list = {}
        reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
            try:
                tensor = sess.graph.get_tensor_by_name(key + ":0")
            except KeyError:
                continue
            var_list[key] = tensor
        saver = tf.train.Saver(var_list=var_list)
        saver.restore(sess, input_checkpoint)

        constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,
                                                                   output_node_names.split(","))
        with tf.gfile.FastGFile(output_graph, mode='wb') as f:
            f.write(constant_graph.SerializeToString())


至此,就完成了模型训练和固化,然后可以根据具体需要自行进行使用。

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
### 回答1: TensorFlow 2于2019年发布,相对于之前的版本,进行了重大的改进和升级。因此,在TensorFlow 2,一些旧的功能和包已经被删除或者合并到了其他的模块,其就包括了contrib_slimTensorFlow的contrib_slim包是一个非常常用的工具包,用于简化神经网络的构建和训练过程。它提供了许多高级的操作和函数,为用户提供了更高的灵活性和便利性。然而在TensorFlow 2,由于追求简化和统一的目标,TensorFlow团队决定将其删除。 在TensorFlow 2,取代contrib_slim的是tf.keras模块,它是一个基于Keras的高级API。tf.keras提供了与contrib_slim类似的功能,并且具有更好的兼容性和易用性。通过tf.keras,您可以更容易地构建和训练神经网络模型。它提供了丰富的层和模型类型,并且支持许多常见的深度学习任务,如图像分类、目标检测、自然语言处理等。 因此,在TensorFlow 2,如果您需要类似于contrib_slim的功能,可以使用tf.keras来代替。您可以使用tf.keras.layers构建网络层,使用tf.keras.models构建模型,以及使用tf.keras.optimizers、tf.keras.losses等来定义优化器和损失函数。通过这种方式,您可以轻松构建和训练自己的深度学习模型。 总的来说,TensorFlow 2取消了contrib_slim包,转而使用tf.keras模块来提供更好的API和功能。这是为了提供更简单、更统一和更高效的开发体验。 ### 回答2: TensorFlow 2 是 TensorFlow 的最新版本,与之前的 TensorFlow 1 有很大的不同。TensorFlow 2 根据用户反馈和需求进行了重构,删除了一些不常用的功能和模块,其就包括 contrib.slim 包。 contrib.slim 包是 TensorFlow 1 的一个非常有用的模块,提供了许多方便的函数和工具,用于构建和训练深度学习模型。但是,由于它的功能被整合到 TensorFlow 2 的核心模块,contrib.slim 包在 TensorFlow 2 被删除。 在 TensorFlow 2 ,大部分 contrib.slim的功能都可以通过其他的模块和函数来实现。例如,构建模型可以使用 Keras API,它提供了更简洁、易用的接口。此外,一些 contrib.slim的函数可以通过使用 TensorFlow 2 的其他核心函数来替代。 如果你在迁移你的代码或项目到 TensorFlow 2 时遇到 contrib.slim 包的问题,你可以参考 TensorFlow 2 的官方文档和示例代码,了解如何使用新的模块和函数来替代 contrib.slim的功能。此外,TensorFlow 社区也提供了许多迁移指南和教程,帮助用户迁移他们的代码到 TensorFlow 2。 总而言之,虽然 TensorFlow 2 没有 contrib.slim 包,但是通过使用其他的模块和函数,你仍然可以实现相同的功能,并享受 TensorFlow 2 带来的新特性和改进。 ### 回答3: tensorflow2版本没有contrib_slim包,这是因为从tensorflow1.x到tensorflow2.0的升级过程,一些模块被重新组织和重构了。contrib_slim是在tensorflow1.x版本引入的一个扩展模块,用于提供一些高级的模型定义和训练工具。在tensorflow2.0,它被废弃了,并且其功能已经被整合到其他模块。 在tensorflow2.0模型定义和训练工具主要集tensorflow.keras模块。Keras是一个高级神经网络API,它提供了更简洁和易于使用的接口来定义和训练模型。与contrib_slim类似的功能可以使用tensorflow.keras模块的各种类和函数来实现。 此外,在tensorflow2.0模型定义和训练的推荐方法是使用自定义模型子类化或函数式API来创建模型。这些方法提供了更大的灵活性和可拓展性,使得模型定义更易于阅读和维护。 总之,虽然tensorflow2版本没有contrib_slim包,但可以使用tensorflow.keras模块以及自定义模型子类化或函数式API来实现类似的功能。这些改进使得tensorflow2更易于使用和扩展,并提供了更一致的编程接口。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值