栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 前沿技术 > 大数据 > 大数据系统

scala spark读取pmml模型预测

scala spark读取pmml模型预测

https://github.com/jpmml/jpmml-evaluator-spark

package scala
import main.scala.IrisHelper
import org.apache.hadoop.shaded.org.eclipse.jetty.websocket.common.frames.Dataframe
import org.apache.spark.mllib.linalg.Vector
import org.jpmml.evaluator.spark.TransformerBuilder;
import java.util.stream.Collectors.toList

// https://blog.csdn.net/weixin_31897613/article/details/112224295

import scala.collection.JavaConversions
import java.util.Arrays
//import org.apache.spark
import org.apache.spark.SparkConf
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.sql.SparkSession
import org.jpmml.evaluator.LoadingModelevaluatorBuilder



object IrisHelper {
  case class InputRecord(
                          `Sepal.Length`:Double,
                          `Sepal.Width`:Double,
                          `Petal.Length`:Double,
                          `Petal.Width`:Double
                        )
}

object SpkJpmml {

  import IrisHelper._
  def main(args: Array[String]): Unit = {
    implicit val sparkSession = SparkSession
      .builder()
      .config(
        new SparkConf()
          .setAppName("GBDT4Iris")
          .setMaster("local")
      ).getOrCreate()

    // prepare the input data
    val inputRdd = sparkSession.sparkContext.makeRDD(Seq(
      InputRecord(5.1, 3.5, 1.4, 0.2),
      InputRecord(5.8, 3.1, 4.8, 1.8),
      InputRecord(4.9, 3, 1.4, 0.2)
    ))
    val inputData = sparkSession.createDataframe(inputRdd)

    // load the pmml
    val pmml = getClass.getClassLoader.getResourceAsStream("GBDT.pmml")

    //create the evaluator
    val evaluator = new LoadingModelevaluatorBuilder()
      .load(pmml)
      .build()

    val targetField  =evaluator.getTargetFields.toString
    println(targetField)

    val outputField  =evaluator.getOutputFields.toString
    println(outputField)
    //create the transformer //
    var pmmlTransformer = new TransformerBuilder(evaluator)
      .withTargetCols()
      .withOutputCols()
      .exploded(false) // This is it!!!
      .build()
      
    sparkSession.sql("set spark.sql.legacy.allowUntypedScalaUDF=true")
    var resultDs = pmmlTransformer.transform(inputData)//inputData
    resultDs.show
    
    resultDs = resultDs.select("pmml")
    resultDs.show
  }

}


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/336108.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号