栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch dataloader num_workers(pytorch dataset)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch dataloader num_workers(pytorch dataset)

pytorch官方对应的教程
torch.utils.data.Dataset主要是针对单个样本
torch.utils.data.Dataloader针对多个样本.

简单的说是通过torch.utils.data.Dataset得到单个样本过后,再用torch.utils.data.Dataloader把它变成随机梯度下降算法 训练所需要的minibatch的形式(比如把多个样本打包成一个batch、或者把样本顺序打乱等等操作都可以通过Dataloader来实现,通常我们会将多个样本同时进行训练,这样一方面可以加快训练速度,另一方面可以提高抗噪性。


torch.utils.data.Dataset

Dataset主要是从磁盘中加载数据,并对样本和标签做一些预处理。



自定义 Dataset 类必须实现三个函数__init__,__len __ ,__getitem __。(继承torch.utils.data.Dataset这个类)
def getitem(self, index):
raise NotImplementedError
该函数通过索引来返回训练样本,比如训练样本有100个,那index的范围为[0,99].

torch.utils.data.Dataloader


Dataset类检索数据集的特征并一次标记一个样本。在训练模型时,我们通常希望以“minibatch”的形式传输样本,在每个 epoch后进行shuffle(打乱顺序)以减少模型过拟合,并使用 Python提供的multiprocessing(多进程)加速数据检索。


Dataloader类需要传入的参数主要有:
dataset: 即Dataset类实例化过后的对象;
batch_size: 默认值为1,通常我们需要将它设置为更大的值;
shuffle: 是否在每个训练周期(epoch)后将数据再次打乱;
sampler: 决定我们如何对数据进行采样,可以使用默认的采样方式也可以自己实现一个sampler、batch_sampler与sampler功能类似;
num_workers: 默认值为0代表使用主进程加载数据,可以根据你cpu的个数来设置num_workers的值,但是设置到一定数值之后,增大num_workers的数值也不会提高数据加载的速度了,即存在一个限度;
pin_memory: 即把tensor保存到GPU当中,不需要进行重复的保存。但是这个参数设置为True或者Flase是否能够提升模型训练的效率是有待考究的;
drop_last:当样本数量不是batch_size的整数倍的时候,即最后剩下几个样本无法构成一个batch,我们是否需要将这几个样本舍弃掉(True:舍弃,Flase:不舍弃) ;
collate_fn: 是否对sampler所采样的batch进行后处理,比如padding;因此collate_fn的输入和输出都是batch。类似与transform,但是transform处理的是单个样本。

此外,sampler与shuffle是互斥的,不能同时为True。

ps:继承的定义:“通过继承创建的新类称为“子类”或“派生类”,被继承的类称为“基类”、“父类”或“超类”,继承的过程,就是从一般到特殊的过程。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/772543.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号