栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 前沿技术 > 大数据 > 大数据系统

决策树(1)

决策树(1)

import java.util.HashMap;
import java.util.Map;

import scala.Tuple2;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
public class DecisionTreeRegression{
    public static void main(String[] args) {
        // TODO Auto-generated method stub
    	SparkConf sparkConf = new SparkConf(). setAppName ("JavaDecisionTreeClassificationExample");
    	sparkConf . setMaster("local[2]");
    	JavaSparkContext jsc = new JavaSparkContext (sparkConf);
    	// Load and parse the data file.
    	String datapath =
    	"file:///home/gyq/下载/spark-2.3.2-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt";
    
    	JavaRDD data = MLUtils. loadLibSVMFile(jsc.sc(), datapath).toJavaRDD() ;
    	// Split the data into training and test sets (30% held out for testing)
    	JavaRDD[] splits = data. randomSplit(new double[]{0.7, 0.3});
    	JavaRDD trainingData = splits[0];
    	JavaRDD testData = splits[1] ;
    	// Set parameters.
    	// Empty categoricalFeaturesInfo indicates all features are cont inuous .
    	Integer numClasses = 2; //类别数量
    	Map categoricalFeaturesInfo = new HashMap() ;
    	
    	String impurity = "gini";
    	Integer maxDepth = 5; // 最大深度
    	Integer maxBins = 32; // 最大划分数
    	// Train a DecisionTree model for classification.
    	final DecisionTreeModel model = DecisionTree . trainClassifier(trainingData,
    	numClasses , categoricalFeaturesInfo, impurity, maxDepth, maxBins);
    	// evaluate model on test instances and compute test error
    	JavaPairRDD predictionAndLabel =
    	testData.mapToPair(new PairFunction() {
   
    	public Tuple2 call(LabeledPoint p) {
    	return new Tuple2 (model. predict(p. features()),p.label());
    	}
    	});
    	Double testErr =1.0
    	* predictionAndLabel. filter(new Function, Boolean>()
    	{
    
    	public Boolean call(Tuple2 pl) {
    	return !pl._1(). equals(pl._2());
    	}
    	}). count() / testData . count();
    	System. out . println("Test Error: "+ testErr);
    	System. out . println("Learned classification tree model:n" + model. toDebugString());
    }
    }

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/581902.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号