Pytorch学习(三)定义自己的数据集及加载训练
Pytorch之Dataset与DataLoader–陈亮的博客
【总结】
Dataset类是Pytorch中所有数据集加载类中应该继承的父类。其中父类中的两个私有成员函数必须被重载,否则将会触发错误提示
class MyDataSet(data.Dataset):#需要继承torch.utils.data.Dataset
def __init__(self, opt, is_train) #初始化文件路径或文件名列表。
def getitem(self, index): # 编写支持数据集索引的函数
#1. 从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)
#2. 预处理数据(例如torchvision.Transform)
#3. 返回数据对(例如图像和标签)
#这里需要注意的是,第一步:read one data,是一个data
def __len__(self):# 数据集大小
getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。而SSD代码中txt文件存储文件名,通过以每一行的形式对应获取xml文件和jpg文件的路径,以元组的形式存储
class VOCDetection(data.Dataset):
def __init__(self, opt, image_sets, is_train):
获取txt文件地址
按行读取文件名
self.ids.append((img_path, ano_path))
def __getitem__(self, index):
img_path, ano_path = self.ids[index]
......
return image, target
def get_annotations(self, path):
# 树的形式读取xml文件内容
return np.array(boxes), np.array(labels)
def __len__(self):
return len(self.ids)
由于Dataset数据集的数据量大,可以直接使用utils.data.DataLoader对其进行一系列操作。
其中shuffle用以打乱数据集内数据分布的顺序,让数据随机化,这样可以避免过拟合。
train_loader = DataLoader(dataset=train_data, batch_size=6, shuffle=True ,num_workers=4) test_loader = DataLoader(dataset=test_data, batch_size=6, shuffle=False,num_workers=4)
其实,Dataset负责建立索引到样本的映射,DataLoader负责以特定的方式从数据集中迭代的产生 一个个batch的样本集合。



