[tensorflow]图片新类别再训练-花分类-代码整理

本文详细介绍了如何使用TensorFlow对图片进行再训练,创建新的花类别预测模型。首先,数据集被划分为训练集、测试集和验证集,然后加载预训练模型,计算图片的特征向量并缓存。接着,构建全连接层并训练新模型,最终保存预测模型。在模型预测部分,文章涵盖了加载预测模型,处理预测图片并进行多张图片的类别预测。
摘要由CSDN通过智能技术生成

目录

一、新类别模型的再训练

1、图片加载,并将数据集划分为训练集、测试集、验证集,比例分别为80%,10%,10%(默认)

2、加载hub某个模型,拉取模型信息,创建图

3、计算所有图片的bottlenecks(特征向量),并缓存

4、新类别模型训练

5、新类别预测模型保存

二、模型预测

1、预测模型加载

2、加载预测图片(图片进行解码和剪裁)

3、多张图片类别预测


原网址:https://www.tensorflow.org/hub/tutorials/image_retraining

一、新类别模型的再训练

预定义-第三方包

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import collections
from datetime import datetime
import hashlib
import os.path
import random
import re
import sys

import numpy as np
import tensorflow as tf
import tensorflow_hub as hub

tf.logging.set_verbosity(tf.logging.INFO)

预定义--文件路径 

FLAGS.image_dir = 'E:\\DataMining\\tensorflow\\google-ImageClassification\\flower_photos'
'''
Path to folders of labeled images.
'''

FLAGS.output_graph = 'E:\\DataMining\\tensorflow\\google-ImageClassification\\flower\\graph\\output_graph.pb'
'''
Where to save the trained graph.
'''

FLAGS.intermediate_output_graphs_dir = 'E:\\DataMining\\tensorflow\\google-ImageClassification\\flower\\intermediate_graph\\'
'''
Where to save the intermediate graphs.
'''

FLAGS.intermediate_store_frequency = 0
"""\
     How many steps to store intermediate graph. If "0" then will not
     store.\
  """
FLAGS.output_labels = 'E:\\DataMining\\tensorflow\\google-ImageClassification\\flower\\labels\\output_labels.txt'
'''
Where to save the trained graph\'s labels.
'''

FLAGS.summaries_dir = 'E:\\DataMining\\tensorflow\\google-ImageClassification\\flower\\retrain_logs'
'''
Where to save summary logs for TensorBoard.
'''
FLAGS.how_many_training_steps = 4000
'''How many training steps to run before ending.'''

FLAGS.learning_rate = t=0.01
'''How large a learning rate to use when training.'''

FLAGS.testing_percentage = 10
'''What percentage of images to use as a test set.'''

FLAGS.validation_percentage = 10
'''What percentage of images to use as a validation set.'''

FLAGS.eval_step_interval = 10
'''How often to evaluate the training results.'''

FLAGS.train_batch_size = 100
'''How many images to train on at a time.'''

FLAGS.test_batch_size = -1
"""\
  How many images to test on. This test set is only used once, to evaluate
  the final accuracy of the model after training completes.
  A value of -1 causes the entire test set to be used, which leads to more
  stable results across runs.\
  """

FLAGS.validation_batch_size = 100
"""\
  How many images to use in an evaluation batch. This validation set is
  used much more often than the test set, and is an early indicator of how
  accurate the model is during training.
  A value of -1 causes the entire validation set to be used, which leads to
  more stable results across training iterations, but may be slower on large
  training sets.\
  """

FLAGS.print_misclassified_test_images = False
"""\
  Whether to print out a list of all misclassified test images.\
  """

FLAGS.bottleneck_dir = 'E:\\DataMining\\tensorflow\\google-ImageClassification\\flower\\bottleneck'
'Path to cache bottleneck layer values as files.'

FLAGS.final_tensor_name = 'final_result'
"""\
  The name of the output classification layer in the retrained graph.\
  """

FLAGS.flip_left_right = False
"""\
  Whether to randomly flip half of the training images horizontally.\
  """

FLAGS.random_crop = 0
"""\
  A percentage determining how much of a margin to randomly crop off the
  training images.\
  """

FLAGS.random_scale = 0
"""\
  A percentage determining how much to randomly scale up the size of the
  training images by.\
  """

FLAGS.random_brightness = 0
"""\
  A percentage determining how much to randomly multiply the training image
  input pixels up or down by.\
  """

FLAGS.tfhub_module = 'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1'
""" Which TensorFlow Hub module to use. For more options, search https://tfhub.dev for image feature vector modules. """

FLAGS.saved_model_dir = 'E:\\DataMining\\tensorflow\\google-ImageClassification\\flower\\exportedGraph'
""" Where to save the exported graph."""

预定义--全局变量

# The location where variable checkpoints will be stored.
CHECKPOINT_NAME = 'E:\\DataMining\\tensorflow\\google-ImageClassification\\flower\\retrain_checkpoint\\'

# A module is understood as instrumented for quantization with TF-Lite
# if it contains any of these ops.
FAKE_QUANT_OPS = ('FakeQuantWithMinMaxVars','FakeQuantWithMinMaxVarsPerChannel')

1、图片加载,并将数据集划分为训练集、测试集、验证集,比例分别为80%,10%,10%(默认)

def create_image_lists(image_dir, testing_percentage, validation_percentage):

    """Builds a list of training images from the file system.

    Analyzes the sub folders in the image directory, splits them into stable
    training, testing, and validation sets, and returns a data structure
    describing the lists of images for each label and their paths.

    Args:
    image_dir: String path to a folder containing subfolders of images.
    testing_percentage: Integer percentage of the images to reserve for tests.
    validation_percentage: Integer percentage of images reserved for validation.

    Returns:
    An OrderedDict containing an entry for each label subfolder, with images
    split into training, testing, and validation sets within each label.
    The order of items defines the class indices.
    """

    if not tf.gfile.Exists(image_dir):
        tf.logging.error("Image directory '" + image_dir + "' not found.")
        return None

    result = collections.OrderedDict()
    sub_dirs = sorted(x[0] for x in tf.gfile.Walk(image_dir))
    # The root directory comes first, so skip it.
    is_root_dir = True
    for sub_dir in sub_dirs:
        if is_root_dir:
          is_root_dir = False
          continue
            
        extensions = sorted(set(os.path.normcase(ext)  for ext in ['JPEG', 'JPG', 'jpeg', 'jpg', 'png']))
        file_list = []
        dir_name = os.path.basename(sub_dir)
        if dir_name == image_dir:
            continue
        #tf.logging.info("Looking for images in '" + dir_name + "'")
        for extension in extensions:
            file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
            file_list.extend(tf.gfile.Glob(file_glob))
        if not file_list:
            tf.logging.warning('No files found')
            continue
        if len(file_list) < 20:
            tf.logging.warning( 'WARNING: Folder has less than 20 images, which may cause issues.')
        elif len(file_list) > MAX_NUM_IMAGES_PER_CLASS:
            tf.logging.warning('WARNING: Folder {} has more than {} images. Some images will never be selected.'.format(dir_name, MAX_NUM_IMAGES_PER_CLASS))
        label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
        
        training_images = []
        testing_images = []
        validation_images = []
        for file_name in file_list:
            base_name = os.path.basename(file_name)
            hash_name = re.sub(r'_nohash_.*$', '', file_name)
            hash_name_hashed = hashlib.sha1(tf.compat.as_bytes(hash_name)).hexdigest()
            percentage_hash = ((int(hash_name_hashed, 16) %(MAX_NUM_IMAGES_PER_CLASS + 1)) *(100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentage_hash < validation_percentage:
                validation_images.append(base_name)
            elif percentage_hash < (testing_percentage + validation_percentage):
                testing_images.append(base_name)
            else:
                training_images.append(base_name)
                
        result[label_name] = {
            'dir': dir_name,
            'training': training_images,
            'testing': testing_images,
            'validation': validation_images,
        }
    return result
def main(_):
    #获取商品图片
    image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, FLAGS.validation_percentage)
    class_count = len(image_lists.keys())

image_lists数据格式为:

([('daisy',
       {'dir': 'daisy',
        'testing':['100080576_f52e8ee070_n.jpg','10172379554_b296050f82_n.jpg',...],
        'training':['10140303196_b88d3d6cec.jpg',...],
        'validation':['102841525_bd6628ae3c.jpg',]}),
  ('dandelion',
        {'dir': 'dandelion',
         'testing': ['10294487385_92a0676c7d_m.jpg',...],
         'training':['10140303196_b88d3d6cec.jpg',...],
         'validation':['102841525_bd6628ae3c.jpg',]}),

2、加载hub某个模型,拉取模型信息,创建图

#加载tensorflow中的某个模型,并拉取模型信息
def create_module_graph(module_spec):
    """Creates a graph and loads Hub Module into it.

    Args:
    module_spec: the hub.ModuleSpec for the image module being used.

    Returns:
    graph: the tf.Graph that was created.
    bottleneck_tensor: the bottleneck values output by the module.
    resized_input_tensor: the input images, resized as expected by the module.
    wants_quantization: a boolean, whether the module has been instrumented
      with fake quantization ops.
    """
    height, width = hub.get_expected_image_size(module_spec)
    with tf.Graph().as_default() as graph:
        resized_input_tensor = tf.placeholder(tf.float32, [None, height, width, 3])
        m = hub.Module(module_spec)
        bottleneck_tensor = m(resized_input_tensor)
        wants_quantization = any(node.op in FAKE_QUANT_OPS for node in graph.as_graph_def().node)
        
    return graph, bottleneck_tensor, resized_input_tensor, wants_quantization



#对图片进行解码和调整大小
def add_jpeg_decoding(module_spec):
    """Adds operations that perform JPEG decoding and resiz
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值