需安装tensorflow和tensorflow_hub libraries:
$ pip install "tensorflow~=2.0" $ pip install "tensorflow-hub[make_image_classifier]~=0.6"
$ make_image_classifier --help
最好安装GPU 版本TF2 "tensorflow-gpu~=2.0"
$ make_image_classifier \ --image_dir my_image_dir \ --tfhub_module https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4 \ --image_size 224 \ --saved_model_dir my_dir/new_model \ --labels_output_file class_labels.txt \ --tflite_output_file new_mobile_model.tflite \ --summaries_dir my_log_dir
|-- cat
| |-- a_feline_photo.jpg
| |-- another_cat_pic.jpg
| `-- ...
|-- dog
| |-- PuppyInBasket.JPG
| |-- walking_the_dog.jpeg
| `-- ...
`-- rabbit
|-- IMG87654321.JPG
|-- my_fluffy_rabbit.JPEG
`-- ...
Good training results need many images (many dozens, possibly hundreds per class).
Note: For a quick demo, omit --image_dir. This will download and use the "TF Flowers" dataset and train a model to classify photos of flowers as daisy, dandelion, rose, sunflower or tulip.
The --tfhub_module
is the URL of a pre-trained model piece, or "module", on TensorFlow Hub. You can point your browser to the module URL to see documentation for it. This tool requires a module for image feature extraction in TF2 format. You can find them on TF Hub with this search.
Images are resized to the given --image_size
after reading from disk. It depends on the TF Hub module whether it accepts only a fixed size (in which case you can omit this flag) or an arbitrary size (in which case you should start off by setting this to the standard value advertised in the module documentation).
Model training consumes your input data multiple times ("epochs"). Some part of the data is set aside as validation data; the partially trained model is evaluated on that after each epoch. You can see progress bars and accuracy indicators on the console.
After training, the given --saved_model_dir
is created and filled with several files that represent the complete image classification model in TensorFlow's SavedModel format. This can be deployed to TensorFlow Serving.
If --labels_output_file
is given, the names of the classes are written to that text file, one per line, in the same order as they appear in the predictions output by the model.
If --tflite_output_file
is given, the complete image classification model is written to that file in TensorFlow Lite's model format ("flatbuffers"). This can be deployed to TF Lite on mobile devices. If you are not deploying to TF Lite, you can simply omit this flag.
If --summaries_dir
is given, you can monitor your model training on TensorBoard. See this guide on how to enable TensorBoard.
If you set all the flags as in the example above, you can test the resulting TF Lite model with tensorflow/lite/examples/python/label_image.py by downloading that program and running on an image like
python label_image.py \ --input_mean 0 --input_std 255 \ --model_file new_mobile_model.tflite --label_file class_labels.txt \ --image my_image_dir/cat/a_feline_photo.jpg # <<< Adjust filename.
Advanced usage
Additional command-line flags let you control the training process. In particular, you can increase --train_epochs
to train more, and set the --learning_rate
and --momentum
for the SGD optimizer.
Also, you can set --do_fine_tuning
to train the TensorFlow Hub module together with the classifier.
There is other hyperparameters for regularization such as --l1_regularizer
and --l2_regularizer
, and for data augmentations such as --rotation_range
and --horizontal_flip
. Generally, the default values can give a good performance. You can find a full list of hyperparameters available in make_image_classifier.py
and their default values in make_image_classifier_lib.py
With tensorflow>=2.5
and tensorflow-hub>=0.12
, you can control whether to read input with a tf.data.Dataset and use TF ops for preprocessing using the use_tf_data_input
flag. Note that the shear data augmentation is not supported in this mode. If set to False
, Keras' legacy Python ImageDataGenerator with numpy ops will be used for data augmentation and other preprocessing.