栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

springboot+jpmml 部署深度学习模型notes

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

springboot+jpmml 部署深度学习模型notes

pmml预测模型教程

1 sklearn2pmml 保存python模型成pmml文件
"""
文件说明:鸢尾花数据集
"""
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
from sklearn2pmml import sklearn2pmml, PMMLPipeline
from sklearn2pmml.decoration import ContinuousDomain
from sklearn.feature_selection import SelectKBest

# frameworks for ML
from sklearn_pandas import DataframeMapper
from sklearn.pipeline import make_pipeline

# transformers for category variables
from sklearn.preprocessing import LabelBinarizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import oneHotEncoder
# from sklearn.preprocessing import Imputer
from sklearn.impute import SimpleImputer

# transformers for numerical variables
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import Normalizer

# transformers for combined variables
from sklearn.decomposition import PCA
from sklearn.preprocessing import PolynomialFeatures

# user-defined transformers
from sklearn.preprocessing import FunctionTransformer

def read_data():
    # 鸢尾花
    data=load_iris()
    x=data.data
    y=data.target
    df_x = pd.Dataframe(x)
    df_x.columns = ["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"]
    return df_x,y


def all_classifiers_test(savemodel='GBDT.pmml'):
    '''
    GBDT model
    '''
    GBDT = GradientBoostingClassifier()
    df_x,y = read_data()
    # Feature Engineering
    mapper = DataframeMapper([
              (["Sepal.Length"],FunctionTransformer(np.abs)),
              (["Sepal.Width"],[MinMaxScaler(), SimpleImputer()]),
              (["Petal.Length"],None),
              (["Petal.Width"],OneHotEncoder()),
            (['Petal.Length', 'Petal.Width'], [MinMaxScaler(),StandardScaler()])
    ])

    iris_pipeline = PMMLPipeline([
        ("mapper", mapper),
        ("pca", PCA(n_components=3)),
        ("selector", SelectKBest(k=2)), #返回k个最佳特征
        ("classifier", GBDT)])
    iris_pipeline.fit(df_x, y)
    # iris_pipeline.fit(X_train.values, y_train)
    # 导出模型文件
    sklearn2pmml(iris_pipeline, savemodel, with_repr=True)

all_classifiers_test('GBDT.pmml')

踩坑:

解决:
放弃Factory,使用Builder

    private Modelevaluator ca_load(String pmmlFileName) throws Exception {
        File pmmlfile = new File(pmmlFileName);
        Modelevaluator r = new LoadingModelevaluatorBuilder()
                .load(pmmlfile)
                .build();
        return r;
    }
2 jpmml读取模型运行
package com.example.spbexm;

import java.io.File;
import java.util.*;

import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.*;


public class Predictor {
    public evaluator modelevaluator;

    public evaluator ca_load(String pmmlFileName) throws Exception {
        File pmmlfile = new File(pmmlFileName);
        evaluator r = new LoadingModelevaluatorBuilder()
                .load(pmmlfile)
                .build();
        return r;
    }

    
    private Map getFieldMap(evaluator evaluator, Map input){
        List inputFields = evaluator.getInputFields();
        Map map = new linkedHashMap();
        for(InputField field: inputFields){
            FieldName fieldName = field.getName();
            Object rawValue = input.get(fieldName.getValue());
            FieldValue value = field.prepare(rawValue);
            map.put(fieldName, value);
        }
        return map;
    }


    private Map getRawMap(Object a, Object b, Object c, Object d) {
        Map data = new HashMap();
        data.put("Sepal.Length", a);
        data.put("Sepal.Width", b);
        data.put("Petal.Length", c);
        data.put("Petal.Width", d);
        return data;
    }

    private Map predict(evaluator evaluator, Map data){
        Map input = getFieldMap(evaluator, data);
        Map output = evaluate(evaluator, input);
        return output;

    }

    private Map evaluate(evaluator evaluator, Map input){
        Map results = evaluator.evaluate(input);
        List targetFields = evaluator.getTargetFields();
        Map output = new linkedHashMap();
        for(int i=0;i> inputs = new ArrayList<>();
        inputs.add(model.getRawMap(5.1, 3.5, 1.4, 0.2));
        inputs.add(model.getRawMap(4.9, 3, 1.4, 0.2));
        inputs.add(model.getRawMap(5.8,3.1,4.8,1.8));
        for (int i = 0; i < inputs.size(); i++) {
            Map output = model.predict(model.modelevaluator, inputs.get(i));
            System.out.println("X=" + inputs.get(i) + " -> y=" + output.get("y"));
        }
    }

踩坑参考了以下博文:

  • 「slibra_L」的原创文章链接:https://blog.csdn.net/slibra_L/article/details/90401020
  • 「XINFINFZ」的原创文章链接:https://blog.csdn.net/weixin_43945848/article/details/119675711
3 Springboot

SpringWebMain.java

package com.example.spbexm;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;

@SpringBootApplication
public class SpringWebMain {

    public static void main(String[] args) {
        SpringApplication.run(SpringWebMain.class, args);
    }
}

SpringWebController.java

package com.example.spbexm;

import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.ControllerAdvice;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import com.example.spbexm.Predictor;

@Controller
public class SpringWebController {
    @ResponseBody
    @GetMapping("/speak")
    public String speak(){
        return "controller";
    }

    @ResponseBody
    @GetMapping("/hello")
    public String sayHello(@RequestParam(value = "myName", defaultValue = "World") String name) {
        return String.format("Hello %s!", name);
    }

    @ResponseBody
    @GetMapping("/predict")
    public String predictor() throws Exception {
        Predictor model = new Predictor();
        String s = "src/main/resources/GBDT.pmml";

        String res = model.getPrediction(model);
        System.out.println(res);
        return res;
    }
}


    4.0.0
    
        org.springframework.boot
        spring-boot-starter-parent
        2.5.5
         
    
    com.example
    spb-exm
    0.0.1-SNAPSHOT
    spb-exm
    Demo project for Spring Boot
    
        1.8
    
    
        
            ai.djl
            api
            0.4.1
        
        
            ai.djl.spring
            djl-spring-boot-starter-pytorch-auto
            0.2 
        










        
            ai.djl.tensorflow
            tensorflow-api
            0.9.0
            compile
        









        
            org.springframework.boot
            spring-boot-starter-web
        

        
            org.springframework.boot
            spring-boot-starter-test
            test
        




        
            org.jpmml
            pmml-evaluator-moxy
            1.5.16
        
        
            io.dropwizard.metrics
            metrics-core
            4.0.5
        
        
            javax.xml.bind
            jaxb-api
            2.3.0
        
        
            org.jpmml
            pmml-evaluator-metro
            1.5.16
        
        
            org.jpmml
            pmml-model
            1.5.16
        
        
            org.jpmml
            pmml-evaluator
            1.5.16
        

    

    
        
            
                org.springframework.boot
                spring-boot-maven-plugin
            

            
                org.apache.maven.plugins
                maven-deploy-plugin
                
                    true
                
            
        
    



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/322436.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号