一,MLlib基本概念二, Pipeline流水线范例
1,准备数据2,定义模型3,训练模型4,使用模型5,评估模型6,保存模型
Llib是Spark的机器学习库,包括以下主要功能。
实用工具:线性代数,统计,数据处理等工具 特征工程:特征提取,特征转换,特征选择 常用算法:分类,回归,聚类,协同过滤,降维 模型优化:模型评估,参数优化。
MLlib库包括两个不同的部分:
pyspark.mllib 包含基于rdd的机器学习算法API,目前不再更新,以后将被丢弃,不建议使用。
pyspark.ml 包含基于Dataframe的机器学习算法API,可以用来构建机器学习工作流Pipeline,推荐使用。
一,MLlib基本概念Dataframe:
MLlib中数据的存储形式,其列可以存储特征向量,标签,以及原始的文本,图像。
Transformer:
转换器。具有transform方法。通过附加一个或多个列将一个Dataframe转换成另外一个Dataframe。
Estimator:估计器。具有fit方法。它接受一个Dataframe数据作为输入后经过训练,产生一个转换器Transformer。
Pipeline:流水线。具有setStages方法。顺序将多个Transformer和1个Estimator串联起来,得到一个流水线模型。
二, Pipeline流水线范例
任务描述:用逻辑回归模型预测句子中是否包括”spark“这个单词。
from pyspark.ml.feature import Tokenizer,HashingTF from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import MulticlassClassificationevaluator,BinaryClassificationevaluator from pyspark.ml import Pipeline,PipelineModel from pyspark.ml.linalg import Vector from pyspark.sql import Row1,准备数据
dftrain = spark.createDataframe([(0,"a b c d e spark",1.0),
(1,"a c f",0.0),
(2,"spark hello world",1.0),
(3,"hadoop mapreduce",0.0),
(4,"I love spark", 1.0),
(5,"big data",0.0)],["id","text","label"])
dftrain.show()
+---+-----------------+-----+ | id| text|label| +---+-----------------+-----+ | 0| a b c d e spark| 1.0| | 1| a c f| 0.0| | 2|spark hello world| 1.0| | 3| hadoop mapreduce| 0.0| | 4| I love spark| 1.0| | 5| big data| 0.0| +---+-----------------+-----+2,定义模型
tokenizer = Tokenizer().setInputCol("text").setOutputCol("words")
print(type(tokenizer))
hashingTF = HashingTF().setNumFeatures(100)
.setInputCol(tokenizer.getOutputCol())
.setOutputCol("features")
print(type(hashingTF))
lr = LogisticRegression().setLabelCol("label")
#print(lr.explainParams)
lr.setFeaturesCol("features").setMaxIter(10).setRegParam(0.2)
print(type(lr))
pipe = Pipeline().setStages([tokenizer,hashingTF,lr])
print(type(pipe))
3,训练模型
model = pipe.fit(dftrain) print(type(model))
4,使用模型
dftest = spark.createDataframe([(7,"spark job",1.0),(9,"hello world",0.0),
(10,"a b c d e",0.0),(11,"you can you up",0.0),
(12,"spark is easy to use.",1.0)]).toDF("id","text","label")
dftest.show()
dfresult = model.transform(dftest)
dfresult.selectExpr("text","features","probability","prediction").show()
+---+--------------------+-----+ | id| text|label| +---+--------------------+-----+ | 7| spark job| 1.0| | 9| hello world| 0.0| | 10| a b c d e| 0.0| | 11| you can you up| 0.0| | 12|spark is easy to ...| 1.0| +---+--------------------+-----+ +--------------------+--------------------+--------------------+----------+ | text| features| probability|prediction| +--------------------+--------------------+--------------------+----------+ | spark job|(100,[57,86],[1.0...|[0.30134853865356...| 1.0| | hello world|(100,[60,89],[1.0...|[0.20714372651040...| 1.0| | a b c d e|(100,[50,65,67,68...|[0.24502686265469...| 1.0| | you can you up|(100,[33,38,51],[...|[0.87589306761045...| 0.0| |spark is easy to ...|(100,[9,21,60,86,...|[0.07662944406376...| 1.0| +--------------------+--------------------+--------------------+----------+5,评估模型
dfresult.printSchema()
root |-- id: long (nullable = true) |-- text: string (nullable = true) |-- label: double (nullable = true) |-- words: array (nullable = true) | |-- element: string (containsNull = true) |-- features: vector (nullable = true) |-- rawPrediction: vector (nullable = true) |-- probability: vector (nullable = true) |-- prediction: double (nullable = false)
evaluator = MulticlassClassificationevaluator().setMetricName("f1")
.setPredictionCol("prediction").setLabelCol("label")
#print(evaluator.explainParams())
accuracy = evaluator.evaluate(dfresult)
print("n accuracy = {}".format(accuracy))
accuracy = 0.5666666666666667
6,保存模型
#可以将训练好的模型保存到磁盘中
model.write().overwrite().save("./data/mymodel.model")
#也可以将没有训练的模型保存到磁盘中
#pipeline.write.overwrite().save("./data/unfit-lr-model")
#重新载入模型
model_loaded = PipelineModel.load("./data/mymodel.model")
model_loaded.transform(dftest).select("text","label","prediction").show()
+--------------------+-----+----------+ | text|label|prediction| +--------------------+-----+----------+ | spark job| 1.0| 1.0| | hello world| 0.0| 1.0| | a b c d e| 0.0| 1.0| | you can you up| 0.0| 0.0| |spark is easy to ...| 1.0| 1.0| +--------------------+-----+----------+



