Java调用机器学习训练包记录一下

前言

  • 最近公司有个需求,需要对用户进行数据画像分析。
  • 公司大数据组通过对线上用户数据进行分析后,通过机器学习用python做了一个训练模型pkl文件包。
  • 要求我部门对用户数据进行分析计算。而我部门的项目都是使用Java进行开发的,所以就需要Java调用pkl训练模型包。
  • 经过调研python的pkl训练模型包不能直接被Java调用,跨平台调用需要使用pmml格式文件,所以就让大数据部门依照已经生成的训练模型pkl文件,在次封装成一个pmml文件。

pmml格式

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
   <Header>
      <Application name="JPMML-SkLearn" version="1.6.27"/>
      <Timestamp>2021-08-30T06:48:45Z</Timestamp>
   </Header>
   <DataDictionary>
      <DataField name="y" optype="categorical" dataType="integer">
         <Value value="0"/>
         <Value value="1"/>
      </DataField>
      <DataField name="x1" optype="continuous" dataType="double"/>
      <DataField name="x2" optype="continuous" dataType="double"/>
      <DataField name="x3" optype="continuous" dataType="double"/>
   </DataDictionary>
   <RegressionModel functionName="classification" algorithmName="sklearn.linear_model._logistic.LogisticRegression" normalizationMethod="logit">
      <MiningSchema>
         <MiningField name="y" usageType="target"/>
         <MiningField name="x1"/>
         <MiningField name="x2"/>
         <MiningField name="x3"/>
      </MiningSchema>
      <RegressionTable intercept="0.5920457931585216" targetCategory="1">
         <NumericPredictor name="x1" coefficient="0.7586778342148665"/>
         <NumericPredictor name="x2" coefficient="0.6562980822443883"/>
         <NumericPredictor name="x3" coefficient="0.9917332587791079"/>
      </RegressionTable>
      <RegressionTable intercept="0.0" targetCategory="0"/>
   </RegressionModel>
</PMML>

Java调用pmml文件

  • 首先在项目中先引用解析pmml的maven包
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.4.1</version>
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator-extension</artifactId>
    <version>1.4.1</version>
</dependency>
  • Java调用方法
  • 当有test.pmml文件后,可以把文件放在springboot项目的resources目录下,使用ClassPathResource类获取到文件流
/**
 * @Author: ZRH
 * @Date: 2021/8/30 9:17
 */
@Slf4j
public final class ClassificationModelOld {

    private static Evaluator modelEvaluator;

    static {
        PMML pmml;
        try {
            Resource resource = new ClassPathResource("test.pmml");
            InputStream is = resource.getInputStream();
            pmml = PMMLUtil.unmarshal(is);
            try {
                is.close();
            } catch (IOException e) {
                log.info("InputStream close error!");
            }

            ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
            modelEvaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
            modelEvaluator.verify();
            log.info("加载模型成功!");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 私有化构造函数,防止外部创建实例
     */
    private ClassificationModelOld () {
    }

    /**
     * 获取模型需要的特征名称
     *
     * @return
     */
    public static List<String> getFeatureNames () {
        List<String> featureNames = new ArrayList<>();
        List<InputField> inputFields = modelEvaluator.getInputFields();
        for (InputField inputField : inputFields) {
            featureNames.add(inputField.getName().toString());
        }
        return featureNames;
    }

    /**
     * 获取目标字段名称
     *
     * @return
     */
    public static String getTargetName () {
        return modelEvaluator.getTargetFields().get(0).getName().toString();
    }

    /**
     * 使用模型生成概率分布
     *
     * @param arguments
     * @return
     */
    private static ProbabilityDistribution getProbabilityDistribution (Map<FieldName, ?> arguments) {
        Map<FieldName, ?> evaluateResult = modelEvaluator.evaluate(arguments);
        FieldName fieldName = FieldName.create(getTargetName());
        return (ProbabilityDistribution) evaluateResult.get(fieldName);

    }

    /**
     * 预测不同分类的概率
     *
     * @param arguments
     * @return
     */
    public static ValueMap<String, Number> predictProba (Map<FieldName, Number> arguments) {
        ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
        return probabilityDistribution.getValues();
    }

    /**
     * 预测结果分类
     *
     * @param arguments
     * @return
     */
    public static Object predict (Map<FieldName, ?> arguments) {
        ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
        return probabilityDistribution.getPrediction();
    }

    private static Integer setScore (float probability) {
        int score = 0;
        try {
            // TODO 根据比例写算法计算出分值
            score = 520;
        } catch (Exception e) {
        }
        return score;
    }

    public static void main (String[] args) {

        // 参数进过转义后:{{"value":"x1"}:-0.216918810277242,{"value":"x2"}:0.0583184157700168,{"value":"x3"}:-0.653728631926331}
        final ArrayList<Double> doubles = Lists.newArrayList(-0.216918810277242, 0.0583184157700168, -0.653728631926331);

        Map<FieldName, Number> waitPreSample = new HashMap<>(8);
        waitPreSample.put(FieldName.create("x1"), doubles.get(0));
        waitPreSample.put(FieldName.create("x2"), doubles.get(1));
        waitPreSample.put(FieldName.create("x3"), doubles.get(2));
        final ValueMap<String, Number> values = ClassificationModelOld.predictProba(waitPreSample);
        System.out.println("机器算法计算分值结果:" + setScore(values.get("1").floatValue()));
    }
}

---------------------
执行结果:
加载模型成功!
机器算法计算分值结果:520

版本问题

  • 上面示例是使用的老版本的包,并且打的pmml文件也是4.3版本的
  • 所以如果使用的是4.4版本的pmml文件

978DAE2D-238F-439a-A0DD-E987A608F417.png

  • 那么需要更新maven引入的包
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.5.11</version>
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator-extension</artifactId>
    <version>1.5.11</version>
</dependency>
  • 在加载模型时需要更新加载方式
static {
    PMML pmml;
    try {
        Resource resource = new ClassPathResource("test.pmml");
        InputStream is = resource.getInputStream();
        pmml = PMMLUtil.unmarshal(is);
        try {
            is.close();
        } catch (IOException e) {
            log.info("InputStream close error!");
        }
        ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        modelEvaluatorBuilder.setModelEvaluatorFactory(modelEvaluatorFactory);
        modelEvaluator = modelEvaluatorBuilder.build();
        modelEvaluator.verify();
        log.info("加载模型成功!");
    } catch (Exception e) {
        e.printStackTrace();
    }
}
  • 这样4.4版本的pmml训练模型文件也是可以执行获取结果

最后

  • 虚心学习,共同进步 -_-
  • 1
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 使用 Java 截取图片中的身份证可以使用 OpenCV 这个图像处理库。首先, 你需要安装 OpenCV 并导入相关的。然后, 你可以使用 OpenCV 的 API 来读取图片, 并使用图像处理技巧来截取身份证区域。 具体地, 你可以使用阈值二值化, 边缘检测等方法来确定身份证的位置。接着你可以使用机器学习算法来识别身份证上的文本信息,比如身份证号码和姓名。 下面是一个简单的例子来读取一张图片并显示出来: ``` import org.opencv.core.Core; import org.opencv.core.Mat; import org.opencv.imgcodecs.Imgcodecs; import org.opencv.highgui.HighGui; public class Main { public static void main(String[] args) { System.loadLibrary(Core.NATIVE_LIBRARY_NAME); Mat image = Imgcodecs.imread("image.jpg"); HighGui.imshow("image", image); HighGui.waitKey(); } } ``` 上面的代码是使用Java读取图片,并使用OpenCV的API在窗口中显示出来。 对于机器学习部分, 建议使用深度学习框架如 TensorFlow 或者 PyTorch, 并使用相应的 OCR 模型来识别文本。 ### 回答2: 要使用机器学习训练代码来截取图片中的身份证,以下是一个可能的方案。 首先,我们需要一个数据集,其中含带有身份证的图片和对应的标签(即身份证的位置和边界框)。可以通过手动标注已知身份证图片的位置来创建此数据集。 之后,我们将使用机器学习的目标检测算法来训练模型。可以选择使用深度学习模型,如基于卷积神经网络(CNN)的目标检测模型,例如Faster R-CNN、YOLO或SSD。这些模型在图像识别和目标检测任务中表现得非常出色。 在训练模型之前,我们需要将数据集分成两个部分:一个用于训练,一个用于验证。训练集用于训练模型的参数,验证集用于评估模型的性能,以便在训练过程中对模型进行调整和改进。 然后,我们可以使用一个深度学习框架(如TensorFlow、Keras或PyTorch)来实现目标检测模型。框架提供了许多现成的算法和工具,可以大大简化模型的搭建和训练过程。我们需要编写代码来定义模型的结构、损失函数和优化算法。 在模型训练完成后,我们可以使用该模型来预测新的图片中的身份证位置。将待检测的图片输入模型,模型将输出一个或多个边界框,表示可能的身份证位置。可以根据模型的输出进行后续的处理,例如对边界框进行非极大值抑制(NMS)来得到最终的身份证位置。 当然,为了获得更好的性能,我们可能需要进行许多迭代和调整,括优化模型的结构、调整超参数、增加更多的训练数据等。 总之,使用机器学习训练代码来截取图片中的身份证,需要准备数据集、选择合适的模型和算法、训练模型、调整参数,并在最终的预测中进行后处理。通过不断的实验和优化,我们可以建立一个高效准确的身份证截取系统。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值