java深度学习之DJL迁移学习

1、数据集很小,数据相似性很高

考虑这个kaggle数据集(https://www.kaggle.com/mriganksingh/cat-images-dataset)。这包括猫的图像和其他非猫的图像。它有209个像素64*64*3的训练图像和50个测试图像。这显然是一个非常小的数据集,但我们知道ResNet是在大量动物和猫图像上训练的,所以我们可以使用ResNet作为固定特征提取器来解决我们的猫与非猫的问题。

我们需要冻结除最后一层之外的所有网络。

2、数据的大小很小,数据相似性也很低

考虑来自(https://www.kaggle.com/kvinicki/canine-coccidiosis),这个数据集包含了犬异孢球虫和犬异孢球虫卵囊的图像和标签,异孢球虫卵囊是一种球虫寄生虫,可感染狗的肠道。它是由萨格勒布兽医学院创建的。它包含了两种寄生虫的341张图片。

这个数据集很小,而且不是Imagenet中的一个类别。在这种情况下,我们保留预先训练好的模型架构,冻结前几层并保留它们的权重,并训练后几层更新它们的权重以适应我们的问题。

3、数据集的大小很大,但数据相似性非常低

考虑这个来自kaggle,皮肤癌MNIST的数据集:HAM10000

其具有超过10015个皮肤镜图像,属于7种不同类别。这不是我们在Imagenet中可以找到的那种数据。

这就是我们只保留模型架构而不保留来自预训练模型的任何权重的地方。让我们重新定义输出层,将项目分类为7个类别。

4、数据大小很大,数据相似性很高

考虑来自kaggle 的鲜花数据集(https://www.kaggle.com/alxmamaev/flowers-recognition)。它包含4242个花卉图像。图片分为五类:洋甘菊,郁金香,玫瑰,向日葵,蒲公英。每个类大约有800张照片。

这是应用迁移学习的理想情况。我们保留了预训练模型的体系结构和每一层的权重,并训练模型更新权重以匹配我们的特定问题。

package com.example.demo.djl;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.Cifar10;
import ai.djl.basicmodelzoo.BasicModelZoo;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.SymbolBlock;
import ai.djl.nn.core.Linear;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;

import com.example.demo.util.Arguments;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * An example of training an image classification (ResNet for Cifar10) model.
 *
 * <p>See this <a
 * href="https://github.com/awslabs/djl/blob/master/examples/docs/train_cifar10_resnet.md">doc</a>
 * for information about this example.
 *迁移学习
 * CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。
 * 一共包含 10 个类别的 RGB 彩色图 片:飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )
 * 、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,
 * 数据集中一共有 50000 张训练圄片和 10000 张测试图片。
 *
 * 与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:
 * • CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
 * • CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
 * • 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,
 * 不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。 直接的线性模型如 Softmax 在 CIFAR-10 上表现得很差。
 */
public final class TrainResnetWithCifar10 {

    private static final Logger logger = LoggerFactory.getLogger(TrainResnetWithCifar10.class);

    private TrainResnetWithCifar10() {}

    public static void main(String[] args) throws ModelException, IOException, TranslateException {
        System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
        TrainResnetWithCifar10.runExample(args);
    }

    public static TrainingResult runExample(String[] args)
            throws IOException, ModelException, TranslateException {
        Arguments arguments = Arguments.parseArgs(args);
//        if (arguments == null) {
//            return null;
//        }
        arguments.setEpoch(3);
        arguments.setBatchSize(32);
        arguments.setMaxGpus(1);
        arguments.setPreTrained(true);
        arguments.setSymbolic(true);


        try (Model model = getModel(arguments)) {
            // get training dataset
            RandomAccessDataset trainDataset = getDataset(Dataset.Usage.TRAIN, arguments);
            RandomAccessDataset validationDataset = getDataset(Dataset.Usage.TEST, arguments);

            // setup training configuration
            DefaultTrainingConfig config = setupTrainingConfig(arguments);

            try (Trainer trainer = model.newTrainer(config)) {
                trainer.setMetrics(new Metrics());

                /*
                 * CIFAR10 is 32x32 image and pre processed into NCHW NDArray.
                 * 1st axis is batch axis, we can use 1 for initialization.
                 * 归一化 RGB彩色图 3  大小 32*32
                 */
                Shape inputShape = new Shape(1, 3, 32, 32);

                // initialize trainer with proper input shape
                trainer.initialize(inputShape);
                EasyTrain.fit(trainer, arguments.getEpoch(), trainDataset, validationDataset);

                TrainingResult result = trainer.getTrainingResult();
                model.setProperty("Epoch", String.valueOf(result.getEpoch()));
                model.setProperty(
                        "Accuracy",
                        String.format("%.5f", result.getValidateEvaluation("Accuracy")));
                model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));

                Path modelPath = Paths.get("build/model");
                model.save(modelPath, "resnetv1");

                Classifications classifications = testSaveParameters(model.getBlock(), modelPath);
                logger.info("Predict result: {}", classifications.topK(3));
                return result;
            }
        }
    }

    private static Model getModel(Arguments arguments)
            throws IOException, ModelNotFoundException, MalformedModelException {
        boolean isSymbolic = arguments.isSymbolic();
        boolean preTrained = arguments.isPreTrained();
        Map<String, String> options = arguments.getCriteria();
        Criteria.Builder<Image, Classifications> builder =
                Criteria.builder()
                        .optApplication(Application.CV.IMAGE_CLASSIFICATION)
                        .setTypes(Image.class, Classifications.class)
                        .optProgress(new ProgressBar())
                        .optArtifactId("resnet");
        if (isSymbolic) {
            // load the model
            builder.optGroupId("ai.djl.mxnet");
            if (options == null) {
                builder.optFilter("layers", "50");
                builder.optFilter("flavor", "v1");
            } else {
                builder.optFilters(options);
            }
            Model model = ModelZoo.loadModel(builder.build());
            SequentialBlock newBlock = new SequentialBlock();
            SymbolBlock block = (SymbolBlock) model.getBlock();
            //迁移学习 去除最后一层
            block.removeLastBlock();

            newBlock.add(block);
            // the original model don't include the flatten
            // so apply the flatten here
            newBlock.add(Blocks.batchFlattenBlock());
            newBlock.add(Linear.builder().setUnits(10).build());
            model.setBlock(newBlock);
            if (!preTrained) {
                //清除迁移学习的模型参数
                model.getBlock().clear();
            }
            //若果不执行  model.getBlock().clear();方法则是
            //冻结所有网络的权重,除了最后的全连接层。
            // 最后一个全连接层被替换为一个具有随机权重的新层,并且只训练这一层。
            return model;
        }
        // imperative resnet50
        if (preTrained) {
            builder.optGroupId(BasicModelZoo.GROUP_ID);
            if (options == null) {
                builder.optFilter("layers", "50");
                builder.optFilter("flavor", "v1");
                builder.optFilter("dataset", "cifar10");
            } else {
                builder.optFilters(options);
            }
            // load pre-trained imperative ResNet50 from DJL model zoo
            // 从DJL model zoo加载预先训练的ResNet50
            return ModelZoo.loadModel(builder.build());
        } else {
            // construct new ResNet50 without pre-trained weights
            //在没有预先训练权重的情况下构造新的ResNet50
            Model model = Model.newInstance("resnetv1");
            Block resNet50 =
                    ResNetV1.builder()
                            .setImageShape(new Shape(3, 32, 32))
                            .setNumLayers(50)
                            .setOutSize(10)
                            .build();
            model.setBlock(resNet50);
            return model;
        }
    }

    private static Classifications testSaveParameters(Block block, Path path)
            throws IOException, ModelException, TranslateException {
        String synsetUrl =
                "https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset_cifar10.txt";
        ImageClassificationTranslator translator =
                ImageClassificationTranslator.builder()
                        .addTransform(new ToTensor())
                        .addTransform(new Normalize(Cifar10.NORMALIZE_MEAN, Cifar10.NORMALIZE_STD))
                        .optSynsetUrl(synsetUrl)
                        .optApplySoftmax(true)
                        .build();

        Image img = ImageFactory.getInstance().fromUrl("src/test/resources/airplane1.png");

        Criteria<Image, Classifications> criteria =
                Criteria.builder()
                        .setTypes(Image.class, Classifications.class)
                        .optModelUrls(path.toUri().toString())
                        .optTranslator(translator)
                        .optBlock(block)
                        .optModelName("resnetv1")
                        .build();

        try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria);
             Predictor<Image, Classifications> predictor = model.newPredictor()) {
            return predictor.predict(img);
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy())
                .optDevices(Device.getDevices(arguments.getMaxGpus()))
                .addTrainingListeners(TrainingListener.Defaults.logging(arguments.getOutputDir()));
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments)
            throws IOException {
        Pipeline pipeline =
                new Pipeline(
                        new ToTensor(),
                        new Normalize(Cifar10.NORMALIZE_MEAN, Cifar10.NORMALIZE_STD));
        Cifar10 cifar10 =
                Cifar10.builder()
                        .optUsage(usage)
                        .setSampling(arguments.getBatchSize(), true)
                        .optLimit(arguments.getLimit())
                        .optPipeline(pipeline)
                        .build();
        cifar10.prepare(new ProgressBar());
        return cifar10;
    }
}

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

非ban必选

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值