import java.util.HashMap;
import java.util.Map;
import scala.Tuple2;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
public class DecisionTreeRegression{
public static void main(String[] args) {
// TODO Auto-generated method stub
SparkConf sparkConf = new SparkConf(). setAppName ("JavaDecisionTreeClassificationExample");
sparkConf . setMaster("local[2]");
JavaSparkContext jsc = new JavaSparkContext (sparkConf);
// Load and parse the data file.
String datapath =
"file:///home/gyq/下载/spark-2.3.2-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt";
JavaRDD data = MLUtils. loadLibSVMFile(jsc.sc(), datapath).toJavaRDD() ;
// Split the data into training and test sets (30% held out for testing)
JavaRDD[] splits = data. randomSplit(new double[]{0.7, 0.3});
JavaRDD trainingData = splits[0];
JavaRDD testData = splits[1] ;
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are cont inuous .
Integer numClasses = 2; //类别数量
Map categoricalFeaturesInfo = new HashMap() ;
String impurity = "gini";
Integer maxDepth = 5; // 最大深度
Integer maxBins = 32; // 最大划分数
// Train a DecisionTree model for classification.
final DecisionTreeModel model = DecisionTree . trainClassifier(trainingData,
numClasses , categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// evaluate model on test instances and compute test error
JavaPairRDD predictionAndLabel =
testData.mapToPair(new PairFunction() {
public Tuple2 call(LabeledPoint p) {
return new Tuple2 (model. predict(p. features()),p.label());
}
});
Double testErr =1.0
* predictionAndLabel. filter(new Function, Boolean>()
{
public Boolean call(Tuple2 pl) {
return !pl._1(). equals(pl._2());
}
}). count() / testData . count();
System. out . println("Test Error: "+ testErr);
System. out . println("Learned classification tree model:n" + model. toDebugString());
}
}