Dataset和DataLoader类:
Dataset:负责可被Pytorch使用的数据集的创建,根据索引去读取图片以及对应的标签;
如果想个性化自己的数据集或者数据传递方式,也可以自己重写子类。
通常需要自己去实现一个datasets对象,传入到dataloader中;然后dataloader内部使用yeild返回每一次batch大小的数据;
Dataset是用来定义数据从哪里读取,以及如何读取的问题;
功能:Dataset是抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__();
Dataset是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:
- __getitem__()
- __len__()
第一个最为重要, 即每次怎么读数据. 以图片为例,
__getitem__():从数据集得到一个数据片段(如:数据,标签),接收一个索引,返回一个样本
def __getitem__(self, index):
img_path, label = self.data[index].img_path, self.data[index].label
img = Image.open(img_path)
return img, label
第二个比较简单, 就是返回整个数据集的长度:
def __len__(self):
return len(self.data)
DataLoader:数据加载,构建可迭代的数据装载器,传递数据
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False)
主要参数有以下几个:
- dataset : 即上面自定义的dataset.
- batch_size:每个batch的大小
- shuffle:是否进行打乱操作
- collate_fn: 这个函数用来打包batch
- num_worker: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据.
pytorch 的数据加载到模型的操作顺序:
① 创建一个 Dataset 对象;
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练;
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for img, label in dataloader:
自定义Dataset基本的框架是:
class CustomDataset(data.Dataset):#需要继承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#这里需要注意的是,第一步:read one data,是一个data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
如下:自己去实现一个datasets对象,传入到dataloader中,一般加载数据的整个流程为:
class DealDataset(Dataset):
"""
下载数据、初始化数据,都可以在这里完成
"""
def __init__(self):
xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32) # 使用numpy读取数据
self.x_data = torch.from_numpy(xy[:, 0:-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
self.len = xy.shape[0]
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
dealDataset = DealDataset()
train_loader2 = DataLoader(dataset=dealDataset,
batch_size=32,
shuffle=True)
for epoch in range(2):
for i, data in enumerate(train_loader2):
# 将数据从 train_loader 中读出来,一次读取的样本数是32个
inputs, labels = data
# 将这些数据转换成Variable类型
inputs, labels = Variable(inputs), Variable(labels)
# 接下来就是跑模型的环节了,我们这里使用print来代替
print("epoch:", epoch, "的第" , i, "个inputs", inputs.data.size(), "labels", labels.data.size())



