栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

实用:sklearn提取决策树规则代码(附python代码)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

实用:sklearn提取决策树规则代码(附python代码)

《老饼讲解机器学习》http://ml.bbbdata.com/teach#107


目录

一.问题

二.主要思路

三.代码实例

1.数据提取

2.预测函数

3.准确性测试


一.问题

在决策树模型建好之后,要提取规则布署到生产。

二.主要思路

只提取数据,在生产环境写出通用预测代码。新的模型只需替换数据即可。

备注:一般不弄成一系列的if else,写死代码不便于更换模型。

三.代码实例

1.数据提取

使用如下get_tree函数,将树数据提取成字典:

from sklearn import tree
import numpy as np
def get_tree(sk_tree):
    #--------------拷贝sklearn树模型关键信息--------------------
    children_left       = sk_tree.tree_.children_left.copy()            # 左节点编号
    children_right      = sk_tree.tree_.children_right.copy()          # 右节点编号
    feature          = sk_tree.tree_.feature.copy()               # 分割的变量
    threshold         = sk_tree.tree_.threshold.copy()              # 分割阈值
    impurity          = sk_tree.tree_.impurity.copy()               # 不纯度(gini)
    n_node_samples      = sk_tree.tree_.n_node_samples.copy()           # 样本个数
    value            = sk_tree.tree_.value.copy()                 # 样本分布
    n_sample         = value[0].sum()                          # 总样本个数
    node_num         = len(children_left)                       # 节点个数
    depth = sk_tree.get_depth()
    
    # ------------补充节点父节点信息---------------------------
    parent = np.zeros(node_num).astype(int)
    parent[0] = -1
    branch_idx = np.where(children_left!=-1)[0]
    for i in branch_idx:
        parent[children_left[i]] = i   
        parent[children_right[i]]= i 
    #-------------存成字典-----------------------------------------    
    tree = {
        'children_left':children_left
        ,'children_right':children_right
        ,'feature':feature
        ,'threshold':threshold
        ,'impurity':impurity
        ,'n_node_samples':n_node_samples
        ,'value':value
        ,'depth':depth
        ,'n_sample':n_sample
        ,'node_num':node_num
        ,'parent':parent
        }
    return tree

将训练好的模型sk_tree传入以上函数,转化成字典,保存成文件。

2.预测函数

在生产时使用如下tree_predict 函数预测(其它语言类似以下逻辑)。

import numpy as np
def tree_predict(tree,x):
    node_idx = 0
    t = 0
    while(t 

3.准确性测试
from sklearn.datasets import load_iris
from sklearn import tree
import numpy as np
from get_tree import get_tree
from tree_pred import tree_predict

#----------------数据准备----------------------------
iris = load_iris()                          # 加载数据
X = iris.data
y = iris.target
#---------------模型训练----------------------------------
clf = tree.DecisionTreeClassifier()               # sk-learn的决策树模型
clf = clf.fit(X, y)                        # 用数据训练树模型构建()
#--------------将树提取成简单的字典--------------------------------
tree = get_tree(clf)
#-------------------------
#将tree持久化到服务器,服务器中用tree_predict进行预测即可
#-------------------------

#------------测试函数的准确性-----------------------------
self_pred_y = np.zeros(len(y))
self_pred_prob = np.zeros((len(y),len(tree['value'][0][0])))
for i in range(X.shape[0]):
    pred_class,pred_prob = tree_predict(tree,X[i])
    self_pred_y[i] = pred_class
    self_pred_prob[i] = pred_prob
pred_y = clf.predict(X)
pred_prob = clf.predict_proba(X)
print("与sklearn预测结果差异个数(类别):",np.sum(pred_y != self_pred_y))
print("与sklearn预测结果差异个数(概率):",np.sum(pred_prob != self_pred_prob))

 测试结果:

与sklearn预测结果差异个数(类别): 0
与sklearn预测结果差异个数(概率): 0

相关文章

《深入浅出:决策树入门简介》

《一个简单的决策树分类例子》

《sklearn决策树结果可视化》

《sklearn决策树参数详解》

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/740824.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号