- 第3例 使用Spark ML的逻辑回归来预测音乐标签
- 3.1 数据准备
- 3.1.1 数据集文件准备
- 2.1.2 数据集字段解释
- 2.2 使用 Spark ML 实现代码
- 2.2.1 引入项目依赖
- 2.2.2 将 `MNIST` 数据集以 `libsvm` 格式进行加载并解析
- 2.2.3 准备训练和测试集
- 2.2.4 运行训练算法来创建模型
- 2.2.5 在测试上计算原始分数
- 2.2.6 为模型评估初始化一个多类度量
- 2.2.7 构造混淆矩阵
- 2.2.8 总体统计信息
- 2.2.9 项目完整代码
3.1 数据准备 3.1.1 数据集文件准备
- 这是一个 多元分类 问题, 也就是预测出来的结果有多种。
- 有关 Spark ML 的介绍与知识点请参考: Spark ML学习笔记—Spark MLlib 与 Spark ML。
-
(1) 该项目并为使用数据库当做数据源,而是直接将数据文件放在项目目录中, 这是一个结构化的简化数据集。
-
(2) 本项目使用的数据集是著名的 MNIST 数据集,该数据集包含 780 个特征。数据集地址: 百万歌曲数据集。
- 由于字段太多,这里不做具体字段解释。
2.2 使用 Spark ML 实现代码 2.2.1 引入项目依赖
使用的依赖包多数来自于 Spark ML, 而非 Spark MLlib。
import org.apache.spark.SparkConf import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.SparkSession
2.2.2 将 MNIST 数据集以 libsvm 格式进行加载并解析
val data = MLUtils.loadLibSVMFile(spark.sparkContext, "datas3/mnist.bz2")
2.2.3 准备训练和测试集
val splits = data.randomSplit(Array(0.75, 0.25), 12345L) val training = splits(0).cache() val test = splits(1)
2.2.4 运行训练算法来创建模型
val model = new LogisticRegressionWithLBFGS()
.setNumClasses(10)
.setIntercept(true)
.setValidateData(true)
.run(training)
- 到这一步, 预测模型便已经创建成功, 后续只需要根据这个模型进行预测即可。
2.2.5 在测试上计算原始分数
val scoreAndLabels = test.map{
point => {
val score = model.predict(point.features)
(score, point.label)
}
}
- 到这一步,预测结果也几经的出来了,只需要循环遍历输出一下即可,预测结果如下图所示:
- 从上图中可以看出: 预测出来的 prediction 与 label 完全一致, 说明预测的准确率是很高的。
- 至此, 预测工作已经进行结束了, 剩下还有一些 观察训练过程 和 模型评估 的操作。
2.2.6 为模型评估初始化一个多类度量
// 为模型评估初始化一个多类度量 (metrics包含模型的各种度量信息) val metrics = new MulticlassMetrics(scoreAndLabels)2.2.7 构造混淆矩阵
println("Confusion matrix: ")
println(metrics.confusionMatrix)
混淆矩阵如下图所示:
2.2.8 总体统计信息
val accuracy = metrics.accuracy
println("Summary Statistics")
println(s"Accuracy = $accuracy")
// Precision by label (准确率)
val labels = metrics.labels
labels.foreach(
l => println(s"Precision($l) = " + metrics.precision(l))
)
// Recall by label (召回率)
labels.foreach(
l => println(s"Recall($l) = " + metrics.recall(l))
)
// False positive rate by label (假正类比例)
labels.foreach(
l => println(s"FPR($l) = " + metrics.falsePositiveRate(l))
)
// F-measure by label (F1分数)
labels.foreach(
l => println(s"F1-Score($l) = " + metrics.fMeasure(l))
)
// 计算总体的统计信息
println(s"Weighted precision: ${metrics.weightedPrecision}")
println(s"Weighted recall: ${metrics.weightedRecall}")
println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")
上述代码的输出信息如下图所示:
2.2.9 项目完整代码
import org.apache.spark.SparkConf
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.SparkSession
object SparkML_0105_test5 {
def main(args: Array[String]): Unit = {
// TODO 创建 Spark SQL 的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkML")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
// step 1: 将 MNIST 数据集以 libsvm 格式进行加载并解析
val data = MLUtils.loadLibSVMFile(spark.sparkContext, "datas3/mnist.bz2")
// step 2: 准备训练和测试集 (将数据拆分为训练集(75%) 和 测试集(25%))
val splits = data.randomSplit(Array(0.75, 0.25), 12345L)
val training = splits(0).cache()
val test = splits(1)
// step 3: 运行训练算法来创建模型
val model = new LogisticRegressionWithLBFGS()
.setNumClasses(10)
.setIntercept(true)
.setValidateData(true)
.run(training)
// step 4: 清理默认的阈值
model.clearThreshold()
// step 5: 在测试上计算原始分数
val scoreAndLabels = test.map{
point => {
val score = model.predict(point.features)
(score, point.label)
}
}
// step 6: 为模型评估初始化一个多类度量 (metrics包含模型的各种度量信息)
val metrics = new MulticlassMetrics(scoreAndLabels)
// step 7: 构造混淆矩阵
println("Confusion matrix: ")
println(metrics.confusionMatrix)
// step 8: 总体统计信息
val accuracy = metrics.accuracy
println("Summary Statistics")
println(s"Accuracy = $accuracy")
// Precision by label (准确率)
val labels = metrics.labels
labels.foreach(
l => println(s"Precision($l) = " + metrics.precision(l))
)
// Recall by label (召回率)
labels.foreach(
l => println(s"Recall($l) = " + metrics.recall(l))
)
// False positive rate by label (假正类比例)
labels.foreach(
l => println(s"FPR($l) = " + metrics.falsePositiveRate(l))
)
// F-measure by label (F1分数)
labels.foreach(
l => println(s"F1-Score($l) = " + metrics.fMeasure(l))
)
// 计算总体的统计信息
println(s"Weighted precision: ${metrics.weightedPrecision}")
println(s"Weighted recall: ${metrics.weightedRecall}")
println(s"Weighted F1 score: ${metrics.weightedFMeasure}")
println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}")
spark.close()
}
}



