【Tensorflow object detection API + 微软NNI】图像分类问题完成自动调参,进一步提升模型准确率!

1. 背景&目标

利用Tensorflow object detection API开发并训练图像分类模型(例如,Mobilenetv2等),自己直接手动调参,对于模型的准确率提不到极致,利用微软NNI自动调参工具进行调参,进一步提升准确率。

2. 方法

  1. 关于Tensorflow object detection API开发并训练图像分类模型详见这篇博客:【tensorflow-slim】使用tensroflow-slim训练自己的图像分类数据集+冻成pb文件+预测(本文针对场景分类,手把手详细教学!)
  2. 关于微软NNI工具的使用参考官方网站即可:Neural Network Intelligence
  3. 具体代码实现,直接将下列代码放在一个.py文件中,文件名为nni_train_eval_image_classifier.py,放在Tensorflow object detection API官方代码仓的models-master/research/slim目录下,再依照NNI工具官方使用方式使用即可。
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic training script that trains a model using a given dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import nni
import argparse
import logging
import time

import tensorflow as tf
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib import slim as contrib_slim
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import timeline
from tensorflow.python.lib.io import file_io

from datasets import dataset_factory
from deployment import model_deploy
from nets import nets_factory
from preprocessing import preprocessing_factory

slim = contrib_slim
logger = logging.getLogger('mobilenetv2_AutoML')

def _configure_learning_rate(num_samples_per_epoch, global_step):
  """Configures the learning rate.

  Args:
    num_samples_per_epoch: The number of samples in each epoch of training.
    global_step: The global_step tensor.

  Returns:
    A `Tensor` representing the learning rate.

  Raises:
    ValueError: if
  """
  # Note: when num_clones is > 1, this will actually have each clone to go
  # over each epoch FLAGS.num_epochs_per_decay times. This is different
  # behavior from sync replicas and is expected to produce different results.
  steps_per_epoch = num_samples_per_epoch / params['batch_size']

  decay_steps = int(steps_per_epoch * params['num_epochs_per_decay'])

  if params['learning_rate_decay_type'] == 'exponential':
    learning_rate = tf.train.exponential_decay(
        params['learning_rate'],
        global_step,
        decay_steps,
        params['learning_rate_decay_factor'],
        staircase=True,
        name='exponential_decay_learning_rate')
  else:
    raise ValueError('learning_rate_decay_type [%s] was not recognized' %
                     params['learning_rate_decay_type'])
  return learning_rate


def _configure_optimizer(learning_rate):
  """Configures the optimizer used for training.

  Args:
    learning_rate: A scalar or `Tensor` learning rate.

  Returns:
    An instance of an optimizer.

  Raises:
    ValueError: if FLAGS.optimizer is not recognized.
  """
  if params['optimizer'] == 'adadelta':
    optimizer = tf.train.AdadeltaOptimizer(
        learning_rate,
        rho=0.95,
        epsilon=1.0)
  elif params['optimizer'] == 'adagrad':
    optimizer = tf.train.AdagradOptimizer(
        learning_rate,
        initial_accumulator_value=0.1)
  elif params['optimizer'] == 'adam':
    optimizer = tf.train.AdamOptimizer(
        learning_rate,
        beta1=0.9,
        beta2=0.999,
        epsilon=1.0)
  elif params['optimizer'] == 'rmsprop':
    optimizer = tf.train.RMSPropOptimizer(
        learning_rate,
        decay=0.9,
        momentum=0.9,
        epsilon=1.0)
  elif params['optimizer'] == 'sgd':
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  else:
    raise ValueError('Optimizer [%s] was not recognized' % params['optimizer'])
  return optimizer


def _get_init_fn():
  """Returns a function run by the chief worker to warm-start the training.

  Note that the init_fn is only run when initializing the model during the very
  first global step.

  Returns:
    An init function run by the supervisor.
  """
  if params['checkpoint_path'] is None:
    return None

  # Warn the user if a checkpoint exists in the train_dir. Then we'll be
  # ignoring the checkpoint anyway.
  if tf.train.latest_checkpoint(os.path.join(params['train_dir'], nni.get_trial_id())):
    tf.logging.info(
        'Ignoring --checkpoint_path because a checkpoint already exists in %s'
        % os.path.join(params['train_dir'], nni.get_trial_id()))
    return None

  exclusions = []


  # TODO(sguada) variables.filter_variables()
  variables_to_restore = []
  for var in slim.get_model_variables():
    for exclusion in exclusions:
      if var.op.name.startswith(exclusion):
        break
    else:
      variables_to_restore.append(var)

  if tf.gfile.IsDirectory(params['checkpoint_path']):
    checkpoint_path = tf.train.latest_checkpoint(params['checkpoint_path'])
  else:
    checkpoint_path = params['checkpoint_path']

  tf.logging.info('Fine-tuning from %s' % checkpoint_path)

  return slim.assign_from_checkpoint_fn(
      checkpoint_path,
      variables_to_restore,
      ignore_missing_vars=False)


def _get_variables_to_train():
  """Returns a list of variables to train.

  Returns:
    A list of variables to train by the optimizer.
  """
  if params['trainable_scopes'] is None:
    return tf.trainable_variables()
  else:
    scopes = [scope.strip() for scope in params['trainable_scopes'].split(',')]

  variables_to_train = []
  for scope in scopes:
    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
    variables_to_train.extend(variables)
  return variables_to_train


def train_step_for_print(sess, train_op, global_step, train_step_kwargs):
  """Function that takes a gradient step and specifies whether to stop.
  Args:
    sess: The current session.
    train_op: An `Operation` that evaluates the gradients and returns the total
      loss.
    global_step: A `Tensor` representing the global training step.
    train_step_kwargs: A dictionary of keyword arguments.
  Returns:
    The total loss and a boolean indicating whether or not to stop training.
  Raises:
    ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
  """
  start_time = time.time()

  trace_run_options = None
  run_metadata = None
  if 'should_trace' in train_step_kwargs:
    if 'logdir' not in train_step_kwargs:
      raise ValueError('logdir must be present in train_step_kwargs when '
                       'should_trace is present')
    if sess.run(train_step_kwargs['should_trace']):
      trace_run_options = config_pb2.RunOptions(
          trace_level=config_pb2.RunOptions.FULL_TRACE)
      run_metadata = config_pb2.RunMetadata()

  total_loss, np_global_step = sess.run([train_op, global_step],
                                        options=trace_run_options,
                                        run_metadata=run_metadata)
  time_elapsed = time.time() - start_time

  if run_metadata is not None:
    tl = timeline.Timeline(run_metadata.step_stats)
    trace = tl.generate_chrome_trace_format()
    trace_filename = os.path.join(train_step_kwargs['logdir'],
                                  'tf_trace-%d.json' % np_global_step)
    logging.info('Writing trace to %s', trace_filename)
    file_io.write_string_to_file(trace_filename, trace)
    if 'summary_writer' in train_step_kwargs:
      train_step_kwargs['summary_writer'].add_run_metadata(
          run_metadata, 'run_metadata-%d' % np_global_step)

  if 'should_log' in train_step_kwargs:
    if sess.run(train_step_kwargs['should_log']):
      logging.info('global step %d: loss = %.4f (%.3f sec/step)',
                   np_global_step, total_loss, time_elapsed)

  if 'should_stop' in train_step_kwargs:
    should_stop = sess.run(train_step_kwargs['should_stop'])
  else:
    should_stop = False
  nni.report_intermediate_result(total_loss)
  return total_loss, should_stop


def main(params):
  if not params['dataset_dir']:
    raise ValueError('You must supply the dataset directory with --dataset_dir')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    #######################
    # Config model_deploy #
    #######################
    deploy_config = model_deploy.DeploymentConfig(
        num_clones=2,
        clone_on_cpu=False,
        replica_id=0,
        num_replicas=1,
        num_ps_tasks=0)

    # Create global_step
    with tf.device(deploy_config.variables_device()):
      global_step = slim.create_global_step()

    ######################
    # Select the dataset #
    ######################
    dataset = dataset_factory.get_dataset(
        params['dataset_name'], params['dataset_split_name'], params['dataset_dir'])

    ######################
    # Select the network #
    ######################
    network_fn = nets_factory.get_network_fn(
        params['model_name'],
        num_classes=(dataset.num_classes - 0 ),
        weight_decay=0.00004,
        is_training=True)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = params['model_name']
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=True,
        use_grayscale=False)

    ##############################################################
    # Create a dataset provider that loads data from the dataset #
    ##############################################################
    with tf.device(deploy_config.inputs_device()):
      provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=params['num_readers'],
          common_queue_capacity=20 * params['batch_size'],
          common_queue_min=10 * params['batch_size'])
      [image, label] = provider.get(['image', 'label'])
      label -=0

      train_image_size = params['train_image_size'] or network_fn.default_image_size

      image = image_preprocessing_fn(image, train_image_size, train_image_size)

      images, labels = tf.train.batch(
          [image, label],
          batch_size=params['batch_size'],
          num_threads=params['num_preprocessing_threads'],
          capacity=5 * params['batch_size'])
      labels = slim.one_hot_encoding(
          labels, dataset.num_classes - 0 )
      batch_queue = slim.prefetch_queue.prefetch_queue(
          [images, labels], capacity=2 * deploy_config.num_clones)

    ####################
    # Define the model #
    ####################
    def clone_fn(batch_queue):
      """Allows data parallelism by creating multiple clones of network_fn."""
      images, labels = batch_queue.dequeue()
      logits, end_points = network_fn(images)

      #############################
      # Specify the loss function #
      #############################
      if 'AuxLogits' in end_points:
        slim.losses.softmax_cross_entropy(
            end_points['AuxLogits'], labels,
            label_smoothing=0.0, weights=0.4,
            scope='aux_loss')
      slim.losses.softmax_cross_entropy(
          logits, labels, label_smoothing=0.0, weights=1.0)
      return end_points

    # Gather initial summaries.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
    first_clone_scope = deploy_config.clone_scope(0)
    # Gather update_ops from the first clone. These contain, for example,
    # the updates for the batch_norm variables created by network_fn.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)

    # Add summaries for end_points.
    end_points = clones[0].outputs
    for end_point in end_points:
      x = end_points[end_point]
      summaries.add(tf.summary.histogram('activations/' + end_point, x))
      summaries.add(tf.summary.scalar('sparsity/' + end_point,
                                      tf.nn.zero_fraction(x)))
      if len(x.shape) <4:
        continue



    # Add summaries for losses.
    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
      summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss))

    # Add summaries for variables.
    for variable in slim.get_model_variables():
      summaries.add(tf.summary.histogram(variable.op.name, variable))


    #########################################
    # Configure the optimization procedure. #
    #########################################
    with tf.device(deploy_config.optimizer_device()):
      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
      optimizer = _configure_optimizer(learning_rate)
      summaries.add(tf.summary.scalar('learning_rate', learning_rate))


    # Variables to train.
    variables_to_train = _get_variables_to_train()

    #  and returns a train_tensor and summary_op
    total_loss, clones_gradients = model_deploy.optimize_clones(
        clones,
        optimizer,
        var_list=variables_to_train)
    # Add total_loss to summary.
    summaries.add(tf.summary.scalar('total_loss', total_loss))

    # Create gradient updates.
    grad_updates = optimizer.apply_gradients(clones_gradients,
                                             global_step=global_step)
    update_ops.append(grad_updates)

    update_op = tf.group(*update_ops)
    with tf.control_dependencies([update_op]):
      train_tensor = tf.identity(total_loss, name='train_op')

    # Add the summaries from the first clone. These contain the summaries
    # created by model_fn and either optimize_clones() or _gather_clone_loss().
    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
                                       first_clone_scope))

    # Merge all summaries together.
    summary_op = tf.summary.merge(list(summaries), name='summary_op')

    ###########################
    # Kicks off the training. #
    ###########################
    final_loss = slim.learning.train(
        train_tensor,
        train_step_fn=train_step_for_print,
        logdir=os.path.join(params['train_dir'], nni.get_trial_id()),
        master=params['master'],
        is_chief=True,
        init_fn=_get_init_fn(),
        summary_op=summary_op,
        number_of_steps=params['max_number_of_steps'],
        log_every_n_steps=params['log_every_n_steps'],
        save_summaries_secs=params['save_summaries_secs'],
        save_interval_secs=params['save_interval_secs'],
        sync_optimizer=None)

    nni.report_final_result(final_loss)




def get_params():
    ''' Get parameters from command line '''
    parser = argparse.ArgumentParser()
    parser.add_argument("--master", type=str, default='', help='The address of the TensorFlow master to use.')
    parser.add_argument("--train_dir", type=str, default='/***/workspace/mobilenetssd/models-master/research/slim/sentiment_cnn/model/training2')
    parser.add_argument("--num_preprocessing_threads", type=int, default=4)
    parser.add_argument("--log_every_n_steps", type=int, default=10)
    parser.add_argument("--save_summaries_secs", type=int, default=600)
    parser.add_argument("--save_interval_secs", type=int, default=600)
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--batch_num", type=int, default=2000)

    parser.add_argument("--dataset_name", type=str, default="sendata3_train")
    parser.add_argument("--dataset_split_name", type=str, default='train')
    parser.add_argument("--dataset_dir", type=str, default='/***/workspace/mobilenetssd/models-master/research/slim/sentiment_cnn/dataset/data_3')
    parser.add_argument("--model_name", type=str, default='mobilenet_v2')

    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--train_image_size", type=int, default=80)
    parser.add_argument("--max_number_of_steps", type=int, default=60000)
    parser.add_argument("--trainable_scopes", type=str, default=None)

    parser.add_argument("--optimizer", type=str, default='adam')
    parser.add_argument("--learning_rate_decay_type", type=str, default='exponential')
    parser.add_argument("--learning_rate_decay_factor", type=float, default=0.8)

    parser.add_argument("--checkpoint_path", type=str, default=None)
    parser.add_argument("--num_readers", type=int, default=4)

    parser.add_argument("--num_epochs_per_decay", type=float, default=3.0)

    args, _ = parser.parse_known_args()
    return args

if __name__ == '__main__':
    try:
        # get parameters form tuner
        tuner_params = nni.get_next_parameter()
        logger.debug(tuner_params)
        params = vars(get_params())
        params.update(tuner_params)
        main(params)
    except Exception as exception:
        logger.exception(exception)
        raise

【注意】:需要依照自己设备、模型等实际情况,修改代码中的如下部分:

def get_params():
    ''' Get parameters from command line '''
    parser = argparse.ArgumentParser()
    parser.add_argument("--master", type=str, default='', help='The address of the TensorFlow master to use.')
    parser.add_argument("--train_dir", type=str, default='/***/workspace/mobilenetssd/models-master/research/slim/sentiment_cnn/model/training2')
    parser.add_argument("--num_preprocessing_threads", type=int, default=4)
    parser.add_argument("--log_every_n_steps", type=int, default=10)
    parser.add_argument("--save_summaries_secs", type=int, default=600)
    parser.add_argument("--save_interval_secs", type=int, default=600)
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--batch_num", type=int, default=2000)

    parser.add_argument("--dataset_name", type=str, default="sendata3_train")
    parser.add_argument("--dataset_split_name", type=str, default='train')
    parser.add_argument("--dataset_dir", type=str, default='/***/workspace/mobilenetssd/models-master/research/slim/sentiment_cnn/dataset/data_3')
    parser.add_argument("--model_name", type=str, default='mobilenet_v2')

    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--train_image_size", type=int, default=80)
    parser.add_argument("--max_number_of_steps", type=int, default=60000)
    parser.add_argument("--trainable_scopes", type=str, default=None)

    parser.add_argument("--optimizer", type=str, default='adam')
    parser.add_argument("--learning_rate_decay_type", type=str, default='exponential')
    parser.add_argument("--learning_rate_decay_factor", type=float, default=0.8)

    parser.add_argument("--checkpoint_path", type=str, default=None)
    parser.add_argument("--num_readers", type=int, default=4)

    parser.add_argument("--num_epochs_per_decay", type=float, default=3.0)

    args, _ = parser.parse_known_args()
    return args

3. 结果

进一步压榨你的图像分类模型准确率叭!
例如,针对私有数据集训练Mobilenetv2以前只能到90%的准确率,现在能到92.2857%
在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值