Tutorial
In this tutorial, we demonstrate how to compress a convolutional neural network and export the compressed model into a *.tflite file for deployment on mobile devices. The model we used here is a 18-layer residual network (denoted as "ResNet-18") trained for the ImageNet classification task. We will compress it with the discrimination-aware channel pruning algorithm (Zhuang et al., NIPS '18) to reduce the number of convolutional channels used in the network for speed-up.
Prepare the Data
To start with, we need to convert the ImageNet data set (ILSVRC-12) into TensorFlow's native TFRecord file format. You may follow the data preparation guide here to download the full data set and convert it into TFRecord files. After that, you should be able to find 1,024 training files and 128 validation files in the data directory, like this:
# training files
train-00000-of-01024
train-00001-of-01024
...
train-01023-of-01024
# validation files
validation-00000-of-00128
validation-00001-of-00128
...
validation-00127-of-00128
Prepare the Pre-trained Model
The discrimination-aware channel pruning algorithm requires a pre-trained uncompressed model provided in advance, so that a channel-pruned model can be trained with warm-start. You can download a pre-trained model from here, and then unzip files into the models
sub-directory.
Alternatively, you can train an uncompressed full-precision model from scratch using FullPrecLearner
with the following command (choose whatever mode that fits you):
# local mode with 1 GPU
$ ./scripts/run_local.sh nets/resnet_at_ilsvrc12_run.py
# docker mode with 8 GPUs
$ ./scripts/run_docker.sh nets/resnet_at_ilsvrc12_run.py -n=8
# seven mode with 8 GPUs
$ ./scripts/run_seven.sh nets/resnet_at_ilsvrc12_run.py -n=8
After the training process, you should be able to find the resulting model files located at the models
sub-directory in PocketFlow's home directory.
Train the Compressed Model
Now, we can train a compressed model with the discrimination-aware channel pruning algorithm, as implemented by DisChnPrunedLearner
. Assuming you are now in PocketFlow's home directory, the training process of model compression can be started using the following command (choose whatever mode that fits you):
# local mode with 1 GPU
$ ./scripts/run_local.sh nets/resnet_at_ilsvrc12_run.py \
--learner dis-chn-pruned
# docker mode with 8 GPUs
$ ./scripts/run_docker.sh nets/resnet_at_ilsvrc12_run.py -n=8 \
--learner dis-chn-pruned
# seven mode with 8 GPUs
$ ./scripts/run_seven.sh nets/resnet_at_ilsvrc12_run.py -n=8 \
--learner dis-chn-pruned
Let's take the execution command for the local mode as an example. In this command, run_local.sh
is a shell script that executes the specified Python script with user-provided arguments. Here, we ask it to run the Python script named nets/resnet_at_ilsvrc12_run.py
, which is the execution script for ResNet models on the ImageNet data set. After that, we use --learner dis-chn-pruned
to specify that the DisChnPrunedLearner
should be used for model compression. You may also use other learners by specifying the corresponding learner name. Below is a full list of available learners in PocketFlow:
Learner name | Learner class | Note |
---|---|---|
full-prec |
FullPrecLearner |
No model compression |
channel |
ChannelPrunedLearner |
Channel pruning with LASSO-based channel selection (He et al., 2017) |
dis-chn-pruned |
DisChnPrunedLearner |
Discrimination-aware channel pruning (Zhuang et al., 2018) |
weight-sparse |
WeightSparseLearner |
Weight sparsification with dynamic pruning schedule (Zhu & Gupta, 2017) |
uniform |
UniformQuantLearner |
Weight quantization with uniform reconstruction levels (Jacob et al., 2018) |
uniform-tf |
UniformQuantTFLearner |
Weight quantization with uniform reconstruction levels and TensorFlow APIs |
non-uniform |
NonUniformQuantLearner |
Weight quantization with non-uniform reconstruction levels (Han et al., 2016) |
The local mode only uses 1 GPU for the training process, which takes approximately 20-30 hours to complete. This can be accelerated by multi-GPU training in the docker and seven mode, which is enabled by adding -n=x
right after the specified Python script, where x
is the number of GPUs to be used.
Optionally, you can pass some extra arguments to customize the training process. For the discrimination-aware channel pruning algorithm, some of key arguments are:
Name | Definition | Default Value |
---|---|---|
enbl_dst |
Enable training with distillation loss | False |
dcp_prune_ratio |
DCP algorithm's pruning ratio | 0.5 |
You may override the default value by appending customized arguments at the end of the execution command. For instance, the following command:
$ ./scripts/run_local.sh nets/resnet_at_ilsvrc12_run.py \
--learner dis-chn-pruned \
--enbl_dst \
--dcp_prune_ratio 0.75
requires the DisChnPrunedLearner
to achieve an overall pruning ratio of 0.75 and the training process will be carried out with the distillation loss. As a result, the number of channels in each convolutional layer of the compressed model will be one quarter of the original one.
After the training process is completed, you should be able to find a sub-directory named models_dcp_eval
created in the home directory of PocketFlow. This sub-directory contains all the files that define the compressed model, and we will export them to a TensorFlow Lite formatted model file for deployment in the next section.
Export to TensorFlow Lite
TensorFlow's checkpoint files cannot be directly used for deployment on mobile devices. Instead, we need to firstly convert them into a single *.tflite file that is supported by the TensorFlow Lite Interpreter. For model compressed with channel-pruning based algorithms, e.g. ChannelPruningLearner
and DisChnPrunedLearner
, we have prepared a model conversion script, tools/conversion/export_pb_tflite_models.py
, to generate a TF-Lite model from TensorFlow's checkpoint files.
To convert checkpoint files into a *.tflite file, use the following command:
# convert checkpoint files into a *.tflite model
$ python tools/conversion/export_pb_tflite_models.py \
--model_dir models_dcp_eval
In the above command, we specify the model directory containing checkpoint files generated in the previous training process. The conversion script automatically detects which channels can be safely pruned, and then produces a light-weighted compressed model. The resulting TensorFlow Lite file is also placed at the models_dcp_eval
directory, named as model_transformed.tflite
.
Deploy on Mobile Devices
After exporting the compressed model to the TensorFlow Lite file format, you may follow the official guide for creating an Android demo App from it. Basically, this demo App uses a TensorFlow Lite model to continuously classifies images captured by the camera, and all the computation are performed on mobile devices in real time.
To use the model_transformed.tflite
model file, you need to place it in the asserts
directory and create a Java class named ImageClassifierFloatResNet
to use this model for classification. Below is the example code, which is modified from ImageClassifierFloatInception.java
used in the official demo project:
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.example.android.tflitecamerademo;
import android.app.Activity;
import java.io.IOException;
/**
* This classifier works with the ResNet-18 model.
* It applies floating point inference rather than using a quantized model.
*/
public class ImageClassifierFloatResNet extends ImageClassifier {
/**
* The ResNet requires additional normalization of the used input.
*/
private static final float IMAGE_MEAN_RED = 123.58f;
private static final float IMAGE_MEAN_GREEN = 116.779f;
private static final float IMAGE_MEAN_BLUE = 103.939f;
/**
* An array to hold inference results, to be feed into Tensorflow Lite as outputs.
* This isn't part of the super class, because we need a primitive array here.
*/
private float[][] labelProbArray = null;
/**
* Initializes an {@code ImageClassifier}.
*
* @param activity
*/
ImageClassifierFloatResNet(Activity activity) throws IOException {
super(activity);
labelProbArray = new float[1][getNumLabels()];
}
@Override
protected String getModelPath() {
return "model_transformed.tflite";
}
@Override
protected String getLabelPath() {
return "labels_imagenet_slim.txt";
}
@Override
protected int getImageSizeX() {
return 224;
}
@Override
protected int getImageSizeY() {
return 224;
}
@Override
protected int getNumBytesPerChannel() {
// a 32bit float value requires 4 bytes
return 4;
}
@Override
protected void addPixelValue(int pixelValue) {
imgData.putFloat(((pixelValue >> 16) & 0xFF) - IMAGE_MEAN_RED);
imgData.putFloat(((pixelValue >> 8) & 0xFF) - IMAGE_MEAN_GREEN);
imgData.putFloat((pixelValue & 0xFF) - IMAGE_MEAN_BLUE);
}
@Override
protected float getProbability(int labelIndex) {
return labelProbArray[0][labelIndex];
}
@Override
protected void setProbability(int labelIndex, Number value) {
labelProbArray[0][labelIndex] = value.floatValue();
}
@Override
protected float getNormalizedProbability(int labelIndex) {
// TODO the following value isn't in [0,1] yet, but may be greater. Why?
return getProbability(labelIndex);
}
@Override
protected void runInference() {
tflite.run(imgData, labelProbArray);
}
}
After that, you need to change the image classifier class used in Camera2BasicFragment.java
. Locate the function named onActivityCreated
and change its content as below. Now you will be able to use the compressed ResNet-18 model to classify objects on your mobile phone in real time.
/** Load the model and labels. */
@Override
public void onActivityCreated(Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);
try {
classifier = new ImageClassifierFloatResNet(getActivity());
} catch (IOException e) {
Log.e(TAG, "Failed to initialize an image classifier.", e);
}
startBackgroundThread();
}