Deep Java Library(六)DJLServing自定义模型,自定义Translator注意事项

DJLServing自定义模型中自定义Translator注意事项需要仔细读一下DJLServing源码中的ServingTranslatorFactory类,,一开始不了解以为DJLServing选择Translator像玄学,后来看了像迷宫一样ServingTranslatorFactory类大致明白了,以下是源码注释版,还有一个整理的流程图。
在这里插入图片描述

/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
import ai.djl.Application;
import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.modality.cv.translator.ImageServingTranslator;
import ai.djl.translate.*;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.reflect.Constructor;
import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.Map;
import java.util.Set;

class ServingTranslatorFactory implements TranslatorFactory {

    //日志打印
    private static final Logger logger = LoggerFactory.getLogger(ServingTranslatorFactory.class);

    //返回只有一个固定元素的SET,约束模型的输入,输出类型
    @Override
    public Set<Pair<Type, Type>> getSupportedTypes() {
        return Collections.singleton(new Pair<>(Input.class, Output.class));
    }

    //工厂实例化方法
    @Override
    @SuppressWarnings("unchecked")
    public <I, O> Translator<I, O> newInstance(
            Class<I> input, Class<O> output, Model model, Map<String, ?> arguments)
            throws TranslateException {

        //如果输出和输出不在支持的范围内直接抛出异常
        if (!isSupported(input, output)) {
            throw new IllegalArgumentException("Unsupported input/output types.");
        }

        //获取model的路径
        Path modelDir = model.getModelPath();
        //获取serving.properties里面的translatorFactory参数
        String factoryClass = ArgumentsUtil.stringValue(arguments, "translatorFactory");
        //如果translatorFactory参数不为null且长度不为0
        if (factoryClass != null && !factoryClass.isEmpty()) {
            //直接加载工厂类
            TranslatorFactory factory = loadTranslatorFactory(factoryClass);
            //如果工厂类加载成功并且工厂类支持要去的输入输出
            if (factory != null && factory.isSupported(input, output)) {
                //打印日志
                logger.info("Using TranslatorFactory: {}", factory.getClass().getName());
                //将工厂类实例化返回
                return factory.newInstance(input, output, model, arguments);
            }
        }

        //如果上面没有匹配上
        //获取serving.properties里面的translator参数
        String className = (String) arguments.get("translator");

        //获取model目录下的libs目录
        Path libPath = modelDir.resolve("libs");
        //如果这个libs目录不存在
        if (!Files.isDirectory(libPath)) {
            //那就找lib目录
            libPath = modelDir.resolve("lib");
            //如果lib目录也没有那就走loadDefaultTranslator(arguments)这个方法,加载默认的Translator
            if (!Files.isDirectory(libPath) && className == null) {
                return (Translator<I, O>) loadDefaultTranslator(arguments);
            }
        }
        //如果model目录下的libs目录存在那就加载class
        ServingTranslator translator = findTranslator(libPath, className);
        //如果加载上了
        if (translator != null) {
            //设置translator的参数
            translator.setArguments(arguments);
            //打印日志
            logger.info("Using translator: {}", translator.getClass().getName());
            //直接返回translator
            return (Translator<I, O>) translator;
        } else if (className != null) {
            //如果加载失败抛出异常
            throw new TranslateException("Failed to load translator: " + className);
        }
        //实在是找不到就走loadDefaultTranslator(arguments)这个方法,加载默认的Translator
        return (Translator<I, O>) loadDefaultTranslator(arguments);
    }

    private ServingTranslator findTranslator(Path path, String className) {
        //找目录里面的classes目录
        Path classesDir = path.resolve("classes");
        //把java编译成classes
        ClassLoaderUtils.compileJavaClass(classesDir);
        //返回出去Translator,该类必须是ServingTranslator的实现类,因为会强制转换成ServingTranslator在子类
        return ClassLoaderUtils.findImplementation(path, ServingTranslator.class, className);
    }

    private TranslatorFactory loadTranslatorFactory(String className) {
        try {
            //通过类名加载该类
            Class<?> clazz = Class.forName(className);
            //将该类强制转换成TranslatorFactory的子类
            Class<? extends TranslatorFactory> subclass = clazz.asSubclass(TranslatorFactory.class);
            //加载该类的构造方法
            Constructor<? extends TranslatorFactory> constructor = subclass.getConstructor();
            //构造该类返回实例
            return constructor.newInstance();
        } catch (Throwable e) {
            //捕获异常
            logger.trace("Not able to load TranslatorFactory: " + className, e);
        }
        return null;
    }

    private Translator<Input, Output> loadDefaultTranslator(Map<String, ?> arguments) {
        //获取serving.properties里面的application参数
        String appName = ArgumentsUtil.stringValue(arguments, "application");
        //如果不为空
        if (appName != null) {
            Application application = Application.of(appName);
            //如果是cv/image_classification
            if (application == Application.CV.IMAGE_CLASSIFICATION) {
                //那就加载ImageClassificationTranslator这个玩意
                return getImageClassificationTranslator(arguments);
            }
        }
        //否则的化就加载NoopServingTranslatorFactory这个玩意
        NoopServingTranslatorFactory factory = new NoopServingTranslatorFactory();
        //最后返回的是NoopServingTranslator这个玩意
        return factory.newInstance(Input.class, Output.class, null, arguments);
    }

    private Translator<Input, Output> getImageClassificationTranslator(Map<String, ?> arguments) {
        //返回ImageServingTranslator的实例
        return new ImageServingTranslator(ImageClassificationTranslator.builder(arguments).build());
    }
}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: import torch import torch.nn as nn# 定义模型 class ImgTransModel(nn.Module): def __init__(self): super(ImgTransModel, self).__init__() self.encoder = nn.Sequential( # 使用卷积和池化层提取图像特征 nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.ReLU(), nn.MaxPool2d(2, stride=2), nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(), nn.MaxPool2d(2, stride=2) ) self.attention = nn.Sequential( # 注意力机制 nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 32) ) self.decoder = nn.Sequential( # 解码器 nn.Linear(32, 64), nn.ReLU(), nn.Linear(64, 10) ) def forward(self, x): x = self.encoder(x) x = self.attention(x) x = self.decoder(x) return x ### 回答2: 添加注意力机制的图像翻译模型的代码如下所示: ```python import tensorflow as tf from tensorflow.keras import layers class Attention(layers.Layer): def __init__(self): super(Attention, self).__init__() def build(self, input_shape): self.W1 = self.add_weight(shape=(input_shape[-1], input_shape[-1])) self.W2 = self.add_weight(shape=(input_shape[-1], input_shape[-1])) self.V = self.add_weight(shape=(input_shape[-1], 1)) def call(self, inputs): features, hidden_state = inputs hidden_with_time_axis = tf.expand_dims(hidden_state, 1) attention_weights = tf.nn.tanh(tf.matmul(features, self.W1) + tf.matmul(hidden_with_time_axis, self.W2)) score = tf.matmul(attention_weights, self.V) attention_weights = tf.nn.softmax(score, axis=1) context_vector = attention_weights * features context_vector = tf.reduce_sum(context_vector, axis=1) return context_vector, attention_weights class Translator(tf.keras.Model): def __init__(self, vocab_size, embedding_dim, units): super(Translator, self).__init__() self.units = units self.embedding = layers.Embedding(vocab_size, embedding_dim) self.gru = layers.GRU(self.units, return_sequences=True, return_state=True) self.fc = layers.Dense(vocab_size) self.attention = Attention() # 添加注意力机制 def call(self, inputs, hidden): context_vector, attention_weights = self.attention([inputs, hidden]) x = self.embedding(inputs) x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1) output, state = self.gru(x) output = tf.reshape(output, (-1, output.shape[2])) x = self.fc(output) return x, state, attention_weights # 示例使用 vocab_size = 10000 embedding_dim = 256 units = 1024 translator = Translator(vocab_size, embedding_dim, units) sample_hidden = translator.gru.initialize_hidden_state(batch_size=1) sample_output, sample_hidden, sample_attention_weights = translator.call(tf.random.uniform((1, 10)), sample_hidden) print(sample_output.shape) # 输出:(1, 10000) print(sample_hidden.shape) # 输出:(1, 1024) print(sample_attention_weights.shape) # 输出:(1, 10, 1) ``` 这段代码实现了一个图像翻译模型,其中添加了一个Attention类作为注意力机制的层。在Translator类的call方法中,调用Attention类对输入进行注意力计算,将注意力结果与上一时刻的隐藏状态合并后再输入GRU层和全连接层进行翻译预测。在示例使用部分,创建了一个示例模型,并将随机输入进行预测,显示预测输出形状和注意力权重的形状。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值