栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch Week 2——Dataloader与Dataset

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch Week 2——Dataloader与Dataset

系列文章目录

PyTorch Week 1


PyTorch Week 2——Dataloader与Dataset
  • 系列文章目录
  • 前言
  • 一、数据读取
    • 1 一个人民币二分类任务
      • DataLoader
      • Dataset
      • 代码调试,理解DataLoader的数据读取机制
  • 总结


前言 本文记录在深度之眼PyTorch基础第二周课程学习的知识
一、数据读取 1 一个人民币二分类任务 DataLoader
torch.utils.data.DataLoader(dataset,#Dataset类,决定数据从哪读取以及如何读取
							batch_size=1,#
							shuffle=False,#
							sampler=None,
							batch_sampler=None,
							num_workers=0,#多进程
							collate_fn=None,
							pin_memory=False,
							drop_last=False,#是否舍弃最后一批数据
							timeout=0,
							worker_init_fn=None,
							multiprocessing_context=None)
Dataset
class Dataset(object):#重写Dataset以用于自己的问题
	def __getitem__(self, index):#接收一个索引返回一个样本
		raise NotImplementedError
		
	def __add__(self,other):
		return ConcatDataset([self, other])

数据读取:

  1. 读那些数据?
  2. 从哪读数据?
  3. 怎么读数据?
代码调试,理解DataLoader的数据读取机制
  1. 设置断点,步入train_loader
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):#设置断点,步入train_loader

2 继续步入,单进程情况下,使用_get_iterator()获取数据

    def __iter__(self) -> '_baseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()# 单进程,使用_get_iterator()获取数据
    def _get_iterator(self) -> '_baseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)#单进程获取数据
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

3 单进程获取数据函数,首先使用_next_index()函数获取index列表,再通过_dataset_fetcher.fetch(index)获取index对应的data。

class _SingleProcessDataLoaderIter(_baseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

3 步入self._next_index()中,进入Iter类,通过_index_sampler获取index

class _baseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        
    def __iter__(self) -> '_baseDataLoaderIter':
        return self

    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        self._IterableDataset_len_called = loader._IterableDataset_len_called

    def _next_index(self):
        return next(self._sampler_iter) 

步入._sampler_iter,在sampler.py文件内,在这里生成了index

    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []#返回的index
        if len(batch) > 0 and not self.drop_last:
            yield batch

4 self._dataset_fetcher.fetch中,fetchh函数用于获取index对应的数据,返回data

class _MapDatasetFetcher(_baseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]#在这里调用了dataset来获取数据
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

5 步入dataset,在自定义的RMBDataset中,__getitem__用于根据传递进来的index读取数据

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):#按照传递进来的index的索引获取图片路径和标签
    	path_img, label = self.data_info[index] 
        img = Image.open(path_img).convert('RGB')     # 0~255根据图片路径读取图片数据

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)
  • 形成的data数据格式如下,列表,包含两个元素,第一个元素是图片数据,shape = (16, 3, 32, 32),第二个元素是标签,shape = (16,)
    

简单来说,Sampler
函数作为采样器,提供一个index列表,决定了这个batch_size数据的索引,

总结

以上就是本节内容,主要是对于pytorch的Dataloader和Dataset模块的机制认识。从数据读那些?从哪读?和怎么读?三个方面去理解代码。

Dataloader通过——>_next_data——>_next_index——>_sampler_iter——>return batch作为index列表,以上步骤通过sampler获取一个index,fetch调用index解决读那些数据的问题。

再通过——>.dataset_fetcher.fetch(index)——>self.dataset[idx]——>RMBDataset里的__getitem_——>self.data_info按照index读取图片路径列表和标签列表——>Image.open().convert()按照图片路径读取图片数据,return 图片数据 标签
dataset中,文件路径解决从哪读,__getitem__解决怎么读的问题

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/302314.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号