PocketFlow文档(二)

Table of Contents

 

Tutorial

Prepare the Data

Prepare the Pre-trained Model

Train the Compressed Model

Export to TensorFlow Lite

Deploy on Mobile Devices


Tutorial

在本教程中,我们将演示如何压缩卷积神经网络,并将压缩后的模型导出为*.tflite。用于在移动设备上部署的tflite文件。我们在这里使用的模型是一个为ImageNet分类任务训练的18层残差网络(记为“ResNet-18”)。我们将使用discrimination-aware channel pruning algorithm (Zhuang et al., NIPS '18)对其进行压缩,以减少网络中用于加速的卷积通道数量。


Prepare the Data

首先,我们需要将ImageNet数据集(ILSVRC-12)转换为TensorFlow的TFRecord文件格式。您可以按照这里的数据准备指南下载完整的数据集并将其转换为TFRecord文件。之后,您应该能够在data目录中找到1024个训练文件和128个验证文件,如下图所示:

# training files
train-00000-of-01024
train-00001-of-01024
...
train-01023-of-01024

# validation files
validation-00000-of-00128
validation-00001-of-00128
...
validation-00127-of-00128

Prepare the Pre-trained Model

The discrimination-aware channel pruning algorithm需要预先提供一个训练好的未压缩模型,这样就可以用warm-start对信道修剪的模型进行训练。您可以从这里下载一个预先训练好的模型,然后将文件解压缩到models子目录中。

或者,您可以使用FullPrecLearner使用以下命令从头开始训练未压缩的全精度模型(选择适合您的模式):

# local mode with 1 GPU
$ ./scripts/run_local.sh nets/resnet_at_ilsvrc12_run.py

# docker mode with 8 GPUs
$ ./scripts/run_docker.sh nets/resnet_at_ilsvrc12_run.py -n=8

# seven mode with 8 GPUs
$ ./scripts/run_seven.sh nets/resnet_at_ilsvrc12_run.py -n=8

在培训过程之后,您应该能够找到位于PocketFlow的主目录中的models子目录中的模型文件。

Train the Compressed Model

现在,我们可以使用the discrimination-aware channel pruning algorithm来训练一个压缩模型,该算法由DisChnPrunedLearner实现。假设您现在处于PocketFlow的主目录中,可以使用以下命令启动模型压缩的训练过程(选择适合您的模式):

# local mode with 1 GPU
$ ./scripts/run_local.sh nets/resnet_at_ilsvrc12_run.py \
    --learner dis-chn-pruned

# docker mode with 8 GPUs
$ ./scripts/run_docker.sh nets/resnet_at_ilsvrc12_run.py -n=8 \
    --learner dis-chn-pruned

# seven mode with 8 GPUs
$ ./scripts/run_seven.sh nets/resnet_at_ilsvrc12_run.py -n=8 \
    --learner dis-chn-pruned

让我们以本地模式的执行命令为例。在这个命令中,run_local.sh是一个shell脚本,它使用用户提供的参数执行指定的Python脚本。在这里,我们运行nets/resnet_at_ilsvrc12_run.py,它是ImageNet数据集上ResNet模型的执行脚本。之后,我们使用--learner dis-chn-pruned来指定应该使用DisChnPrunedLearner进行模型压缩。您还可以通过指定相应的学习者名称来使用其他学习者。下面是PocketFlow中可用的学习者的完整列表:

Learner nameLearner classNote
full-precFullPrecLearnerNo model compression
channelChannelPrunedLearnerChannel pruning with LASSO-based channel selection (He et al., 2017)
dis-chn-prunedDisChnPrunedLearnerDiscrimination-aware channel pruning (Zhuang et al., 2018)
weight-sparseWeightSparseLearnerWeight sparsification with dynamic pruning schedule (Zhu & Gupta, 2017)
uniformUniformQuantLearnerWeight quantization with uniform reconstruction levels (Jacob et al., 2018)
uniform-tfUniformQuantTFLearnerWeight quantization with uniform reconstruction levels and TensorFlow APIs
non-uniformNonUniformQuantLearnerWeight quantization with non-uniform reconstruction levels (Han et al., 2016)

本地模式在训练过程中只使用1 GPU,大约需要20-30个小时来完成。这可以通过docker和seven模式下的多gpu培训来加速,这可以通过在指定的Python脚本后面添加-n=x来实现,其中x是要使用的gpu数量。

您可以选择传递一些额外的参数来定制培训过程。对于the discrimination-aware channel pruning algorithm,一些关键参数如下:

NameDefinitionDefault Value
enbl_dstEnable training with distillation lossFalse
dcp_prune_ratioDCP algorithm's pruning ratio0.5

您可以通过在执行命令的末尾追加自定义参数来覆盖默认值。例如,下面的命令:

$ ./scripts/run_local.sh nets/resnet_at_ilsvrc12_run.py \
    --learner dis-chn-pruned \
    --enbl_dst \
    --dcp_prune_ratio 0.75

 要求DisChnPrunedLearner实现整体修剪比0.75,训练过程以the distillation loss.为代价进行。因此,压缩后的模型中每个卷积层的信道数将是原始卷积层的四分之一。

培训过程完成后,您应该能够找到在PocketFlow的主目录中创建的名为models_dcp_eval的子目录。这个子目录包含定义压缩模型的所有文件,我们将把它们导出到TensorFlow Lite格式的模型文件中,以便在下一节中部署。

Export to TensorFlow Lite

TensorFlow的checkpoint files文件不能直接用于移动设备上的部署。相反,我们需要首先将它们转换为单个*.tflite。TensorFlow Lite解释器支持的tflite文件。针对基于通道剪枝的模型压缩算法,如ChannelPruningLearner和DisChnPrunedLearner,我们编写了一个模型转换脚本tools/conversion/export_pb_tflite_models.py。从TensorFlow的checkpoint文件生成TF-Lite模型。

To convert checkpoint files into a *.tflite file, use the following command:

# convert checkpoint files into a *.tflite model
$ python tools/conversion/export_pb_tflite_models.py \
    --model_dir models_dcp_eval

在上面的命令中,我们指定了包含在前面的培训过程中生成的checkpoint 文件的模型目录。转换脚本自动检测哪些通道可以安全修剪,然后生成一个轻量级的压缩模型。生成的TensorFlow Lite文件也放在models_dcp_eval目录中,名为model_transformed.tflite。

Deploy on Mobile Devices

将压缩模型导出为TensorFlow Lite文件格式后,您可以按照官方指南从它创建一个Android演示应用程序。基本上,这个demo App使用TensorFlow Lite模型对摄像头拍摄的图像进行连续分类,所有的计算都是在移动设备上实时执行的。

使用model_transformed.tflite模型文件,您需要将其放在assert目录中,并创建一个名为ImageClassifierFloatResNet的Java类,以使用此模型进行分类。下面是示例代码,它是从ImageClassifierFloatInception修改的。官方演示项目中使用的java:

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package com.example.android.tflitecamerademo;

import android.app.Activity;

import java.io.IOException;

/**
 * This classifier works with the ResNet-18 model.
 * It applies floating point inference rather than using a quantized model.
 */
public class ImageClassifierFloatResNet extends ImageClassifier {

  /**
   * The ResNet requires additional normalization of the used input.
   */
  private static final float IMAGE_MEAN_RED = 123.58f;
  private static final float IMAGE_MEAN_GREEN = 116.779f;
  private static final float IMAGE_MEAN_BLUE = 103.939f;

  /**
   * An array to hold inference results, to be feed into Tensorflow Lite as outputs.
   * This isn't part of the super class, because we need a primitive array here.
   */
  private float[][] labelProbArray = null;

  /**
   * Initializes an {@code ImageClassifier}.
   *
   * @param activity
   */
  ImageClassifierFloatResNet(Activity activity) throws IOException {
    super(activity);
    labelProbArray = new float[1][getNumLabels()];
  }

  @Override
  protected String getModelPath() {
    return "model_transformed.tflite";
  }

  @Override
  protected String getLabelPath() {
    return "labels_imagenet_slim.txt";
  }

  @Override
  protected int getImageSizeX() {
    return 224;
  }

  @Override
  protected int getImageSizeY() {
    return 224;
  }

  @Override
  protected int getNumBytesPerChannel() {
    // a 32bit float value requires 4 bytes
    return 4;
  }

  @Override
  protected void addPixelValue(int pixelValue) {
    imgData.putFloat(((pixelValue >> 16) & 0xFF) - IMAGE_MEAN_RED);
    imgData.putFloat(((pixelValue >> 8) & 0xFF) - IMAGE_MEAN_GREEN);
    imgData.putFloat((pixelValue & 0xFF) - IMAGE_MEAN_BLUE);
  }

  @Override
  protected float getProbability(int labelIndex) {
    return labelProbArray[0][labelIndex];
  }

  @Override
  protected void setProbability(int labelIndex, Number value) {
    labelProbArray[0][labelIndex] = value.floatValue();
  }

  @Override
  protected float getNormalizedProbability(int labelIndex) {
    // TODO the following value isn't in [0,1] yet, but may be greater. Why?
    return getProbability(labelIndex);
  }

  @Override
  protected void runInference() {
    tflite.run(imgData, labelProbArray);
  }
}

之后,您需要更改Camera2BasicFragment.java中使用的图像分类器类。找到名为onActivityCreated的函数并更改其内容,如下所示。现在您将能够使用压缩的ResNet-18模型对您的手机上的对象进行实时分类。

/** Load the model and labels. */
@Override
public void onActivityCreated(Bundle savedInstanceState) {
  super.onActivityCreated(savedInstanceState);
  try {
    classifier = new ImageClassifierFloatResNet(getActivity());
  } catch (IOException e) {
    Log.e(TAG, "Failed to initialize an image classifier.", e);
  }
  startBackgroundThread();
}

 

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值