utils.data包括Dataset和DataLoader。
torch.utils.data.Dataset为抽象类。自定义数据集余姚继承这个类,并实现两个函数:
__getitem__ , __len__ ,前者通过给定的索引获取数据和标签,后者提供数据集大小。
__getitem__ 一次只能获取一个数据,所以用DataLoader来实现batchsize的读取。
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import DataLoader
class TestDataset(data.Dataset):
def __init__(self):
self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])
self.Label = np.asarray([0,1,0,1,2])
def __getitem__(self, index):
txt = torch.from_numpy(self.Data[index])
label = torch.tensor(self.Label[index])
return txt,label
def __len__(self):
return len(self.Data)
Test = TestDataset()
print(Test[2])
print(Test.__len__())
以上数据以tuple返回,每次只返回一个样本。实际上,Dataset只负责数据的抽取,调用一次
__getitem__只返回一个样本。
下面是使用DataLoader的批处理
dataset = TestDataset()
test_loader = DataLoader(dataset,batch_size=2,shuffle=False)
for i,traindata in enumerate(test_loader):
print('i:',i)
Data,Label = traindata
print('data:',Data)
print('Label:',Label)
结果
(tensor([2, 1], dtype=torch.int32), tensor(0, dtype=torch.int32))
5
i: 0
data: tensor([[1, 2],
[3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)
i: 1
data: tensor([[2, 1],
[3, 4]], dtype=torch.int32)
Label: tensor([0, 1], dtype=torch.int32)
i: 2
data: tensor([[4, 5]], dtype=torch.int32)
Label: tensor([2], dtype=torch.int32)
data.Dataset只能处理同一个目录下的数据。想要处理不同目录下的数据可以使用torchvision。



