使用slim模块中的resnet模型

1、slim模块简介

slim是一个使构建,训练,评估神经网络变得简单的库:
-- 它可以消除原生tensorflow里面很多重复的模板性的代码,让代码更紧凑,更具备可读性。
-- 另外slim提供了很多计算机视觉方面的著名模型(VGG, AlexNet等),我们不仅可以直接使用,甚至能以各种方式进行扩展。

2、使用slim库进行模型定义

with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
-- 即是设定nets.resnet_v1的基本学习参数为resnet_arg_scope() 
-- resnet_arg_scope()函数设定见:https://www.cnblogs.com/hellcat/p/8060236.html
-- arg_scope用法详解:https://blog.csdn.net/u013921430/article/details/80915696
   net, endpoints = nets.resnet_v1.resnet_v1_50(preprocessed_inputs, num_classes=None,is_training=self._is_training)
   net = tf.squeeze(net, axis=[1, 2])
   -- tf.squeeze():axis指定的维度中,如果有为1的维度就删除
   logits = slim.fully_connected(net, num_outputs=self.num_classes,activation_fn=None, scope='Predict/logits')

3、tf.estimator.train_and_evaluate

函数原型:tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
参数解释:
  -- 其中 estimator 是一个 评估器(tf.estimator.Estimator) 对象,用于指定模型函数以及其它相关参数;
  -- train_spec 是一个 tf.estimator.TrainSpec 对象,用于指定训练的输入函数以及其它参数;
  -- eval_spec 是一个 tf.estimator.EvalSpec 对象,用于指定验证的输入函数以及其它参数。

4、tf.estimator.Estimator

函数原型:tf.estimator.Estimator(model_fn=create_model_fn,model_dir=FLAGS.model_dir)
参数解释:
  -- 其中 model_fn 是模型函数
  -- model_dir 是训练时模型保存的路径;
  -- config 是 tf.estimator.RunConfig 的配置对象;
  -- params 是传入 model_fn 的超参数字典;
  -- warm_start_from 或者是一个预训练文件的路径,或者是一个 tf.estimator.WarmStartSettings 对象,用于完整的配置热启动参数

5、eval_spec:tf.estimator.EvalSpec 

函数原型:tf.estimator.EvalSpec(input_fn,steps=100,name=None,hooks=None,exporters=None,start_delay_secs=120,throttle_secs=600)
参数解释:
  -- 其中 input_fn 用来提供验证时的输入数据;
  -- steps 指定总共验证多少步(一般设定为 None 即可);
  -- hooks 用来配置分布式训练等参数;
  -- exporters 是一个 Exporter 迭代器,会参与到每次的模型验证;
  -- start_delay_secs 指定多少秒之后开始模型验证;
  -- throttle_secs 指定多少秒之后重新开始新一轮模型验证(当然,如果没有新的模型断点保存,则该数值秒之后不会进行模型验证,因此这是新一轮模型验证需要等待的最小秒数)。

6、定义estimatorSpec实例estimater:model_fn

函数名和函数结构可以自定义,只要保证函数返回值是类 tf.estimator.EstimatorSpec 的一个实例即可
如:
def create_model_fn(features, labels, mode, params=None):
    params = params or {}
    loss, train_op, ... = None, None, ...
    prediction_dict = ...
    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
        loss = ...
    if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = ...
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=prediction_dict,
        loss=loss,
        train_op=train_op,
        ...)

7、全部代码

# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 17:21:12 2018

@author: shirhe-lyh
"""

import tensorflow as tf

from tensorflow.contrib.slim import nets

import preprocessing

slim = tf.contrib.slim
    
        
class Model(object):
    """xxx definition."""
    
    def __init__(self, is_training,
                 num_classes=2,
                 fixed_resize_side=256,
                 default_image_size=224):
        """Constructor.
        
        Args:
            is_training: A boolean indicating whether the training version of
                computation graph should be constructed.
            num_classes: Number of classes.
        """
        self._num_classes = num_classes
        self._is_training = is_training
        self._fixed_resize_side = fixed_resize_side
        self._default_image_size = default_image_size
        
    @property
    def num_classes(self):
        return self._num_classes
        
    def preprocess(self, inputs):
        """preprocessing.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        preprocessed_inputs = preprocessing.preprocess_images(
            inputs, self._default_image_size, self._default_image_size, 
            resize_side_min=self._fixed_resize_side,
            is_training=self._is_training,
            border_expand=False, normalize=False,
            preserving_aspect_ratio_resize=False)
        preprocessed_inputs = tf.cast(preprocessed_inputs, tf.float32)
        return preprocessed_inputs
    
    def predict(self, preprocessed_inputs):
        """Predict prediction tensors from inputs tensor.
        
        Outputs of this function can be passed to loss or postprocess functions.
        
        Args:
            preprocessed_inputs: A float32 tensor with shape [batch_size,
                height, width, num_channels] representing a batch of images.
            
        Returns:
            prediction_dict: A dictionary holding prediction tensors to be
                passed to the Loss or Postprocess functions.
        """
        with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
            net, endpoints = nets.resnet_v1.resnet_v1_50(
                preprocessed_inputs, num_classes=None,
                is_training=self._is_training)
        net = tf.squeeze(net, axis=[1, 2])
        logits = slim.fully_connected(net, num_outputs=self.num_classes,
                                      activation_fn=None, 
                                      scope='Predict/logits')
        return {'logits': logits}
    
    def postprocess(self, prediction_dict):
        """Convert predicted output tensors to final forms.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            **params: Additional keyword arguments for specific implementations
                of specified models.
                
        Returns:
            A dictionary containing the postprocessed results.
        """
        postprocessed_dict = {}
        for logits_name, logits in prediction_dict.items():
            logits = tf.nn.softmax(logits)
            classes = tf.argmax(logits, axis=1)
            classes_name = logits_name.replace('logits', 'classes')
            postprocessed_dict[logits_name] = logits
            postprocessed_dict[classes_name] = classes
        return postprocessed_dict
    
    def loss(self, prediction_dict, groundtruth_lists):
        """Compute scalar loss tensors with respect to provided groundtruth.
        
        Args:
            prediction_dict: A dictionary holding prediction tensors.
            groundtruth_lists: A list of tensors holding groundtruth
                information, with one entry for each branch prediction.
                
        Returns:
            A dictionary mapping strings (loss names) to scalar tensors
                representing loss values.
        """
        logits = prediction_dict.get('logits')
        slim.losses.sparse_softmax_cross_entropy(logits, groundtruth_lists)
        loss = slim.losses.get_total_loss()
        loss_dict = {'loss': loss}
        return loss_dict
        
    def accuracy(self, postprocessed_dict, groundtruth_lists):
        """Calculate accuracy.
        
        Args:
            postprocessed_dict: A dictionary containing the postprocessed 
                results
            groundtruth_lists: A dict of tensors holding groundtruth
                information, with one entry for each image in the batch.
                
        Returns:
            accuracy: The scalar accuracy.
        """
        classes = postprocessed_dict['classes']
        accuracy = tf.reduce_mean(
            tf.cast(tf.equal(classes, groundtruth_lists), dtype=tf.float32))
        return accuracy

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值