栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

P-R曲线绘制(多分类问题)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

P-R曲线绘制(多分类问题)

以iris数据为样本实现P-R曲线的绘制

import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm, datasets
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier

导入iris数据集

iris = datasets.load_iris()
X = iris.data
y = iris.target

因为target列是iris分类的文字描述形式,需将其转换为类别标签

y = label_binarize(y,classes=[0,1,2]) # 运用标签二值化的方法
n_classes = y.shape[1] 

形成的部分y如下图

 为使得曲线效果变化显著,适当增加噪声样本

random_state = np.random.RandomState(0)
n_samples,n_features = X.shape
# 增加200倍的噪声值,即在原始x的列上增加200*4列
X = np.c_[X,random_state.randn(n_samples,200*n_features)]

训练模型,并计算decision_function()

X_train,X_test,y_train,y_test = train_test_split(X,y, test_size=0.8,random_state=random_state)
classifier = oneVsRestClassifier(svm.SVC(kernel = "linear",probability = True, random_state=random_state))
# decision_function计算样本点到分割超平面的函数距离。输出表示分类器对x_test的预测样本是位于超平面的右侧还是左侧,以及离它有多远。它还告诉我们分类器为x_test预测的每个值是正的(大幅度正值)还是负的(大幅度负值)。
y_score = classifier.fit(X_train,y_train).decision_function(X_test)
print(y_score)

对三个分类以此计算precision、recall,并且运用micro方式对precision、recall求平均(也可以使用macro、weighted的方式进行求平均

 

precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
    precision[i],recall[i],_ = precision_recall_curve(y_test[:,i],y_score[:,i])
    average_precision[i] = average_precision_score(y_test[:,i],y_score[:,i])
    
precision["micro"],recall["micro"],_ = precision_recall_curve(y_test.ravel(),y_score.ravel())
average_precision["micro"] = average_precision_score(y_test,y_score,average="micro") 

绘制P-R曲线

plt.clf()
plt.plot(recall["micro"],precision["micro"],label = "micro_average P_R(area={0:0.2f})".format(average_precision["micro"]))
for i in range(n_classes):
    plt.plot(recall[i],precision[i],label = "P_R curve of class{0}(area={1:0.2f})".format(i,average_precision[i]))

plt.xlim([0.0,0.1])
plt.ylim([0.0,1.05])
plt.legend(loc = "lower right")
plt.show()

 源代码源自深度学习基础_哈尔滨工业大学_中国大学MOOC(慕课) (icourse163.org)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/769038.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号