栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch加载自带数据集以及个人数据集的方式

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch加载自带数据集以及个人数据集的方式

pytorch加载数据集

一、加载pytorch自带数据集

1.使用torchvision.datasets加载数据集2.使用torch.utils.data.DataLoader来实例化3.测试 二、加载个人的数据集

1.继承Dataset类,生成数据集2.加载数据集

一、加载pytorch自带数据集

torchvison.datasets是torch.utils.data.Dataset的实现。
包括如下数据集:
all = (‘LSUN’, ‘LSUNClass’,
‘ImageFolder’, ‘DatasetFolder’, ‘FakeData’,
‘CocoCaptions’, ‘CocoDetection’,
‘CIFAR10’, ‘CIFAR100’, ‘EMNIST’, ‘FashionMNIST’, ‘QMNIST’,
‘MNIST’, ‘KMNIST’, ‘STL10’, ‘SVHN’, ‘PhotoTour’, ‘SEMEION’,
‘Omniglot’, ‘SBU’, ‘Flickr8k’, ‘Flickr30k’,
‘VOCSegmentation’, ‘VOCDetection’, ‘Cityscapes’, ‘ImageNet’,
‘Caltech101’, ‘Caltech256’, ‘CelebA’, ‘SBDataset’, ‘VisionDataset’,
‘USPS’, ‘Kinetics400’, ‘HMDB51’, ‘UCF101’, ‘Places365’)

1.使用torchvision.datasets加载数据集
import torch
import torchvision
from PIL import Image

cifarSet = torchvision.datasets.CIFAR10(root = "../data/cifar/", train= True, download = True)

2.使用torch.utils.data.DataLoader来实例化
cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)
3.测试
for i, data in enumerate(cifarLoader, 0):
    print(data[i][0])
    # PIL
    img = transforms.ToPILImage()(data[i][0])
    img.show()
    break
二、加载个人的数据集 1.继承Dataset类,生成数据集
import torch.utils.data as data
#定义myDataSet类来继承Dataset

#generate train_data or test_data...
def default_loader(path):
    return  Image.open(path).convert('RGB')

class myDataSet(data.Dataset):
    """"
    @:param
    label_txt:每个图像名称以及路径,one image one line
    """
    def __init__(self,label_txt,transform = None,target_transform = None, loader=default_loader):
        super(myDataSet, self).__init__()
        self.imgs = []
        self.transform =transform
        self.target_transform = target_transform
        self.loader =loader
        fn = open(label_txt,'r')
        imgs=[]
        for line in fn:
            line  = line.strip('n')
            line = line.rstrip('n')
            words = line.split()
            imgs.append(words[0])
        self.imgs = imgs

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        fn = self.img[index]
        img = self.loader(os.path.join(self.root,fn))
        return  img

label_txt的格式如下:
每一行是一个图像的绝对路径
同时,需要重写__len__与__getitem__两个函数如上

2.加载数据集
def get_my_data():
    train_data = myDataSet(label_txt='',transforms=transform.ToTensor())
    test_data = myDataSet(label_txt='', transforms=transform.ToTensor())
    train_loader = DataLoader(train_data,shuffle=True,batch_size=BATCH_SIZE,num_workers=1)
    #test_loader = DataLoader(test_data, shuffle=False, batch_size=BATCH_SIZE, num_workers=1)
    return train_loader

参考文献:
https://blog.csdn.net/sinat_42239797/article/details/90641659
https://zhuanlan.zhihu.com/p/27434001

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/770042.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号