import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataframe, Row, SparkSession}
// 多进一出
object UDAF {
def main(args: Array[String]): Unit = {
val sparkSession: SparkSession = SparkSession.builder().appName("UDAF").master("local[*]").getOrCreate()
val sc: SparkContext = sparkSession.sparkContext
import sparkSession.implicits._
val students: Seq[Student] = Seq(
Student(1, "zhangsan", "F", 22),
Student(2, "lisi", "M", 38),
Student(3, "wangwu", "M", 13),
Student(4, "zhaoliu", "F", 17),
Student(5, "songba", "M", 32),
Student(6, "sunjiu", "M", 16),
Student(7, "qianshiyi", "F", 17),
Student(8, "yinshier", "F", 15),
Student(9, "fangshisan", "M", 12),
Student(10, "yeshisan", "F", 11),
Student(11, "ruishiyi", "F", 26),
Student(12, "chenshier", "M", 28)
)
// seq to df : 1. roDF 2.spark.createDataframe
val frame: Dataframe = sparkSession.createDataframe(students)
frame.printSchema()
import org.apache.spark.sql.functions._
sparkSession.udf.register("myAvg",new MyAgeAvgFunction)
frame.createOrReplaceTempView("students")
val resultDF: Dataframe = sparkSession.sql("select gender,myAvg(age) avgage from students group by gender")
resultDF.printSchema()
resultDF.show(false)
}
}
//自定义聚合函数 UDAF 继承UserDefinedAggregateFunction
class MyAgeAvgFunction extends UserDefinedAggregateFunction{
//聚合函数的输出数据的数据结构
override def inputSchema: StructType = {
new StructType().add("age",LongType)
}
//在缓冲区内的数据结构
//sum 用来记录 所有年龄值相加的总和 43 + 52 + 61 + 78 = 234 => sum
//count 用来记录 相加各个的总和 1 + 1 + 1 + 1 = 4 => count
override def bufferSchema: StructType = {
new StructType().add("sum",LongType).add("count",LongType)
}
//定义当前函数返回值的类型 sum/count 得到 Double类型
override def dataType: DataType = DoubleType
//聚合函数幂等
override def deterministic: Boolean = true
//初始值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L //记录 传入所有用户年龄相加的总和
buffer(1)=0L //记录 传入所有用户年龄的个数
}
// 传入一条新数据后需要进入处理
// 将Row() 对象中的值与buffer(0) 数据相加
// buffer(1)数据个数加一
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
//合并 各分区内的数据
//例如 p1(321,6) p2(128,2) p3(219,3)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//计算年龄相加的总和
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
//总人数
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1).toDouble
}
}