栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

k折交叉验证(原理+python实现)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

k折交叉验证(原理+python实现)

交叉验证用于数据集的数据量不充足情况,将数据集分成训练集、验证集、测试集。

k折交叉验证,将数据集先分为训练集与测试集,再把训练集分成k份(大小相等)。其中,k-1份作为训练集训练模型,剩下的1份作为验证集进行模型的评估,把k次评估指标的平均值作为最终的评估指标。

以5折交叉验证,数据集大小6000为例:

1.将数据集分为5000训练,1000测试;

2.5折交叉验证时,把训练集分成等大小的5份,每份大小1000;

如图:

3.k折遍历k次,5折遍历5次。第一次把第1份作为验证集(最前面1000个),剩下4份作为训练集;第二次把第2份作为验证集,第1、3、4、5份作为验证集。以此类推,第5次将第5份做为验证集、第1、2、3、4份作为训练集;

4.5次评估指标的平均值作为最终的评估指标。

以下是代码实现:

#5折交叉验证

k = 5
mun_validation_samples = len(x_Train_normaliza) // k
 
#np.random.shuffle(x_Train_normaliza)  #

validation_score = []
sum=0
import random
for fold in range(k):
    
    validation_data = x_Train_normaliza[mun_validation_samples*fold:mun_validation_samples*(fold+1)]
    validation_data_label=y_Trainonehot[mun_validation_samples*fold:mun_validation_samples*(fold+1)]
    a=x_Train_normaliza[:mun_validation_samples * fold]
    b=x_Train_normaliza[mun_validation_samples * (fold+1):]
    training_data=np.append(a,b,axis=0)
    c=y_Trainonehot[:mun_validation_samples*fold]
    d= y_Trainonehot[mun_validation_samples*(fold+1):]
    training_label=np.append(c,d,axis=0)
    #training_label=y_Trainonehot[:mun_validation_samples*fold] + y_Trainonehot[mun_validation_samples*(fold+1):]
    
    #打散数据
    index = [i for i in range(len(training_data))] 
    random.shuffle(index)
    data = training_data[index]
    label = training_label[index]
    #开始训练
    train_history=model.fit(x=data, #使用model.fit进行训练,训练过程存储在train_history变量里
                            y=label,
                            epochs=2,
                            batch_size=200,#每次处理200张
                            verbose=2)#显示训练过程
    #model.train(training_data)
    validation_score = model.evaluate(validation_data,validation_data_label)
    validation_score_accuracy=validation_score[1]
    sum=sum+validation_score_accuracy
    print('validation_score=',validation_score_accuracy)
validation_score_average=sum/k
print('validation_score_average',validation_score_average)

若想10折或其他折,将k=5改为k=10或其他想要的折

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/846283.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号