栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

机器学习模型的保存和加载

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

机器学习模型的保存和加载

当我们的数据集的数量非常庞大的时候,并不适合每次运行的时候都加载一遍,那样的话,所需要的时间就非常庞大。因此我们需要进行模型保存
    1. 模型保存API
        joblib.dump(estimator, filename)
            estimator: 就是我们训练完成的模型
            filename:就是我们要保存的文件名,通常,文件名的后缀用.pkl来保存
    2. 模型加载
        joblib.load(filename)
            filename: 传入文件路径的字符串即可

模型保存代码:

# 对乳腺癌进行分类和评估(通过ROC曲线和AUC指标)
import joblib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, plot_roc_curve # 用来绘制ROC曲线
from sklearn.metrics import roc_auc_score # 用来计算AUC指标
from sklearn.metrics import classification_report
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression



# 1)数据集获取
data = load_breast_cancer()
# 2)数据集分离
x_train, x_test, y_train, y_test = train_test_split(data.data, data.target, random_state=22)
# 3)特诊工程标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)
# 4)逻辑回归流程
# 注意,这里可以采用网格搜索和交叉验证来进行出来,找到合适的estimator
estimator = LogisticRegression(solver='liblinear', penalty='l2', C=1.0)
estimator.fit(x_test, y_test)
# 对模型进行保存
joblib.dump(estimator, '逻辑回归.pkl')

模型加载代码:

import joblib
import joblib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, plot_roc_curve # 用来绘制ROC曲线
from sklearn.metrics import roc_auc_score # 用来计算AUC指标
from sklearn.metrics import classification_report
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

# 模型加载
estimator = joblib.load('逻辑回归.pkl')

estimator.coef_, estimator.intercept_
# 准确率(注意,这里的x_test和y_test并不会保存下来,因此需要在保存模型的同时,保存测试集)
estimator.score(x_test, y_test)

# 5)精确率、召回率、F1-score
report = classification_report(y_test, estimator.predict(x_test), labels=[0, 1], target_names=['良性', '恶性'])
print(report)
# 6)ROC曲线和AUC指标
print(roc_curve(y_test, estimator.predict(x_test)))
print(roc_auc_score(y_test, estimator.predict(x_test)))
plot_roc_curve(estimator,x_test, y_test)
plt.plot([0, 1], [0, 1], 'r--', label='random classify')
plt.legend()
plt.show()
         precision    recall  f1-score   support

          良性       0.98      0.93      0.95        55
          恶性       0.96      0.99      0.97        88

    accuracy                           0.97       143
   macro avg       0.97      0.96      0.96       143
weighted avg       0.97      0.97      0.96       143

(array([0.        , 0.07272727, 1.        ]), array([0.        , 0.98863636, 1.        ]), array([2, 1, 0]))
0.9579545454545455

学习地址:

 黑马程序员3天快速入门python机器学习_哔哩哔哩_bilibili

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/656803.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号