栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 前沿技术 > 大数据 > 大数据系统

3.Spark 学习成果转化—机器学习—使用Spark ML的逻辑回归来预测音乐标签 (多元分类问题)

3.Spark 学习成果转化—机器学习—使用Spark ML的逻辑回归来预测音乐标签 (多元分类问题)

本文目录如下:
  • 第3例 使用Spark ML的逻辑回归来预测音乐标签
    • 3.1 数据准备
      • 3.1.1 数据集文件准备
      • 2.1.2 数据集字段解释
    • 2.2 使用 Spark ML 实现代码
      • 2.2.1 引入项目依赖
      • 2.2.2 将 `MNIST` 数据集以 `libsvm` 格式进行加载并解析
      • 2.2.3 准备训练和测试集
      • 2.2.4 运行训练算法来创建模型
      • 2.2.5 在测试上计算原始分数
      • 2.2.6 为模型评估初始化一个多类度量
      • 2.2.7 构造混淆矩阵
      • 2.2.8 总体统计信息
      • 2.2.9 项目完整代码

第3例 使用Spark ML的逻辑回归来预测音乐标签
  • 这是一个 多元分类 问题, 也就是预测出来的结果有多种。
  • 有关 Spark ML 的介绍与知识点请参考: Spark ML学习笔记—Spark MLlib 与 Spark ML。
3.1 数据准备 3.1.1 数据集文件准备
  • (1) 该项目并为使用数据库当做数据源,而是直接将数据文件放在项目目录中, 这是一个结构化的简化数据集。

  • (2) 本项目使用的数据集是著名的 MNIST 数据集,该数据集包含 780 个特征。数据集地址: 百万歌曲数据集。

2.1.2 数据集字段解释
  • 由于字段太多,这里不做具体字段解释。

2.2 使用 Spark ML 实现代码 2.2.1 引入项目依赖

使用的依赖包多数来自于 Spark ML, 而非 Spark MLlib。

import org.apache.spark.SparkConf
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession

2.2.2 将 MNIST 数据集以 libsvm 格式进行加载并解析
val data = MLUtils.loadLibSVMFile(spark.sparkContext, "datas3/mnist.bz2")

2.2.3 准备训练和测试集
val splits = data.randomSplit(Array(0.75, 0.25), 12345L)
val training = splits(0).cache()
val test = splits(1)

2.2.4 运行训练算法来创建模型
val model = new LogisticRegressionWithLBFGS()
      .setNumClasses(10)
      .setIntercept(true)
      .setValidateData(true)
      .run(training)
  • 到这一步, 预测模型便已经创建成功, 后续只需要根据这个模型进行预测即可。

2.2.5 在测试上计算原始分数
val scoreAndLabels = test.map{
  point => {
    val score = model.predict(point.features)
    (score, point.label)
  }
}
  • 到这一步,预测结果也几经的出来了,只需要循环遍历输出一下即可,预测结果如下图所示:
  • 从上图中可以看出: 预测出来的 prediction 与 label 完全一致, 说明预测的准确率是很高的。
  • 至此, 预测工作已经进行结束了, 剩下还有一些 观察训练过程 和 模型评估 的操作。

2.2.6 为模型评估初始化一个多类度量
// 为模型评估初始化一个多类度量 (metrics包含模型的各种度量信息)
val metrics = new MulticlassMetrics(scoreAndLabels)
2.2.7 构造混淆矩阵
println("Confusion matrix: ")
println(metrics.confusionMatrix)

混淆矩阵如下图所示:


2.2.8 总体统计信息
val accuracy = metrics.accuracy
println("Summary Statistics")
println(s"Accuracy = $accuracy")
// Precision by label (准确率)
val labels = metrics.labels
labels.foreach(
  l => println(s"Precision($l) = " + metrics.precision(l))
)
// Recall by label (召回率)
labels.foreach(
  l => println(s"Recall($l) = " + metrics.recall(l))
)
// False positive rate by label (假正类比例)
labels.foreach(
  l => println(s"FPR($l) = " + metrics.falsePositiveRate(l))
)
// F-measure by label (F1分数)
labels.foreach(
  l => println(s"F1-Score($l) = " + metrics.fMeasure(l))
)

// 计算总体的统计信息
println(s"Weighted precision: ${metrics.weightedPrecision}")
println(s"Weighted recall: ${metrics.weightedRecall}")
println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")

上述代码的输出信息如下图所示:


2.2.9 项目完整代码
import org.apache.spark.SparkConf
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession



object SparkML_0105_test5 {
  def main(args: Array[String]): Unit = {
    // TODO 创建 Spark SQL 的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkML")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()

    // step 1: 将 MNIST 数据集以 libsvm 格式进行加载并解析
    val data = MLUtils.loadLibSVMFile(spark.sparkContext, "datas3/mnist.bz2")

    // step 2: 准备训练和测试集 (将数据拆分为训练集(75%) 和 测试集(25%))
    val splits = data.randomSplit(Array(0.75, 0.25), 12345L)
    val training = splits(0).cache()
    val test = splits(1)

    // step 3: 运行训练算法来创建模型
    val model = new LogisticRegressionWithLBFGS()
      .setNumClasses(10)
      .setIntercept(true)
      .setValidateData(true)
      .run(training)

    // step 4: 清理默认的阈值
    model.clearThreshold()

    // step 5: 在测试上计算原始分数
    val scoreAndLabels = test.map{
      point => {
        val score = model.predict(point.features)
        (score, point.label)
      }
    }

    // step 6: 为模型评估初始化一个多类度量 (metrics包含模型的各种度量信息)
    val metrics = new MulticlassMetrics(scoreAndLabels)

    // step 7: 构造混淆矩阵
    println("Confusion matrix: ")
    println(metrics.confusionMatrix)

    // step 8: 总体统计信息
    val accuracy = metrics.accuracy
    println("Summary Statistics")
    println(s"Accuracy = $accuracy")
    // Precision by label (准确率)
    val labels = metrics.labels
    labels.foreach(
      l => println(s"Precision($l) = " + metrics.precision(l))
    )
    // Recall by label (召回率)
    labels.foreach(
      l => println(s"Recall($l) = " + metrics.recall(l))
    )
    // False positive rate by label (假正类比例)
    labels.foreach(
      l => println(s"FPR($l) = " + metrics.falsePositiveRate(l))
    )
    // F-measure by label (F1分数)
    labels.foreach(
      l => println(s"F1-Score($l) = " + metrics.fMeasure(l))
    )

    // 计算总体的统计信息
    println(s"Weighted precision: ${metrics.weightedPrecision}")
    println(s"Weighted recall: ${metrics.weightedRecall}")
    println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
    println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")

    spark.close()
  }
}
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/279818.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号