栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > R语言

用R语言实现决策树分类

R语言 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

用R语言实现决策树分类

最近在看西瓜书中有关决策树的部分,就想用R语言建立简单的决策树模型,因为Python实在还不太会,哈。
这里为了方便,我就直接使用自带的数据集鸢尾花iris,用的R包有rpart和rpart.plot。rpart是一个专门用于做决策树模型的包,rpart.plot则用于绘制rpart模型。
为了方便理解和记忆,此处将模型的完整建立分成导入数据包/设置建模参数/数据切分/建模/调整模型参数并计算训练误差和测试误差这几个步骤。
以下是代码的具体实现部分。

首先加载需要的R包:

install.packages('rpart')
install.packages('rpart.plot')
library(rpart)
library(rpart.plot)

查看数据集,

iris
str(iris)


主要要查看数据集的标签列的位置在哪一列,
还要注意标签列的数据类型必须为factor因子型,不然数据类型不对不好分类。
我们可以看到鸢尾花数据集的标签Species,类型为factor,因此不需要再转换数据类型。

对数据进行切分,随机分为训练集和测试集,

index <- sample(nrow(iris), 0.7*nrow(iris))
train <- iris[index, ]
test <- iris[-index, ]

设置建模控制参数,参数的设置在一定程度上可以防止模型过拟合。
其中rpart.control 参数minbucket 表示叶节点至少包含的样本数,少于这个数量就进行剪枝;参数maxdepth设置树的最大深度;xval是交叉验证次数;cp是树生长的最低增长指标,也就是每生长一步,对整体纯度提升的的最低指标,低于这个指标就进行剪枝。

tc <- rpart.control(minbucket=5,maxdepth=10,xval=5,cp=0.005)

接下来就可以用训练集建立模型啦,

fit <- rpart(Species ~ ., data=train, control="tc")

然后用建立好的模型分别对训练集和测试集进行预测,并计算准确率。
其中table函数可以统计每个类别的频数,通过公式:预测正确的个数除以总数 可以很好的计算出准确率。

train.pred <- predict(fit, train[,-5], type="class")
table(train$Species == train.pred)['TRUE'] / length(train.pred)
test.pred <- predict(fit, test[,-5], type="class") 
table(test$Species == test.pred)['TRUE'] / length(test.pred)

得到的结果如下,
接下来就可以画一棵决策树了!

 rpart.plot(fit, main="Decision Tree")


从上面计算的模型的准确率来看,模型的泛化能力还是挺好的。

如果想要对树进行剪枝的话,我们可以根据设置相应的cp值来进行剪枝,

fit$cptable

查看模型各层的cp值,一共有三层,


如果要对最下面一层进行剪枝的话,我们要设置最低cp值略大于倒数第二层的cp值,这样的话倒数第二层就不会继续生长了。

prune(fit, 0.43077)
rpart.plot(prune(fit, 0.43077)

我们还可以确定最佳cp值。
对控制参数重新设置,初始cp值设置为0,然后用同样的方法建立模型fit2,这里就不再重复了

fit2$cptable

通过cptable图可以看出cp值在0.081时模型的准确率已经很好,大于这个值模型容易欠拟合,小于这个容易过拟合。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/855905.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号