前言
- 最近公司有个需求,需要对用户进行数据画像分析。
- 公司大数据组通过对线上用户数据进行分析后,通过机器学习用python做了一个训练模型pkl文件包。
- 要求我部门对用户数据进行分析计算。而我部门的项目都是使用Java进行开发的,所以就需要Java调用pkl训练模型包。
- 经过调研python的pkl训练模型包不能直接被Java调用,跨平台调用需要使用pmml格式文件,所以就让大数据部门依照已经生成的训练模型pkl文件,在次封装成一个pmml文件。
pmml格式
Java调用pmml文件
org.jpmml
pmml-evaluator
1.4.1
org.jpmml
pmml-evaluator-extension
1.4.1
- Java调用方法
- 当有test.pmml文件后,可以把文件放在springboot项目的resources目录下,使用ClassPathResource类获取到文件流
@Slf4j
public final class ClassificationModelOld {
private static evaluator modelevaluator;
static {
PMML pmml;
try {
Resource resource = new ClassPathResource("test.pmml");
InputStream is = resource.getInputStream();
pmml = PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException e) {
log.info("InputStream close error!");
}
ModelevaluatorFactory modelevaluatorFactory = ModelevaluatorFactory.newInstance();
modelevaluator = modelevaluatorFactory.newModelevaluator(pmml);
modelevaluator.verify();
log.info("加载模型成功!");
} catch (Exception e) {
e.printStackTrace();
}
}
private ClassificationModelOld () {
}
public static List getFeatureNames () {
List featureNames = new ArrayList<>();
List inputFields = modelevaluator.getInputFields();
for (InputField inputField : inputFields) {
featureNames.add(inputField.getName().toString());
}
return featureNames;
}
public static String getTargetName () {
return modelevaluator.getTargetFields().get(0).getName().toString();
}
private static ProbabilityDistribution getProbabilityDistribution (Map arguments) {
Map evaluateResult = modelevaluator.evaluate(arguments);
FieldName fieldName = FieldName.create(getTargetName());
return (ProbabilityDistribution) evaluateResult.get(fieldName);
}
public static ValueMap predictProba (Map arguments) {
ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
return probabilityDistribution.getValues();
}
public static Object predict (Map arguments) {
ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
return probabilityDistribution.getPrediction();
}
private static Integer setScore (float probability) {
int score = 0;
try {
// TODO 根据比例写算法计算出分值
score = 520;
} catch (Exception e) {
}
return score;
}
public static void main (String[] args) {
// 参数进过转义后:{{"value":"x1"}:-0.216918810277242,{"value":"x2"}:0.0583184157700168,{"value":"x3"}:-0.653728631926331}
final ArrayList doubles = Lists.newArrayList(-0.216918810277242, 0.0583184157700168, -0.653728631926331);
Map waitPreSample = new HashMap<>(8);
waitPreSample.put(FieldName.create("x1"), doubles.get(0));
waitPreSample.put(FieldName.create("x2"), doubles.get(1));
waitPreSample.put(FieldName.create("x3"), doubles.get(2));
final ValueMap values = ClassificationModelOld.predictProba(waitPreSample);
System.out.println("机器算法计算分值结果:" + setScore(values.get("1").floatValue()));
}
}
---------------------
执行结果:
加载模型成功!
机器算法计算分值结果:520
版本问题
- 上面示例是使用的老版本的包,并且打的pmml文件也是4.3版本的
- 所以如果使用的是4.4版本的pmml文件
org.jpmml
pmml-evaluator
1.5.11
org.jpmml
pmml-evaluator-extension
1.5.11
static {
PMML pmml;
try {
Resource resource = new ClassPathResource("test.pmml");
InputStream is = resource.getInputStream();
pmml = PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException e) {
log.info("InputStream close error!");
}
ModelevaluatorBuilder modelevaluatorBuilder = new ModelevaluatorBuilder(pmml);
ModelevaluatorFactory modelevaluatorFactory = ModelevaluatorFactory.newInstance();
modelevaluatorBuilder.setModelevaluatorFactory(modelevaluatorFactory);
modelevaluator = modelevaluatorBuilder.build();
modelevaluator.verify();
log.info("加载模型成功!");
} catch (Exception e) {
e.printStackTrace();
}
}
- 这样4.4版本的pmml训练模型文件也是可以执行获取结果
最后