栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 前沿技术 > 大数据 > 大数据系统

Spark

Spark

import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataframe, Row, SparkSession}



//  多进一出
object UDAF {
  def main(args: Array[String]): Unit = {
    val sparkSession: SparkSession = SparkSession.builder().appName("UDAF").master("local[*]").getOrCreate()
    val sc: SparkContext = sparkSession.sparkContext

    import sparkSession.implicits._

    val students: Seq[Student] = Seq(
      Student(1, "zhangsan", "F", 22),
      Student(2, "lisi", "M", 38),
      Student(3, "wangwu", "M", 13),
      Student(4, "zhaoliu", "F", 17),
      Student(5, "songba", "M", 32),
      Student(6, "sunjiu", "M", 16),
      Student(7, "qianshiyi", "F", 17),
      Student(8, "yinshier", "F", 15),
      Student(9, "fangshisan", "M", 12),
      Student(10, "yeshisan", "F", 11),
      Student(11, "ruishiyi", "F", 26),
      Student(12, "chenshier", "M", 28)
    )

//    seq to df :   1. roDF  2.spark.createDataframe

    val frame: Dataframe = sparkSession.createDataframe(students)
    frame.printSchema()

    import org.apache.spark.sql.functions._

    sparkSession.udf.register("myAvg",new MyAgeAvgFunction)
    frame.createOrReplaceTempView("students")

    val resultDF: Dataframe = sparkSession.sql("select gender,myAvg(age) avgage from students group by gender")
    resultDF.printSchema()
    resultDF.show(false)
  }
}

//自定义聚合函数 UDAF 继承UserDefinedAggregateFunction
class MyAgeAvgFunction extends UserDefinedAggregateFunction{

  //聚合函数的输出数据的数据结构
  override def inputSchema: StructType = {
    new StructType().add("age",LongType)
  }

  //在缓冲区内的数据结构
  //sum 用来记录 所有年龄值相加的总和 43 + 52 + 61 + 78 = 234 => sum
  //count 用来记录 相加各个的总和 1 + 1 + 1 + 1 = 4 => count
  override def bufferSchema: StructType = {
    new StructType().add("sum",LongType).add("count",LongType)
  }

  //定义当前函数返回值的类型 sum/count 得到 Double类型
  override def dataType: DataType = DoubleType

  //聚合函数幂等
  override def deterministic: Boolean = true

  //初始值
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0L //记录 传入所有用户年龄相加的总和
    buffer(1)=0L //记录 传入所有用户年龄的个数
  }

  // 传入一条新数据后需要进入处理
  // 将Row() 对象中的值与buffer(0) 数据相加
  // buffer(1)数据个数加一
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  //合并 各分区内的数据
  //例如 p1(321,6) p2(128,2) p3(219,3)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //计算年龄相加的总和
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    //总人数
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  //计算最终结果
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0)/buffer.getLong(1).toDouble
  }
}
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/674323.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号