栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

scikit-learn中文文档(sklearn和scikit-learn)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

scikit-learn中文文档(sklearn和scikit-learn)

此次笔记的内容:cross validation交叉验证
我们先沿用第一次笔记的代码

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split#切割训练集与测试集
from sklearn.neighbors import KNeighborsClassifier#K临近学习

iris=load_iris()
X=iris.data
y=iris.target

X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=4)
knn=KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train,y_train)
print(knn.score(X_test,y_test))
0.9736842105263158

但是我们这么做,如果总样本太少,测试集与验证集就不能像上述的去划分。

from sklearn.model_selection import cross_val_score
knn=KNeighborsClassifier(n_neighbors=5)
scores=cross_val_score(knn,X,y,cv=5,scoring='accuracy')#会被分为五组
print(scores)
[0.96666667 1.         0.93333333 0.96666667 1.        ]
print(scores.mean())
0.9733333333333334

然后我们开始讨论如何选择n_neighbors的参数

from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt 
k_range=range(1,31)
k_scores=[]
for k in k_range:
	knn=KNeighborsClassifier(n_neighbors=k)
	scores=cross_val_score(knn,X,y,cv=10,scoring='accuracy')#分成了10个data,用于分类

	k_scores.append(scores.mean()) 
	
plt.plot(k_range,k_scores)
plt.xlabel('Value of K forKNN')
plt.ylabel('Cross-Validataed Accuracy')
plt.show()

效果如下

此外如果我们像用于线性回归
只需把scores=cross_val_score(knn,X,y,cv=10,scoring=‘accuracy’)更改为loss=-cross_val_score(knn,X,y,cv=10,scoring=‘neg_mean_squared_error’)效果如下
然后我们就开始讨论过拟合
我们先举例看怎么去观查过拟合

from sklearn.model_selection import learning_curve#可视化学习过程
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np

digits=load_digits()
X=digits.data
y=digits.target
train_sizes,train_loss,test_loss=learning_curve(SVC(gamma=0.001),X,y,cv=10,scoring='neg_mean_squared_error',train_sizes=[0.1,0.25,0.5,0.75,1])
train_loss_mean=-np.mean(train_loss,axis=1)
test_loss_mean=-np.mean(test_loss,axis=1)

plt.plot(train_sizes,train_loss_mean,'o-',color='r',label='Training')
plt.plot(train_sizes,test_loss_mean,'o-',color='g',label='Cross-validation')

plt.xlabel('Training examples')
plt.ylabel('Loss')
plt.legend(loc='best')
plt.show()

from sklearn.model_selection import validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import numpy as np

digits=load_digits()
X=digits.data
y=digits.target
param_range=np.logspace(-6,-2.3,5)
train_loss,test_loss=validation_curve(SVC(1),X,y,param_name='gamma',param_range=param_range,cv=10,scoring='neg_mean_squared_error')
train_loss_mean=-np.mean(train_loss,axis=1)
test_loss_mean=-np.mean(test_loss,axis=1)

plt.plot(param_range,train_loss_mean,'o-',color='r',label='Training')
plt.plot(param_range,test_loss_mean,'o-',color='g',label='Cross-validation')

plt.xlabel('gamma')
plt.ylabel('Loss')
plt.legend(loc='best')
plt.show()

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/772645.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号