栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

python脚本划分训练集、验证集、测试集

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

python脚本划分训练集、验证集、测试集

import os
import shutil
import random

# 保证随机可复现
random.seed(0)

def mk_dir(file_path):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹在重新创建
        shutil.rmtree(file_path)
    os.makedirs(file_path)

def split_data(file_path,new_file_path,train_rate,val_rate,test_rate):
    class_names = []
    for cla in os.listdir(file_path):
        class_names.append(cla)
    for cla in class_names:
        mk_dir(new_file_path + '/' + 'train' + '/' + cla)
        mk_dir(new_file_path + '/' + 'val' + '/' + cla)
        mk_dir(new_file_path + '/' + 'test' + '/' + cla)
    for cla in class_names:
        eachclass_image = []
        for image in os.listdir(os.path.join(file_path,cla)):
            eachclass_image.append(image)
        total = len(eachclass_image)
        random.shuffle(eachclass_image)
        train_images = eachclass_image[0:int(train_rate*total)] #注意左闭右开
        val_images = eachclass_image[int(train_rate*total):int((train_rate+val_rate)*total)] #注意左闭右开
        test_images = eachclass_image[int((train_rate+val_rate)*total):]

        for image in train_images:
            old_path = file_path+'/'+cla+'/'+image
            new_path = new_file_path+'/'+'train'+'/'+cla+'/'+image
            shutil.copy(old_path,new_path)


        for image in val_images:
            old_path = file_path+'/'+cla+'/'+image
            new_path = new_file_path+'/'+'val'+'/'+cla+'/'+image
            shutil.copy(old_path,new_path)

        for image in test_images:
            old_path = file_path+'/'+cla+'/'+image
            new_path = new_file_path+'/'+'test'+'/'+cla+'/'+image
            shutil.copy(old_path,new_path)

if __name__ == '__main__':
    file_path = f"../ERA_dataset/ERA_Dataset/Singleframes/Tra"
    new_file_path = f"../ERA_dataset/ERA_Dataset/Singleframes/mysplit_data"
    split_data(file_path,new_file_path,train_rate=0.6,val_rate=0.1,test_rate=0.3)



转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/445788.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号