Java程序员学深度学习 DJL上手8 使用风格迁移学习

一、风格迁移学习简介

在这里插入图片描述

1. 风格迁移学习

风格迁移,英文名称:Transfer learning,是机器学习的一种,它是在有一定的额外数据和存在一个已有模型的前提下,来生成目标数据,典型应用如生成新的画作,2015年由Gatys等人发表了文章《A Neural Algorithm of Artistic Style》,首次使用深度学习进行艺术画风格学习。

2. BERT

BERT的全称为Bidirectional Encoder Representation from Transformers,是一个预训练的语言表征模型,主要用来作自然语言的分词。

3. DistilBERT

BERT的参数量巨大,运行中需要巨大的空间、消耗大量资源,而DistilBERT 则对Bert进行瘦身。

二、实现过程

1. 说明

这里使用亚马逊的评论数据集,商品品类是数码软件,包含10.2万条有效的评论。选择的预训练模型DistilBERT是一个轻量级的BERT模型 , 已经使用维基百科的超过一百分的文本语料库进行了训练。DistilBERT作为基本层添加到了分类模型用来输出结评论结果星级,星级范围是1-5.
评论数据将作为数据传入,而评分则作为标签。
亚马逊的评论示例:
在这里插入图片描述

2. 准备数据集

首先是要准备数据集,原始数据是TSV格式,这里使用CSVDataset来作为数据容器,使用Featurizer接口来对原始数据的行/列进行预处理,以实现特征提取。

final class BertFeaturizer implements CsvDataset.Featurizer {

    private final BertFullTokenizer tokenizer;
    private final int maxLength; // the cut-off length

    public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength) {
        this.tokenizer = tokenizer;
        this.maxLength = maxLength;
    }

    /** {@inheritDoc} */
    @Override
    public void featurize(DynamicBuffer buf, String input) {
        SimpleVocabulary vocab = tokenizer.getVocabulary();
        // convert sentence to tokens (toLowerCase for uncased model)
        List<String> tokens = tokenizer.tokenize(input.toLowerCase());
        // 超出maxLength的进行截取
        tokens = tokens.size() > maxLength ? tokens.subList(0, maxLength) : tokens;
        // BERT embedding convention "[CLS] Your Sentence [SEP]"
        buf.put(vocab.getIndex("[CLS]"));
        tokens.forEach(token -> buf.put(vocab.getIndex(token)));
        buf.put(vocab.getIndex("[SEP]"));
    }
}

对于BERT模型,我们构造一个BertFeaturizer 对象,实现 CsvDataset.Featurizer 方法来进行特征提取。本示例里对数据进行简单的清理。

3. 把 BertFeaturizer 应用在数据集上

CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {
    String amazonReview =
            "https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz";
    float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
    return CsvDataset.builder()
            .optCsvUrl(amazonReview) // load from Url
            .setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format
            .setSampling(batchSize, true) // make sample size and random access
            .optLimit(limit)
            .addFeature(new CsvDataset.Feature("review_body", new BertFeaturizer(tokenizer, maxLength)))
            .addLabel(new CsvDataset.Feature("star_rating", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))
            .optDataBatchifier(
                    PaddingStackBatchifier.builder()
                            .optIncludeValidLengths(false)
                            .addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
                            .build()) // define how to pad dataset to a fix length
            .build();
}

在列上应用上面定义的 BertFeaturizer,评分作为标签集。另外一句提取的词比我们的定义长度小的时候,还定义了数据填充方法。

4. 构造模型

先下载DistiledBERT模型,再下载预训练的权重。下载的模型没有包含分类层,我们还需要在构造模型的最后加上分类层然后再训练。对块完成修改后,使用.criteria loadModel setBlock 把模型。

2.4.1 加载模型

// MXNet base model
String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
    modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
}

Criteria<NDList, NDList> criteria = Criteria.builder()
        .optApplication(Application.NLP.WORD_EMBEDDING)
        .setTypes(NDList.class, NDList.class)
        .optModelUrls(modelUrls)
        .optProgress(new ProgressBar())
        .build();
ZooModel<NDList, NDList> embedding = criteria.loadModel();

2.4.2 创建分类层

这里创建一个简单的MLP层用来对评论级别分类,最后一个全连接层输出5个数值,用来对应评价的5个级别。
层的最前面还会对内嵌文本进行处理。
之后把块加载到模型里。

Predictor<NDList, NDList> embedder = embedding.newPredictor();
Block classifier = new SequentialBlock()
        // text embedding layer
        .add(
            ndList -> {
                NDArray data = ndList.singletonOrThrow();
                NDList inputs = new NDList();
                long batchSize = data.getShape().get(0);
                float maxLength = data.getShape().get(1);

                if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
                    inputs.add(data.toType(DataType.INT64, false));
                    inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
                    inputs.add(data.getManager().arange(maxLength)
                               .toType(DataType.INT64, false)
                               .broadcast(data.getShape()));
                } else {
                    inputs.add(data);
                    inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
                }
                // run embedding
                try {
                    return embedder.predict(inputs);
                } catch (TranslateException e) {
                    throw new IllegalArgumentException("embedding error", e);
                }
            })
        // classification layer
        .add(Linear.builder().setUnits(768).build()) // pre classifier
        .add(Activation::relu)  // 激活函数
        .add(Dropout.builder().optRate(0.2f).build()) 
        .add(Linear.builder().setUnits(5).build()) // 5 star rating
        .addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
Model model = Model.newInstance("AmazonReviewRatingClassification");
model.setBlock(classifier);

5. 开始训练

2.5.1 创建训练集和测试集

首先建立一个单词表,把单词转到数字。然后把字母表喂给tokenizer特征提取器。
最后,要把数据集按比例进行拆分成训练集和测试集。

tokens长度最大设置为64,这意味着评论里只有64个特征分词会被用到。

// Prepare the vocabulary
SimpleVocabulary vocabulary = SimpleVocabulary.builder()
        .optMinFrequency(1)
        .addFromTextFile(embedding.getArtifact("vocab.txt"))
        .optUnknownToken("[UNK]")
        .build();
// Prepare dataset
int maxTokenLength = 64; // cutoff tokens length
int batchSize = 8;
int limit = Integer.MAX_VALUE;
// int limit = 512; // uncomment for quick testing

BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
RandomAccessDataset trainingSet = datasets[0];
RandomAccessDataset validationSet = datasets[1];

2.5.2 创建训练监听器跟踪训练过程

这里要注意设置的精确度、损失函数。训练日志会保存到 build/model1里。

SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
        listener.setSaveModelCallback(
            trainer -> {
                TrainingResult result = trainer.getTrainingResult();
                Model model = trainer.getModel();
                // track for accuracy and loss
                float accuracy = result.getValidateEvaluation("Accuracy");
                model.setProperty("Accuracy", String.format("%.5f", accuracy));
                model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
            });
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
        .addEvaluator(new Accuracy())
        .optDevices(Device.getDevices(1)) // train using single GPU
        .addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
        .addTrainingListeners(listener);

2.5.3 训练

int epoch = 2;

Trainer trainer = model.newTrainer(config);
trainer.setMetrics(new Metrics());
Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
trainer.initialize(encoderInputShape);
EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
System.out.println(trainer.getTrainingResult());

2.5.4 保存模型

model.save(Paths.get("build/model"), "amazon-review.param");

2.5.5 验证模型

从模型创建一个预测器,然后使用自己的数据进行训练,来验证模型效果。


class MyTranslator implements Translator<String, Classifications> {

    private BertFullTokenizer tokenizer;
    private SimpleVocabulary vocab;
    private List<String> ranks;

    public MyTranslator(BertFullTokenizer tokenizer) {
        this.tokenizer = tokenizer;
        vocab = tokenizer.getVocabulary();
        ranks = Arrays.asList("1", "2", "3", "4", "5");
    }

    @Override
    public Batchifier getBatchifier() { return new StackBatchifier(); }

    @Override
    public NDList processInput(TranslatorContext ctx, String input) {
        List<String> tokens = tokenizer.tokenize(input);
        float[] indices = new float[tokens.size() + 2];
        indices[0] = vocab.getIndex("[CLS]");
        for (int i = 0; i < tokens.size(); i++) {
            indices[i+1] = vocab.getIndex(tokens.get(i));
        }
        indices[indices.length - 1] = vocab.getIndex("[SEP]");
        return new NDList(ctx.getNDManager().create(indices));
    }

    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        return new Classifications(ranks, list.singletonOrThrow().softmax(0));
    }
}

创建一个预测器:

String review = "It works great, but it takes too long to update itself and slows the system";
Predictor<String, Classifications> predictor = model.newPredictor(new MyTranslator(tokenizer));
System.out.println(predictor.predict(review));

三、源程序

在这里插入图片描述

PyTorchLearn

package com.xundh;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.PaddingStackBatchifier;
import ai.djl.translate.TranslateException;
import org.apache.commons.csv.CSVFormat;

import java.io.IOException;
import java.nio.file.Paths;

public class PyTorchLearn {
    public static void main(String[] args) throws IOException, TranslateException, MalformedModelException, ModelNotFoundException {
        // 根据深度学习引擎,选择要下载的模型
        // MXNet base model
        String modelUrls = "https://resources.djl.ai/test-models/distilbert.zip";
        if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
            modelUrls = "https://resources.djl.ai/test-models/traced_distilbert_wikipedia_uncased.zip";
        }


        Criteria<NDList, NDList> criteria = Criteria.builder()
                .optApplication(Application.NLP.WORD_EMBEDDING)
                .setTypes(NDList.class, NDList.class)
                .optModelUrls(modelUrls)
                .optProgress(new ProgressBar())
                .build();
        ZooModel<NDList, NDList> embedding = criteria.loadModel();

        Predictor<NDList, NDList> embedder = embedding.newPredictor();
        Block classifier = new SequentialBlock()
                // text embedding layer
                .add(
                        ndList -> {
                            NDArray data = ndList.singletonOrThrow();
                            NDList inputs = new NDList();
                            long batchSize = data.getShape().get(0);
                            float maxLength = data.getShape().get(1);

                            if ("PyTorch".equals(Engine.getInstance().getEngineName())) {
                                inputs.add(data.toType(DataType.INT64, false));
                                inputs.add(data.getManager().full(data.getShape(), 1, DataType.INT64));
                                inputs.add(data.getManager().arange(maxLength)
                                        .toType(DataType.INT64, false)
                                        .broadcast(data.getShape()));
                            } else {
                                inputs.add(data);
                                inputs.add(data.getManager().full(new Shape(batchSize), maxLength));
                            }
                            // run embedding
                            try {
                                return embedder.predict(inputs);
                            } catch (TranslateException e) {
                                throw new IllegalArgumentException("embedding error", e);
                            }
                        })
                // classification layer
                .add(Linear.builder().setUnits(768).build()) // pre classifier
                .add(Activation::relu)
                .add(Dropout.builder().optRate(0.2f).build())
                .add(Linear.builder().setUnits(5).build()) // 5 star rating
                .addSingleton(nd -> nd.get(":,0")); // Take [CLS] as the head
        Model model = Model.newInstance("AmazonReviewRatingClassification");
        model.setBlock(classifier);

        // Prepare the vocabulary
        SimpleVocabulary vocabulary = SimpleVocabulary.builder()
                .optMinFrequency(1)
                .addFromTextFile(embedding.getArtifact("vocab.txt"))
                .optUnknownToken("[UNK]")
                .build();
// Prepare dataset
        int maxTokenLength = 64; // cutoff tokens length
        int batchSize = 8;
        int limit = Integer.MAX_VALUE;
// int limit = 512; // uncomment for quick testing

        BertFullTokenizer tokenizer = new BertFullTokenizer(vocabulary, true);
        CsvDataset amazonReviewDataset = getDataset(batchSize, tokenizer, maxTokenLength, limit);
// split data with 7:3 train:valid ratio
        RandomAccessDataset[] datasets = amazonReviewDataset.randomSplit(7, 3);
        RandomAccessDataset trainingSet = datasets[0];
        RandomAccessDataset validationSet = datasets[1];
        SaveModelTrainingListener listener = new SaveModelTrainingListener("build/model");
        listener.setSaveModelCallback(
                trainer -> {
                    TrainingResult result = trainer.getTrainingResult();
                    Model model1 = trainer.getModel();
                    // track for accuracy and loss
                    float accuracy = result.getValidateEvaluation("Accuracy");
                    model1.setProperty("Accuracy", String.format("%.5f", accuracy));
                    model1.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
                });
        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) // loss type
                .addEvaluator(new Accuracy())
                .optDevices(new Device[]{Device.cpu()}) // train using single GPU
                .addTrainingListeners(TrainingListener.Defaults.logging("build/model"))
                .addTrainingListeners(listener);

        int epoch = 2;

        Trainer trainer = model.newTrainer(config);
        trainer.setMetrics(new Metrics());
        Shape encoderInputShape = new Shape(batchSize, maxTokenLength);
// initialize trainer with proper input shape
        trainer.initialize(encoderInputShape);
        EasyTrain.fit(trainer, epoch, trainingSet, validationSet);
        System.out.println(trainer.getTrainingResult());

        model.save(Paths.get("build/model"), "amazon-review.param");
    }

    /**
     * 下载创建数据集对象
     */
    static CsvDataset getDataset(int batchSize, BertFullTokenizer tokenizer, int maxLength, int limit) {
        String amazonReview =
                "https://s3.amazonaws.com/amazon-reviews-pds/tsv/amazon_reviews_us_Digital_Software_v1_00.tsv.gz";
        float paddingToken = tokenizer.getVocabulary().getIndex("[PAD]");
        return CsvDataset.builder()
                .optCsvUrl(amazonReview) // load from Url
                .setCsvFormat(CSVFormat.TDF.withQuote(null).withHeader()) // Setting TSV loading format
                .setSampling(batchSize, true) // make sample size and random access
                .optLimit(limit)
                .addFeature(
                        new CsvDataset.Feature(
                                "review_body", new BertFeaturizer(tokenizer, maxLength)))
                .addLabel(
                        new CsvDataset.Feature(
                                "star_rating", (buf, data) -> buf.put(Float.parseFloat(data) - 1.0f)))
                .optDataBatchifier(
                        PaddingStackBatchifier.builder()
                                .optIncludeValidLengths(false)
                                .addPad(0, 0, (m) -> m.ones(new Shape(1)).mul(paddingToken))
                                .build()) // define how to pad dataset to a fix length
                .build();
    }
}

BertFeaturizer

package com.xundh;

import ai.djl.basicdataset.tabular.CsvDataset;
import ai.djl.basicdataset.utils.DynamicBuffer;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;

import java.util.List;

final class BertFeaturizer implements CsvDataset.Featurizer {

    private final BertFullTokenizer tokenizer;
    private final int maxLength; // the cut-off length

    public BertFeaturizer(BertFullTokenizer tokenizer, int maxLength) {
        this.tokenizer = tokenizer;
        this.maxLength = maxLength;
    }

    /** {@inheritDoc} */
    @Override
    public void featurize(DynamicBuffer buf, String input) {
        SimpleVocabulary vocab = tokenizer.getVocabulary();
        // convert sentence to tokens (toLowerCase for uncased model)
        List<String> tokens = tokenizer.tokenize(input.toLowerCase());
        // trim the tokens to maxLength
        tokens = tokens.size() > maxLength ? tokens.subList(0, maxLength) : tokens;
        // BERT embedding convention "[CLS] Your Sentence [SEP]"
        buf.put(vocab.getIndex("[CLS]"));
        tokens.forEach(token -> buf.put(vocab.getIndex(token)));
        buf.put(vocab.getIndex("[SEP]"));
    }
}

运行效果:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

编程圈子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值