转载

使用PMML在Java中调用Python模型

公司大部分应用的使用的是JAVA开发,要想使用Python模型非常困难,网上搜索了下,可以先将生成的模型转换为PMML文件后即可在JAVA中直接调用。

以LightGBM为例:

1、将生成的模型导出为txt格式

import pandas as pd
from lightgbm import LGBMClassifier
iris_df = pd.read_csv("xml/iris.csv")
d_x = iris_df.iloc[:, 0:4].values
d_y = iris_df.iloc[:, 4].values
model = LGBMClassifier(
    boosting_type='gbdt', objective="multiclass", nthread=8, seed=42)
model.n_classes =3
model.fit(d_x,d_y,feature_name=iris_df.columns.tolist()[0:-1])
model.booster_.save_model("xml/lightgbm.txt")

2、 使用工具将txt模型转化为pmml格式

java -jar converter-executable-1.2-SNAPSHOT.jar  --lgbm-input lightgbm.txt --pmml-output lightgbm.pmml

3、 在JAVA代码中直接调用

备注,调用前需要引入如下架包: https://github.com/jpmml/jpmml-evaluator ,示例代码:

package com.pmmldemo.test;
 
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
 
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;
 
public class PMMLPrediction {
	
	public static void main(String[] args) throws Exception {
		String  pathxml="lightgbm.pmml";
		Map<String, Double>  map=new HashMap<String, Double>();
        //拼装模型参数
		map.put("sepal_length", 5.1);
		map.put("sepal_width", 3.5);
		map.put("petal_length", 1.4);
		map.put("petal_width", 0.2);	
		predictLrHeart(map, pathxml);
	}
	
	public static void predictLrHeart(Map<String, Double> irismap,String  pathxml)throws Exception {
 
		PMML pmml;
		// 模型导入
		File file = new File(pathxml);
		InputStream inputStream = new FileInputStream(file);
		try (InputStream is = inputStream) {
			pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
 
			ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory
					.newInstance();
			ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory
					.newModelEvaluator(pmml);
			Evaluator evaluator = (Evaluator) modelEvaluator;
 
			List<InputField> inputFields = evaluator.getInputFields();
			// 过模型的原始特征,从画像中获取数据,作为模型输入
			Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
			for (InputField inputField : inputFields) {
				FieldName inputFieldName = inputField.getName();
				Object rawValue = irismap
						.get(inputFieldName.getValue());
				FieldValue inputFieldValue = inputField.prepare(rawValue);
				arguments.put(inputFieldName, inputFieldValue);
			}
 
			Map<FieldName, ?> results = evaluator.evaluate(arguments);
			List<TargetField> targetFields = evaluator.getTargetFields();
			//对于分类问题等有多个输出。
			for (TargetField targetField : targetFields) {
				FieldName targetFieldName = targetField.getName();
				Object targetFieldValue = results.get(targetFieldName);
				System.err.println("target: " + targetFieldName.getValue()
						+ " value: " + targetFieldValue);
			}
		}
	}
}

常见模型转化方方法:

  • Scikit-Learn模型转PMML: https://github.com/jpmml/sklearn2pmml
  • LightGBM模型转PMML: https://github.com/jpmml/jpmml-lightgbm
  • XGBoost模型转PMML: https://github.com/jpmml/jpmml-xgboost
  • Keras模型转PMML: https://github.com/vaclavcadek/keras2pmml
  • TensorFlow模型转PMML: https://github.com/jpmml/jpmml-tensorflow
原文  https://www.biaodianfu.com/pmml.html
正文到此结束
Loading...