栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 前沿技术 > 大数据 > 大数据系统

2021-10-27

2021-10-27

Spark 自定义函数UDF UDAF

步骤:自定义函数,再注册
用户自定义函数在sparksql中可以分为两类,
– udf :用户自定义函数, 通常指的是一对一形式,进入一条记录,出来一条记录
– udaf 用户自定义聚合函数, 通常指的是多对一形式,进入多条记录,出来一条记录,比如模拟max
案例演示
UDF

package com.qf.sql.day03

import org.apache.spark.sql.{Dataframe, SparkSession}

object _05TestUDF1 {
    def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()
        import spark.implicits._

        val df: Dataframe = spark.read.json("sql/emp.json")
        df.createTempView("emp")

        //spark.sql("select empno,ename,job,sal,deptno from emp ").show
        //使用内置函数,查询姓名长度大于4的员工信息
        //spark.sql("select empno,ename,job,sal,deptno from emp where length(ename)>4").show

        //定义一个方法
        def func1(word:String)={
            word.length
        }
        //注册函数,   func1 _  将方法转成函数
        //spark.udf.register("mylength", func1 _)
        //匿名函数的写法
        spark.udf.register("mylength",{word:String=>word.length})

        spark.sql("select empno,ename,mylength(ename) as lg,job,sal,deptno from emp where mylength(ename)>4").show

        spark.stop()
    }
}

package com.qf.sql.day03

import org.apache.spark.sql.{Dataframe, SparkSession}

object _05TestUDF2 {
    def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()

        val df: Dataframe = spark.read.json("sql/emp.json")
        df.createTempView("emp")

        // 显示每个员工的工资等级,sal>3000 显示level3      sal>1500 显示level2   其他的显示level1
        val sql =
            """
              |select empno,ename,job,sal,
              |case when sal>3000 then 'level3'
              |when sal>1500 then 'level2'
              |else 'level1' end
              |as level
              |from emp
              |""".stripMargin
        spark.sql(sql).show

        //统计每个工资等级的人数, 编程如下
       

        //上述的case when 需要写多次,比较浪费时间,不如自定义一个用户函数
        def func2(salary:Double)={
            if(salary>3000)
                "level3"
            else if(salary>1500)
                "level2"
            else
                "level1"

        }
        //注册
        spark.udf.register("myudf",func2 _)
        val sql2 =
            """
              | select count(1),myudf(sal)
              | from emp
              | group by myudf(sal)
              |""".stripMargin
        spark.sql(sql2).show
        spark.stop()
    }
}

自定义UDAF

package com.qf.sql.day03

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{Dataframe, Row, SparkSession}

object _06TestUDAF {
    def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()

        val df: Dataframe = spark.read.json("sql/emp.json")
        df.createTempView("emp")

        //查詢每個部門的平均工資
        

        
        //注册函数
        spark.udf.register("myavg",new MyUDAF)
        val sql1 =
            """
              |select deptno,myavg(sal)
              |from emp
              |group by deptno
              |""".stripMargin
        spark.sql(sql1).show()

        spark.stop()
    }

    
    class MyUDAF extends UserDefinedAggregateFunction{

        //用来描述进入函数的参数的类型
        override def inputSchema: StructType = StructType(
            Array(
                StructField("sal",DoubleType)
            )
        )
        //用来描述计算过程中涉及到的变量的类型
        override def bufferSchema: StructType = StructType(
            Array(
                StructField("sum",DoubleType),
                StructField("count",LongType)
            )
        )
        //用来描述计算结果的类型
        override def dataType: DataType = DoubleType

        //用来表示函数的稳定性
        override def deterministic: Boolean = true

        //用来对计算过程中涉及到的两个变量进行初始化
        override def initialize(buffer: MutableAggregationBuffer): Unit = {
            //buffer的第一个元素表示sum
            buffer(0) = 0D
            //buffer的第一个元素表示count
            buffer(1) = 0L
        }
        
        override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
            //update用于更新数据
            buffer.update(0,buffer.getDouble(0)+input.getDouble(0))
            //加+1
            buffer.update(1,buffer.getLong(1)+1)
        }

        
        override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
            buffer1.update(0,buffer1.getDouble(0)+buffer2.getDouble(0))
            buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
        }
        //用于计算结果
        override def evaluate(buffer: Row): Any = {
            buffer.getDouble(0)/buffer.getLong(1)
        }
    }

}
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/349932.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号