1、slim模块简介
slim是一个使构建,训练,评估神经网络变得简单的库:
-- 它可以消除原生tensorflow里面很多重复的模板性的代码,让代码更紧凑,更具备可读性。
-- 另外slim提供了很多计算机视觉方面的著名模型(VGG, AlexNet等),我们不仅可以直接使用,甚至能以各种方式进行扩展。
2、使用slim库进行模型定义
with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
-- 即是设定nets.resnet_v1的基本学习参数为resnet_arg_scope()
-- resnet_arg_scope()函数设定见:https://www.cnblogs.com/hellcat/p/8060236.html
-- arg_scope用法详解:https://blog.csdn.net/u013921430/article/details/80915696
net, endpoints = nets.resnet_v1.resnet_v1_50(preprocessed_inputs, num_classes=None,is_training=self._is_training)
net = tf.squeeze(net, axis=[1, 2])
-- tf.squeeze():axis指定的维度中,如果有为1的维度就删除
logits = slim.fully_connected(net, num_outputs=self.num_classes,activation_fn=None, scope='Predict/logits')
3、tf.estimator.train_and_evaluate
函数原型:tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
参数解释:
-- 其中 estimator 是一个 评估器(tf.estimator.Estimator) 对象,用于指定模型函数以及其它相关参数;
-- train_spec 是一个 tf.estimator.TrainSpec 对象,用于指定训练的输入函数以及其它参数;
-- eval_spec 是一个 tf.estimator.EvalSpec 对象,用于指定验证的输入函数以及其它参数。
4、tf.estimator.Estimator
函数原型:tf.estimator.Estimator(model_fn=create_model_fn,model_dir=FLAGS.model_dir)
参数解释:
-- 其中 model_fn 是模型函数
-- model_dir 是训练时模型保存的路径;
-- config 是 tf.estimator.RunConfig 的配置对象;
-- params 是传入 model_fn 的超参数字典;
-- warm_start_from 或者是一个预训练文件的路径,或者是一个 tf.estimator.WarmStartSettings 对象,用于完整的配置热启动参数
5、eval_spec:tf.estimator.EvalSpec
函数原型:tf.estimator.EvalSpec(input_fn,steps=100,name=None,hooks=None,exporters=None,start_delay_secs=120,throttle_secs=600)
参数解释:
-- 其中 input_fn 用来提供验证时的输入数据;
-- steps 指定总共验证多少步(一般设定为 None 即可);
-- hooks 用来配置分布式训练等参数;
-- exporters 是一个 Exporter 迭代器,会参与到每次的模型验证;
-- start_delay_secs 指定多少秒之后开始模型验证;
-- throttle_secs 指定多少秒之后重新开始新一轮模型验证(当然,如果没有新的模型断点保存,则该数值秒之后不会进行模型验证,因此这是新一轮模型验证需要等待的最小秒数)。
6、定义estimatorSpec实例estimater:model_fn
函数名和函数结构可以自定义,只要保证函数返回值是类 tf.estimator.EstimatorSpec 的一个实例即可
如:
def create_model_fn(features, labels, mode, params=None):
params = params or {}
loss, train_op, ... = None, None, ...
prediction_dict = ...
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
loss = ...
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = ...
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=prediction_dict,
loss=loss,
train_op=train_op,
...)
7、全部代码
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 11 17:21:12 2018
@author: shirhe-lyh
"""
import tensorflow as tf
from tensorflow.contrib.slim import nets
import preprocessing
slim = tf.contrib.slim
class Model(object):
"""xxx definition."""
def __init__(self, is_training,
num_classes=2,
fixed_resize_side=256,
default_image_size=224):
"""Constructor.
Args:
is_training: A boolean indicating whether the training version of
computation graph should be constructed.
num_classes: Number of classes.
"""
self._num_classes = num_classes
self._is_training = is_training
self._fixed_resize_side = fixed_resize_side
self._default_image_size = default_image_size
@property
def num_classes(self):
return self._num_classes
def preprocess(self, inputs):
"""preprocessing.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
preprocessed_inputs = preprocessing.preprocess_images(
inputs, self._default_image_size, self._default_image_size,
resize_side_min=self._fixed_resize_side,
is_training=self._is_training,
border_expand=False, normalize=False,
preserving_aspect_ratio_resize=False)
preprocessed_inputs = tf.cast(preprocessed_inputs, tf.float32)
return preprocessed_inputs
def predict(self, preprocessed_inputs):
"""Predict prediction tensors from inputs tensor.
Outputs of this function can be passed to loss or postprocess functions.
Args:
preprocessed_inputs: A float32 tensor with shape [batch_size,
height, width, num_channels] representing a batch of images.
Returns:
prediction_dict: A dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
with slim.arg_scope(nets.resnet_v1.resnet_arg_scope()):
net, endpoints = nets.resnet_v1.resnet_v1_50(
preprocessed_inputs, num_classes=None,
is_training=self._is_training)
net = tf.squeeze(net, axis=[1, 2])
logits = slim.fully_connected(net, num_outputs=self.num_classes,
activation_fn=None,
scope='Predict/logits')
return {'logits': logits}
def postprocess(self, prediction_dict):
"""Convert predicted output tensors to final forms.
Args:
prediction_dict: A dictionary holding prediction tensors.
**params: Additional keyword arguments for specific implementations
of specified models.
Returns:
A dictionary containing the postprocessed results.
"""
postprocessed_dict = {}
for logits_name, logits in prediction_dict.items():
logits = tf.nn.softmax(logits)
classes = tf.argmax(logits, axis=1)
classes_name = logits_name.replace('logits', 'classes')
postprocessed_dict[logits_name] = logits
postprocessed_dict[classes_name] = classes
return postprocessed_dict
def loss(self, prediction_dict, groundtruth_lists):
"""Compute scalar loss tensors with respect to provided groundtruth.
Args:
prediction_dict: A dictionary holding prediction tensors.
groundtruth_lists: A list of tensors holding groundtruth
information, with one entry for each branch prediction.
Returns:
A dictionary mapping strings (loss names) to scalar tensors
representing loss values.
"""
logits = prediction_dict.get('logits')
slim.losses.sparse_softmax_cross_entropy(logits, groundtruth_lists)
loss = slim.losses.get_total_loss()
loss_dict = {'loss': loss}
return loss_dict
def accuracy(self, postprocessed_dict, groundtruth_lists):
"""Calculate accuracy.
Args:
postprocessed_dict: A dictionary containing the postprocessed
results
groundtruth_lists: A dict of tensors holding groundtruth
information, with one entry for each image in the batch.
Returns:
accuracy: The scalar accuracy.
"""
classes = postprocessed_dict['classes']
accuracy = tf.reduce_mean(
tf.cast(tf.equal(classes, groundtruth_lists), dtype=tf.float32))
return accuracy