栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

划分数据集

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

划分数据集

```python
import os
import shutil
import random


def split_dataset(origin_path, new_dataset_path, split_ratio):
    '''

    :param origin_path:存放所有数据的路径
    :param new_dataset_path:新数据集的路径
    :param split_ratio: 存储训练集/测试集/验证集的划分比例的字典
    :return:
    '''
    # 用了os.walk函数而非os.listdir,是因为后者只能返回一级目录下的文件,而前者则可以将子目录下所有文件返回。
    for root, dirs, files in os.walk(origin_path, topdown=True):
        print(root, dirs, files)
        # 打乱数据
        random.shuffle(files)
        # 原数据的总长度
        origin_len = len(files)
        # 训练集、测试集、验证集的数据个数
        len_train = origin_len * split_ratio['train']
        len_test = origin_len * split_ratio['test']
        len_eval = origin_len - len_train - len_test

        if not os.path.exists(new_dataset_path):
            os.mkdir(new_dataset_path)
        # 存放子数据集的路径
        train_path = os.path.join(new_dataset_path, 'train')
        test_path = os.path.join(new_dataset_path, 'test')
        eval_path = os.path.join(new_dataset_path, 'eval')

        for sub_dataset_path in (train_path, test_path, eval_path):
            if not os.path.exists(sub_dataset_path):
                os.mkdir(sub_dataset_path)

        # 遍历每个文件
        for idx, file in enumerate(files):
            filename = os.path.join(root, file)
            if idx < len_train:
                shutil.copyfile(filename, os.path.join(train_path, file))
            elif idx - len_train < len_test:
                shutil.copyfile(filename, os.path.join(test_path, file))
            else:
                shutil.copyfile(filename, os.path.join(eval_path, file))


def set_split_radio(train_ratio, test_ratio, eval_ratio):
    split_ratio = dict()
    train_ratio, test_ratio = train_ratio / (train_ratio + eval_ratio + test_ratio), test_ratio / (
            train_ratio + eval_ratio + test_ratio)
    eval_ratio = 1 - train_ratio - test_ratio
    split_ratio['train'] = train_ratio
    split_ratio['test'] = test_ratio
    split_ratio['eval'] = eval_ratio
    return split_ratio


# 原数据存放的路径
origin_path = "D:/Daily Code/Python Code/New_Folder/cycled_csv_by_day"
# 划分后数据集的路径
new_dataset_path = os.path.join(os.getcwd(), 'split_dataset')
# 按训练集:测试集:验证集=7:2:1的比例划分数据集
split_ratio = set_split_radio(7, 2, 1)
# 划分数据集
split_dataset(origin_path, new_dataset_path, split_ratio)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/275700.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号