栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

如何将sklearn决策树规则提取到熊猫布尔条件?

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

如何将sklearn决策树规则提取到熊猫布尔条件?

首先,让我们使用决策树结构上的scikit文档获取有关所构建树的信息:

n_nodes = clf.tree_.node_countchildren_left = clf.tree_.children_leftchildren_right = clf.tree_.children_rightfeature = clf.tree_.featurethreshold = clf.tree_.threshold

然后,我们定义两个递归函数。第一个将找到树根的路径以创建一个特定节点(本例中的所有叶子)。第二个将使用其创建路径编写用于创建节点的特定规则:

def find_path(node_numb, path, x):        path.append(node_numb)        if node_numb == x: return True        left = False        right = False        if (children_left[node_numb] !=-1): left = find_path(children_left[node_numb], path, x)        if (children_right[node_numb] !=-1): right = find_path(children_right[node_numb], path, x)        if left or right : return True        path.remove(node_numb)        return Falsedef get_rule(path, column_names):    mask = ''    for index, node in enumerate(path):        #We check if we are not in the leaf        if index!=len(path)-1: # Do we go under or over the threshold ? if (children_left[node] == path[index+1]):     mask += "(df['{}']<= {}) t ".format(column_names[feature[node]], threshold[node]) else:     mask += "(df['{}']> {}) t ".format(column_names[feature[node]], threshold[node])    # We insert the & at the right places    mask = mask.replace("t", "&", mask.count("t") - 1)    mask = mask.replace("t", "")    return mask

最后,我们使用这两个函数来首先存储每个叶子的创建路径。然后存储用于创建每个叶子的规则:

# Leavesleave_id = clf.apply(X_test)paths ={}for leaf in np.unique(leave_id):    path_leaf = []    find_path(0, path_leaf, leaf)    paths[leaf] = np.unique(np.sort(path_leaf))rules = {}for key in paths:    rules[key] = get_rule(paths[key], pima.columns)

使用您提供的数据,输出为:

rules ={3: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']<= 9.100000381469727)  ", 4: "(df['insulin']<= 127.5) & (df['bp']<= 26.450000762939453) & (df['bp']> 9.100000381469727)  ", 6: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']<= 27.5)  ", 7: "(df['insulin']<= 127.5) & (df['bp']> 26.450000762939453) & (df['skin']> 27.5)  ", 10: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']<= 145.5)  ", 11: "(df['insulin']> 127.5) & (df['bp']<= 28.149999618530273) & (df['insulin']> 145.5)  ", 13: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']<= 158.5)  ", 14: "(df['insulin']> 127.5) & (df['bp']> 28.149999618530273) & (df['insulin']> 158.5)  "}

由于规则是字符串,因此不能使用直接调用它们

df[rules[3]]
,而必须像这样使用eval函数
df[eval(rules[3])]



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/611919.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号