栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【深度学习|数据集】Python 划分训练集和验证集

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【深度学习|数据集】Python 划分训练集和验证集

文章目录
  • 问题背景
  • 代码实现
  • 备注

问题背景

有了训练集和测试集,但是没有验证集,只能从训练集中划分 20% 作为验证集,剩下的 80% 作为训练集。

我的划分前的训练集中包含 png 格式图像和 json 格式的标签,需要新建训练集和验证集文件夹,并分别将图像和标签从源文件夹移动到这两个新建的目标文件夹。

代码实现

安装依赖:

pip install shutil
pip install scikit-learn

这里提供满足我的需求的代码,供参考。

import os
import shutil
from sklearn.model_selection import train_test_split

# 创建文件夹
def mkdir(path):
    folder = os.path.exists(path)
    if not folder:
        os.makedirs(path)
        print(f'-- new folder "{path}" --')
    else:
        print(f'-- the folder "{path}" is already here --')

dataset_path = "/home/xxx/HardDisk/Datasets/weld/weldSeam615/train_val"
train_set_save_path = "/home/xxx/HardDisk/Datasets/weld/weldRGB615/train"
val_set_save_path = "/home/xxx/HardDisk/Datasets/weld/weldRGB615/val"
mkdir(train_set_save_path)
mkdir(val_set_save_path)

file_pathes = os.listdir(dataset_path)
# 获取文件夹下所有 png 格式的图像的名称(不包含后缀名)
img_names = []
for file_path in file_pathes:
    if os.path.splitext(file_path)[1] == ".png":
        file_name = os.path.splitext(file_path)[0]
        img_names.append(file_name)

# 划分训练集和验证集
train_set, val_set = train_test_split(img_names, test_size=0.2, random_state=42)
print(f"train_set size: {len(train_set)}, val_set size: {len(val_set)}")

# 训练集处理:将图像和标签文件移动到目标文件夹
for file_name in train_set:
    img_src_path = os.path.join(dataset_path, file_name+".png")
    img_dst_path = os.path.join(train_set_save_path, file_name+".png")
    shutil.copyfile(img_src_path, img_dst_path)

    json_src_path = os.path.join(dataset_path, file_name+".json")
    json_dst_path = os.path.join(train_set_save_path, file_name+".json")
    shutil.copyfile(json_src_path, json_dst_path)

# 验证集处理:将图像和标签文件移动到目标文件夹
for file_name in val_set:
    img_src_path = os.path.join(dataset_path, file_name+".png")
    img_dst_path = os.path.join(val_set_save_path, file_name+".png")
    shutil.copyfile(img_src_path, img_dst_path)

    json_src_path = os.path.join(dataset_path, file_name+".json")
    json_dst_path = os.path.join(val_set_save_path, file_name+".json")
    shutil.copyfile(json_src_path, json_dst_path)
备注

图像来源网络,侵删。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/857792.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号