最近公司有个需求 需要对用户进行数据画像分析。公司大数据组通过对线上用户数据进行分析后 通过机器学习用python做了一个训练模型pkl文件包。要求我部门对用户数据进行分析计算。而我部门的项目都是使用Java进行开发的 所以就需要Java调用pkl训练模型包。经过调研python的pkl训练模型包不能直接被Java调用 跨平台调用需要使用pmml格式文件 所以就让大数据部门依照已经生成的训练模型pkl文件 在次封装成一个pmml文件。
pmml格式
?xml version 1.0 encoding UTF-8 standalone yes ?
PMML xmlns http://www.dmg.org/PMML-4_3 xmlns:data http://jpmml.org/jpmml-model/InlineTable version 4.3
Header
Application name JPMML-SkLearn version 1.6.27 /
Timestamp 2021-08-30T06:48:45Z /Timestamp
/Header
DataDictionary
DataField name y optype categorical dataType integer
Value value 0 /
Value value 1 /
/DataField
DataField name x1 optype continuous dataType double /
DataField name x2 optype continuous dataType double /
DataField name x3 optype continuous dataType double /
/DataDictionary
RegressionModel functionName classification algorithmName sklearn.linear_model._logistic.LogisticRegression normalizationMethod logit
MiningSchema
MiningField name y usageType target /
MiningField name x1 /
MiningField name x2 /
MiningField name x3 /
/MiningSchema
RegressionTable intercept 0.5920457931585216 targetCategory 1
NumericPredictor name x1 coefficient 0.7586778342148665 /
NumericPredictor name x2 coefficient 0.6562980822443883 /
NumericPredictor name x3 coefficient 0.9917332587791079 /
/RegressionTable
RegressionTable intercept 0.0 targetCategory 0 /
/RegressionModel
/PMML
Java调用pmml文件
首先在项目中先引用解析pmml的maven包
dependency
groupId org.jpmml /groupId
artifactId pmml-evaluator /artifactId
version 1.4.1 /version
/dependency
dependency
groupId org.jpmml /groupId
artifactId pmml-evaluator-extension /artifactId
version 1.4.1 /version
/dependency
Java调用方法当有test.pmml文件后 可以把文件放在springboot项目的resources目录下 使用ClassPathResource类获取到文件流
/**
* Author: ZRH
* Date: 2021/8/30 9:17
Slf4j
public final class ClassificationModelOld {
private static evaluator modelevaluator;
static {
PMML pmml;
try {
Resource resource new ClassPathResource( test.pmml
InputStream is resource.getInputStream();
pmml PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException e) {
log.info( InputStream close error!
ModelevaluatorFactory modelevaluatorFactory ModelevaluatorFactory.newInstance();
modelevaluator modelevaluatorFactory.newModelevaluator(pmml);
modelevaluator.verify();
log.info( 加载模型成功!
} catch (Exception e) {
e.printStackTrace();
* 私有化构造函数 防止外部创建实例
private ClassificationModelOld () {
* 获取模型需要的特征名称
* return
public static List String getFeatureNames () {
List String featureNames new ArrayList ();
List InputField inputFields modelevaluator.getInputFields();
for (InputField inputField : inputFields) {
featureNames.add(inputField.getName().toString());
return featureNames;
* 获取目标字段名称
* return
public static String getTargetName () {
return modelevaluator.getTargetFields().get(0).getName().toString();
* 使用模型生成概率分布
* param arguments
* return
private static ProbabilityDistribution getProbabilityDistribution (Map FieldName, ? arguments) {
Map FieldName, ? evaluateResult modelevaluator.evaluate(arguments);
FieldName fieldName FieldName.create(getTargetName());
return (ProbabilityDistribution) evaluateResult.get(fieldName);
* 预测不同分类的概率
* param arguments
* return
public static ValueMap String, Number predictProba (Map FieldName, Number arguments) {
ProbabilityDistribution probabilityDistribution getProbabilityDistribution(arguments);
return probabilityDistribution.getValues();
* 预测结果分类
* param arguments
* return
public static Object predict (Map FieldName, ? arguments) {
ProbabilityDistribution probabilityDistribution getProbabilityDistribution(arguments);
return probabilityDistribution.getPrediction();
private static Integer setScore (float probability) {
int score
try {
// TODO 根据比例写算法计算出分值
score 520;
} catch (Exception e) {
return score;
public static void main (String[] args) {
// 参数进过转义后 {{ value : x1 }:-0.216918810277242,{ value : x2 }:0.0583184157700168,{ value : x3 }:-0.653728631926331}
final ArrayList Double doubles Lists.newArrayList(-0.216918810277242, 0.0583184157700168, -0.653728631926331);
Map FieldName, Number waitPreSample new HashMap (8);
waitPreSample.put(FieldName.create( x1 ), doubles.get(0));
waitPreSample.put(FieldName.create( x2 ), doubles.get(1));
waitPreSample.put(FieldName.create( x3 ), doubles.get(2));
final ValueMap String, Number values ClassificationModelOld.predictProba(waitPreSample);
System.out.println( 机器算法计算分值结果 setScore(values.get( 1 ).floatValue()));
---------------------
执行结果
加载模型成功
机器算法计算分值结果 520
版本问题
上面示例是使用的老版本的包 并且打的pmml文件也是4.3版本的所以如果使用的是4.4版本的pmml文件
那么需要更新maven引入的包
dependency
groupId org.jpmml /groupId
artifactId pmml-evaluator /artifactId
version 1.5.11 /version
/dependency
dependency
groupId org.jpmml /groupId
artifactId pmml-evaluator-extension /artifactId
version 1.5.11 /version
/dependency
在加载模型时需要更新加载方式
static {
PMML pmml;
try {
Resource resource new ClassPathResource( test.pmml
InputStream is resource.getInputStream();
pmml PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException e) {
log.info( InputStream close error!
ModelevaluatorBuilder modelevaluatorBuilder new ModelevaluatorBuilder(pmml);
ModelevaluatorFactory modelevaluatorFactory ModelevaluatorFactory.newInstance();
modelevaluatorBuilder.setModelevaluatorFactory(modelevaluatorFactory);
modelevaluator modelevaluatorBuilder.build();
modelevaluator.verify();
log.info( 加载模型成功!
} catch (Exception e) {
e.printStackTrace();
这样4.4版本的pmml训练模型文件也是可以执行获取结果
最后
虚心学习 共同进步 -_-