https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py
源码分析:
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- import argparse
- from datetime import datetime
- import hashlib
- import os.path
- import random
- import re
- import struct
- import sys
- import tarfile
-
- import numpy as np
- from six.moves import urllib
- import tensorflow as tf
-
- from tensorflow.python.framework import graph_util
- from tensorflow.python.framework import tensor_shape
- from tensorflow.python.platform import gfile
- from tensorflow.python.util import compat
-
- FLAGS = None
-
-
-
-
- DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
-
- BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
- BOTTLENECK_TENSOR_SIZE = 2048
- MODEL_INPUT_WIDTH = 299
- MODEL_INPUT_HEIGHT = 299
- MODEL_INPUT_DEPTH = 3
- JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
- RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear:0'
- MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1
-
-
- def create_image_lists(image_dir, testing_percentage, validation_percentage):
-
-
-
-
-
-
-
-
- if not gfile.Exists(image_dir):
- print("Image directory '" + image_dir + "' not found.")
- return None
- result = {}
- sub_dirs = [x[0] for x in gfile.Walk(image_dir)]
-
- is_root_dir = True
- for sub_dir in sub_dirs:
- if is_root_dir:
- is_root_dir = False
- continue
- extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
- file_list = []
- dir_name = os.path.basename(sub_dir)
- if dir_name == image_dir:
- continue
- print("Looking for images in '" + dir_name + "'")
- for extension in extensions:
- file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
- file_list.extend(gfile.Glob(file_glob))
- if not file_list:
- print('No files found')
- continue
- if len(file_list) < 20:
- print('WARNING: Folder has less than 20 images, which may cause issues.')
- elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
- print('WARNING: Folder {} has more than {} images. Some images will '
- 'never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
- label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
- training_images = []
- testing_images = []
- validation_images = []
- for file_name in file_list:
- base_name = os.path.basename(file_name)
-
-
- hash_name = re.sub(r'_nohash_.*$', '', file_name)
-
-
- hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
- percentage_hash = ((int(hash_name_hashed, 16) %
- (MAX_NUM_IMAGES_PER_CLASS + 1)) *
- (100.0 / MAX_NUM_IMAGES_PER_CLASS))
- if percentage_hash < validation_percentage:
- validation_images.append(base_name)
- elif percentage_hash < (testing_percentage + validation_percentage):
- testing_images.append(base_name)
- else:
- training_images.append(base_name)
- result[label_name] = {
- 'dir': dir_name,
- 'training': training_images,
- 'testing': testing_images,
- 'validation': validation_images,
- }
- return result
-
-
- def get_image_path(image_lists, label_name, index, image_dir, category):
-
-
-
-
-
-
-
-
-
-
- if label_name not in image_lists:
- tf.logging.fatal('Label does not exist %s.', label_name)
- label_lists = image_lists[label_name]
- if category not in label_lists:
- tf.logging.fatal('Category does not exist %s.', category)
- category_list = label_lists[category]
- if not category_list:
- tf.logging.fatal('Label %s has no images in the category %s.',
- label_name, category)
- mod_index = index % len(category_list)
- base_name = category_list[mod_index]
- sub_dir = label_lists['dir']
- full_path = os.path.join(image_dir, sub_dir, base_name)
- return full_path
-
-
- def get_bottleneck_path(image_lists, label_name, index, bottleneck_dir,
- category):
-
-
-
-
-
-
-
-
-
-
- return get_image_path(image_lists, label_name, index, bottleneck_dir,
- category) + '.txt'
-
-
- def create_inception_graph():
-
-
-
-
- with tf.Session() as sess:
- model_filename = os.path.join(
- FLAGS.model_dir, 'classify_image_graph_def.pb')
- with gfile.FastGFile(model_filename, 'rb') as f:
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(f.read())
- bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
- tf.import_graph_def(graph_def, name='', return_elements=[
- BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
- RESIZED_INPUT_TENSOR_NAME]))
- return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor
-
-
- def run_bottleneck_on_image(sess, image_data, image_data_tensor,
- bottleneck_tensor):
-
-
-
-
-
-
-
-
-
- bottleneck_values = sess.run(
- bottleneck_tensor,
- {image_data_tensor: image_data})
- bottleneck_values = np.squeeze(bottleneck_values)
- return bottleneck_values
-
-
- def maybe_download_and_extract():
-
-
-
- dest_directory = FLAGS.model_dir
- if not os.path.exists(dest_directory):
- os.makedirs(dest_directory)
- filename = DATA_URL.split('/')[-1]
- filepath = os.path.join(dest_directory, filename)
- if not os.path.exists(filepath):
-
- def _progress(count, block_size, total_size):
- sys.stdout.write('\r>> Downloading %s %.1f%%' %
- (filename,
- float(count * block_size) / float(total_size) * 100.0))
- sys.stdout.flush()
-
- filepath, _ = urllib.request.urlretrieve(DATA_URL,
- filepath,
- _progress)
- print()
- statinfo = os.stat(filepath)
- print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
- tarfile.open(filepath, 'r:gz').extractall(dest_directory)
-
-
- def ensure_dir_exists(dir_name):
-
-
-
-
- if not os.path.exists(dir_name):
- os.makedirs(dir_name)
-
-
- def write_list_of_floats_to_file(list_of_floats , file_path):
-
-
-
-
-
-
- s = struct.pack('d' * BOTTLENECK_TENSOR_SIZE, *list_of_floats)
- with open(file_path, 'wb') as f:
- f.write(s)
-
-
- def read_list_of_floats_from_file(file_path):
-
-
-
-
-
- with open(file_path, 'rb') as f:
- s = struct.unpack('d' * BOTTLENECK_TENSOR_SIZE, f.read())
- return list(s)
-
-
- bottleneck_path_2_bottleneck_values = {}
-
- def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
- image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor):
- print('Creating bottleneck at ' + bottleneck_path)
- image_path = get_image_path(image_lists, label_name, index, image_dir, category)
- if not gfile.Exists(image_path):
- tf.logging.fatal('File does not exist %s', image_path)
- image_data = gfile.FastGFile(image_path, 'rb').read()
- bottleneck_values = run_bottleneck_on_image(sess, image_data, jpeg_data_tensor, bottleneck_tensor)
- bottleneck_string = ','.join(str(x) for x in bottleneck_values)
- with open(bottleneck_path, 'w') as bottleneck_file:
- bottleneck_file.write(bottleneck_string)
-
- def get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir,
- category, bottleneck_dir, jpeg_data_tensor,
- bottleneck_tensor):
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- label_lists = image_lists[label_name]
- sub_dir = label_lists['dir']
- sub_dir_path = os.path.join(bottleneck_dir, sub_dir)
- ensure_dir_exists(sub_dir_path)
- bottleneck_path = get_bottleneck_path(image_lists, label_name, index, bottleneck_dir, category)
- if not os.path.exists(bottleneck_path):
- create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
- with open(bottleneck_path, 'r') as bottleneck_file:
- bottleneck_string = bottleneck_file.read()
- did_hit_error = False
- try:
- bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
- except:
- print("Invalid float found, recreating bottleneck")
- did_hit_error = True
- if did_hit_error:
- create_bottleneck_file(bottleneck_path, image_lists, label_name, index, image_dir, category, sess, jpeg_data_tensor, bottleneck_tensor)
- with open(bottleneck_path, 'r') as bottleneck_file:
- bottleneck_string = bottleneck_file.read()
-
- bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
- return bottleneck_values
-
- def cache_bottlenecks(sess, image_lists, image_dir, bottleneck_dir,
- jpeg_data_tensor, bottleneck_tensor):
-
-
-
-
-
-
-
-
-
-
-
-
- how_many_bottlenecks = 0
- ensure_dir_exists(bottleneck_dir)
- for label_name, label_lists in image_lists.items():
- for category in ['training', 'testing', 'validation']:
- category_list = label_lists[category]
- for index, unused_base_name in enumerate(category_list):
- get_or_create_bottleneck(sess, image_lists, label_name, index,
- image_dir, category, bottleneck_dir,
- jpeg_data_tensor, bottleneck_tensor)
-
- how_many_bottlenecks += 1
- if how_many_bottlenecks % 100 == 0:
- print(str(how_many_bottlenecks) + ' bottleneck files created.')
-
-
- def get_random_cached_bottlenecks(sess, image_lists, how_many, category,
- bottleneck_dir, image_dir, jpeg_data_tensor,
- bottleneck_tensor):
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- class_count = len(image_lists.keys())
- bottlenecks = []
- ground_truths = []
- filenames = []
- if how_many >= 0:
-
- for unused_i in range(how_many):
- label_index = random.randrange(class_count)
- label_name = list(image_lists.keys())[label_index]
- image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
- image_name = get_image_path(image_lists, label_name, image_index,
- image_dir, category)
- bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
- image_index, image_dir, category,
- bottleneck_dir, jpeg_data_tensor,
- bottleneck_tensor)
- ground_truth = np.zeros(class_count, dtype=np.float32)
- ground_truth[label_index] = 1.0
- bottlenecks.append(bottleneck)
- ground_truths.append(ground_truth)
- filenames.append(image_name)
- else:
-
- for label_index, label_name in enumerate(image_lists.keys()):
- for image_index, image_name in enumerate(
- image_lists[label_name][category]):
- image_name = get_image_path(image_lists, label_name, image_index,
- image_dir, category)
- bottleneck = get_or_create_bottleneck(sess, image_lists, label_name,
- image_index, image_dir, category,
- bottleneck_dir, jpeg_data_tensor,
- bottleneck_tensor)
- ground_truth = np.zeros(class_count, dtype=np.float32)
- ground_truth[label_index] = 1.0
- bottlenecks.append(bottleneck)
- ground_truths.append(ground_truth)
- filenames.append(image_name)
- return bottlenecks, ground_truths, filenames
-
-
- def get_random_distorted_bottlenecks(
- sess, image_lists, how_many, category, image_dir, input_jpeg_tensor,
- distorted_image, resized_input_tensor, bottleneck_tensor):
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- class_count = len(image_lists.keys())
- bottlenecks = []
- ground_truths = []
- for unused_i in range(how_many):
- label_index = random.randrange(class_count)
- label_name = list(image_lists.keys())[label_index]
- image_index = random.randrange(MAX_NUM_IMAGES_PER_CLASS + 1)
- image_path = get_image_path(image_lists, label_name, image_index, image_dir,
- category)
- if not gfile.Exists(image_path):
- tf.logging.fatal('File does not exist %s', image_path)
- jpeg_data = gfile.FastGFile(image_path, 'rb').read()
-
- distorted_image_data = sess.run(distorted_image,
- {input_jpeg_tensor: jpeg_data})
- bottleneck = run_bottleneck_on_image(sess, distorted_image_data,
- resized_input_tensor,
- bottleneck_tensor)
- ground_truth = np.zeros(class_count, dtype=np.float32)
- ground_truth[label_index] = 1.0
- bottlenecks.append(bottleneck)
- ground_truths.append(ground_truth)
- return bottlenecks, ground_truths
-
-
- def should_distort_images(flip_left_right, random_crop, random_scale,
- random_brightness):
-
-
-
-
-
-
-
-
-
- return (flip_left_right or (random_crop != 0) or (random_scale != 0) or
- (random_brightness != 0))
-
-
- def add_input_distortions(flip_left_right, random_crop, random_scale,
- random_brightness):
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput')
- decoded_image = tf.image.decode_jpeg(jpeg_data, channels=MODEL_INPUT_DEPTH)
- decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
- decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
- margin_scale = 1.0 + (random_crop / 100.0)
- resize_scale = 1.0 + (random_scale / 100.0)
- margin_scale_value = tf.constant(margin_scale)
- resize_scale_value = tf.random_uniform(tensor_shape.scalar(),
- minval=1.0,
- maxval=resize_scale)
- scale_value = tf.multiply(margin_scale_value, resize_scale_value)
- precrop_width = tf.multiply(scale_value, MODEL_INPUT_WIDTH)
- precrop_height = tf.multiply(scale_value, MODEL_INPUT_HEIGHT)
- precrop_shape = tf.stack([precrop_height, precrop_width])
- precrop_shape_as_int = tf.cast(precrop_shape, dtype=tf.int32)
- precropped_image = tf.image.resize_bilinear(decoded_image_4d,
- precrop_shape_as_int)
- precropped_image_3d = tf.squeeze(precropped_image, squeeze_dims=[0])
- cropped_image = tf.random_crop(precropped_image_3d,
- [MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH,
- MODEL_INPUT_DEPTH])
- if flip_left_right:
- flipped_image = tf.image.random_flip_left_right(cropped_image)
- else:
- flipped_image = cropped_image
- brightness_min = 1.0 - (random_brightness / 100.0)
- brightness_max = 1.0 + (random_brightness / 100.0)
- brightness_value = tf.random_uniform(tensor_shape.scalar(),
- minval=brightness_min,
- maxval=brightness_max)
- brightened_image = tf.multiply(flipped_image, brightness_value)
- distort_result = tf.expand_dims(brightened_image, 0, name='DistortResult')
- return jpeg_data, distort_result
-
-
- def variable_summaries(var):
-
- with tf.name_scope('summaries'):
- mean = tf.reduce_mean(var)
- tf.summary.scalar('mean', mean)
- with tf.name_scope('stddev'):
- stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
- tf.summary.scalar('stddev', stddev)
- tf.summary.scalar('max', tf.reduce_max(var))
- tf.summary.scalar('min', tf.reduce_min(var))
- tf.summary.histogram('histogram', var)
-
-
- def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor):
-
-
-
-
-
-
-
-
-
-
-
-
-
- with tf.name_scope('input'):
- bottleneck_input = tf.placeholder_with_default(
- bottleneck_tensor, shape=[None, BOTTLENECK_TENSOR_SIZE],
- name='BottleneckInputPlaceholder')
-
- ground_truth_input = tf.placeholder(tf.float32,
- [None, class_count],
- name='GroundTruthInput')
-
- layer_name = 'final_training_ops'
- with tf.name_scope(layer_name):
- with tf.name_scope('weights'):
- layer_weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, class_count], stddev=0.001), name='final_weights')
- variable_summaries(layer_weights)
- with tf.name_scope('biases'):
- layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
- variable_summaries(layer_biases)
- with tf.name_scope('Wx_plus_b'):
- logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
- tf.summary.histogram('pre_activations', logits)
-
- final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
- tf.summary.histogram('activations', final_tensor)
-
- with tf.name_scope('cross_entropy'):
- cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
- labels=ground_truth_input, logits=logits)
- with tf.name_scope('total'):
- cross_entropy_mean = tf.reduce_mean(cross_entropy)
- tf.summary.scalar('cross_entropy', cross_entropy_mean)
-
- with tf.name_scope('train'):
- train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(
- cross_entropy_mean)
-
- return (train_step, cross_entropy_mean, bottleneck_input, ground_truth_input,
- final_tensor)
-
-
- def add_evaluation_step(result_tensor, ground_truth_tensor):
-
-
-
-
-
-
-
- with tf.name_scope('accuracy'):
- with tf.name_scope('correct_prediction'):
- prediction = tf.argmax(result_tensor, 1)
- correct_prediction = tf.equal(
- prediction, tf.argmax(ground_truth_tensor, 1))
- with tf.name_scope('accuracy'):
- evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- tf.summary.scalar('accuracy', evaluation_step)
- return evaluation_step, prediction
-
-
- def main(_):
-
- if tf.gfile.Exists(FLAGS.summaries_dir):
- tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
- tf.gfile.MakeDirs(FLAGS.summaries_dir)
-
-
- maybe_download_and_extract()
- graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = (
- create_inception_graph())
-
-
- image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage,
- FLAGS.validation_percentage)
- class_count = len(image_lists.keys())
- if class_count == 0:
- print('No valid folders of images found at ' + FLAGS.image_dir)
- return -1
- if class_count == 1:
- print('Only one valid folder of images found at ' + FLAGS.image_dir +
- ' - multiple classes are needed for classification.')
- return -1
-
-
- do_distort_images = should_distort_images(
- FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
- FLAGS.random_brightness)
- sess = tf.Session()
-
- if do_distort_images:
-
- distorted_jpeg_data_tensor, distorted_image_tensor = add_input_distortions(
- FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale,
- FLAGS.random_brightness)
- else:
-
- cache_bottlenecks(sess, image_lists, FLAGS.image_dir, FLAGS.bottleneck_dir,
- jpeg_data_tensor, bottleneck_tensor)
-
-
- (train_step, cross_entropy, bottleneck_input, ground_truth_input,
- final_tensor) = add_final_training_ops(len(image_lists.keys()),
- FLAGS.final_tensor_name,
- bottleneck_tensor)
-
-
- evaluation_step, prediction = add_evaluation_step(
- final_tensor, ground_truth_input)
-
-
- merged = tf.summary.merge_all()
- train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
- sess.graph)
- validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/validation')
-
-
- init = tf.global_variables_initializer()
- sess.run(init)
-
-
- for i in range(FLAGS.how_many_training_steps):
-
- if do_distort_images:
- train_bottlenecks, train_ground_truth = get_random_distorted_bottlenecks(
- sess, image_lists, FLAGS.train_batch_size, 'training',
- FLAGS.image_dir, distorted_jpeg_data_tensor,
- distorted_image_tensor, resized_image_tensor, bottleneck_tensor)
- else:
- train_bottlenecks, train_ground_truth, _ = get_random_cached_bottlenecks(
- sess, image_lists, FLAGS.train_batch_size, 'training',
- FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor)
-
- train_summary, _ = sess.run([merged, train_step],
- feed_dict={bottleneck_input: train_bottlenecks,
- ground_truth_input: train_ground_truth})
- train_writer.add_summary(train_summary, i)
-
-
- is_last_step = (i + 1 == FLAGS.how_many_training_steps)
- if (i % FLAGS.eval_step_interval) == 0 or is_last_step:
- train_accuracy, cross_entropy_value = sess.run(
- [evaluation_step, cross_entropy],
- feed_dict={bottleneck_input: train_bottlenecks,
- ground_truth_input: train_ground_truth})
- print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i,
- train_accuracy * 100))
- print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i,
- cross_entropy_value))
- validation_bottlenecks, validation_ground_truth, _ = (
- get_random_cached_bottlenecks(
- sess, image_lists, FLAGS.validation_batch_size, 'validation',
- FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor))
-
- validation_summary, validation_accuracy = sess.run(
- [merged, evaluation_step],
- feed_dict={bottleneck_input: validation_bottlenecks,
- ground_truth_input: validation_ground_truth})
- validation_writer.add_summary(validation_summary, i)
- print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %
- (datetime.now(), i, validation_accuracy * 100,
- len(validation_bottlenecks)))
-
-
- test_bottlenecks, test_ground_truth, test_filenames = (
- get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size,
- 'testing', FLAGS.bottleneck_dir,
- FLAGS.image_dir, jpeg_data_tensor,
- bottleneck_tensor))
- test_accuracy, predictions = sess.run(
- [evaluation_step, prediction],
- feed_dict={bottleneck_input: test_bottlenecks,
- ground_truth_input: test_ground_truth})
- print('Final test accuracy = %.1f%% (N=%d)' % (
- test_accuracy * 100, len(test_bottlenecks)))
-
- if FLAGS.print_misclassified_test_images:
- print('=== MISCLASSIFIED TEST IMAGES ===')
- for i, test_filename in enumerate(test_filenames):
- if predictions[i] != test_ground_truth[i].argmax():
- print('%70s %s' % (test_filename,
- list(image_lists.keys())[predictions[i]]))
-
-
- output_graph_def = graph_util.convert_variables_to_constants(
- sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
- with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
- f.write(output_graph_def.SerializeToString())
- with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
- f.write('\n'.join(image_lists.keys()) + '\n')
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--image_dir',
- type=str,
- default='',
- help='Path to folders of labeled images.'
- )
- parser.add_argument(
- '--output_graph',
- type=str,
- default='/tmp/output_graph.pb',
- help='Where to save the trained graph.'
- )
- parser.add_argument(
- '--output_labels',
- type=str,
- default='/tmp/output_labels.txt',
- help='Where to save the trained graph\'s labels.'
- )
- parser.add_argument(
- '--summaries_dir',
- type=str,
- default='/tmp/retrain_logs',
- help='Where to save summary logs for TensorBoard.'
- )
- parser.add_argument(
- '--how_many_training_steps',
- type=int,
- default=4000,
- help='How many training steps to run before ending.'
- )
- parser.add_argument(
- '--learning_rate',
- type=float,
- default=0.01,
- help='How large a learning rate to use when training.'
- )
- parser.add_argument(
- '--testing_percentage',
- type=int,
- default=10,
- help='What percentage of images to use as a test set.'
- )
- parser.add_argument(
- '--validation_percentage',
- type=int,
- default=10,
- help='What percentage of images to use as a validation set.'
- )
- parser.add_argument(
- '--eval_step_interval',
- type=int,
- default=10,
- help='How often to evaluate the training results.'
- )
- parser.add_argument(
- '--train_batch_size',
- type=int,
- default=100,
- help='How many images to train on at a time.'
- )
- parser.add_argument(
- '--test_batch_size',
- type=int,
- default=-1,
- help=
-
-
-
-
-
- )
- parser.add_argument(
- '--validation_batch_size',
- type=int,
- default=100,
- help=
-
-
-
-
-
-
-
- )
- parser.add_argument(
- '--print_misclassified_test_images',
- default=False,
- help=
-
- ,
- action='store_true'
- )
- parser.add_argument(
- '--model_dir',
- type=str,
- default='/tmp/imagenet',
- help=
-
-
-
-
- )
- parser.add_argument(
- '--bottleneck_dir',
- type=str,
- default='/tmp/bottleneck',
- help='Path to cache bottleneck layer values as files.'
- )
- parser.add_argument(
- '--final_tensor_name',
- type=str,
- default='final_result',
- help=
-
-
- )
- parser.add_argument(
- '--flip_left_right',
- default=False,
- help=
-
- ,
- action='store_true'
- )
- parser.add_argument(
- '--random_crop',
- type=int,
- default=0,
- help=
-
-
-
- )
- parser.add_argument(
- '--random_scale',
- type=int,
- default=0,
- help=
-
-
-
- )
- parser.add_argument(
- '--random_brightness',
- type=int,
- default=0,
- help=
-
-
-
- )
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
参数说明:
--image_dir 标签图像文件夹的路径
--output_graph 训练的图像保存的位置
--output_labels 训练的图像的标签保存的位置
--summaries_dir TensorBoard的日志摘要的保存位置
--how_many_training_steps 训练结束前运行的训练步数
--learning_rate训练时使用的学习率大小
--testing_percentage 使用图像作为测试集的百分比
--validation_percentage使用图像作为验证集的百分比
--eval_step_interval 训练结果评估的时间间隔
--train_batch_size 一次训练的图像的数量
--test_batch_size 测试图像的数量。此测试集仅使用一次,以评估训练完成后模型的最终精度。值为-1时使用整个测试集,会在运行时得到更稳定结果。
--validation_batch_size在评价批次中使用的图像数量。此验证集比测试集使用得多,是模型在训练过程中准确度如何的一个早期的指标。值为-1时使用整个验证集,从而在训练迭代时得到更稳定的结果,但在大的训练集中可能会变慢。
--print_misclassified_test_images是否打印输出所有错误分类的测试图像列表。
--model_dir classify_image_graph_def.pb,imagenet_synset_to_human_label_map.txt和imagenet_2012_challenge_label_map_proto.pbtxt的路径
--bottleneck_dir 缓存的瓶颈层值的文件路径
--final_tensor_name 在重新训练的图像中输出的分类层的名字
--flip_left_right是否随机水平翻转训练图像的一半
--random_crop 训练图像随机修剪的边缘百分比大小
--random_scale 训练图像随机缩放的尺寸百分比大小
--random_brightness训练图像输入像素上下的随机百分比大小
首先下载tensorflow源代码:
|
$
git
clone
https
:
/
/
github
.com
/
tensorflow
/
tensorflow
$
git
checkout
r0
.
11
# checkout对应已安装的Tensorflow版本
|
在retrain自己的图像分类器之前,我们先来测试一下Google的Inception模型:
|
$
cd
tensorflow
/
models
/
image
/
imagenet
/
$
python
classify_image
.py
--
image_file
~
/
Desktop
/
bigcat
.jpg
# 自动下载100多M的模型文件
# 参数的解释, 查看源码文件
|
使用examples中的image_retraining。
训练:
|
$
python
tensorflow
/
tensorflow
/
examples
/
image_retraining
/
retrain
.py
--
bottleneck_dir
bottleneck
--
how_many_training
_steps
4000
--
model_dir
model
--
output_graph
output_graph
.pb
--
output_labels
output_labels
.txt
--
image_dir
girl_types
/
|
参数解释参考retrain.py源文件。
大概训练了半个小时:
生成的模型文件和labels文件:
使用训练好的模型:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
import
tensorflow
as
tf
import
sys
# 命令行参数,传入要判断的图片路径
image_file
=
sys
.argv
[
1
]
#print(image_file)
# 读取图像
image
=
tf
.gfile
.FastGFile
(
image_file
,
'rb'
)
.read
(
)
# 加载图像分类标签
labels
=
[
]
for
label
in
tf
.gfile
.GFile
(
"output_labels.txt"
)
:
labels
.append
(
label
.rstrip
(
)
)
# 加载Graph
with
tf
.gfile
.FastGFile
(
"output_graph.pb"
,
'rb'
)
as
f
:
graph_def
=
tf
.GraphDef
(
)
graph_def
.ParseFromString
(
f
.read
(
)
)
tf
.import_graph_def
(
graph_def
,
name
=
''
)
with
tf
.Session
(
)
as
sess
:
softmax_tensor
=
sess
.graph
.get_tensor_by_name
(
'final_result:0'
)
predict
=
sess
.run
(
softmax_tensor
,
{
'DecodeJpeg/contents:0'
:
image
}
)
# 根据分类概率进行排序
top
=
predict
[
0
]
.argsort
(
)
[
-
len
(
predict
[
0
]
)
:
]
[
::
-
1
]
for
index
in
top
:
human_string
=
labels
[
index
]
score
=
predict
[
0
]
[
index
]
print
(
human_string
,
score
)
|
执行结果:
参考:http://blog.csdn.net/daydayup_668819/article/details/68060483
http://blog.topspeedsnail.com/archives/10685