没有阅读spark分组排序二的同学需要先阅读下才能理解本篇文章的优化逻辑。上demo源码
二、源码import org.apache.spark.{Partitioner, SparkConf, SparkContext}
object Demo2 {
def main(args: Array[String]): Unit = {
val sc = new SparkContext(new SparkConf().setMaster("local[*]").setAppName("demo2"))
val rdd1 = sc.textFile("ttxs-spark/data/teacher")
// 2. 转换数据格式,字符串变元组: cate,sku => ((cate,sku), 1)
val rdd2 = rdd1.map(line=>{
((line.split(",")(0), line.split(",")(1)), 1)
})
// 统计有多少个品类
val arr = rdd2.keys.map(_._1).distinct().collect()
// 创建自定义分区器
val myPartitioner = new MyPartitioner(arr)
// 3. 统计每个sku的销量,并且将相同品类的数据shuffle到同一个分区
val rdd3 = rdd2.reduceByKey(myPartitioner, _ + _)
// 7. 计算每个分区内的top1:即每个品类小销量最多的sku
val rdd5 = rdd3.mapPartitions(iter => {
// 新建一个容量为1的数组,存储销量最好的sku信息
val arr = new Array[((String, String), Int)](1)
// 迭代分区内所有的数据,找出销量最大的sku
// 这个地方就是较初版优化的地方,使用迭代器就不会将所有数据一次性加载到内存
iter.foreach(x => {
if(arr(0) == null) {
arr(0) = x
}else {
if(arr(0)._2 < x._2) {
arr(0) = x
}
}
})
arr.toIterator
})
rdd5.foreach(println)
Thread.sleep(2000000000)
sc.stop()
}
}
分区器
class MyPartitioner(cates: Array[String]) extends Partitioner {
// 分区数即为品类数
override def numPartitions: Int = cates.length
// 每条数据所映射到的分区为类目id的索引值
override def getPartition(key: Any): Int = {
val cate = key.asInstanceOf[(String, String)]
cates.indexOf(cate._1)
}
}
三、总结
方案三在spark分组排序二的基础上主要是把reducebykey和partitionby进行合并,减少了一次shuffle,DAG如下:
具体原理:调用reducebykey,在传入计算逻辑的同时传入自定义分区器,自定义分区器逻辑:按品类进行分区。这样reducebykey在计算的同时完成了数据的重分区,减少了一次shuffle



