机器学习生命周期
遵循机器学习生命周期来生成鞋类分类模型。 ML生命周期不同于传统的软件开发生命周期,它包含六个具体步骤:
- 获取数据
- 清理并准备数据
- 产生模型
- 评估模型
- 部署模型
- 从模型获得预测(或推论)
生命周期的最终结果是一个机器学习模型,可以查询该模型并返回答案(或预测)。
模型只是数据中趋势和模式的数学表示。 好的数据是所有机器学习项目的基础。
在步骤1中,获取数据。 在第2步中,将数据清理,转换并以机器可以学习的格式放置。 清理和转换过程通常是机器学习生命周期中最耗时的部分。 DJL通过提供使用翻译器预处理图像的功能,使开发人员可以简化此过程。 翻译人员可以执行诸如根据预期参数调整图像大小或将图像从彩色转换为灰度的任务。
过渡到机器学习的开发人员通常会低估清理和转换数据所需的时间,因此翻译员是快速启动该过程的好方法。 在训练过程的第3步中,机器学习算法对数据进行多次遍历(或历时),然后对它们进行研究,以尝试学习不同类型的鞋类。 发现的与鞋类有关的趋势和样式存储在模型中。 当评估模型以确定模型在识别鞋类方面的能力时,第4步是训练的一部分。 如果发现错误,则将其纠正。 在步骤5中,将模型部署到生产环境。 模型投入生产后,第6步允许模型被其他系统使用。
数据
鞋类分类模型是一种多类分类计算机视觉(CV)模型,使用监督学习进行训练,该模型将鞋类分为四个类别标签之一:靴子,凉鞋,鞋子或拖鞋。 监督学习必须包括已经用您要预测的目标(或答案)标记的数据; 这是机器学习的方式。
鞋类分类模型的数据源是UTZappos50k数据集。
使用resnet50网络结构 训练代码代码如下
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.example.demo.djl.footwearclassification;
import ai.djl.Model;
import ai.djl.basicdataset.ImageFolder;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
* In training, multiple passes (or epochs) are made over the training data trying to find patterns
* and trends in the data, which are then stored in the model. During the process, the model is
* evaluated for accuracy using the validation data. The model is updated with findings over each
* epoch, which improves the accuracy of the model.
*/
public final class Training {
// represents number of training samples processed before the model is updated
private static final int BATCH_SIZE = 32;
// the number of passes over the complete dataset
private static final int EPOCHS = 2;
public static void main(String[] args) throws IOException, TranslateException {
System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
// the location to save the model
Path modelDir = Paths.get("models");
// create ImageFolder dataset from directory
ImageFolder dataset = initDataset("ut-zap50k-images-square");
// Split the dataset set into training dataset and validate dataset
RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);
// set loss function, which seeks to minimize errors
// loss function evaluates model's predictions against the correct answer (during training)
// higher numbers are bad - means model performed poorly; indicates more errors; want to
// minimize errors (loss)
Loss loss = Loss.softmaxCrossEntropyLoss();
// setting training parameters (ie hyperparameters)
TrainingConfig config = setupTrainingConfig(loss);
try (Model model = Models.getModel(); // empty model instance to hold patterns
Trainer trainer = model.newTrainer(config)) {
// metrics collect and report key performance indicators, like accuracy
trainer.setMetrics(new Metrics());
Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT);
// initialize trainer with proper input shape
trainer.initialize(inputShape);
// find the patterns in data
EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]);
// set model properties
TrainingResult result = trainer.getTrainingResult();
model.setProperty("Epoch", String.valueOf(EPOCHS));
model.setProperty(
"Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy")));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
// save the model after done training for inference later
// model saved as shoeclassifier-0000.params
model.save(modelDir, Models.MODEL_NAME);
// save labels into model directory
Models.saveSynset(modelDir, dataset.getSynset());
}
}
private static ImageFolder initDataset(String datasetRoot) throws IOException, TranslateException {
ImageFolder dataset =
ImageFolder.builder()
// retrieve the data
.setRepositoryPath(Paths.get(datasetRoot))
.optMaxDepth(10)
.addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
.addTransform(new ToTensor())
// random sampling; don't process the data in order
.setSampling(BATCH_SIZE, true)
.build();
dataset.prepare();
return dataset;
}
private static TrainingConfig setupTrainingConfig(Loss loss) {
return new DefaultTrainingConfig(loss)
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
}
}
构建resnet50网络结构代码如下
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.example.demo.djl.footwearclassification;
import ai.djl.Model;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import java.io.IOException;
import java.io.Writer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
/** A helper class loads and saves model. */
public final class Models {
// the number of classification labels: boots, sandals, shoes, slippers
public static final int NUM_OF_OUTPUT = 4;
// the height and width for pre-processing of the image
public static final int IMAGE_HEIGHT = 100;
public static final int IMAGE_WIDTH = 100;
// the name of the model
public static final String MODEL_NAME = "shoeclassifier";
private Models() {}
public static Model getModel() {
// create new instance of an empty model
Model model = Model.newInstance(MODEL_NAME);
// Block is a composable unit that forms a neural network; combine them like Lego blocks
// to form a complex network
Block resNet50 =
ResNetV1.builder() // construct the network
.setImageShape(new Shape(3, IMAGE_HEIGHT, IMAGE_WIDTH))
.setNumLayers(50)
.setOutSize(NUM_OF_OUTPUT)
.build();
// set the neural network to the model
model.setBlock(resNet50);
return model;
}
public static void saveSynset(Path modelDir, List<String> synset) throws IOException {
Path synsetFile = modelDir.resolve("synset.txt");
try (Writer writer = Files.newBufferedWriter(synsetFile)) {
writer.write(String.join("\n", synset));
}
}
}
根据生成的模型选择鞋子图片进行预测
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.example.demo.djl.footwearclassification;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
/** Uses the model to generate a prediction called an inference */
public class Inference {
public static void main(String[] args) throws ModelException, TranslateException, IOException {
// the location where the model is saved
Path modelDir = Paths.get("models");
System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
// the path of image to classify
String imageFilePath;
if (args.length == 0) {
imageFilePath = "ut-zap50k-images-square/Sandals/Heel/Annie/7350693.3.jpg";
} else {
imageFilePath = args[0];
}
// Load the image file from the path
Image img = ImageFactory.getInstance().fromFile(Paths.get(imageFilePath));
try (Model model = Models.getModel()) { // empty model instance
// load the model
model.load(modelDir, Models.MODEL_NAME);
// define a translator for pre and post processing
// out of the box this translator converts images to ResNet friendly ResNet 18 shape
Translator<Image, Classifications> translator =
ImageClassificationTranslator.builder()
.addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
.addTransform(new ToTensor())
.optApplySoftmax(true)
.build();
// run the inference using a Predictor
try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
// holds the probability score per label
Classifications predictResult = predictor.predict(img);
System.out.println(predictResult);
}
}
}
}
预测结果
[
class: "Sandals", probability: 0.73966
class: "Shoes", probability: 0.25818
class: "Slippers", probability: 0.00200
class: "Boots", probability: 0.00014
]