栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Python基础----Matplotlib

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Python基础----Matplotlib

XGBClassifier(base_score 0.5, booster gbtree , colsample_bylevel 1, colsample_bynode 1, colsample_bytree 1, gamma 0, gpu_id -1, importance_type gain , interaction_constraints , learning_rate 0.1, max_delta_step 0, max_depth 4, min_child_weight 1, missing nan, monotone_constraints () , n_estimators 10, n_jobs 0, num_parallel_tree 1, random_state 0, reg_alpha 0, reg_lambda 1, scale_pos_weight 1, subsample 1, tree_method exact , validate_parameters 1, verbosity None)

随机构造的数据 所以在训练集上误差较小 在验证集上误差较大 即模型产生了过拟合

混淆矩阵、召回率和精确率

真实评估时以验证集为准

def plot_confusion_matrix(cm, classes, title Confusion matrix , cmap plt.cm.Blues):
 plt.imshow(cm, interpolation nearest , cmap cmap)
 plt.title(title)
 tick_marks np.arange(len(classes))
 plt.xticks(tick_marks, classes)
 plt.yticks(tick_marks, classes)
 thresh cm.max()/2
# print (thresh)
# print (cm)
 for i,j in itertools.product(range(cm.shape[0]), range(cm.shape[0])):
 plt.text(j,i,cm[i,j], horizontalalignment center ,color red if cm[i,j] thresh else black )
 plt.tight_layout()
 plt.ylabel( True label )
 plt.xlabel( Predicted label ) 
def matrixs_plot(X_test, y_test, clf, thresh 0.5, png_savename 0):
 plt.figure(figsize (10,6))
 y_pre clf.predict(X_test)
 y_score clf.predict_proba(X_test)[:,1]
 y_prediction y_score thresh # 多少概率以上的设定为正
 cnf_matrix metrics.confusion_matrix(y_test, y_prediction)
 np.set_printoptions(precision 2) #设置浮点进度
 vali_recall {0:.3f} .format(cnf_matrix[1,1]/(cnf_matrix[1,0] cnf_matrix[1,1]))
 vali_precision {0:.3f} .format(cnf_matrix[1,1]/(cnf_matrix[0,1] cnf_matrix[1,1]))
 class_names [0,1]
 title Recall %s%% n Precision %s%% %( {0:.1f} .format(float(vali_recall)*100), {0:.1f} .format(float(vali_precision)*100))
 plot_confusion_matrix(cnf_matrix, classes class_names, title title)
 plt.xlabel( Predict label )
 plt.ylabel( True label )
 if png_savename! 0:
 plt.savefig( pic/%s_混淆矩阵.png %png_savename,dpi 300)
y_val val_y
y_pre model.predict(val_x)
tn, fp, fn, tp metrics.confusion_matrix(y_val, y_pre).ravel()
print ( Recall is : ,round(tp/(tp fn),3))
print ( Precision is : ,round(tp/(tp fp),3))
print( matrix label0 label1 )
print( predict0 {: 6d} {: 6d} .format(int(tn), int(fn)))
print( predict1 {: 6d} {: 6d} .format(int(fp), int(tp)))
Recall is : 0.49
Precision is : 0.493
matrix label0 label1
predict0 82 74 
predict1 73 71 
matrixs_plot(val_x,val_y,model,thresh 0.5)

AUC
def auc_plot(X_test, y_test, clf, png_savename 0):
 from sklearn.metrics import auc,roc_curve, accuracy_score
 plt.figure(figsize (10,6))
 y_pre clf.predict(X_test)
 y_score clf.predict_proba(X_test)[:,1] # 输出预测的概率
 fpr, tpr, thresholds roc_curve(y_test, y_score)
 thresholds np.clip(thresholds,0,1)
 roc_auc auc(fpr, tpr) # 计算AUC
 # 画出AUC
 plt.plot(fpr, tpr, color blue ,label AUC {0:.4f} .format(roc_auc), ms 100)
 plt.xlabel( FPR , fontsize 15)
 plt.ylabel( TPR , fontsize 15)
 plt.legend(loc center left )
 # 画出thresholds
 plt.twiny()
 plt.plot(thresholds,tpr,color green ,label thresholds )
 plt.xlabel( thresholds , fontsize 15)
 # 画出对角线
 plt.plot([0, 1], [0, 1], r-- )
 plt.title( ROC curve , fontsize 20)
 plt.legend(loc center right )
 if png_savename ! 0:
 plt.savefig( %s_AUC.png % png_savename) # 保存AUC图
 plt.show()
 print( Accuracy: {0:.2f} .format(accuracy_score(y_test, y_pre)))

真实评估时以验证集为准

auc_plot(val_x, val_y, model, png_savename 0)

Accuracy: 0.51
auc_plot(train_x, train_y, model, png_savename 0)

Accuracy: 0.69
KS
def metrics_ks(X_test, y_test, clf):
 功能: 计算模型性能指标 ks 找到最佳threshold值
 X_test:测试数据集x
 y_test: 测试数据集y
 clf:训练好的模型
 return:
 ks_threshold
 from sklearn.metrics import auc,roc_curve
 import matplotlib.pyplot as plt
 plt.figure(figsize (10,6))
 y_pre clf.predict(X_test)
 y_score clf.predict_proba(X_test)[:,1] # 输出预测的概率
 fpr, tpr, thresholds roc_curve(y_test, y_score, pos_label 1)
 thresholds np.clip(thresholds,0,1)
 ks abs(fpr - tpr).max() 
 tmp abs(fpr - tpr)
 index_ks np.where(tmp ks) # np.where: 返回符合条件的下标函数
# print (np.argwhere(tmp ks)[0,0])
# print (index_ks[0][0])
 ks_threshold thresholds[index_ks][0]
# x_curve range(len(thresholds))
# plt.plot(x_curve,fpr,label bad ,linewidth 2,color r )
 plt.plot(fpr,label bad ,linewidth 2,color r )
 plt.plot(tpr,label good ,linewidth 2,color green )
 plt.plot(tmp,label diff ,linewidth 2,color orange )
 # 标记KS
 bad_point fpr[index_ks][0]
 good_point tpr[index_ks][0]
 x_point [index_ks[0][0],index_ks[0][0]]
 y_point [bad_point,good_point]
 plt.plot(x_point,y_point,label ks - {:.2f} .format(ks),color purple ,marker o ,markersize 5)
 plt.scatter(x_point,y_point,color purple )
 plt.title( KS curve , fontsize 20)
 plt.xlabel( Number , fontsize 15)
 plt.ylabel( FPR TPR , fontsize 15)
 plt.legend()
 plt.show()
 print( ks value: {0:.2f} .format(ks)) 
 print( ks_threshold: {0:.2f} .format(ks_threshold)) 
 return ks, ks_threshold

真实评估时以验证集为准

metrics_ks(val_x, val_y, model)

ks value: 0.09
ks_threshold: 0.49
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/267799.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号