栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

Java实现的决策树算法完整实例

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Java实现的决策树算法完整实例

本文实例讲述了Java实现的决策树算法。分享给大家供大家参考,具体如下:

决策树算法是一种逼近离散函数值的方法。它是一种典型的分类方法,首先对数据进行处理,利用归纳算法生成可读的规则和决策树,然后使用决策对新数据进行分析。本质上决策树是通过一系列规则对数据进行分类的过程。

决策树构造可以分两步进行。第一步,决策树的生成:由训练样本集生成决策树的过程。一般情况下,训练样本数据集是根据实际需要有历史的、有一定综合程度的,用于数据分析处理的数据集。第二步,决策树的剪枝:决策树的剪枝是对上一阶段生成的决策树进行检验、校正和修下的过程,主要是用新的样本数据集(称为测试数据集)中的数据校验决策树生成过程中产生的初步规则,将那些影响预衡准确性的分枝剪除。

java实现代码如下:

package demo;
import java.util.HashMap;
import java.util.linkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
public class DicisionTree {
  public static void main(String[] args) throws Exception {
    System.out.print("考高分网测试结果:");
    String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT",
 "CREDIT_RATING" };
    // 读取样本集
    Map> samples = readSamples(attrNames);
    // 生成决策树
    Object decisionTree = generateDecisionTree(samples, attrNames);
    // 输出决策树
    outputDecisionTree(decisionTree, 0, null);
  }
  
  static Map> readSamples(String[] attrNames) {
    // 样本属性及其所属分类(数组中的最后一个元素为样本所属分类)
    Object[][] rawData = new Object[][] {
 { "<30 ", "High ", "No ", "Fair   ", "0" },
 { "<30 ", "High ", "No ", "Excellent", "0" },
 { "30-40", "High ", "No ", "Fair   ", "1" },
 { ">40 ", "Medium", "No ", "Fair   ", "1" },
 { ">40 ", "Low  ", "Yes", "Fair   ", "1" },
 { ">40 ", "Low  ", "Yes", "Excellent", "0" },
 { "30-40", "Low  ", "Yes", "Excellent", "1" },
 { "<30 ", "Medium", "No ", "Fair   ", "0" },
 { "<30 ", "Low  ", "Yes", "Fair   ", "1" },
 { ">40 ", "Medium", "Yes", "Fair   ", "1" },
 { "<30 ", "Medium", "Yes", "Excellent", "1" },
 { "30-40", "Medium", "No ", "Excellent", "1" },
 { "30-40", "High ", "Yes", "Fair   ", "1" },
 { ">40 ", "Medium", "No ", "Excellent", "0" } };
    // 读取样本属性及其所属分类,构造表示样本的Sample对象,并按分类划分样本集
    Map> ret = new HashMap>();
    for (Object[] row : rawData) {
      Sample sample = new Sample();
      int i = 0;
      for (int n = row.length - 1; i < n; i++)
 sample.setAttribute(attrNames[i], row[i]);
      sample.setCategory(row[i]);
      List samples = ret.get(row[i]);
      if (samples == null) {
 samples = new linkedList();
 ret.put(row[i], samples);
      }
      samples.add(sample);
    }
    return ret;
  }
  
  static Object generateDecisionTree(
      Map> categoryToSamples, String[] attrNames) {
    // 如果只有一个样本,将该样本所属分类作为新样本的分类
    if (categoryToSamples.size() == 1)
      return categoryToSamples.keySet().iterator().next();
    // 如果没有供决策的属性,则将样本集中具有最多样本的分类作为新样本的分类,即投票选举出分类
    if (attrNames.length == 0) {
      int max = 0;
      Object maxCategory = null;
      for (Entry> entry : categoryToSamples
   .entrySet()) {
 int cur = entry.getValue().size();
 if (cur > max) {
   max = cur;
   maxCategory = entry.getKey();
 }
      }
      return maxCategory;
    }
    // 选取测试属性
    Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames);
    // 决策树根结点,分支属性为选取的测试属性
    Tree tree = new Tree(attrNames[(Integer) rst[0]]);
    // 已用过的测试属性不应再次被选为测试属性
    String[] subA = new String[attrNames.length - 1];
    for (int i = 0, j = 0; i < attrNames.length; i++)
      if (i != (Integer) rst[0])
 subA[j++] = attrNames[i];
    // 根据分支属性生成分支
    @SuppressWarnings("unchecked")
    Map>> splits =
    (Map>>) rst[2];
    for (Entry>> entry : splits.entrySet()) {
      Object attrValue = entry.getKey();
      Map> split = entry.getValue();
      Object child = generateDecisionTree(split, subA);
      tree.setChild(attrValue, child);
    }
    return tree;
  }
  
  static Object[] chooseBestTestAttribute(
      Map> categoryToSamples, String[] attrNames) {
    int minIndex = -1; // 最优属性下标
    double minValue = Double.MAX_VALUE; // 最小信息量
    Map>> minSplits = null; // 最优分支方案
    // 对每一个属性,计算将其作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和,选取最小为最优
    for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) {
      int allCount = 0; // 统计样本总数的计数器
      // 按当前属性构建Map:属性值->(分类->样本列表)
      Map>> curSplits =
      new HashMap>>();
      for (Entry> entry : categoryToSamples
   .entrySet()) {
 Object category = entry.getKey();
 List samples = entry.getValue();
 for (Sample sample : samples) {
   Object attrValue = sample
.getAttribute(attrNames[attrIndex]);
   Map> split = curSplits.get(attrValue);
   if (split == null) {
     split = new HashMap>();
     curSplits.put(attrValue, split);
   }
   List splitSamples = split.get(category);
   if (splitSamples == null) {
     splitSamples = new linkedList();
     split.put(category, splitSamples);
   }
   splitSamples.add(sample);
 }
 allCount += samples.size();
      }
      // 计算将当前属性作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和
      double curValue = 0.0; // 计数器:累加各分支
      for (Map> splits : curSplits.values()) {
 double perSplitCount = 0;
 for (List list : splits.values())
   perSplitCount += list.size(); // 累计当前分支样本数
 double perSplitValue = 0.0; // 计数器:当前分支
 for (List list : splits.values()) {
   double p = list.size() / perSplitCount;
   perSplitValue -= p * (Math.log(p) / Math.log(2));
 }
 curValue += (perSplitCount / allCount) * perSplitValue;
      }
      // 选取最小为最优
      if (minValue > curValue) {
 minIndex = attrIndex;
 minValue = curValue;
 minSplits = curSplits;
      }
    }
    return new Object[] { minIndex, minValue, minSplits };
  }
  
  static void outputDecisionTree(Object obj, int level, Object from) {
    for (int i = 0; i < level; i++)
      System.out.print("|-----");
    if (from != null)
      System.out.printf("(%s):", from);
    if (obj instanceof Tree) {
      Tree tree = (Tree) obj;
      String attrName = tree.getAttribute();
      System.out.printf("[%s = ?]n", attrName);
      for (Object attrValue : tree.getAttributevalues()) {
 Object child = tree.getChild(attrValue);
 outputDecisionTree(child, level + 1, attrName + " = "
     + attrValue);
      }
    } else {
      System.out.printf("[CATEGORY = %s]n", obj);
    }
  }
  
  static class Sample {
    private Map attributes = new HashMap();
    private Object category;
    public Object getAttribute(String name) {
      return attributes.get(name);
    }
    public void setAttribute(String name, Object value) {
      attributes.put(name, value);
    }
    public Object getCategory() {
      return category;
    }
    public void setCategory(Object category) {
      this.category = category;
    }
    public String toString() {
      return attributes.toString();
    }
  }
  
  static class Tree {
    private String attribute;
    private Map children = new HashMap();
    public Tree(String attribute) {
      this.attribute = attribute;
    }
    public String getAttribute() {
      return attribute;
    }
    public Object getChild(Object attrValue) {
      return children.get(attrValue);
    }
    public void setChild(Object attrValue, Object child) {
      children.put(attrValue, child);
    }
    public Set getAttributevalues() {
      return children.keySet();
    }
  }
}



运行结果:

更多关于java算法相关内容感兴趣的读者可查看本站专题:《Java数据结构与算法教程》、《Java操作DOM节点技巧总结》、《Java文件与目录操作技巧汇总》和《Java缓存操作技巧汇总》

希望本文所述对大家java程序设计有所帮助。

转载请注明:文章转载自 www.mshxw.com
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号