pytorch官方对应的教程
torch.utils.data.Dataset主要是针对单个样本
torch.utils.data.Dataloader针对多个样本.
简单的说是通过torch.utils.data.Dataset得到单个样本过后,再用torch.utils.data.Dataloader把它变成随机梯度下降算法 训练所需要的minibatch的形式(比如把多个样本打包成一个batch、或者把样本顺序打乱等等操作都可以通过Dataloader来实现,通常我们会将多个样本同时进行训练,这样一方面可以加快训练速度,另一方面可以提高抗噪性。
torch.utils.data.Dataset
Dataset主要是从磁盘中加载数据,并对样本和标签做一些预处理。
自定义 Dataset 类必须实现三个函数__init__,__len __ ,__getitem __。(继承torch.utils.data.Dataset这个类)
def getitem(self, index):
raise NotImplementedError
该函数通过索引来返回训练样本,比如训练样本有100个,那index的范围为[0,99].
Dataset类检索数据集的特征并一次标记一个样本。在训练模型时,我们通常希望以“minibatch”的形式传输样本,在每个 epoch后进行shuffle(打乱顺序)以减少模型过拟合,并使用 Python提供的multiprocessing(多进程)加速数据检索。
Dataloader类需要传入的参数主要有:
dataset: 即Dataset类实例化过后的对象;
batch_size: 默认值为1,通常我们需要将它设置为更大的值;
shuffle: 是否在每个训练周期(epoch)后将数据再次打乱;
sampler: 决定我们如何对数据进行采样,可以使用默认的采样方式也可以自己实现一个sampler、batch_sampler与sampler功能类似;
num_workers: 默认值为0代表使用主进程加载数据,可以根据你cpu的个数来设置num_workers的值,但是设置到一定数值之后,增大num_workers的数值也不会提高数据加载的速度了,即存在一个限度;
pin_memory: 即把tensor保存到GPU当中,不需要进行重复的保存。但是这个参数设置为True或者Flase是否能够提升模型训练的效率是有待考究的;
drop_last:当样本数量不是batch_size的整数倍的时候,即最后剩下几个样本无法构成一个batch,我们是否需要将这几个样本舍弃掉(True:舍弃,Flase:不舍弃) ;
collate_fn: 是否对sampler所采样的batch进行后处理,比如padding;因此collate_fn的输入和输出都是batch。类似与transform,但是transform处理的是单个样本。
此外,sampler与shuffle是互斥的,不能同时为True。
ps:继承的定义:“通过继承创建的新类称为“子类”或“派生类”,被继承的类称为“基类”、“父类”或“超类”,继承的过程,就是从一般到特殊的过程。



