步骤:自定义函数,再注册
用户自定义函数在sparksql中可以分为两类,
– udf :用户自定义函数, 通常指的是一对一形式,进入一条记录,出来一条记录
– udaf 用户自定义聚合函数, 通常指的是多对一形式,进入多条记录,出来一条记录,比如模拟max
案例演示
UDF
package com.qf.sql.day03
import org.apache.spark.sql.{Dataframe, SparkSession}
object _05TestUDF1 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()
import spark.implicits._
val df: Dataframe = spark.read.json("sql/emp.json")
df.createTempView("emp")
//spark.sql("select empno,ename,job,sal,deptno from emp ").show
//使用内置函数,查询姓名长度大于4的员工信息
//spark.sql("select empno,ename,job,sal,deptno from emp where length(ename)>4").show
//定义一个方法
def func1(word:String)={
word.length
}
//注册函数, func1 _ 将方法转成函数
//spark.udf.register("mylength", func1 _)
//匿名函数的写法
spark.udf.register("mylength",{word:String=>word.length})
spark.sql("select empno,ename,mylength(ename) as lg,job,sal,deptno from emp where mylength(ename)>4").show
spark.stop()
}
}
package com.qf.sql.day03
import org.apache.spark.sql.{Dataframe, SparkSession}
object _05TestUDF2 {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()
val df: Dataframe = spark.read.json("sql/emp.json")
df.createTempView("emp")
// 显示每个员工的工资等级,sal>3000 显示level3 sal>1500 显示level2 其他的显示level1
val sql =
"""
|select empno,ename,job,sal,
|case when sal>3000 then 'level3'
|when sal>1500 then 'level2'
|else 'level1' end
|as level
|from emp
|""".stripMargin
spark.sql(sql).show
//统计每个工资等级的人数, 编程如下
//上述的case when 需要写多次,比较浪费时间,不如自定义一个用户函数
def func2(salary:Double)={
if(salary>3000)
"level3"
else if(salary>1500)
"level2"
else
"level1"
}
//注册
spark.udf.register("myudf",func2 _)
val sql2 =
"""
| select count(1),myudf(sal)
| from emp
| group by myudf(sal)
|""".stripMargin
spark.sql(sql2).show
spark.stop()
}
}
自定义UDAF
package com.qf.sql.day03
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{Dataframe, Row, SparkSession}
object _06TestUDAF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()
val df: Dataframe = spark.read.json("sql/emp.json")
df.createTempView("emp")
//查詢每個部門的平均工資
//注册函数
spark.udf.register("myavg",new MyUDAF)
val sql1 =
"""
|select deptno,myavg(sal)
|from emp
|group by deptno
|""".stripMargin
spark.sql(sql1).show()
spark.stop()
}
class MyUDAF extends UserDefinedAggregateFunction{
//用来描述进入函数的参数的类型
override def inputSchema: StructType = StructType(
Array(
StructField("sal",DoubleType)
)
)
//用来描述计算过程中涉及到的变量的类型
override def bufferSchema: StructType = StructType(
Array(
StructField("sum",DoubleType),
StructField("count",LongType)
)
)
//用来描述计算结果的类型
override def dataType: DataType = DoubleType
//用来表示函数的稳定性
override def deterministic: Boolean = true
//用来对计算过程中涉及到的两个变量进行初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//buffer的第一个元素表示sum
buffer(0) = 0D
//buffer的第一个元素表示count
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//update用于更新数据
buffer.update(0,buffer.getDouble(0)+input.getDouble(0))
//加+1
buffer.update(1,buffer.getLong(1)+1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,buffer1.getDouble(0)+buffer2.getDouble(0))
buffer1.update(1,buffer1.getLong(1)+buffer2.getLong(1))
}
//用于计算结果
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0)/buffer.getLong(1)
}
}
}



