Java调用机器学习训练包记录一下

前言

  • 最近公司有个需求,需要对用户进行数据画像分析。
  • 公司大数据组通过对线上用户数据进行分析后,通过机器学习用python做了一个训练模型pkl文件包。
  • 要求我部门对用户数据进行分析计算。而我部门的项目都是使用Java进行开发的,所以就需要Java调用pkl训练模型包。
  • 经过调研python的pkl训练模型包不能直接被Java调用,跨平台调用需要使用pmml格式文件,所以就让大数据部门依照已经生成的训练模型pkl文件,在次封装成一个pmml文件。

pmml格式

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
   <Header>
      <Application name="JPMML-SkLearn" version="1.6.27"/>
      <Timestamp>2021-08-30T06:48:45Z</Timestamp>
   </Header>
   <DataDictionary>
      <DataField name="y" optype="categorical" dataType="integer">
         <Value value="0"/>
         <Value value="1"/>
      </DataField>
      <DataField name="x1" optype="continuous" dataType="double"/>
      <DataField name="x2" optype="continuous" dataType="double"/>
      <DataField name="x3" optype="continuous" dataType="double"/>
   </DataDictionary>
   <RegressionModel functionName="classification" algorithmName="sklearn.linear_model._logistic.LogisticRegression" normalizationMethod="logit">
      <MiningSchema>
         <MiningField name="y" usageType="target"/>
         <MiningField name="x1"/>
         <MiningField name="x2"/>
         <MiningField name="x3"/>
      </MiningSchema>
      <RegressionTable intercept="0.5920457931585216" targetCategory="1">
         <NumericPredictor name="x1" coefficient="0.7586778342148665"/>
         <NumericPredictor name="x2" coefficient="0.6562980822443883"/>
         <NumericPredictor name="x3" coefficient="0.9917332587791079"/>
      </RegressionTable>
      <RegressionTable intercept="0.0" targetCategory="0"/>
   </RegressionModel>
</PMML>
复制代码

Java调用pmml文件

  • 首先在项目中先引用解析pmml的maven包
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.4.1</version>
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator-extension</artifactId>
    <version>1.4.1</version>
</dependency>
复制代码
  • Java调用方法
  • 当有test.pmml文件后,可以把文件放在springboot项目的resources目录下,使用ClassPathResource类获取到文件流
/**
 * @Author: ZRH
 * @Date: 2021/8/30 9:17
 */
@Slf4j
public final class ClassificationModelOld {

    private static Evaluator modelEvaluator;

    static {
        PMML pmml;
        try {
            Resource resource = new ClassPathResource("test.pmml");
            InputStream is = resource.getInputStream();
            pmml = PMMLUtil.unmarshal(is);
            try {
                is.close();
            } catch (IOException e) {
                log.info("InputStream close error!");
            }

            ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
            modelEvaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
            modelEvaluator.verify();
            log.info("加载模型成功!");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 私有化构造函数,防止外部创建实例
     */
    private ClassificationModelOld () {
    }

    /**
     * 获取模型需要的特征名称
     *
     * @return
     */
    public static List<String> getFeatureNames () {
        List<String> featureNames = new ArrayList<>();
        List<InputField> inputFields = modelEvaluator.getInputFields();
        for (InputField inputField : inputFields) {
            featureNames.add(inputField.getName().toString());
        }
        return featureNames;
    }

    /**
     * 获取目标字段名称
     *
     * @return
     */
    public static String getTargetName () {
        return modelEvaluator.getTargetFields().get(0).getName().toString();
    }

    /**
     * 使用模型生成概率分布
     *
     * @param arguments
     * @return
     */
    private static ProbabilityDistribution getProbabilityDistribution (Map<FieldName, ?> arguments) {
        Map<FieldName, ?> evaluateResult = modelEvaluator.evaluate(arguments);
        FieldName fieldName = FieldName.create(getTargetName());
        return (ProbabilityDistribution) evaluateResult.get(fieldName);

    }

    /**
     * 预测不同分类的概率
     *
     * @param arguments
     * @return
     */
    public static ValueMap<String, Number> predictProba (Map<FieldName, Number> arguments) {
        ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
        return probabilityDistribution.getValues();
    }

    /**
     * 预测结果分类
     *
     * @param arguments
     * @return
     */
    public static Object predict (Map<FieldName, ?> arguments) {
        ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
        return probabilityDistribution.getPrediction();
    }

    private static Integer setScore (float probability) {
        int score = 0;
        try {
            // TODO 根据比例写算法计算出分值
            score = 520;
        } catch (Exception e) {
        }
        return score;
    }

    public static void main (String[] args) {

        // 参数进过转义后:{{"value":"x1"}:-0.216918810277242,{"value":"x2"}:0.0583184157700168,{"value":"x3"}:-0.653728631926331}
        final ArrayList<Double> doubles = Lists.newArrayList(-0.216918810277242, 0.0583184157700168, -0.653728631926331);

        Map<FieldName, Number> waitPreSample = new HashMap<>(8);
        waitPreSample.put(FieldName.create("x1"), doubles.get(0));
        waitPreSample.put(FieldName.create("x2"), doubles.get(1));
        waitPreSample.put(FieldName.create("x3"), doubles.get(2));
        final ValueMap<String, Number> values = ClassificationModelOld.predictProba(waitPreSample);
        System.out.println("机器算法计算分值结果:" + setScore(values.get("1").floatValue()));
    }
}

---------------------
执行结果:
加载模型成功!
机器算法计算分值结果:520
复制代码

版本问题

  • 上面示例是使用的老版本的包,并且打的pmml文件也是4.3版本的
  • 所以如果使用的是4.4版本的pmml文件

978DAE2D-238F-439a-A0DD-E987A608F417.png

  • 那么需要更新maven引入的包
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator</artifactId>
    <version>1.5.11</version>
</dependency>
<dependency>
    <groupId>org.jpmml</groupId>
    <artifactId>pmml-evaluator-extension</artifactId>
    <version>1.5.11</version>
</dependency>
复制代码
  • 在加载模型时需要更新加载方式
static {
    PMML pmml;
    try {
        Resource resource = new ClassPathResource("test.pmml");
        InputStream is = resource.getInputStream();
        pmml = PMMLUtil.unmarshal(is);
        try {
            is.close();
        } catch (IOException e) {
            log.info("InputStream close error!");
        }
        ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
        modelEvaluatorBuilder.setModelEvaluatorFactory(modelEvaluatorFactory);
        modelEvaluator = modelEvaluatorBuilder.build();
        modelEvaluator.verify();
        log.info("加载模型成功!");
    } catch (Exception e) {
        e.printStackTrace();
    }
}
复制代码
  • 这样4.4版本的pmml训练模型文件也是可以执行获取结果

最后

  • 虚心学习,共同进步 -_-