栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

数据集划分

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

数据集划分

train是训练集,val是训练过程中的测试集,是为了让你在边训练边看到训练的结果,及时判断学习状态。test就是训练模型结束后,用于评价模型结果的测试集。只有train就可以训练,val不是必须的,比例也可以设置很小。

验证数据集可以理解为训练数据集的一块

制作图书馆数据集代码如下:

### Data Format for Semantic Segmentation

The raw data will be processed by generator shell scripts. There will be two subdirs('train' & 'val')

```

train or val dir {

image: contains the images for train or val.

label: contains the label png files(mode='P') for train or val.

mask: contains the mask png files(mode='P') for train or val.

}

```

"""
 -*- coding: utf-8 -*-
 author: Hao Hu
 @date   2022/1/20 11:02 AM
"""
import cv2
import numpy as np
from matplotlib import pyplot as plt
import os.path as osp
import os
from tqdm import tqdm
from PIL import Image
import PIL
from concurrent.futures import ThreadPoolExecutor
def grab_cut(img_path):
    """使用了grab_cut算法获得物体和背景轮廓"""
    img_ori = cv2.imread(img_path)
    # 将img二值化
    retVal, image = cv2.threshold(img_ori, 50, 100, cv2.THRESH_BINARY)
    mask = np.zeros(image.shape[:2], np.uint8)
    bgdModel = np.zeros((1, 65), np.float64)
    fgdModel = np.zeros((1, 65), np.float64)
    ix = int(img_ori.shape[0] / 22)
    iy = int(img_ori.shape[1] / 20)
    w = iy * 20
    h = ix * 22
    rect = (ix, iy, int(w), int(h))
    # cv2.rectangle(img, (ix*2, iy*3), (int(w*0.9), int(h*0.9)), (0, 255, 0), 2)
    # 默认几个点作为物体和背景像素点
    # (ix*15,iy*26),(ix*21,iy*15),(ix*21,iy*10)为背景像素点
    cv2.circle(mask, (ix*15, iy*26), 15, [0,0,0], -1)
    cv2.grabCut(image, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
    mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
    mask2[ix * 21, iy * 19] = 1
    #plt.imshow(mask2), plt.colorbar(), plt.show()
    img = image * mask2[:, :, np.newaxis]

    return img,image,mask2,img_ori

def get_mask_box(mask):

    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = list(contours)
    contours.sort(key=lambda x: cv2.contourArea(x), reverse=True)
    cnt = cv2.approxPolyDP(contours[0], epsilon=100, closed=True)
    cnt = cv2.minAreaRect(cnt)
    box = np.int0(cv2.boxPoints(cnt))
    return mask, box


def imwrite_the_label_img(ori_folder,end_folder_path,img_NAME):
    img_path = osp.join(ori_folder,img_NAME)
    img,image,mask,img_ori = grab_cut(img_path)
    _, box=get_mask_box(mask)
    re = cv2.drawContours(image.copy(), [box], 0, (0, 255, 0), -1)

    end_path = osp.join(end_folder_path, img_NAME[:-2]+'.png')
    cv2.imwrite((end_path), re)
    # 将图片转为model = P
    re = PIL.Image.open(end_path)
    re = re.convert('P')
    re.save(end_path)




if __name__ == '__main__':
    ori_folder = '/cloud_disk/users/huh/dataset/lib_dataset/train/image'
    img_list = os.listdir(ori_folder)
    end_folder_path = '/cloud_disk/users/huh/dataset/lib_dataset/train/label'
    executor = ThreadPoolExecutor(max_workers=100)  # 最大线程数量
    for img_NAME in tqdm(img_list):
        executor.map(imwrite_the_label_img(ori_folder,end_folder_path,img_NAME))

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/738255.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号