栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch中Dataset与DataLoader

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch中Dataset与DataLoader

Dataset和DataLoader类:

Dataset:负责可被Pytorch使用的数据集的创建,根据索引去读取图片以及对应的标签;

如果想个性化自己的数据集或者数据传递方式,也可以自己重写子类。

通常需要自己去实现一个datasets对象,传入到dataloader中;然后dataloader内部使用yeild返回每一次batch大小的数据;

Dataset是用来定义数据从哪里读取,以及如何读取的问题;
功能:Dataset是抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__();

Dataset是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:

  1. __getitem__()
  2. __len__()

第一个最为重要, 即每次怎么读数据. 以图片为例,

__getitem__():从数据集得到一个数据片段(如:数据,标签),接收一个索引,返回一个样本

    def __getitem__(self, index):
        img_path, label = self.data[index].img_path, self.data[index].label
        img = Image.open(img_path)

        return img, label

第二个比较简单, 就是返回整个数据集的长度:

    def __len__(self):
        return len(self.data)

DataLoader:数据加载,构建可迭代的数据装载器,传递数据

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False)

主要参数有以下几个:

  1. dataset : 即上面自定义的dataset.
  2. batch_size:每个batch的大小
  3. shuffle:是否进行打乱操作
  4. collate_fn: 这个函数用来打包batch
  5. num_worker: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据.

pytorch 的数据加载到模型的操作顺序:

① 创建一个 Dataset 对象;
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练;

dataset = MyDataset() 
dataloader = DataLoader(dataset) 
num_epoches = 100 
for epoch in range(num_epoches):     
    for img, label in dataloader: 

自定义Dataset基本的框架是:

class CustomDataset(data.Dataset):#需要继承data.Dataset 
    def __init__(self): 
        # TODO 
        # 1. Initialize file path or list of file names. 
        pass 
        def __getitem__(self, index): 
            # TODO 
            # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). 
            # 2. Preprocess the data (e.g. torchvision.Transform). 
            # 3. Return a data pair (e.g. image and label). 
            #这里需要注意的是,第一步:read one data,是一个data 
            pass 
        def __len__(self): 
            # You should change 0 to the total size of your dataset. 
            return 0

如下:自己去实现一个datasets对象,传入到dataloader中,一般加载数据的整个流程为:

class DealDataset(Dataset):
    """
        下载数据、初始化数据,都可以在这里完成
    """
    def __init__(self):
        xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32) # 使用numpy读取数据
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
        self.len = xy.shape[0]
    
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。    
dealDataset = DealDataset()

train_loader2 = DataLoader(dataset=dealDataset,
                          batch_size=32,
                          shuffle=True)


for epoch in range(2):
    for i, data in enumerate(train_loader2):
        # 将数据从 train_loader 中读出来,一次读取的样本数是32个
        inputs, labels = data

        # 将这些数据转换成Variable类型
        inputs, labels = Variable(inputs), Variable(labels)

        # 接下来就是跑模型的环节了,我们这里使用print来代替
        print("epoch:", epoch, "的第" , i, "个inputs", inputs.data.size(), "labels", labels.data.size())
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/350469.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号