预处理
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
def apply_with_random_selector(x, func, num_cases):
"""Computes func(x, sel), with sel sampled from [0...num_cases-1].
Args:
x: input Tensor.
func: Python function to apply.
num_cases: Python int32, number of cases to sample sel from.
Returns:
The result of func(x, sel), where func receives the value of the
selector as a python integer, but sel is sampled dynamically.
"""
sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
# Pass the real x only to one of the func calls.
return control_flow_ops.merge([
func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
for case in range(num_cases)])[0]
def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
"""Distort the color of a Tensor image.
Each color distortion is non-commutative and thus ordering of the color ops
matters. Ideally we would randomly permute the ordering of the color ops.
Rather then adding that level of complication, we select a distinct ordering
of color ops for each preprocessing thread.
Args:
image: 3-D Tensor containing single image in [0, 1].
color_ordering: Python int, a type of distortion (valid values: 0-3).
fast_mode: Avoids slower ops (random_hue and random_contrast)
scope: Optional scope for name_scope.
Returns:
3-D Tensor color-distorted image on range [0, 1]
Raises:
ValueError: if color_ordering not in [0, 3]
"""
with tf.name_scope(scope, 'distort_color', [image]):
if fast_mode:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
else:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
else:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
elif color_ordering == 1:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
elif color_ordering == 2:
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
elif color_ordering == 3:
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
else:
raise ValueError('color_ordering must be in [0, 3]')
# The random_* ops do not necessarily clamp.
return tf.clip_by_value(image, 0.0, 1.0)
def distorted_bounding_box_crop(image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100,
scope=None):
"""Generates cropped_image using a one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image: 3-D Tensor of image (it will be converted to floats in [0, 1]).
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged
as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole
image.
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
area of the image must contain at least this fraction of any bounding box
supplied.
aspect_ratio_range: An optional list of `floats`. The cropped area of the
image must have an aspect ratio = width / height within this range.
area_range: An optional list of `floats`. The cropped area of the image
must contain a fraction of the supplied image within in this range.
max_attempts: An optional `int`. Number of attempts at generating a cropped
region of the image of the specified constraints. After `max_attempts`
failures, return the entire image.
scope: Optional scope for name_scope.
Returns:
A tuple, a 3-D Tensor cropped_image and the distorted bbox
"""
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].
# A large fraction of image datasets contain a human-annotated bounding
# box delineating the region of the image containing the object of interest.
# We choose to create a new bounding box for the object which is a randomly
# distorted version of the human-annotated bounding box that obeys an
# allowed range of aspect ratios, sizes and overlap with the human-annotated
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.shape(image),
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
# Crop the image to the specified bounding box.
cropped_image = tf.slice(image, bbox_begin, bbox_size)
return cropped_image, distort_bbox
def preprocess_for_train(image, height, width, bbox,
fast_mode=True,
scope=None):
"""Distort one image for training a network.
Distorting images provides a useful technique for augmenting the data
set during training in order to make the network invariant to aspects
of the image that do not effect the label.
Additionally it would create image_summaries to display the different
transformations applied to the image.
Args:
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
[0, 1], otherwise it would converted to tf.float32 assuming that the range
is [0, MAX], where MAX is largest positive representable number for
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
height: integer
width: integer
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged
as [ymin, xmin, ymax, xmax].
fast_mode: Optional boolean, if True avoids slower transformations (i.e.
bi-cubic resizing, random_hue or random_contrast).
scope: Optional scope for name_scope.
Returns:
3-D float Tensor of distorted image used for training with range [-1, 1].
"""
with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
if bbox is None:
bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
dtype=tf.float32,
shape=[1, 1, 4])
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].
image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
bbox)
tf.summary.image('image_with_bounding_boxes', image_with_box)
distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
# Restore the shape since the dynamic slice based upon the bbox_size loses
# the third dimension.
distorted_image.set_shape([None, None, 3])
image_with_distorted_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), distorted_bbox)
tf.summary.image('images_with_distorted_bounding_box',
image_with_distorted_box)
# This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin
# fashion based on the thread number.
# Note that ResizeMethod contains 4 enumerated resizing methods.
# We select only 1 case for fast_mode bilinear.
num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method=method),
num_cases=num_resize_cases)
tf.summary.image('cropped_resized_image',
tf.expand_dims(distorted_image, 0))
# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image)
# Randomly distort the colors. There are 4 ways to do it.
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, ordering: distort_color(x, ordering, fast_mode),
num_cases=4)
tf.summary.image('final_distorted_image',
tf.expand_dims(distorted_image, 0))
distorted_image = tf.subtract(distorted_image, 0.5)
distorted_image = tf.multiply(distorted_image, 2.0)
return distorted_image
def preprocess_for_eval(image, height, width,
central_fraction=0.875, scope=None):
"""Prepare one image for evaluation.
If height and width are specified it would output an image with that size by
applying resize_bilinear.
If central_fraction is specified it would cropt the central fraction of the
input image.
Args:
image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
[0, 1], otherwise it would converted to tf.float32 assuming that the range
is [0, MAX], where MAX is largest positive representable number for
int(8/16/32) data type (see `tf.image.convert_image_dtype` for details)
height: integer
width: integer
central_fraction: Optional Float, fraction of the image to crop.
scope: Optional scope for name_scope.
Returns:
3-D float Tensor of prepared image.
"""
with tf.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Crop the central region of the image with an area containing 87.5% of
# the original image.
if central_fraction:
image = tf.image.central_crop(image, central_fraction=central_fraction)
if height and width:
# Resize the image to the specified height and width.
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, [height, width],
align_corners=False)
image = tf.squeeze(image, [0])
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
return image
def preprocess_image(image, height, width,
is_training=False,
bbox=None,
fast_mode=True):
"""Pre-process one image for training or evaluation.
Args:
image: 3-D Tensor [height, width, channels] with the image.
height: integer, image expected height.
width: integer, image expected width.
is_training: Boolean. If true it would transform an image for train,
otherwise it would transform it for evaluation.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
fast_mode: Optional boolean, if True avoids slower transformations.
Returns:
3-D float Tensor containing an appropriately scaled image
Raises:
ValueError: if user does not provide bounding box
"""
if is_training:
return preprocess_for_train(image, height, width, bbox, fast_mode)
else:
return preprocess_for_eval(image, height, width)
网络模型的构建
模型结构是将Inception结构与Resnet结构相结合。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
"""Builds the 35x35 resnet block."""
with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
with tf.variable_scope('Branch_1'):
tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
with tf.variable_scope('Branch_2'):
tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3')
tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3')
mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_1, tower_conv2_2])
up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
activation_fn=None, scope='Conv2d_1x1')
net += scale * up
if activation_fn:
net = activation_fn(net)
return net
def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
"""Builds the 17x17 resnet block."""
with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
with tf.variable_scope('Branch_1'):
tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7],
scope='Conv2d_0b_1x7')
tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1],
scope='Conv2d_0c_7x1')
mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
activation_fn=None, scope='Conv2d_1x1')
net += scale * up
if activation_fn:
net = activation_fn(net)
return net
def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
"""Builds the 8x8 resnet block."""
with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
with tf.variable_scope('Branch_1'):
tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3],
scope='Conv2d_0b_1x3')
tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1],
scope='Conv2d_0c_3x1')
mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
activation_fn=None, scope='Conv2d_1x1')
net += scale * up
if activation_fn:
net = activation_fn(net)
return net
def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
dropout_keep_prob=0.8,
reuse=None,
scope='InceptionResnetV2'):
"""Creates the Inception Resnet V2 model.
Args:
inputs: a 4-D tensor of size [batch_size, height, width, 3].
num_classes: number of predicted classes.
is_training: whether is training or not.
dropout_keep_prob: float, the fraction to keep before final layer.
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
Returns:
logits: the logits outputs of the model.
end_points: the set of end_points from the inception model.
"""
end_points = {}
with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse):
with slim.arg_scope([slim.batch_norm, slim.dropout],
is_training=is_training):
with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
stride=1, padding='SAME'):
# 149 x 149 x 32
net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID',
scope='Conv2d_1a_3x3')
end_points['Conv2d_1a_3x3'] = net
# 147 x 147 x 32
net = slim.conv2d(net, 32, 3, padding='VALID',
scope='Conv2d_2a_3x3')
end_points['Conv2d_2a_3x3'] = net
# 147 x 147 x 64
net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
end_points['Conv2d_2b_3x3'] = net
# 73 x 73 x 64
net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
scope='MaxPool_3a_3x3')
end_points['MaxPool_3a_3x3'] = net
# 73 x 73 x 80
net = slim.conv2d(net, 80, 1, padding='VALID',
scope='Conv2d_3b_1x1')
end_points['Conv2d_3b_1x1'] = net
# 71 x 71 x 192
net = slim.conv2d(net, 192, 3, padding='VALID',
scope='Conv2d_4a_3x3')
end_points['Conv2d_4a_3x3'] = net
# 35 x 35 x 192
net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
scope='MaxPool_5a_3x3')
end_points['MaxPool_5a_3x3'] = net
# 35 x 35 x 320
with tf.variable_scope('Mixed_5b'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
with tf.variable_scope('Branch_1'):
tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
scope='Conv2d_0b_5x5')
with tf.variable_scope('Branch_2'):
tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
scope='Conv2d_0b_3x3')
tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
scope='Conv2d_0c_3x3')
with tf.variable_scope('Branch_3'):
tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
scope='AvgPool_0a_3x3')
tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
scope='Conv2d_0b_1x1')
net = tf.concat(axis=3, values=[tower_conv, tower_conv1_1,
tower_conv2_2, tower_pool_1])
end_points['Mixed_5b'] = net
net = slim.repeat(net, 10, block35, scale=0.17)
# 17 x 17 x 1024
with tf.variable_scope('Mixed_6a'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID',
scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_1'):
tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,
scope='Conv2d_0b_3x3')
tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,
stride=2, padding='VALID',
scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_2'):
tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
scope='MaxPool_1a_3x3')
net = tf.concat(axis=3, values=[tower_conv, tower_conv1_2, tower_pool])
end_points['Mixed_6a'] = net
net = slim.repeat(net, 20, block17, scale=0.10)
# Auxillary tower
with tf.variable_scope('AuxLogits'):
aux = slim.avg_pool2d(net, 5, stride=3, padding='VALID',
scope='Conv2d_1a_3x3')
aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1')
aux = slim.conv2d(aux, 768, aux.get_shape()[1:3],
padding='VALID', scope='Conv2d_2a_5x5')
aux = slim.flatten(aux)
aux = slim.fully_connected(aux, num_classes, activation_fn=None,
scope='Logits')
end_points['AuxLogits'] = aux
with tf.variable_scope('Mixed_7a'):
with tf.variable_scope('Branch_0'):
tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
padding='VALID', scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_1'):
tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,
padding='VALID', scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_2'):
tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
scope='Conv2d_0b_3x3')
tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,
padding='VALID', scope='Conv2d_1a_3x3')
with tf.variable_scope('Branch_3'):
tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
scope='MaxPool_1a_3x3')
net = tf.concat(axis=3, values=[tower_conv_1, tower_conv1_1,
tower_conv2_2, tower_pool])
end_points['Mixed_7a'] = net
net = slim.repeat(net, 9, block8, scale=0.20)
net = block8(net, activation_fn=None)
net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')
end_points['Conv2d_7b_1x1'] = net
with tf.variable_scope('Logits'):
end_points['PrePool'] = net
net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
scope='AvgPool_1a_8x8')
net = slim.flatten(net)
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
scope='Dropout')
end_points['PreLogitsFlatten'] = net
logits = slim.fully_connected(net, num_classes, activation_fn=None,
scope='Logits')
end_points['Logits'] = logits
end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
return logits, end_points
inception_resnet_v2.default_image_size = 299
def inception_resnet_v2_arg_scope(weight_decay=0.00004,
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001):
"""Yields the scope with the default parameters for inception_resnet_v2.
Args:
weight_decay: the weight decay for weights variables.
batch_norm_decay: decay for the moving average of batch_norm momentums.
batch_norm_epsilon: small float added to variance to avoid dividing by zero.
Returns:
a arg_scope with the parameters needed for inception_resnet_v2.
"""
# Set weight_decay for weights in conv2d and fully_connected layers.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
biases_regularizer=slim.l2_regularizer(weight_decay)):
batch_norm_params = {
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon,
}
# Set activation_fn and parameters for batch_norm.
with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params) as scope:
return scope
网络的训练
import tensorflow as tf
from tensorflow.contrib.framework.python.ops.variables import get_or_create_global_step
from tensorflow.python.platform import tf_logging as logging
import inception_preprocessing
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
import os
import time
slim = tf.contrib.slim
#================ DATASET INFORMATION ======================
#State dataset directory where the tfrecord files are located
dataset_dir = '.'
#State where your log file is at. If it doesn't exist, create it.
log_dir = './log'
#State where your checkpoint file is
checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'
#State the image size you're resizing your images to. We will use the default inception size of 299.
image_size = 299
#State the number of classes to predict:
num_classes = 5
#State the labels file and read it
labels_file = './labels.txt'
labels = open(labels_file, 'r')
#Create a dictionary to refer each label to their string name
labels_to_name = {}
for line in labels:
label, string_name = line.split(':')
string_name = string_name[:-1] #Remove newline
labels_to_name[int(label)] = string_name
#Create the file pattern of your TFRecord files so that it could be recognized later on
file_pattern = 'flowers_%s_*.tfrecord'
#Create a dictionary that will help people understand your dataset better. This is required by the Dataset class later.
items_to_descriptions = {
'image': 'A 3-channel RGB coloured flower image that is either tulips, sunflowers, roses, dandelion, or daisy.',
'label': 'A label that is as such -- 0:daisy, 1:dandelion, 2:roses, 3:sunflowers, 4:tulips'
}
#================= TRAINING INFORMATION ==================
#State the number of epochs to train
num_epochs = 1
#State your batch size
batch_size = 8
#Learning rate information and configuration (Up to you to experiment)
initial_learning_rate = 0.0002
learning_rate_decay_factor = 0.7
num_epochs_before_decay = 2
#============== DATASET LOADING ======================
#We now create a function that creates a Dataset class which will give us many TFRecord files to feed in the examples into a queue in parallel.
def get_split(split_name, dataset_dir, file_pattern=file_pattern, file_pattern_for_counting='flowers'):
'''
Obtains the split - training or validation - to create a Dataset class for feeding the examples into a queue later on. This function will
set up the decoder and dataset information all into one Dataset class so that you can avoid the brute work later on.
Your file_pattern is very important in locating the files later.
INPUTS:
- split_name(str): 'train' or 'validation'. Used to get the correct data split of tfrecord files
- dataset_dir(str): the dataset directory where the tfrecord files are located
- file_pattern(str): the file name structure of the tfrecord files in order to get the correct data
- file_pattern_for_counting(str): the string name to identify your tfrecord files for counting
OUTPUTS:
- dataset (Dataset): A Dataset class object where we can read its various components for easier batch creation later.
'''
#First check whether the split_name is train or validation
if split_name not in ['train', 'validation']:
raise ValueError('The split_name %s is not recognized. Please input either train or validation as the split_name' % (split_name))
#Create the full path for a general file_pattern to locate the tfrecord_files
file_pattern_path = os.path.join(dataset_dir, file_pattern % (split_name))
#Count the total number of examples in all of these shard
num_samples = 0
file_pattern_for_counting = file_pattern_for_counting + '_' + split_name
tfrecords_to_count = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir) if file.startswith(file_pattern_for_counting)]
for tfrecord_file in tfrecords_to_count:
for record in tf.python_io.tf_record_iterator(tfrecord_file):
num_samples += 1
#Create a reader, which must be a TFRecord reader in this case
reader = tf.TFRecordReader
#Create the keys_to_features dictionary for the decoder
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
#Create the items_to_handlers dictionary for the decoder.
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
#Start to create the decoder
decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
#Create the labels_to_name file
labels_to_name_dict = labels_to_name
#Actually create the dataset
dataset = slim.dataset.Dataset(
data_sources = file_pattern_path,
decoder = decoder,
reader = reader,
num_readers = 4,
num_samples = num_samples,
num_classes = num_classes,
labels_to_name = labels_to_name_dict,
items_to_descriptions = items_to_descriptions)
return dataset
def load_batch(dataset, batch_size, height=image_size, width=image_size, is_training=True):
'''
Loads a batch for training.
INPUTS:
- dataset(Dataset): a Dataset class object that is created from the get_split function
- batch_size(int): determines how big of a batch to train
- height(int): the height of the image to resize to during preprocessing
- width(int): the width of the image to resize to during preprocessing
- is_training(bool): to determine whether to perform a training or evaluation preprocessing
OUTPUTS:
- images(Tensor): a Tensor of the shape (batch_size, height, width, channels) that contain one batch of images
- labels(Tensor): the batch's labels with the shape (batch_size,) (requires one_hot_encoding).
'''
#First create the data_provider object
data_provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
common_queue_capacity = 24 + 3 * batch_size,
common_queue_min = 24)
#Obtain the raw image using the get method
raw_image, label = data_provider.get(['image', 'label'])
#Perform the correct preprocessing for this image depending if it is training or evaluating
image = inception_preprocessing.preprocess_image(raw_image, height, width, is_training)
#As for the raw images, we just do a simple reshape to batch it up
raw_image = tf.expand_dims(raw_image, 0)
raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])
raw_image = tf.squeeze(raw_image)
#Batch up the image by enqueing the tensors internally in a FIFO queue and dequeueing many elements with tf.train.batch.
images, raw_images, labels = tf.train.batch(
[image, raw_image, label],
batch_size = batch_size,
num_threads = 4,
capacity = 4 * batch_size,
allow_smaller_final_batch = True)
return images, raw_images, labels
def run():
#Create the log directory here. Must be done here otherwise import will activate this unneededly.
if not os.path.exists(log_dir):
os.mkdir(log_dir)
#======================= TRAINING PROCESS =========================
#Now we start to construct the graph and build our model
with tf.Graph().as_default() as graph:
tf.logging.set_verbosity(tf.logging.INFO) #Set the verbosity to INFO level
#First create the dataset and load one batch
dataset = get_split('train', dataset_dir, file_pattern=file_pattern)
images, _, labels = load_batch(dataset, batch_size=batch_size)
#Know the number steps to take before decaying the learning rate and batches per epoch
num_batches_per_epoch = int(dataset.num_samples / batch_size)
num_steps_per_epoch = num_batches_per_epoch #Because one step is one batch processed
decay_steps = int(num_epochs_before_decay * num_steps_per_epoch)
#Create the model inference
with slim.arg_scope(inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2(images, num_classes = dataset.num_classes, is_training = True)
#Define the scopes that you want to exclude for restoration
exclude = ['InceptionResnetV2/Logits', 'InceptionResnetV2/AuxLogits']
variables_to_restore = slim.get_variables_to_restore(exclude = exclude)
#Perform one-hot-encoding of the labels (Try one-hot-encoding within the load_batch function!)
one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
#Performs the equivalent to tf.nn.sparse_softmax_cross_entropy_with_logits but enhanced with checks
loss = tf.losses.softmax_cross_entropy(onehot_labels = one_hot_labels, logits = logits)
total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well
#Create the global step for monitoring the learning_rate and training.
global_step = get_or_create_global_step()
#Define your exponentially decaying learning rate
lr = tf.train.exponential_decay(
learning_rate = initial_learning_rate,
global_step = global_step,
decay_steps = decay_steps,
decay_rate = learning_rate_decay_factor,
staircase = True)
#Now we can define the optimizer that takes on the learning rate
optimizer = tf.train.AdamOptimizer(learning_rate = lr)
#Create the train_op.
train_op = slim.learning.create_train_op(total_loss, optimizer)
#State the metrics that you want to predict. We get a predictions that is not one_hot_encoded.
predictions = tf.argmax(end_points['Predictions'], 1)
probabilities = end_points['Predictions']
accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)
metrics_op = tf.group(accuracy_update, probabilities)
#Now finally create all the summaries you need to monitor and group them into one summary op.
tf.summary.scalar('losses/Total_Loss', total_loss)
tf.summary.scalar('accuracy', accuracy)
tf.summary.scalar('learning_rate', lr)
my_summary_op = tf.summary.merge_all()
#Now we need to create a training step function that runs both the train_op, metrics_op and updates the global_step concurrently.
def train_step(sess, train_op, global_step):
'''
Simply runs a session for the three arguments provided and gives a logging on the time elapsed for each global step
'''
#Check the time for each sess run
start_time = time.time()
total_loss, global_step_count, _ = sess.run([train_op, global_step, metrics_op])
time_elapsed = time.time() - start_time
#Run the logging to print some results
logging.info('global step %s: loss: %.4f (%.2f sec/step)', global_step_count, total_loss, time_elapsed)
return total_loss, global_step_count
#Now we create a saver function that actually restores the variables from a checkpoint file in a sess
saver = tf.train.Saver(variables_to_restore)
def restore_fn(sess):
return saver.restore(sess, checkpoint_file)
#Define your supervisor for running a managed session. Do not run the summary_op automatically or else it will consume too much memory
sv = tf.train.Supervisor(logdir = log_dir, summary_op = None, init_fn = restore_fn)
#Run the managed session
with sv.managed_session() as sess:
for step in xrange(num_steps_per_epoch * num_epochs):
#At the start of every epoch, show the vital information:
if step % num_batches_per_epoch == 0:
logging.info('Epoch %s/%s', step/num_batches_per_epoch + 1, num_epochs)
learning_rate_value, accuracy_value = sess.run([lr, accuracy])
logging.info('Current Learning Rate: %s', learning_rate_value)
logging.info('Current Streaming Accuracy: %s', accuracy_value)
# optionally, print your logits and predictions for a sanity check that things are going fine.
logits_value, probabilities_value, predictions_value, labels_value = sess.run([logits, probabilities, predictions, labels])
print 'logits: \n', logits_value
print 'Probabilities: \n', probabilities_value
print 'predictions: \n', predictions_value
print 'Labels:\n:', labels_value
#Log the summaries every 10 step.
if step % 10 == 0:
loss, _ = train_step(sess, train_op, sv.global_step)
summaries = sess.run(my_summary_op)
sv.summary_computed(sess, summaries)
#If not, simply run the training step
else:
loss, _ = train_step(sess, train_op, sv.global_step)
#We log the final training loss and accuracy
logging.info('Final Loss: %s', loss)
logging.info('Final Accuracy: %s', sess.run(accuracy))
#Once all the training has been done, save the log files and checkpoint model
logging.info('Finished training! Saving model to disk now.')
# saver.save(sess, "./flowers_model.ckpt")
sv.saver.save(sess, sv.save_path, global_step = sv.global_step)
if __name__ == '__main__':
run()
网络模型的验证
import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from tensorflow.contrib.framework.python.ops.variables import get_or_create_global_step
import inception_preprocessing
from inception_resnet_v2 import inception_resnet_v2, inception_resnet_v2_arg_scope
import time
import os
from train_flowers import get_split, load_batch
import matplotlib.pyplot as plt
plt.style.use('ggplot')
slim = tf.contrib.slim
#State your log directory where you can retrieve your model
log_dir = './log'
#Create a new evaluation log directory to visualize the validation process
log_eval = './log_eval_test'
#State the dataset directory where the validation set is found
dataset_dir = '.'
#State the batch_size to evaluate each time, which can be a lot more than the training batch
batch_size = 36
#State the number of epochs to evaluate
num_epochs = 3
#Get the latest checkpoint file
checkpoint_file = tf.train.latest_checkpoint(log_dir)
def run():
#Create log_dir for evaluation information
if not os.path.exists(log_eval):
os.mkdir(log_eval)
#Just construct the graph from scratch again
with tf.Graph().as_default() as graph:
tf.logging.set_verbosity(tf.logging.INFO)
#Get the dataset first and load one batch of validation images and labels tensors. Set is_training as False so as to use the evaluation preprocessing
dataset = get_split('validation', dataset_dir)
images, raw_images, labels = load_batch(dataset, batch_size = batch_size, is_training = False)
#Create some information about the training steps
num_batches_per_epoch = dataset.num_samples / batch_size
num_steps_per_epoch = num_batches_per_epoch
#Now create the inference model but set is_training=False
with slim.arg_scope(inception_resnet_v2_arg_scope()):
logits, end_points = inception_resnet_v2(images, num_classes = dataset.num_classes, is_training = False)
# #get all the variables to restore from the checkpoint file and create the saver function to restore
variables_to_restore = slim.get_variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
def restore_fn(sess):
return saver.restore(sess, checkpoint_file)
#Just define the metrics to track without the loss or whatsoever
predictions = tf.argmax(end_points['Predictions'], 1)
accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)
metrics_op = tf.group(accuracy_update)
#Create the global step and an increment op for monitoring
global_step = get_or_create_global_step()
global_step_op = tf.assign(global_step, global_step + 1) #no apply_gradient method so manually increasing the global_step
#Create a evaluation step function
def eval_step(sess, metrics_op, global_step):
'''
Simply takes in a session, runs the metrics op and some logging information.
'''
start_time = time.time()
_, global_step_count, accuracy_value = sess.run([metrics_op, global_step_op, accuracy])
time_elapsed = time.time() - start_time
#Log some information
logging.info('Global Step %s: Streaming Accuracy: %.4f (%.2f sec/step)', global_step_count, accuracy_value, time_elapsed)
return accuracy_value
#Define some scalar quantities to monitor
tf.summary.scalar('Validation_Accuracy', accuracy)
my_summary_op = tf.summary.merge_all()
#Get your supervisor
sv = tf.train.Supervisor(logdir = log_eval, summary_op = None, saver = None, init_fn = restore_fn)
#Now we are ready to run in one session
with sv.managed_session() as sess:
for step in xrange(num_steps_per_epoch * num_epochs):
sess.run(sv.global_step)
#print vital information every start of the epoch as always
if step % num_batches_per_epoch == 0:
logging.info('Epoch: %s/%s', step / num_batches_per_epoch + 1, num_epochs)
logging.info('Current Streaming Accuracy: %.4f', sess.run(accuracy))
#Compute summaries every 10 steps and continue evaluating
if step % 10 == 0:
eval_step(sess, metrics_op = metrics_op, global_step = sv.global_step)
summaries = sess.run(my_summary_op)
sv.summary_computed(sess, summaries)
#Otherwise just run as per normal
else:
eval_step(sess, metrics_op = metrics_op, global_step = sv.global_step)
#At the end of all the evaluation, show the final accuracy
logging.info('Final Streaming Accuracy: %.4f', sess.run(accuracy))
#Now we want to visualize the last batch's images just to see what our model has predicted
raw_images, labels, predictions = sess.run([raw_images, labels, predictions])
for i in range(10):
image, label, prediction = raw_images[i], labels[i], predictions[i]
prediction_name, label_name = dataset.labels_to_name[prediction], dataset.labels_to_name[label]
text = 'Prediction: %s \n Ground Truth: %s' %(prediction_name, label_name)
img_plot = plt.imshow(image)
#Set up the plot and hide axes
plt.title(text)
img_plot.axes.get_yaxis().set_ticks([])
img_plot.axes.get_xaxis().set_ticks([])
plt.show()
logging.info('Model evaluation has completed! Visit TensorBoard for more information regarding your evaluation.')
if __name__ == '__main__':
run()